Add multi-device support for Video Learning Agent

- Add device-session bidirectional mapping (_device_sessions)
- Integrate DeviceManager lock mechanism (acquire/release)
- Add device-level API endpoints:
  - GET /devices/{device_id}/sessions - List sessions on device
  - POST /devices/{device_id}/stop-all - Stop all sessions on device
  - GET /devices/sessions - Overview of all devices
- Update session cleanup to maintain mapping consistency
- Prevent concurrent sessions on same device

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
let5sne.win10
2026-01-10 01:55:16 +08:00
parent b97d3f3a9f
commit a223d63088

View File

@@ -64,7 +64,11 @@ class VideoInfo(BaseModel):
# Global session storage (in production, use database)
# 会话存储session_id -> VideoLearningAgent
_active_sessions: Dict[str, VideoLearningAgent] = {}
# 设备-会话映射device_id -> {session_ids}
# 支持多设备同时运行,每个设备可以有多个会话(但同一时间只能有一个活跃的)
_device_sessions: Dict[str, set[str]] = {}
@router.post("/sessions", response_model=Dict[str, str])
@@ -81,67 +85,86 @@ async def create_session(
if not device.is_connected:
raise HTTPException(status_code=400, detail="Device not connected")
if device.status == "busy":
raise HTTPException(status_code=409, detail="Device is busy")
# 使用设备锁机制确保同一设备不会同时运行多个会话
if not device_manager.acquire_device(request.device_id):
raise HTTPException(status_code=409, detail="Device is busy - another session is running on this device")
# Create model config from environment
model_config = ModelConfig(
base_url=config.MODEL_BASE_URL,
model_name=config.MODEL_NAME,
api_key=config.MODEL_API_KEY,
max_tokens=config.MAX_TOKENS,
temperature=config.TEMPERATURE,
top_p=config.TOP_P,
frequency_penalty=config.FREQUENCY_PENALTY,
lang="cn",
)
try:
# Create model config from environment
model_config = ModelConfig(
base_url=config.MODEL_BASE_URL,
model_name=config.MODEL_NAME,
api_key=config.MODEL_API_KEY,
max_tokens=config.MAX_TOKENS,
temperature=config.TEMPERATURE,
top_p=config.TOP_P,
frequency_penalty=config.FREQUENCY_PENALTY,
lang="cn",
)
# Create video learning agent
agent = VideoLearningAgent(
model_config=model_config,
platform=request.platform,
output_dir=config.VIDEO_LEARNING_OUTPUT_DIR,
enable_analysis=request.enable_analysis,
)
# Create video learning agent
agent = VideoLearningAgent(
model_config=model_config,
platform=request.platform,
output_dir=config.VIDEO_LEARNING_OUTPUT_DIR,
enable_analysis=request.enable_analysis,
)
# Setup callbacks for real-time updates
session_id = None
# Setup callbacks for real-time updates
session_id = None
def on_video_watched(record):
"""Callback when a video is watched."""
# Broadcast via WebSocket
if session_id:
# This would be integrated with WebSocket manager
pass
def on_video_watched(record):
"""Callback when a video is watched."""
if session_id:
pass # WebSocket broadcast would go here
def on_progress_update(current, total):
"""Callback for progress updates."""
if session_id:
# Broadcast progress
pass
def on_progress_update(current, total):
"""Callback for progress updates."""
if session_id:
pass # Progress broadcast would go here
def on_session_complete(session):
"""Callback when session completes."""
if session_id and session_id in _active_sessions:
del _active_sessions[session_id]
def on_session_complete(session):
"""Callback when session completes - release device lock and clean up."""
if session_id:
# 从设备-会话映射中移除
if request.device_id in _device_sessions:
_device_sessions[request.device_id].discard(session_id)
if not _device_sessions[request.device_id]:
del _device_sessions[request.device_id]
agent.on_video_watched = on_video_watched
agent.on_progress_update = on_progress_update
agent.on_session_complete = on_session_complete
# 从活跃会话中移除(但保留历史记录可查询)
if session_id in _active_sessions:
# 标记为完成但不删除,允许用户查看结果
pass
# Start session
session_id = agent.start_session(
device_id=request.device_id,
target_count=request.target_count,
category=request.category,
watch_duration=request.watch_duration,
max_steps=500,
)
# 释放设备锁
device_manager.release_device(request.device_id)
# Store session
_active_sessions[session_id] = agent
agent.on_video_watched = on_video_watched
agent.on_progress_update = on_progress_update
agent.on_session_complete = on_session_complete
return {"session_id": session_id, "status": "created"}
# Start session
session_id = agent.start_session(
device_id=request.device_id,
target_count=request.target_count,
category=request.category,
watch_duration=request.watch_duration,
max_steps=500,
)
# Store session in both mappings
_active_sessions[session_id] = agent
if request.device_id not in _device_sessions:
_device_sessions[request.device_id] = set()
_device_sessions[request.device_id].add(session_id)
return {"session_id": session_id, "status": "created"}
except Exception as e:
# 创建失败时释放设备锁
device_manager.release_device(request.device_id)
raise HTTPException(status_code=500, detail=f"Failed to create session: {str(e)}")
@router.post("/sessions/{session_id}/start", response_model=Dict[str, str])
@@ -330,11 +353,26 @@ async def list_sessions() -> List[str]:
@router.delete("/sessions/{session_id}", response_model=Dict[str, str])
async def delete_session(session_id: str) -> Dict[str, str]:
"""Delete a session."""
"""Delete a session and clean up device mapping."""
if session_id not in _active_sessions:
raise HTTPException(status_code=404, detail="Session not found")
# 获取设备ID以便清理映射
agent = _active_sessions[session_id]
device_id = None
if agent.current_session:
# 从 session 中获取 device_id
device_id = agent._device_id
# 删除会话
del _active_sessions[session_id]
# 清理设备-会话映射
if device_id and device_id in _device_sessions:
_device_sessions[device_id].discard(session_id)
if not _device_sessions[device_id]:
del _device_sessions[device_id]
return {"session_id": session_id, "status": "deleted"}
@@ -361,3 +399,78 @@ async def analyze_session(session_id: str) -> Dict[str, Any]:
"analyzed_count": analyzed_count,
"status": "analysis_triggered"
}
# ==================== 设备级别的 API ====================
@router.get("/devices/{device_id}/sessions")
async def get_device_sessions(device_id: str) -> List[Dict[str, Any]]:
"""获取指定设备上的所有会话"""
if device_id not in _device_sessions:
return []
session_ids = _device_sessions[device_id]
result = []
for sid in session_ids:
if sid in _active_sessions:
agent = _active_sessions[sid]
progress = agent.get_session_progress()
result.append({
"session_id": sid,
"platform": progress.get("platform", ""),
"target_count": progress.get("target_count", 0),
"watched_count": progress.get("watched_count", 0),
"progress_percent": progress.get("progress_percent", 0),
"is_active": progress.get("is_active", False),
"is_paused": progress.get("is_paused", False),
})
return result
@router.post("/devices/{device_id}/stop-all", response_model=Dict[str, Any])
async def stop_all_device_sessions(device_id: str) -> Dict[str, Any]:
"""停止指定设备上的所有活跃会话"""
if device_id not in _device_sessions:
return {"device_id": device_id, "stopped": 0, "message": "No sessions on this device"}
count = 0
session_ids = list(_device_sessions[device_id])
for session_id in session_ids:
if session_id in _active_sessions:
agent = _active_sessions[session_id]
# 只停止活跃的会话
if agent.current_session and agent.current_session.is_active:
agent.stop_session()
count += 1
return {
"device_id": device_id,
"stopped": count,
"message": f"Stopped {count} active session(s) on device {device_id}"
}
@router.get("/devices/sessions", response_model=Dict[str, Any])
async def get_all_device_sessions() -> Dict[str, Any]:
"""获取所有设备上的会话概览"""
result = {
"devices": {},
"total_devices": len(_device_sessions),
"total_sessions": len(_active_sessions)
}
for device_id, session_ids in _device_sessions.items():
active_count = 0
for sid in session_ids:
if sid in _active_sessions:
agent = _active_sessions[sid]
if agent.current_session and agent.current_session.is_active:
active_count += 1
result["devices"][device_id] = {
"total_sessions": len(session_ids),
"active_sessions": active_count,
"session_ids": list(session_ids)
}
return result