Features: - Web Dashboard: FastAPI-based dashboard with Vue.js frontend - Multi-device support (ADB, HDC, iOS) - Real-time WebSocket updates for task progress - Device management with status tracking - Task queue with execution controls (start/stop/re-execute) - Detailed task information display (thinking, actions, completion messages) - Screenshot viewing per device - LAN deployment support with configurable CORS - Callback Hooks: Interrupt and modify task execution - step_callback: Called after each step with StepResult - before_action_callback: Called before executing action - Support for task interruption and dynamic task switching - Example scripts demonstrating callback usage - Configuration: Environment-based configuration - .env file support for all settings - .env.example template with documentation - Model API configuration (base URL, model name, API key) - Dashboard configuration (host, port, CORS, device type) - Phone agent configuration (delays, max steps, language) Technical improvements: - Fixed forward reference issue with StepResult - Added package exports for callback types and configs - Enhanced dependencies with FastAPI, WebSocket support - Thread-safe task execution with device locking - Async WebSocket broadcasting from sync thread pool Co-Authored-By: Claude <noreply@anthropic.com>
339 lines
9.8 KiB
Python
339 lines
9.8 KiB
Python
"""
|
|
WebSocket manager for real-time updates.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Any, Dict, Optional
|
|
|
|
from fastapi import WebSocket
|
|
|
|
from dashboard.models.ws_messages import (
|
|
ScreenshotUpdate,
|
|
WSMessage,
|
|
WSMessageType,
|
|
)
|
|
|
|
|
|
class WebSocketManager:
|
|
"""Manage WebSocket connections for real-time updates."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the WebSocket manager."""
|
|
self.active_connections: Dict[str, WebSocket] = {}
|
|
self.client_subscriptions: Dict[str, set[str]] = {} # client_id -> set of device_ids
|
|
|
|
async def connect(self, websocket: WebSocket, client_id: Optional[str] = None) -> str:
|
|
"""Accept and store connection.
|
|
|
|
Args:
|
|
websocket: WebSocket connection
|
|
client_id: Optional client ID (auto-generated if not provided)
|
|
|
|
Returns:
|
|
Client ID
|
|
"""
|
|
await websocket.accept()
|
|
|
|
if client_id is None:
|
|
client_id = f"client_{uuid.uuid4().hex[:8]}"
|
|
|
|
self.active_connections[client_id] = websocket
|
|
self.client_subscriptions[client_id] = set()
|
|
|
|
return client_id
|
|
|
|
def disconnect(self, client_id: str):
|
|
"""Remove connection.
|
|
|
|
Args:
|
|
client_id: Client ID
|
|
"""
|
|
if client_id in self.active_connections:
|
|
del self.active_connections[client_id]
|
|
|
|
if client_id in self.client_subscriptions:
|
|
del self.client_subscriptions[client_id]
|
|
|
|
async def send_to_client(
|
|
self, client_id: str, message: WSMessage | Dict[str, Any]
|
|
):
|
|
"""Send message to specific client.
|
|
|
|
Args:
|
|
client_id: Client ID
|
|
message: Message to send
|
|
"""
|
|
if client_id not in self.active_connections:
|
|
return
|
|
|
|
websocket = self.active_connections[client_id]
|
|
|
|
# Convert dict to WSMessage if needed
|
|
if isinstance(message, dict):
|
|
message = WSMessage(
|
|
type=WSMessageType(message.get("type", "error")),
|
|
data=message.get("data", {}),
|
|
timestamp=message.get("timestamp", datetime.now()),
|
|
)
|
|
|
|
try:
|
|
await websocket.send_json(message.model_dump(mode="json"))
|
|
except Exception:
|
|
# Connection may be closed
|
|
self.disconnect(client_id)
|
|
|
|
async def broadcast(self, message: WSMessage | Dict[str, Any]):
|
|
"""Broadcast message to all connected clients.
|
|
|
|
Args:
|
|
message: Message to broadcast
|
|
"""
|
|
# Convert dict to WSMessage if needed
|
|
if isinstance(message, dict):
|
|
message = WSMessage(
|
|
type=WSMessageType(message.get("type", "error")),
|
|
data=message.get("data", {}),
|
|
timestamp=message.get("timestamp", datetime.now()),
|
|
)
|
|
|
|
# Create list of clients to avoid modification during iteration
|
|
clients = list(self.active_connections.items())
|
|
|
|
for client_id, websocket in clients:
|
|
try:
|
|
await websocket.send_json(message.model_dump(mode="json"))
|
|
except Exception:
|
|
self.disconnect(client_id)
|
|
|
|
async def broadcast_to_device_subscribers(
|
|
self, device_id: str, message: WSMessage | Dict[str, Any]
|
|
):
|
|
"""Broadcast message to clients subscribed to a device.
|
|
|
|
Args:
|
|
device_id: Device ID
|
|
message: Message to broadcast
|
|
"""
|
|
# Convert dict to WSMessage if needed
|
|
if isinstance(message, dict):
|
|
message = WSMessage(
|
|
type=WSMessageType(message.get("type", "error")),
|
|
data=message.get("data", {}),
|
|
timestamp=message.get("timestamp", datetime.now()),
|
|
)
|
|
|
|
# Find clients subscribed to this device
|
|
for client_id, subscriptions in self.client_subscriptions.items():
|
|
if device_id in subscriptions or "*" in subscriptions:
|
|
await self.send_to_client(client_id, message)
|
|
|
|
def subscribe_to_device(self, client_id: str, device_id: str):
|
|
"""Subscribe client to device updates.
|
|
|
|
Args:
|
|
client_id: Client ID
|
|
device_id: Device ID (use "*" for all devices)
|
|
"""
|
|
if client_id in self.client_subscriptions:
|
|
self.client_subscriptions[client_id].add(device_id)
|
|
|
|
def unsubscribe_from_device(self, client_id: str, device_id: str):
|
|
"""Unsubscribe client from device updates.
|
|
|
|
Args:
|
|
client_id: Client ID
|
|
device_id: Device ID
|
|
"""
|
|
if client_id in self.client_subscriptions:
|
|
self.client_subscriptions[client_id].discard(device_id)
|
|
|
|
# Convenience methods for specific message types
|
|
|
|
async def broadcast_device_update(self, device_id: str, device_data: Dict[str, Any]):
|
|
"""Broadcast device update.
|
|
|
|
Args:
|
|
device_id: Device ID
|
|
device_data: Device data
|
|
"""
|
|
await self.broadcast_to_device_subscribers(
|
|
device_id,
|
|
WSMessage(
|
|
type=WSMessageType.DEVICE_UPDATE,
|
|
data={"device_id": device_id, **device_data},
|
|
),
|
|
)
|
|
|
|
async def broadcast_task_started(self, task_id: str, task_data: Dict[str, Any]):
|
|
"""Broadcast task started.
|
|
|
|
Args:
|
|
task_id: Task ID
|
|
task_data: Task data
|
|
"""
|
|
await self.broadcast(
|
|
WSMessage(
|
|
type=WSMessageType.TASK_STARTED,
|
|
data={"task_id": task_id, **task_data},
|
|
)
|
|
)
|
|
|
|
async def broadcast_step_update(
|
|
self,
|
|
task_id: str,
|
|
device_id: str,
|
|
step: int,
|
|
action: Optional[Dict] = None,
|
|
thinking: Optional[str] = None,
|
|
finished: bool = False,
|
|
success: bool = True,
|
|
message: Optional[str] = None,
|
|
):
|
|
"""Broadcast task step update.
|
|
|
|
Args:
|
|
task_id: Task ID
|
|
device_id: Device ID
|
|
step: Step number
|
|
action: Current action
|
|
thinking: AI reasoning
|
|
finished: Whether task is finished
|
|
success: Whether step succeeded
|
|
message: Status message
|
|
"""
|
|
await self.broadcast_to_device_subscribers(
|
|
device_id,
|
|
WSMessage(
|
|
type=WSMessageType.TASK_STEP,
|
|
data={
|
|
"task_id": task_id,
|
|
"device_id": device_id,
|
|
"step": step,
|
|
"action": action,
|
|
"thinking": thinking,
|
|
"finished": finished,
|
|
"success": success,
|
|
"message": message,
|
|
},
|
|
),
|
|
)
|
|
|
|
async def broadcast_task_completed(
|
|
self, task_id: str, device_id: str, status: str, message: Optional[str] = None
|
|
):
|
|
"""Broadcast task completed.
|
|
|
|
Args:
|
|
task_id: Task ID
|
|
device_id: Device ID
|
|
status: Task status
|
|
message: Completion message
|
|
"""
|
|
await self.broadcast(
|
|
WSMessage(
|
|
type=WSMessageType.TASK_COMPLETED,
|
|
data={
|
|
"task_id": task_id,
|
|
"device_id": device_id,
|
|
"status": status,
|
|
"message": message,
|
|
},
|
|
)
|
|
)
|
|
|
|
async def broadcast_task_failed(self, task_id: str, device_id: str, error: str):
|
|
"""Broadcast task failed.
|
|
|
|
Args:
|
|
task_id: Task ID
|
|
device_id: Device ID
|
|
error: Error message
|
|
"""
|
|
await self.broadcast(
|
|
WSMessage(
|
|
type=WSMessageType.TASK_FAILED,
|
|
data={
|
|
"task_id": task_id,
|
|
"device_id": device_id,
|
|
"error": error,
|
|
},
|
|
)
|
|
)
|
|
|
|
async def broadcast_task_stopped(self, task_id: str, device_id: str):
|
|
"""Broadcast task stopped.
|
|
|
|
Args:
|
|
task_id: Task ID
|
|
device_id: Device ID
|
|
"""
|
|
await self.broadcast(
|
|
WSMessage(
|
|
type=WSMessageType.TASK_STOPPED,
|
|
data={
|
|
"task_id": task_id,
|
|
"device_id": device_id,
|
|
},
|
|
)
|
|
)
|
|
|
|
async def broadcast_screenshot(
|
|
self, device_id: str, screenshot: str, width: int, height: int
|
|
):
|
|
"""Broadcast screenshot update.
|
|
|
|
Args:
|
|
device_id: Device ID
|
|
screenshot: Base64 encoded screenshot
|
|
width: Image width
|
|
height: Image height
|
|
"""
|
|
await self.broadcast_to_device_subscribers(
|
|
device_id,
|
|
WSMessage(
|
|
type=WSMessageType.SCREENSHOT,
|
|
data={
|
|
"device_id": device_id,
|
|
"screenshot": screenshot,
|
|
"width": width,
|
|
"height": height,
|
|
},
|
|
),
|
|
)
|
|
|
|
async def broadcast_error(self, error: str, details: Optional[Dict] = None):
|
|
"""Broadcast error.
|
|
|
|
Args:
|
|
error: Error message
|
|
details: Additional error details
|
|
"""
|
|
await self.broadcast(
|
|
WSMessage(
|
|
type=WSMessageType.ERROR,
|
|
data={
|
|
"error": error,
|
|
"details": details,
|
|
},
|
|
)
|
|
)
|
|
|
|
def get_connection_count(self) -> int:
|
|
"""Get number of active connections.
|
|
|
|
Returns:
|
|
Connection count
|
|
"""
|
|
return len(self.active_connections)
|
|
|
|
def get_client_ids(self) -> list[str]:
|
|
"""Get list of connected client IDs.
|
|
|
|
Returns:
|
|
List of client IDs
|
|
"""
|
|
return list(self.active_connections.keys())
|