fix format
This commit is contained in:
1
main.py
1
main.py
@@ -476,6 +476,7 @@ def main():
|
|||||||
base_url=args.base_url,
|
base_url=args.base_url,
|
||||||
model_name=args.model,
|
model_name=args.model,
|
||||||
api_key=args.apikey,
|
api_key=args.apikey,
|
||||||
|
lang=args.lang,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
|
|||||||
@@ -285,7 +285,7 @@ def parse_action(response: str) -> dict[str, Any]:
|
|||||||
if response.startswith("do"):
|
if response.startswith("do"):
|
||||||
# Use AST parsing instead of eval for safety
|
# Use AST parsing instead of eval for safety
|
||||||
try:
|
try:
|
||||||
tree = ast.parse(response, mode='eval')
|
tree = ast.parse(response, mode="eval")
|
||||||
if not isinstance(tree.body, ast.Call):
|
if not isinstance(tree.body, ast.Call):
|
||||||
raise ValueError("Expected a function call")
|
raise ValueError("Expected a function call")
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,10 @@ MESSAGES_ZH = {
|
|||||||
"step": "步骤",
|
"step": "步骤",
|
||||||
"task": "任务",
|
"task": "任务",
|
||||||
"result": "结果",
|
"result": "结果",
|
||||||
|
"performance_metrics": "性能指标",
|
||||||
|
"time_to_first_token": "首 Token 延迟 (TTFT)",
|
||||||
|
"time_to_thinking_end": "思考完成延迟",
|
||||||
|
"total_inference_time": "总推理时间",
|
||||||
}
|
}
|
||||||
|
|
||||||
# English messages
|
# English messages
|
||||||
@@ -40,6 +44,10 @@ MESSAGES_EN = {
|
|||||||
"step": "Step",
|
"step": "Step",
|
||||||
"task": "Task",
|
"task": "Task",
|
||||||
"result": "Result",
|
"result": "Result",
|
||||||
|
"performance_metrics": "Performance Metrics",
|
||||||
|
"time_to_first_token": "Time to First Token (TTFT)",
|
||||||
|
"time_to_thinking_end": "Time to Thinking End",
|
||||||
|
"total_inference_time": "Total Inference Time",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
"""Model client for AI inference using OpenAI-compatible API."""
|
"""Model client for AI inference using OpenAI-compatible API."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from phone_agent.config.i18n import get_message
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
@@ -19,6 +22,7 @@ class ModelConfig:
|
|||||||
top_p: float = 0.85
|
top_p: float = 0.85
|
||||||
frequency_penalty: float = 0.2
|
frequency_penalty: float = 0.2
|
||||||
extra_body: dict[str, Any] = field(default_factory=dict)
|
extra_body: dict[str, Any] = field(default_factory=dict)
|
||||||
|
lang: str = "cn" # Language for UI messages: 'cn' or 'en'
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -28,6 +32,10 @@ class ModelResponse:
|
|||||||
thinking: str
|
thinking: str
|
||||||
action: str
|
action: str
|
||||||
raw_content: str
|
raw_content: str
|
||||||
|
# Performance metrics
|
||||||
|
time_to_first_token: float | None = None # Time to first token (seconds)
|
||||||
|
time_to_thinking_end: float | None = None # Time to thinking end (seconds)
|
||||||
|
total_time: float | None = None # Total inference time (seconds)
|
||||||
|
|
||||||
|
|
||||||
class ModelClient:
|
class ModelClient:
|
||||||
@@ -55,6 +63,11 @@ class ModelClient:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the response cannot be parsed.
|
ValueError: If the response cannot be parsed.
|
||||||
"""
|
"""
|
||||||
|
# Start timing
|
||||||
|
start_time = time.time()
|
||||||
|
time_to_first_token = None
|
||||||
|
time_to_thinking_end = None
|
||||||
|
|
||||||
stream = self.client.chat.completions.create(
|
stream = self.client.chat.completions.create(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=self.config.model_name,
|
model=self.config.model_name,
|
||||||
@@ -70,6 +83,7 @@ class ModelClient:
|
|||||||
buffer = "" # Buffer to hold content that might be part of a marker
|
buffer = "" # Buffer to hold content that might be part of a marker
|
||||||
action_markers = ["finish(message=", "do(action="]
|
action_markers = ["finish(message=", "do(action="]
|
||||||
in_action_phase = False # Track if we've entered the action phase
|
in_action_phase = False # Track if we've entered the action phase
|
||||||
|
first_token_received = False
|
||||||
|
|
||||||
for chunk in stream:
|
for chunk in stream:
|
||||||
if len(chunk.choices) == 0:
|
if len(chunk.choices) == 0:
|
||||||
@@ -78,6 +92,11 @@ class ModelClient:
|
|||||||
content = chunk.choices[0].delta.content
|
content = chunk.choices[0].delta.content
|
||||||
raw_content += content
|
raw_content += content
|
||||||
|
|
||||||
|
# Record time to first token
|
||||||
|
if not first_token_received:
|
||||||
|
time_to_first_token = time.time() - start_time
|
||||||
|
first_token_received = True
|
||||||
|
|
||||||
if in_action_phase:
|
if in_action_phase:
|
||||||
# Already in action phase, just accumulate content without printing
|
# Already in action phase, just accumulate content without printing
|
||||||
continue
|
continue
|
||||||
@@ -94,6 +113,11 @@ class ModelClient:
|
|||||||
print() # Print newline after thinking is complete
|
print() # Print newline after thinking is complete
|
||||||
in_action_phase = True
|
in_action_phase = True
|
||||||
marker_found = True
|
marker_found = True
|
||||||
|
|
||||||
|
# Record time to thinking end
|
||||||
|
if time_to_thinking_end is None:
|
||||||
|
time_to_thinking_end = time.time() - start_time
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if marker_found:
|
if marker_found:
|
||||||
@@ -115,10 +139,39 @@ class ModelClient:
|
|||||||
print(buffer, end="", flush=True)
|
print(buffer, end="", flush=True)
|
||||||
buffer = ""
|
buffer = ""
|
||||||
|
|
||||||
|
# Calculate total time
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
|
||||||
# Parse thinking and action from response
|
# Parse thinking and action from response
|
||||||
thinking, action = self._parse_response(raw_content)
|
thinking, action = self._parse_response(raw_content)
|
||||||
|
|
||||||
return ModelResponse(thinking=thinking, action=action, raw_content=raw_content)
|
# Print performance metrics
|
||||||
|
lang = self.config.lang
|
||||||
|
print()
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"⏱️ {get_message('performance_metrics', lang)}:")
|
||||||
|
print("-" * 50)
|
||||||
|
if time_to_first_token is not None:
|
||||||
|
print(
|
||||||
|
f"{get_message('time_to_first_token', lang)}: {time_to_first_token:.3f}s"
|
||||||
|
)
|
||||||
|
if time_to_thinking_end is not None:
|
||||||
|
print(
|
||||||
|
f"{get_message('time_to_thinking_end', lang)}: {time_to_thinking_end:.3f}s"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"{get_message('total_inference_time', lang)}: {total_time:.3f}s"
|
||||||
|
)
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
return ModelResponse(
|
||||||
|
thinking=thinking,
|
||||||
|
action=action,
|
||||||
|
raw_content=raw_content,
|
||||||
|
time_to_first_token=time_to_first_token,
|
||||||
|
time_to_thinking_end=time_to_thinking_end,
|
||||||
|
total_time=total_time,
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_response(self, content: str) -> tuple[str, str]:
|
def _parse_response(self, content: str) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -41,19 +41,31 @@ Usage examples:
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-tokens", type=int, default=3000, help="Maximum generation tokens (default: 3000)"
|
"--max-tokens",
|
||||||
|
type=int,
|
||||||
|
default=3000,
|
||||||
|
help="Maximum generation tokens (default: 3000)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--temperature", type=float, default=0.0, help="Sampling temperature (default: 0.0)"
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="Sampling temperature (default: 0.0)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--top_p", type=float, default=0.85, help="Nucleus sampling parameter (default: 0.85)"
|
"--top_p",
|
||||||
|
type=float,
|
||||||
|
default=0.85,
|
||||||
|
help="Nucleus sampling parameter (default: 0.85)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--frequency_penalty", type=float, default=0.2, help="Frequency penalty parameter (default: 0.2)"
|
"--frequency_penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help="Frequency penalty parameter (default: 0.2)",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -103,7 +115,9 @@ Usage examples:
|
|||||||
print(f" - Completion tokens: {response.usage.completion_tokens}")
|
print(f" - Completion tokens: {response.usage.completion_tokens}")
|
||||||
print(f" - Total tokens: {response.usage.total_tokens}")
|
print(f" - Total tokens: {response.usage.total_tokens}")
|
||||||
|
|
||||||
print(f"\nPlease evaluate the above inference result to determine if the model deployment meets expectations.")
|
print(
|
||||||
|
f"\nPlease evaluate the above inference result to determine if the model deployment meets expectations."
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\nError occurred while calling API:")
|
print(f"\nError occurred while calling API:")
|
||||||
|
|||||||
Reference in New Issue
Block a user