""" WebSocket manager for real-time updates. """ import asyncio import json import uuid from datetime import datetime from typing import Any, Dict, Optional from fastapi import WebSocket from dashboard.models.ws_messages import ( ScreenshotUpdate, WSMessage, WSMessageType, ) class WebSocketManager: """Manage WebSocket connections for real-time updates.""" def __init__(self): """Initialize the WebSocket manager.""" self.active_connections: Dict[str, WebSocket] = {} self.client_subscriptions: Dict[str, set[str]] = {} # client_id -> set of device_ids async def connect(self, websocket: WebSocket, client_id: Optional[str] = None) -> str: """Accept and store connection. Args: websocket: WebSocket connection client_id: Optional client ID (auto-generated if not provided) Returns: Client ID """ await websocket.accept() if client_id is None: client_id = f"client_{uuid.uuid4().hex[:8]}" self.active_connections[client_id] = websocket self.client_subscriptions[client_id] = set() return client_id def disconnect(self, client_id: str): """Remove connection. Args: client_id: Client ID """ if client_id in self.active_connections: del self.active_connections[client_id] if client_id in self.client_subscriptions: del self.client_subscriptions[client_id] async def send_to_client( self, client_id: str, message: WSMessage | Dict[str, Any] ): """Send message to specific client. Args: client_id: Client ID message: Message to send """ if client_id not in self.active_connections: return websocket = self.active_connections[client_id] # Convert dict to WSMessage if needed if isinstance(message, dict): message = WSMessage( type=WSMessageType(message.get("type", "error")), data=message.get("data", {}), timestamp=message.get("timestamp", datetime.now()), ) try: await websocket.send_json(message.model_dump(mode="json")) except Exception: # Connection may be closed self.disconnect(client_id) async def broadcast(self, message: WSMessage | Dict[str, Any]): """Broadcast message to all connected clients. Args: message: Message to broadcast """ # Convert dict to WSMessage if needed if isinstance(message, dict): message = WSMessage( type=WSMessageType(message.get("type", "error")), data=message.get("data", {}), timestamp=message.get("timestamp", datetime.now()), ) # Create list of clients to avoid modification during iteration clients = list(self.active_connections.items()) for client_id, websocket in clients: try: await websocket.send_json(message.model_dump(mode="json")) except Exception: self.disconnect(client_id) async def broadcast_to_device_subscribers( self, device_id: str, message: WSMessage | Dict[str, Any] ): """Broadcast message to clients subscribed to a device. Args: device_id: Device ID message: Message to broadcast """ # Convert dict to WSMessage if needed if isinstance(message, dict): message = WSMessage( type=WSMessageType(message.get("type", "error")), data=message.get("data", {}), timestamp=message.get("timestamp", datetime.now()), ) # Find clients subscribed to this device for client_id, subscriptions in self.client_subscriptions.items(): if device_id in subscriptions or "*" in subscriptions: await self.send_to_client(client_id, message) def subscribe_to_device(self, client_id: str, device_id: str): """Subscribe client to device updates. Args: client_id: Client ID device_id: Device ID (use "*" for all devices) """ if client_id in self.client_subscriptions: self.client_subscriptions[client_id].add(device_id) def unsubscribe_from_device(self, client_id: str, device_id: str): """Unsubscribe client from device updates. Args: client_id: Client ID device_id: Device ID """ if client_id in self.client_subscriptions: self.client_subscriptions[client_id].discard(device_id) # Convenience methods for specific message types async def broadcast_device_update(self, device_id: str, device_data: Dict[str, Any]): """Broadcast device update. Args: device_id: Device ID device_data: Device data """ await self.broadcast_to_device_subscribers( device_id, WSMessage( type=WSMessageType.DEVICE_UPDATE, data={"device_id": device_id, **device_data}, ), ) async def broadcast_task_started(self, task_id: str, task_data: Dict[str, Any]): """Broadcast task started. Args: task_id: Task ID task_data: Task data """ await self.broadcast( WSMessage( type=WSMessageType.TASK_STARTED, data={"task_id": task_id, **task_data}, ) ) async def broadcast_step_update( self, task_id: str, device_id: str, step: int, action: Optional[Dict] = None, thinking: Optional[str] = None, finished: bool = False, success: bool = True, message: Optional[str] = None, ): """Broadcast task step update. Args: task_id: Task ID device_id: Device ID step: Step number action: Current action thinking: AI reasoning finished: Whether task is finished success: Whether step succeeded message: Status message """ await self.broadcast_to_device_subscribers( device_id, WSMessage( type=WSMessageType.TASK_STEP, data={ "task_id": task_id, "device_id": device_id, "step": step, "action": action, "thinking": thinking, "finished": finished, "success": success, "message": message, }, ), ) async def broadcast_task_completed( self, task_id: str, device_id: str, status: str, message: Optional[str] = None ): """Broadcast task completed. Args: task_id: Task ID device_id: Device ID status: Task status message: Completion message """ await self.broadcast( WSMessage( type=WSMessageType.TASK_COMPLETED, data={ "task_id": task_id, "device_id": device_id, "status": status, "message": message, }, ) ) async def broadcast_task_failed(self, task_id: str, device_id: str, error: str): """Broadcast task failed. Args: task_id: Task ID device_id: Device ID error: Error message """ await self.broadcast( WSMessage( type=WSMessageType.TASK_FAILED, data={ "task_id": task_id, "device_id": device_id, "error": error, }, ) ) async def broadcast_task_stopped(self, task_id: str, device_id: str): """Broadcast task stopped. Args: task_id: Task ID device_id: Device ID """ await self.broadcast( WSMessage( type=WSMessageType.TASK_STOPPED, data={ "task_id": task_id, "device_id": device_id, }, ) ) async def broadcast_screenshot( self, device_id: str, screenshot: str, width: int, height: int ): """Broadcast screenshot update. Args: device_id: Device ID screenshot: Base64 encoded screenshot width: Image width height: Image height """ await self.broadcast_to_device_subscribers( device_id, WSMessage( type=WSMessageType.SCREENSHOT, data={ "device_id": device_id, "screenshot": screenshot, "width": width, "height": height, }, ), ) async def broadcast_error(self, error: str, details: Optional[Dict] = None): """Broadcast error. Args: error: Error message details: Additional error details """ await self.broadcast( WSMessage( type=WSMessageType.ERROR, data={ "error": error, "details": details, }, ) ) def get_connection_count(self) -> int: """Get number of active connections. Returns: Connection count """ return len(self.active_connections) def get_client_ids(self) -> list[str]: """Get list of connected client IDs. Returns: List of client IDs """ return list(self.active_connections.keys())