From a223d63088247f55cdc66d21fe219e9f1d8e8f25 Mon Sep 17 00:00:00 2001 From: "let5sne.win10" Date: Sat, 10 Jan 2026 01:55:16 +0800 Subject: [PATCH] Add multi-device support for Video Learning Agent - Add device-session bidirectional mapping (_device_sessions) - Integrate DeviceManager lock mechanism (acquire/release) - Add device-level API endpoints: - GET /devices/{device_id}/sessions - List sessions on device - POST /devices/{device_id}/stop-all - Stop all sessions on device - GET /devices/sessions - Overview of all devices - Update session cleanup to maintain mapping consistency - Prevent concurrent sessions on same device Co-Authored-By: Claude --- dashboard/api/video_learning.py | 217 ++++++++++++++++++++++++-------- 1 file changed, 165 insertions(+), 52 deletions(-) diff --git a/dashboard/api/video_learning.py b/dashboard/api/video_learning.py index f235bee..4438df2 100644 --- a/dashboard/api/video_learning.py +++ b/dashboard/api/video_learning.py @@ -64,7 +64,11 @@ class VideoInfo(BaseModel): # Global session storage (in production, use database) +# 会话存储:session_id -> VideoLearningAgent _active_sessions: Dict[str, VideoLearningAgent] = {} +# 设备-会话映射:device_id -> {session_ids} +# 支持多设备同时运行,每个设备可以有多个会话(但同一时间只能有一个活跃的) +_device_sessions: Dict[str, set[str]] = {} @router.post("/sessions", response_model=Dict[str, str]) @@ -81,67 +85,86 @@ async def create_session( if not device.is_connected: raise HTTPException(status_code=400, detail="Device not connected") - if device.status == "busy": - raise HTTPException(status_code=409, detail="Device is busy") + # 使用设备锁机制确保同一设备不会同时运行多个会话 + if not device_manager.acquire_device(request.device_id): + raise HTTPException(status_code=409, detail="Device is busy - another session is running on this device") - # Create model config from environment - model_config = ModelConfig( - base_url=config.MODEL_BASE_URL, - model_name=config.MODEL_NAME, - api_key=config.MODEL_API_KEY, - max_tokens=config.MAX_TOKENS, - temperature=config.TEMPERATURE, - top_p=config.TOP_P, - frequency_penalty=config.FREQUENCY_PENALTY, - lang="cn", - ) + try: + # Create model config from environment + model_config = ModelConfig( + base_url=config.MODEL_BASE_URL, + model_name=config.MODEL_NAME, + api_key=config.MODEL_API_KEY, + max_tokens=config.MAX_TOKENS, + temperature=config.TEMPERATURE, + top_p=config.TOP_P, + frequency_penalty=config.FREQUENCY_PENALTY, + lang="cn", + ) - # Create video learning agent - agent = VideoLearningAgent( - model_config=model_config, - platform=request.platform, - output_dir=config.VIDEO_LEARNING_OUTPUT_DIR, - enable_analysis=request.enable_analysis, - ) + # Create video learning agent + agent = VideoLearningAgent( + model_config=model_config, + platform=request.platform, + output_dir=config.VIDEO_LEARNING_OUTPUT_DIR, + enable_analysis=request.enable_analysis, + ) - # Setup callbacks for real-time updates - session_id = None + # Setup callbacks for real-time updates + session_id = None - def on_video_watched(record): - """Callback when a video is watched.""" - # Broadcast via WebSocket - if session_id: - # This would be integrated with WebSocket manager - pass + def on_video_watched(record): + """Callback when a video is watched.""" + if session_id: + pass # WebSocket broadcast would go here - def on_progress_update(current, total): - """Callback for progress updates.""" - if session_id: - # Broadcast progress - pass + def on_progress_update(current, total): + """Callback for progress updates.""" + if session_id: + pass # Progress broadcast would go here - def on_session_complete(session): - """Callback when session completes.""" - if session_id and session_id in _active_sessions: - del _active_sessions[session_id] + def on_session_complete(session): + """Callback when session completes - release device lock and clean up.""" + if session_id: + # 从设备-会话映射中移除 + if request.device_id in _device_sessions: + _device_sessions[request.device_id].discard(session_id) + if not _device_sessions[request.device_id]: + del _device_sessions[request.device_id] - agent.on_video_watched = on_video_watched - agent.on_progress_update = on_progress_update - agent.on_session_complete = on_session_complete + # 从活跃会话中移除(但保留历史记录可查询) + if session_id in _active_sessions: + # 标记为完成但不删除,允许用户查看结果 + pass - # Start session - session_id = agent.start_session( - device_id=request.device_id, - target_count=request.target_count, - category=request.category, - watch_duration=request.watch_duration, - max_steps=500, - ) + # 释放设备锁 + device_manager.release_device(request.device_id) - # Store session - _active_sessions[session_id] = agent + agent.on_video_watched = on_video_watched + agent.on_progress_update = on_progress_update + agent.on_session_complete = on_session_complete - return {"session_id": session_id, "status": "created"} + # Start session + session_id = agent.start_session( + device_id=request.device_id, + target_count=request.target_count, + category=request.category, + watch_duration=request.watch_duration, + max_steps=500, + ) + + # Store session in both mappings + _active_sessions[session_id] = agent + if request.device_id not in _device_sessions: + _device_sessions[request.device_id] = set() + _device_sessions[request.device_id].add(session_id) + + return {"session_id": session_id, "status": "created"} + + except Exception as e: + # 创建失败时释放设备锁 + device_manager.release_device(request.device_id) + raise HTTPException(status_code=500, detail=f"Failed to create session: {str(e)}") @router.post("/sessions/{session_id}/start", response_model=Dict[str, str]) @@ -330,11 +353,26 @@ async def list_sessions() -> List[str]: @router.delete("/sessions/{session_id}", response_model=Dict[str, str]) async def delete_session(session_id: str) -> Dict[str, str]: - """Delete a session.""" + """Delete a session and clean up device mapping.""" if session_id not in _active_sessions: raise HTTPException(status_code=404, detail="Session not found") + # 获取设备ID以便清理映射 + agent = _active_sessions[session_id] + device_id = None + if agent.current_session: + # 从 session 中获取 device_id + device_id = agent._device_id + + # 删除会话 del _active_sessions[session_id] + + # 清理设备-会话映射 + if device_id and device_id in _device_sessions: + _device_sessions[device_id].discard(session_id) + if not _device_sessions[device_id]: + del _device_sessions[device_id] + return {"session_id": session_id, "status": "deleted"} @@ -361,3 +399,78 @@ async def analyze_session(session_id: str) -> Dict[str, Any]: "analyzed_count": analyzed_count, "status": "analysis_triggered" } + + +# ==================== 设备级别的 API ==================== + +@router.get("/devices/{device_id}/sessions") +async def get_device_sessions(device_id: str) -> List[Dict[str, Any]]: + """获取指定设备上的所有会话""" + if device_id not in _device_sessions: + return [] + + session_ids = _device_sessions[device_id] + result = [] + for sid in session_ids: + if sid in _active_sessions: + agent = _active_sessions[sid] + progress = agent.get_session_progress() + result.append({ + "session_id": sid, + "platform": progress.get("platform", ""), + "target_count": progress.get("target_count", 0), + "watched_count": progress.get("watched_count", 0), + "progress_percent": progress.get("progress_percent", 0), + "is_active": progress.get("is_active", False), + "is_paused": progress.get("is_paused", False), + }) + return result + + +@router.post("/devices/{device_id}/stop-all", response_model=Dict[str, Any]) +async def stop_all_device_sessions(device_id: str) -> Dict[str, Any]: + """停止指定设备上的所有活跃会话""" + if device_id not in _device_sessions: + return {"device_id": device_id, "stopped": 0, "message": "No sessions on this device"} + + count = 0 + session_ids = list(_device_sessions[device_id]) + for session_id in session_ids: + if session_id in _active_sessions: + agent = _active_sessions[session_id] + # 只停止活跃的会话 + if agent.current_session and agent.current_session.is_active: + agent.stop_session() + count += 1 + + return { + "device_id": device_id, + "stopped": count, + "message": f"Stopped {count} active session(s) on device {device_id}" + } + + +@router.get("/devices/sessions", response_model=Dict[str, Any]) +async def get_all_device_sessions() -> Dict[str, Any]: + """获取所有设备上的会话概览""" + result = { + "devices": {}, + "total_devices": len(_device_sessions), + "total_sessions": len(_active_sessions) + } + + for device_id, session_ids in _device_sessions.items(): + active_count = 0 + for sid in session_ids: + if sid in _active_sessions: + agent = _active_sessions[sid] + if agent.current_session and agent.current_session.is_active: + active_count += 1 + + result["devices"][device_id] = { + "total_sessions": len(session_ids), + "active_sessions": active_count, + "session_ids": list(session_ids) + } + + return result