Add Web Dashboard with multi-device control and callback hooks
Features: - Web Dashboard: FastAPI-based dashboard with Vue.js frontend - Multi-device support (ADB, HDC, iOS) - Real-time WebSocket updates for task progress - Device management with status tracking - Task queue with execution controls (start/stop/re-execute) - Detailed task information display (thinking, actions, completion messages) - Screenshot viewing per device - LAN deployment support with configurable CORS - Callback Hooks: Interrupt and modify task execution - step_callback: Called after each step with StepResult - before_action_callback: Called before executing action - Support for task interruption and dynamic task switching - Example scripts demonstrating callback usage - Configuration: Environment-based configuration - .env file support for all settings - .env.example template with documentation - Model API configuration (base URL, model name, API key) - Dashboard configuration (host, port, CORS, device type) - Phone agent configuration (delays, max steps, language) Technical improvements: - Fixed forward reference issue with StepResult - Added package exports for callback types and configs - Enhanced dependencies with FastAPI, WebSocket support - Thread-safe task execution with device locking - Async WebSocket broadcasting from sync thread pool Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
15
dashboard/services/__init__.py
Normal file
15
dashboard/services/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Services for the dashboard.
|
||||
|
||||
Includes device management, task execution, and WebSocket management.
|
||||
"""
|
||||
|
||||
from dashboard.services.device_manager import DeviceManager
|
||||
from dashboard.services.task_executor import TaskExecutor
|
||||
from dashboard.services.websocket_manager import WebSocketManager
|
||||
|
||||
__all__ = [
|
||||
"DeviceManager",
|
||||
"TaskExecutor",
|
||||
"WebSocketManager",
|
||||
]
|
||||
336
dashboard/services/device_manager.py
Normal file
336
dashboard/services/device_manager.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""
|
||||
Device manager for handling device pool and connections.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from phone_agent.adb.connection import ADBConnection
|
||||
from phone_agent.adb.screenshot import get_screenshot as adb_screenshot
|
||||
from phone_agent.device_factory import (
|
||||
DeviceFactory,
|
||||
DeviceType,
|
||||
get_device_factory,
|
||||
set_device_type,
|
||||
)
|
||||
from phone_agent.hdc.connection import HDCConnection
|
||||
from phone_agent.hdc.screenshot import get_screenshot as hdc_screenshot
|
||||
|
||||
from dashboard.models.device import DeviceInfo, DeviceStatus, DeviceSchema
|
||||
|
||||
|
||||
class DeviceConnectionError(Exception):
|
||||
"""Raised when device connection fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DeviceManager:
|
||||
"""Manage device pool and connections."""
|
||||
|
||||
def __init__(self, device_type: str = "adb"):
|
||||
"""Initialize the device manager.
|
||||
|
||||
Args:
|
||||
device_type: Default device type (adb, hdc, ios)
|
||||
"""
|
||||
self._devices: Dict[str, DeviceInfo] = {}
|
||||
self._device_locks: Dict[str, Lock] = {}
|
||||
self._device_type = self._parse_device_type(device_type)
|
||||
self._screenshot_cache: Dict[str, tuple[str, datetime]] = {} # (base64, timestamp)
|
||||
self._cache_ttl_seconds = 2.0 # Cache screenshots for 2 seconds
|
||||
|
||||
# Set the global device type
|
||||
set_device_type(self._device_type)
|
||||
|
||||
def _parse_device_type(self, device_type: str) -> DeviceType:
|
||||
"""Parse device type string to enum."""
|
||||
try:
|
||||
return DeviceType[device_type.upper()]
|
||||
except KeyError:
|
||||
return DeviceType.ADB
|
||||
|
||||
@property
|
||||
def factory(self) -> DeviceFactory:
|
||||
"""Get the device factory instance."""
|
||||
return get_device_factory()
|
||||
|
||||
async def refresh_devices(self) -> List[DeviceInfo]:
|
||||
"""Scan and return all connected devices.
|
||||
|
||||
Returns:
|
||||
List of device info objects
|
||||
"""
|
||||
try:
|
||||
# Run device listing in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
devices = await loop.run_in_executor(None, self.factory.list_devices)
|
||||
|
||||
# Update device cache
|
||||
current_time = datetime.now()
|
||||
for device in devices:
|
||||
device_id = device.device_id
|
||||
|
||||
# Initialize lock if needed
|
||||
if device_id not in self._device_locks:
|
||||
self._device_locks[device_id] = Lock()
|
||||
|
||||
# Update or create device info
|
||||
if device_id in self._devices:
|
||||
# Update existing device
|
||||
self._devices[device_id].last_seen = current_time
|
||||
self._devices[device_id].is_connected = True
|
||||
self._devices[device_id].status = (
|
||||
DeviceStatus.BUSY
|
||||
if not self._is_device_available(device_id)
|
||||
else DeviceStatus.ONLINE
|
||||
)
|
||||
else:
|
||||
# New device
|
||||
self._devices[device_id] = DeviceInfo(
|
||||
device_id=device_id,
|
||||
status=DeviceStatus.ONLINE,
|
||||
device_type=self._device_type_to_schema(device.connection_type),
|
||||
model=device.model,
|
||||
android_version=device.android_version,
|
||||
current_app=None,
|
||||
last_seen=current_time,
|
||||
is_connected=True,
|
||||
)
|
||||
|
||||
# Mark disconnected devices
|
||||
connected_ids = {d.device_id for d in devices}
|
||||
for device_id, device_info in self._devices.items():
|
||||
if device_id not in connected_ids:
|
||||
device_info.is_connected = False
|
||||
device_info.status = DeviceStatus.OFFLINE
|
||||
|
||||
return list(self._devices.values())
|
||||
|
||||
except Exception as e:
|
||||
raise DeviceConnectionError(f"Failed to refresh devices: {e}")
|
||||
|
||||
async def get_device(self, device_id: str) -> Optional[DeviceInfo]:
|
||||
"""Get device info by ID.
|
||||
|
||||
Args:
|
||||
device_id: Device identifier
|
||||
|
||||
Returns:
|
||||
Device info or None if not found
|
||||
"""
|
||||
return self._devices.get(device_id)
|
||||
|
||||
def acquire_device(self, device_id: str) -> bool:
|
||||
"""Acquire lock on device.
|
||||
|
||||
Args:
|
||||
device_id: Device identifier
|
||||
|
||||
Returns:
|
||||
True if acquired, False if device is busy
|
||||
"""
|
||||
if device_id not in self._device_locks:
|
||||
return False
|
||||
|
||||
lock = self._device_locks[device_id]
|
||||
acquired = lock.acquire(blocking=False)
|
||||
|
||||
if acquired:
|
||||
# Update device status
|
||||
if device_id in self._devices:
|
||||
self._devices[device_id].status = DeviceStatus.BUSY
|
||||
|
||||
return acquired
|
||||
|
||||
def release_device(self, device_id: str):
|
||||
"""Release device lock.
|
||||
|
||||
Args:
|
||||
device_id: Device identifier
|
||||
"""
|
||||
if device_id in self._device_locks:
|
||||
self._device_locks[device_id].release()
|
||||
|
||||
# Update device status
|
||||
if device_id in self._devices:
|
||||
device = self._devices[device_id]
|
||||
if device.is_connected:
|
||||
device.status = DeviceStatus.ONLINE
|
||||
else:
|
||||
device.status = DeviceStatus.OFFLINE
|
||||
|
||||
def is_device_available(self, device_id: str) -> bool:
|
||||
"""Check if device is available for task execution.
|
||||
|
||||
Args:
|
||||
device_id: Device identifier
|
||||
|
||||
Returns:
|
||||
True if available, False if busy or offline
|
||||
"""
|
||||
if device_id not in self._devices:
|
||||
return False
|
||||
|
||||
device = self._devices[device_id]
|
||||
if not device.is_connected:
|
||||
return False
|
||||
|
||||
if device.status == DeviceStatus.BUSY:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_device_available(self, device_id: str) -> bool:
|
||||
"""Internal check if device lock is available."""
|
||||
if device_id not in self._device_locks:
|
||||
return True
|
||||
lock = self._device_locks[device_id]
|
||||
return not lock.locked()
|
||||
|
||||
async def get_screenshot(self, device_id: str) -> Optional[str]:
|
||||
"""Get screenshot for device.
|
||||
|
||||
Args:
|
||||
device_id: Device identifier
|
||||
|
||||
Returns:
|
||||
Base64 encoded screenshot or None
|
||||
"""
|
||||
# Check cache
|
||||
if device_id in self._screenshot_cache:
|
||||
screenshot, timestamp = self._screenshot_cache[device_id]
|
||||
age = (datetime.now() - timestamp).total_seconds()
|
||||
if age < self._cache_ttl_seconds:
|
||||
return screenshot
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if self._device_type == DeviceType.HDC:
|
||||
result = await loop.run_in_executor(
|
||||
None, hdc_screenshot, device_id, 10
|
||||
)
|
||||
else: # ADB or default
|
||||
result = await loop.run_in_executor(
|
||||
None, adb_screenshot, device_id, 10
|
||||
)
|
||||
|
||||
if result:
|
||||
# Cache the screenshot
|
||||
self._screenshot_cache[device_id] = (result.base64_data, datetime.now())
|
||||
return result.base64_data
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
async def get_current_app(self, device_id: str) -> Optional[str]:
|
||||
"""Get current app for device.
|
||||
|
||||
Args:
|
||||
device_id: Device identifier
|
||||
|
||||
Returns:
|
||||
Current app package name or None
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
app = await loop.run_in_executor(
|
||||
None, self.factory.get_current_app, device_id
|
||||
)
|
||||
|
||||
# Update device info
|
||||
if device_id in self._devices and app:
|
||||
self._devices[device_id].current_app = app
|
||||
|
||||
return app
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def connect_device(self, address: str) -> bool:
|
||||
"""Connect to device via WiFi.
|
||||
|
||||
Args:
|
||||
address: Device address (IP:PORT)
|
||||
|
||||
Returns:
|
||||
True if connected successfully
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if self._device_type == DeviceType.HDC:
|
||||
conn = HDCConnection()
|
||||
else:
|
||||
conn = ADBConnection()
|
||||
|
||||
success, _ = await loop.run_in_executor(None, conn.connect, address)
|
||||
|
||||
if success:
|
||||
# Refresh devices after connecting
|
||||
await self.refresh_devices()
|
||||
|
||||
return success
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def disconnect_device(self, address: str) -> bool:
|
||||
"""Disconnect from device.
|
||||
|
||||
Args:
|
||||
address: Device address (IP:PORT)
|
||||
|
||||
Returns:
|
||||
True if disconnected successfully
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if self._device_type == DeviceType.HDC:
|
||||
conn = HDCConnection()
|
||||
else:
|
||||
conn = ADBConnection()
|
||||
|
||||
success, _ = await loop.run_in_executor(None, conn.disconnect, address)
|
||||
|
||||
if success:
|
||||
# Refresh devices after disconnecting
|
||||
await self.refresh_devices()
|
||||
|
||||
return success
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _device_type_to_schema(self, connection_type) -> "DeviceSchema":
|
||||
"""Convert connection type to DeviceType enum."""
|
||||
from dashboard.models.device import DeviceType as SchemaDeviceType
|
||||
|
||||
# Mapping from phone_agent connection types to schema types
|
||||
type_map = {
|
||||
"USB": SchemaDeviceType.ADB,
|
||||
"WIFI": SchemaDeviceType.ADB,
|
||||
"REMOTE": SchemaDeviceType.ADB,
|
||||
}
|
||||
|
||||
# Try to get string value from enum
|
||||
try:
|
||||
type_str = connection_type.value if hasattr(connection_type, "value") else str(connection_type)
|
||||
return type_map.get(type_str.upper(), SchemaDeviceType.ADB)
|
||||
except (AttributeError, KeyError):
|
||||
return SchemaDeviceType.ADB
|
||||
|
||||
def list_all_devices(self) -> List[DeviceInfo]:
|
||||
"""Get all cached devices.
|
||||
|
||||
Returns:
|
||||
List of all device info
|
||||
"""
|
||||
return list(self._devices.values())
|
||||
410
dashboard/services/task_executor.py
Normal file
410
dashboard/services/task_executor.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Task executor for running tasks on devices with thread pool.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor, Future
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from phone_agent import AgentConfig, PhoneAgent
|
||||
from phone_agent.agent import StepResult
|
||||
from phone_agent.model import ModelConfig
|
||||
|
||||
from dashboard.config import config
|
||||
from dashboard.models.task import TaskCreateRequest, TaskSchema, TaskStatus
|
||||
from dashboard.services.device_manager import DeviceManager
|
||||
from dashboard.services.websocket_manager import WebSocketManager
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""Run async coroutine from sync context and wait for completion."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# We're already in an async context, use run_coroutine_threadsafe
|
||||
import concurrent.futures
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
# Wait for completion with timeout to avoid hanging
|
||||
try:
|
||||
future.result(timeout=5)
|
||||
except concurrent.futures.TimeoutError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
except RuntimeError:
|
||||
# No running loop, use asyncio.run
|
||||
asyncio.run(coro)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveTask:
|
||||
"""An active task being executed."""
|
||||
|
||||
task_id: str
|
||||
device_id: str
|
||||
task: str
|
||||
status: TaskStatus
|
||||
current_step: int
|
||||
max_steps: int
|
||||
started_at: datetime
|
||||
updated_at: datetime
|
||||
finished_at: Optional[datetime] = None
|
||||
future: Optional[Future] = None
|
||||
stop_requested: bool = False
|
||||
current_action: Optional[dict] = None
|
||||
thinking: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
completion_message: Optional[str] = None
|
||||
|
||||
|
||||
class TaskExecutor:
|
||||
"""Execute tasks on devices with thread pool."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_manager: DeviceManager,
|
||||
ws_manager: Optional[WebSocketManager] = None,
|
||||
max_workers: int = 10,
|
||||
):
|
||||
"""Initialize the task executor.
|
||||
|
||||
Args:
|
||||
device_manager: Device manager instance
|
||||
ws_manager: WebSocket manager for real-time updates (optional)
|
||||
max_workers: Maximum number of concurrent tasks
|
||||
"""
|
||||
self.device_manager = device_manager
|
||||
self.ws_manager = ws_manager
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.active_tasks: Dict[str, ActiveTask] = {}
|
||||
self.task_history: Dict[str, TaskSchema] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def set_ws_manager(self, ws_manager: WebSocketManager):
|
||||
"""Set the WebSocket manager.
|
||||
|
||||
Args:
|
||||
ws_manager: WebSocket manager instance
|
||||
"""
|
||||
self.ws_manager = ws_manager
|
||||
|
||||
async def execute_task(self, request: TaskCreateRequest) -> str:
|
||||
"""Execute task on device.
|
||||
|
||||
Args:
|
||||
request: Task creation request
|
||||
|
||||
Returns:
|
||||
Task ID
|
||||
"""
|
||||
task_id = f"task_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Create active task
|
||||
active_task = ActiveTask(
|
||||
task_id=task_id,
|
||||
device_id=request.device_id,
|
||||
task=request.task,
|
||||
status=TaskStatus.RUNNING,
|
||||
current_step=0,
|
||||
max_steps=request.max_steps,
|
||||
started_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self.active_tasks[task_id] = active_task
|
||||
|
||||
# Notify WebSocket
|
||||
if self.ws_manager:
|
||||
await self.ws_manager.broadcast_task_started(task_id, request.dict())
|
||||
|
||||
# Submit to thread pool
|
||||
future = self.executor.submit(
|
||||
self._run_task,
|
||||
task_id,
|
||||
request,
|
||||
)
|
||||
active_task.future = future
|
||||
|
||||
return task_id
|
||||
|
||||
def _run_task(self, task_id: str, request: TaskCreateRequest):
|
||||
"""Run task in thread pool.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
request: Task creation request
|
||||
"""
|
||||
import os
|
||||
|
||||
# Get model config from request
|
||||
model_config = ModelConfig(
|
||||
base_url=request.base_url,
|
||||
model_name=request.model_name,
|
||||
api_key=request.api_key,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
frequency_penalty=request.frequency_penalty,
|
||||
lang=request.lang,
|
||||
)
|
||||
|
||||
# Get agent config
|
||||
agent_config = AgentConfig(
|
||||
max_steps=request.max_steps,
|
||||
device_id=request.device_id,
|
||||
lang=request.lang,
|
||||
step_callback=lambda result: self._step_callback(task_id, result),
|
||||
before_action_callback=lambda action: self._before_action_callback(
|
||||
task_id, action
|
||||
),
|
||||
)
|
||||
|
||||
# Create agent
|
||||
agent = PhoneAgent(model_config=model_config, agent_config=agent_config)
|
||||
|
||||
try:
|
||||
# Acquire device
|
||||
if not self.device_manager.acquire_device(request.device_id):
|
||||
raise Exception(f"Device {request.device_id} is not available")
|
||||
|
||||
try:
|
||||
# Run task
|
||||
result = agent.run(request.task)
|
||||
|
||||
# Update task status
|
||||
with self._lock:
|
||||
if task_id in self.active_tasks:
|
||||
task = self.active_tasks[task_id]
|
||||
task.status = (
|
||||
TaskStatus.COMPLETED if result else TaskStatus.FAILED
|
||||
)
|
||||
task.finished_at = datetime.now()
|
||||
task.updated_at = datetime.now()
|
||||
|
||||
finally:
|
||||
# Release device
|
||||
self.device_manager.release_device(request.device_id)
|
||||
|
||||
# Broadcast device status update
|
||||
if self.ws_manager:
|
||||
device = self.device_manager._devices.get(request.device_id)
|
||||
if device:
|
||||
_run_async(
|
||||
self.ws_manager.broadcast_device_update(
|
||||
request.device_id,
|
||||
{
|
||||
"status": device.status.value,
|
||||
"is_connected": device.is_connected,
|
||||
"current_app": device.current_app,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Notify completion
|
||||
if self.ws_manager:
|
||||
with self._lock:
|
||||
task = self.active_tasks.get(task_id)
|
||||
task_status = task.status if task else TaskStatus.COMPLETED
|
||||
message = result.message if hasattr(result, "message") else None
|
||||
|
||||
_run_async(
|
||||
self.ws_manager.broadcast_task_completed(
|
||||
task_id,
|
||||
request.device_id,
|
||||
task_status,
|
||||
message,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Release device on error
|
||||
self.device_manager.release_device(request.device_id)
|
||||
|
||||
# Broadcast device status update
|
||||
if self.ws_manager:
|
||||
device = self.device_manager._devices.get(request.device_id)
|
||||
if device:
|
||||
_run_async(
|
||||
self.ws_manager.broadcast_device_update(
|
||||
request.device_id,
|
||||
{
|
||||
"status": device.status.value,
|
||||
"is_connected": device.is_connected,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Update task status
|
||||
with self._lock:
|
||||
if task_id in self.active_tasks:
|
||||
task = self.active_tasks[task_id]
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error = str(e)
|
||||
task.finished_at = datetime.now()
|
||||
task.updated_at = datetime.now()
|
||||
|
||||
# Notify error
|
||||
if self.ws_manager:
|
||||
_run_async(
|
||||
self.ws_manager.broadcast_task_failed(
|
||||
task_id, request.device_id, str(e)
|
||||
)
|
||||
)
|
||||
|
||||
finally:
|
||||
# Move to history
|
||||
with self._lock:
|
||||
if task_id in self.active_tasks:
|
||||
active_task = self.active_tasks.pop(task_id)
|
||||
task_schema = self._active_task_to_schema(active_task)
|
||||
self.task_history[task_id] = task_schema
|
||||
|
||||
# Trim history
|
||||
if len(self.task_history) > config.MAX_TASK_HISTORY:
|
||||
oldest = min(self.task_history.items(), key=lambda x: x[1].started_at)
|
||||
del self.task_history[oldest[0]]
|
||||
|
||||
def _step_callback(
|
||||
self, task_id: str, result: StepResult
|
||||
) -> Optional[str]:
|
||||
"""Callback after each step.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
result: Step result
|
||||
|
||||
Returns:
|
||||
"stop" to interrupt, new task to switch, None to continue
|
||||
"""
|
||||
with self._lock:
|
||||
if task_id not in self.active_tasks:
|
||||
return None
|
||||
|
||||
task = self.active_tasks[task_id]
|
||||
|
||||
# Check if stop was requested
|
||||
if task.stop_requested:
|
||||
return "stop"
|
||||
|
||||
# Update task state
|
||||
task.current_step = result.step_count
|
||||
task.updated_at = datetime.now()
|
||||
task.thinking = result.thinking
|
||||
task.current_action = result.action
|
||||
|
||||
# Store completion message when finished
|
||||
if result.finished and result.message:
|
||||
task.completion_message = result.message
|
||||
|
||||
# Notify WebSocket
|
||||
if self.ws_manager:
|
||||
_run_async(
|
||||
self.ws_manager.broadcast_step_update(
|
||||
task_id=task_id,
|
||||
device_id=task.device_id,
|
||||
step=result.step_count,
|
||||
action=result.action,
|
||||
thinking=result.thinking,
|
||||
finished=result.finished,
|
||||
success=result.success,
|
||||
message=result.message,
|
||||
)
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _before_action_callback(self, task_id: str, action: dict) -> Optional[dict]:
|
||||
"""Callback before executing action.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
action: Action to execute
|
||||
|
||||
Returns:
|
||||
Modified action or None to proceed as-is
|
||||
"""
|
||||
# Can be used for action validation/logging
|
||||
return None
|
||||
|
||||
async def stop_task(self, task_id: str):
|
||||
"""Stop running task.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
"""
|
||||
with self._lock:
|
||||
if task_id not in self.active_tasks:
|
||||
return
|
||||
|
||||
task = self.active_tasks[task_id]
|
||||
task.stop_requested = True
|
||||
|
||||
async def get_task_status(self, task_id: str) -> Optional[TaskSchema]:
|
||||
"""Get task status.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
Task schema or None if not found
|
||||
"""
|
||||
with self._lock:
|
||||
if task_id in self.active_tasks:
|
||||
return self._active_task_to_schema(self.active_tasks[task_id])
|
||||
elif task_id in self.task_history:
|
||||
return self.task_history[task_id]
|
||||
|
||||
return None
|
||||
|
||||
async def list_tasks(self) -> list[TaskSchema]:
|
||||
"""List all tasks (active and recent).
|
||||
|
||||
Returns:
|
||||
List of task schemas
|
||||
"""
|
||||
with self._lock:
|
||||
active_schemas = [
|
||||
self._active_task_to_schema(t) for t in self.active_tasks.values()
|
||||
]
|
||||
history_schemas = list(self.task_history.values())
|
||||
|
||||
return active_schemas + history_schemas
|
||||
|
||||
def _active_task_to_schema(self, active_task: ActiveTask) -> TaskSchema:
|
||||
"""Convert ActiveTask to TaskSchema.
|
||||
|
||||
Args:
|
||||
active_task: Active task
|
||||
|
||||
Returns:
|
||||
Task schema
|
||||
"""
|
||||
return TaskSchema(
|
||||
task_id=active_task.task_id,
|
||||
device_id=active_task.device_id,
|
||||
task=active_task.task,
|
||||
status=active_task.status,
|
||||
current_step=active_task.current_step,
|
||||
max_steps=active_task.max_steps,
|
||||
current_action=active_task.current_action,
|
||||
thinking=active_task.thinking,
|
||||
started_at=active_task.started_at,
|
||||
updated_at=active_task.updated_at,
|
||||
finished_at=active_task.finished_at,
|
||||
error=active_task.error,
|
||||
completion_message=active_task.completion_message,
|
||||
)
|
||||
|
||||
def get_active_task_count(self) -> int:
|
||||
"""Get number of active tasks.
|
||||
|
||||
Returns:
|
||||
Active task count
|
||||
"""
|
||||
with self._lock:
|
||||
return len(self.active_tasks)
|
||||
338
dashboard/services/websocket_manager.py
Normal file
338
dashboard/services/websocket_manager.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
WebSocket manager for real-time updates.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from dashboard.models.ws_messages import (
|
||||
ScreenshotUpdate,
|
||||
WSMessage,
|
||||
WSMessageType,
|
||||
)
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
"""Manage WebSocket connections for real-time updates."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the WebSocket manager."""
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.client_subscriptions: Dict[str, set[str]] = {} # client_id -> set of device_ids
|
||||
|
||||
async def connect(self, websocket: WebSocket, client_id: Optional[str] = None) -> str:
|
||||
"""Accept and store connection.
|
||||
|
||||
Args:
|
||||
websocket: WebSocket connection
|
||||
client_id: Optional client ID (auto-generated if not provided)
|
||||
|
||||
Returns:
|
||||
Client ID
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
if client_id is None:
|
||||
client_id = f"client_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
self.active_connections[client_id] = websocket
|
||||
self.client_subscriptions[client_id] = set()
|
||||
|
||||
return client_id
|
||||
|
||||
def disconnect(self, client_id: str):
|
||||
"""Remove connection.
|
||||
|
||||
Args:
|
||||
client_id: Client ID
|
||||
"""
|
||||
if client_id in self.active_connections:
|
||||
del self.active_connections[client_id]
|
||||
|
||||
if client_id in self.client_subscriptions:
|
||||
del self.client_subscriptions[client_id]
|
||||
|
||||
async def send_to_client(
|
||||
self, client_id: str, message: WSMessage | Dict[str, Any]
|
||||
):
|
||||
"""Send message to specific client.
|
||||
|
||||
Args:
|
||||
client_id: Client ID
|
||||
message: Message to send
|
||||
"""
|
||||
if client_id not in self.active_connections:
|
||||
return
|
||||
|
||||
websocket = self.active_connections[client_id]
|
||||
|
||||
# Convert dict to WSMessage if needed
|
||||
if isinstance(message, dict):
|
||||
message = WSMessage(
|
||||
type=WSMessageType(message.get("type", "error")),
|
||||
data=message.get("data", {}),
|
||||
timestamp=message.get("timestamp", datetime.now()),
|
||||
)
|
||||
|
||||
try:
|
||||
await websocket.send_json(message.model_dump(mode="json"))
|
||||
except Exception:
|
||||
# Connection may be closed
|
||||
self.disconnect(client_id)
|
||||
|
||||
async def broadcast(self, message: WSMessage | Dict[str, Any]):
|
||||
"""Broadcast message to all connected clients.
|
||||
|
||||
Args:
|
||||
message: Message to broadcast
|
||||
"""
|
||||
# Convert dict to WSMessage if needed
|
||||
if isinstance(message, dict):
|
||||
message = WSMessage(
|
||||
type=WSMessageType(message.get("type", "error")),
|
||||
data=message.get("data", {}),
|
||||
timestamp=message.get("timestamp", datetime.now()),
|
||||
)
|
||||
|
||||
# Create list of clients to avoid modification during iteration
|
||||
clients = list(self.active_connections.items())
|
||||
|
||||
for client_id, websocket in clients:
|
||||
try:
|
||||
await websocket.send_json(message.model_dump(mode="json"))
|
||||
except Exception:
|
||||
self.disconnect(client_id)
|
||||
|
||||
async def broadcast_to_device_subscribers(
|
||||
self, device_id: str, message: WSMessage | Dict[str, Any]
|
||||
):
|
||||
"""Broadcast message to clients subscribed to a device.
|
||||
|
||||
Args:
|
||||
device_id: Device ID
|
||||
message: Message to broadcast
|
||||
"""
|
||||
# Convert dict to WSMessage if needed
|
||||
if isinstance(message, dict):
|
||||
message = WSMessage(
|
||||
type=WSMessageType(message.get("type", "error")),
|
||||
data=message.get("data", {}),
|
||||
timestamp=message.get("timestamp", datetime.now()),
|
||||
)
|
||||
|
||||
# Find clients subscribed to this device
|
||||
for client_id, subscriptions in self.client_subscriptions.items():
|
||||
if device_id in subscriptions or "*" in subscriptions:
|
||||
await self.send_to_client(client_id, message)
|
||||
|
||||
def subscribe_to_device(self, client_id: str, device_id: str):
|
||||
"""Subscribe client to device updates.
|
||||
|
||||
Args:
|
||||
client_id: Client ID
|
||||
device_id: Device ID (use "*" for all devices)
|
||||
"""
|
||||
if client_id in self.client_subscriptions:
|
||||
self.client_subscriptions[client_id].add(device_id)
|
||||
|
||||
def unsubscribe_from_device(self, client_id: str, device_id: str):
|
||||
"""Unsubscribe client from device updates.
|
||||
|
||||
Args:
|
||||
client_id: Client ID
|
||||
device_id: Device ID
|
||||
"""
|
||||
if client_id in self.client_subscriptions:
|
||||
self.client_subscriptions[client_id].discard(device_id)
|
||||
|
||||
# Convenience methods for specific message types
|
||||
|
||||
async def broadcast_device_update(self, device_id: str, device_data: Dict[str, Any]):
|
||||
"""Broadcast device update.
|
||||
|
||||
Args:
|
||||
device_id: Device ID
|
||||
device_data: Device data
|
||||
"""
|
||||
await self.broadcast_to_device_subscribers(
|
||||
device_id,
|
||||
WSMessage(
|
||||
type=WSMessageType.DEVICE_UPDATE,
|
||||
data={"device_id": device_id, **device_data},
|
||||
),
|
||||
)
|
||||
|
||||
async def broadcast_task_started(self, task_id: str, task_data: Dict[str, Any]):
|
||||
"""Broadcast task started.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
task_data: Task data
|
||||
"""
|
||||
await self.broadcast(
|
||||
WSMessage(
|
||||
type=WSMessageType.TASK_STARTED,
|
||||
data={"task_id": task_id, **task_data},
|
||||
)
|
||||
)
|
||||
|
||||
async def broadcast_step_update(
|
||||
self,
|
||||
task_id: str,
|
||||
device_id: str,
|
||||
step: int,
|
||||
action: Optional[Dict] = None,
|
||||
thinking: Optional[str] = None,
|
||||
finished: bool = False,
|
||||
success: bool = True,
|
||||
message: Optional[str] = None,
|
||||
):
|
||||
"""Broadcast task step update.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
device_id: Device ID
|
||||
step: Step number
|
||||
action: Current action
|
||||
thinking: AI reasoning
|
||||
finished: Whether task is finished
|
||||
success: Whether step succeeded
|
||||
message: Status message
|
||||
"""
|
||||
await self.broadcast_to_device_subscribers(
|
||||
device_id,
|
||||
WSMessage(
|
||||
type=WSMessageType.TASK_STEP,
|
||||
data={
|
||||
"task_id": task_id,
|
||||
"device_id": device_id,
|
||||
"step": step,
|
||||
"action": action,
|
||||
"thinking": thinking,
|
||||
"finished": finished,
|
||||
"success": success,
|
||||
"message": message,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
async def broadcast_task_completed(
|
||||
self, task_id: str, device_id: str, status: str, message: Optional[str] = None
|
||||
):
|
||||
"""Broadcast task completed.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
device_id: Device ID
|
||||
status: Task status
|
||||
message: Completion message
|
||||
"""
|
||||
await self.broadcast(
|
||||
WSMessage(
|
||||
type=WSMessageType.TASK_COMPLETED,
|
||||
data={
|
||||
"task_id": task_id,
|
||||
"device_id": device_id,
|
||||
"status": status,
|
||||
"message": message,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def broadcast_task_failed(self, task_id: str, device_id: str, error: str):
|
||||
"""Broadcast task failed.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
device_id: Device ID
|
||||
error: Error message
|
||||
"""
|
||||
await self.broadcast(
|
||||
WSMessage(
|
||||
type=WSMessageType.TASK_FAILED,
|
||||
data={
|
||||
"task_id": task_id,
|
||||
"device_id": device_id,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def broadcast_task_stopped(self, task_id: str, device_id: str):
|
||||
"""Broadcast task stopped.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
device_id: Device ID
|
||||
"""
|
||||
await self.broadcast(
|
||||
WSMessage(
|
||||
type=WSMessageType.TASK_STOPPED,
|
||||
data={
|
||||
"task_id": task_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
async def broadcast_screenshot(
|
||||
self, device_id: str, screenshot: str, width: int, height: int
|
||||
):
|
||||
"""Broadcast screenshot update.
|
||||
|
||||
Args:
|
||||
device_id: Device ID
|
||||
screenshot: Base64 encoded screenshot
|
||||
width: Image width
|
||||
height: Image height
|
||||
"""
|
||||
await self.broadcast_to_device_subscribers(
|
||||
device_id,
|
||||
WSMessage(
|
||||
type=WSMessageType.SCREENSHOT,
|
||||
data={
|
||||
"device_id": device_id,
|
||||
"screenshot": screenshot,
|
||||
"width": width,
|
||||
"height": height,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
async def broadcast_error(self, error: str, details: Optional[Dict] = None):
|
||||
"""Broadcast error.
|
||||
|
||||
Args:
|
||||
error: Error message
|
||||
details: Additional error details
|
||||
"""
|
||||
await self.broadcast(
|
||||
WSMessage(
|
||||
type=WSMessageType.ERROR,
|
||||
data={
|
||||
"error": error,
|
||||
"details": details,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
def get_connection_count(self) -> int:
|
||||
"""Get number of active connections.
|
||||
|
||||
Returns:
|
||||
Connection count
|
||||
"""
|
||||
return len(self.active_connections)
|
||||
|
||||
def get_client_ids(self) -> list[str]:
|
||||
"""Get list of connected client IDs.
|
||||
|
||||
Returns:
|
||||
List of client IDs
|
||||
"""
|
||||
return list(self.active_connections.keys())
|
||||
Reference in New Issue
Block a user