fix format
This commit is contained in:
1
main.py
1
main.py
@@ -476,6 +476,7 @@ def main():
|
||||
base_url=args.base_url,
|
||||
model_name=args.model,
|
||||
api_key=args.apikey,
|
||||
lang=args.lang,
|
||||
)
|
||||
|
||||
agent_config = AgentConfig(
|
||||
|
||||
@@ -285,7 +285,7 @@ def parse_action(response: str) -> dict[str, Any]:
|
||||
if response.startswith("do"):
|
||||
# Use AST parsing instead of eval for safety
|
||||
try:
|
||||
tree = ast.parse(response, mode='eval')
|
||||
tree = ast.parse(response, mode="eval")
|
||||
if not isinstance(tree.body, ast.Call):
|
||||
raise ValueError("Expected a function call")
|
||||
|
||||
|
||||
@@ -19,6 +19,10 @@ MESSAGES_ZH = {
|
||||
"step": "步骤",
|
||||
"task": "任务",
|
||||
"result": "结果",
|
||||
"performance_metrics": "性能指标",
|
||||
"time_to_first_token": "首 Token 延迟 (TTFT)",
|
||||
"time_to_thinking_end": "思考完成延迟",
|
||||
"total_inference_time": "总推理时间",
|
||||
}
|
||||
|
||||
# English messages
|
||||
@@ -40,6 +44,10 @@ MESSAGES_EN = {
|
||||
"step": "Step",
|
||||
"task": "Task",
|
||||
"result": "Result",
|
||||
"performance_metrics": "Performance Metrics",
|
||||
"time_to_first_token": "Time to First Token (TTFT)",
|
||||
"time_to_thinking_end": "Time to Thinking End",
|
||||
"total_inference_time": "Total Inference Time",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
"""Model client for AI inference using OpenAI-compatible API."""
|
||||
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from phone_agent.config.i18n import get_message
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
@@ -19,6 +22,7 @@ class ModelConfig:
|
||||
top_p: float = 0.85
|
||||
frequency_penalty: float = 0.2
|
||||
extra_body: dict[str, Any] = field(default_factory=dict)
|
||||
lang: str = "cn" # Language for UI messages: 'cn' or 'en'
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -28,6 +32,10 @@ class ModelResponse:
|
||||
thinking: str
|
||||
action: str
|
||||
raw_content: str
|
||||
# Performance metrics
|
||||
time_to_first_token: float | None = None # Time to first token (seconds)
|
||||
time_to_thinking_end: float | None = None # Time to thinking end (seconds)
|
||||
total_time: float | None = None # Total inference time (seconds)
|
||||
|
||||
|
||||
class ModelClient:
|
||||
@@ -55,6 +63,11 @@ class ModelClient:
|
||||
Raises:
|
||||
ValueError: If the response cannot be parsed.
|
||||
"""
|
||||
# Start timing
|
||||
start_time = time.time()
|
||||
time_to_first_token = None
|
||||
time_to_thinking_end = None
|
||||
|
||||
stream = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=self.config.model_name,
|
||||
@@ -70,6 +83,7 @@ class ModelClient:
|
||||
buffer = "" # Buffer to hold content that might be part of a marker
|
||||
action_markers = ["finish(message=", "do(action="]
|
||||
in_action_phase = False # Track if we've entered the action phase
|
||||
first_token_received = False
|
||||
|
||||
for chunk in stream:
|
||||
if len(chunk.choices) == 0:
|
||||
@@ -78,6 +92,11 @@ class ModelClient:
|
||||
content = chunk.choices[0].delta.content
|
||||
raw_content += content
|
||||
|
||||
# Record time to first token
|
||||
if not first_token_received:
|
||||
time_to_first_token = time.time() - start_time
|
||||
first_token_received = True
|
||||
|
||||
if in_action_phase:
|
||||
# Already in action phase, just accumulate content without printing
|
||||
continue
|
||||
@@ -94,6 +113,11 @@ class ModelClient:
|
||||
print() # Print newline after thinking is complete
|
||||
in_action_phase = True
|
||||
marker_found = True
|
||||
|
||||
# Record time to thinking end
|
||||
if time_to_thinking_end is None:
|
||||
time_to_thinking_end = time.time() - start_time
|
||||
|
||||
break
|
||||
|
||||
if marker_found:
|
||||
@@ -115,10 +139,39 @@ class ModelClient:
|
||||
print(buffer, end="", flush=True)
|
||||
buffer = ""
|
||||
|
||||
# Calculate total time
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# Parse thinking and action from response
|
||||
thinking, action = self._parse_response(raw_content)
|
||||
|
||||
return ModelResponse(thinking=thinking, action=action, raw_content=raw_content)
|
||||
# Print performance metrics
|
||||
lang = self.config.lang
|
||||
print()
|
||||
print("=" * 50)
|
||||
print(f"⏱️ {get_message('performance_metrics', lang)}:")
|
||||
print("-" * 50)
|
||||
if time_to_first_token is not None:
|
||||
print(
|
||||
f"{get_message('time_to_first_token', lang)}: {time_to_first_token:.3f}s"
|
||||
)
|
||||
if time_to_thinking_end is not None:
|
||||
print(
|
||||
f"{get_message('time_to_thinking_end', lang)}: {time_to_thinking_end:.3f}s"
|
||||
)
|
||||
print(
|
||||
f"{get_message('total_inference_time', lang)}: {total_time:.3f}s"
|
||||
)
|
||||
print("=" * 50)
|
||||
|
||||
return ModelResponse(
|
||||
thinking=thinking,
|
||||
action=action,
|
||||
raw_content=raw_content,
|
||||
time_to_first_token=time_to_first_token,
|
||||
time_to_thinking_end=time_to_thinking_end,
|
||||
total_time=total_time,
|
||||
)
|
||||
|
||||
def _parse_response(self, content: str) -> tuple[str, str]:
|
||||
"""
|
||||
|
||||
@@ -41,19 +41,31 @@ Usage examples:
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-tokens", type=int, default=3000, help="Maximum generation tokens (default: 3000)"
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=3000,
|
||||
help="Maximum generation tokens (default: 3000)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--temperature", type=float, default=0.0, help="Sampling temperature (default: 0.0)"
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Sampling temperature (default: 0.0)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--top_p", type=float, default=0.85, help="Nucleus sampling parameter (default: 0.85)"
|
||||
"--top_p",
|
||||
type=float,
|
||||
default=0.85,
|
||||
help="Nucleus sampling parameter (default: 0.85)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frequency_penalty", type=float, default=0.2, help="Frequency penalty parameter (default: 0.2)"
|
||||
"--frequency_penalty",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Frequency penalty parameter (default: 0.2)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -103,7 +115,9 @@ Usage examples:
|
||||
print(f" - Completion tokens: {response.usage.completion_tokens}")
|
||||
print(f" - Total tokens: {response.usage.total_tokens}")
|
||||
|
||||
print(f"\nPlease evaluate the above inference result to determine if the model deployment meets expectations.")
|
||||
print(
|
||||
f"\nPlease evaluate the above inference result to determine if the model deployment meets expectations."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError occurred while calling API:")
|
||||
|
||||
Reference in New Issue
Block a user