Add multi-device support for Video Learning Agent
- Add device-session bidirectional mapping (_device_sessions)
- Integrate DeviceManager lock mechanism (acquire/release)
- Add device-level API endpoints:
- GET /devices/{device_id}/sessions - List sessions on device
- POST /devices/{device_id}/stop-all - Stop all sessions on device
- GET /devices/sessions - Overview of all devices
- Update session cleanup to maintain mapping consistency
- Prevent concurrent sessions on same device
Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -64,7 +64,11 @@ class VideoInfo(BaseModel):
|
||||
|
||||
|
||||
# Global session storage (in production, use database)
|
||||
# 会话存储:session_id -> VideoLearningAgent
|
||||
_active_sessions: Dict[str, VideoLearningAgent] = {}
|
||||
# 设备-会话映射:device_id -> {session_ids}
|
||||
# 支持多设备同时运行,每个设备可以有多个会话(但同一时间只能有一个活跃的)
|
||||
_device_sessions: Dict[str, set[str]] = {}
|
||||
|
||||
|
||||
@router.post("/sessions", response_model=Dict[str, str])
|
||||
@@ -81,67 +85,86 @@ async def create_session(
|
||||
if not device.is_connected:
|
||||
raise HTTPException(status_code=400, detail="Device not connected")
|
||||
|
||||
if device.status == "busy":
|
||||
raise HTTPException(status_code=409, detail="Device is busy")
|
||||
# 使用设备锁机制确保同一设备不会同时运行多个会话
|
||||
if not device_manager.acquire_device(request.device_id):
|
||||
raise HTTPException(status_code=409, detail="Device is busy - another session is running on this device")
|
||||
|
||||
# Create model config from environment
|
||||
model_config = ModelConfig(
|
||||
base_url=config.MODEL_BASE_URL,
|
||||
model_name=config.MODEL_NAME,
|
||||
api_key=config.MODEL_API_KEY,
|
||||
max_tokens=config.MAX_TOKENS,
|
||||
temperature=config.TEMPERATURE,
|
||||
top_p=config.TOP_P,
|
||||
frequency_penalty=config.FREQUENCY_PENALTY,
|
||||
lang="cn",
|
||||
)
|
||||
try:
|
||||
# Create model config from environment
|
||||
model_config = ModelConfig(
|
||||
base_url=config.MODEL_BASE_URL,
|
||||
model_name=config.MODEL_NAME,
|
||||
api_key=config.MODEL_API_KEY,
|
||||
max_tokens=config.MAX_TOKENS,
|
||||
temperature=config.TEMPERATURE,
|
||||
top_p=config.TOP_P,
|
||||
frequency_penalty=config.FREQUENCY_PENALTY,
|
||||
lang="cn",
|
||||
)
|
||||
|
||||
# Create video learning agent
|
||||
agent = VideoLearningAgent(
|
||||
model_config=model_config,
|
||||
platform=request.platform,
|
||||
output_dir=config.VIDEO_LEARNING_OUTPUT_DIR,
|
||||
enable_analysis=request.enable_analysis,
|
||||
)
|
||||
# Create video learning agent
|
||||
agent = VideoLearningAgent(
|
||||
model_config=model_config,
|
||||
platform=request.platform,
|
||||
output_dir=config.VIDEO_LEARNING_OUTPUT_DIR,
|
||||
enable_analysis=request.enable_analysis,
|
||||
)
|
||||
|
||||
# Setup callbacks for real-time updates
|
||||
session_id = None
|
||||
# Setup callbacks for real-time updates
|
||||
session_id = None
|
||||
|
||||
def on_video_watched(record):
|
||||
"""Callback when a video is watched."""
|
||||
# Broadcast via WebSocket
|
||||
if session_id:
|
||||
# This would be integrated with WebSocket manager
|
||||
pass
|
||||
def on_video_watched(record):
|
||||
"""Callback when a video is watched."""
|
||||
if session_id:
|
||||
pass # WebSocket broadcast would go here
|
||||
|
||||
def on_progress_update(current, total):
|
||||
"""Callback for progress updates."""
|
||||
if session_id:
|
||||
# Broadcast progress
|
||||
pass
|
||||
def on_progress_update(current, total):
|
||||
"""Callback for progress updates."""
|
||||
if session_id:
|
||||
pass # Progress broadcast would go here
|
||||
|
||||
def on_session_complete(session):
|
||||
"""Callback when session completes."""
|
||||
if session_id and session_id in _active_sessions:
|
||||
del _active_sessions[session_id]
|
||||
def on_session_complete(session):
|
||||
"""Callback when session completes - release device lock and clean up."""
|
||||
if session_id:
|
||||
# 从设备-会话映射中移除
|
||||
if request.device_id in _device_sessions:
|
||||
_device_sessions[request.device_id].discard(session_id)
|
||||
if not _device_sessions[request.device_id]:
|
||||
del _device_sessions[request.device_id]
|
||||
|
||||
agent.on_video_watched = on_video_watched
|
||||
agent.on_progress_update = on_progress_update
|
||||
agent.on_session_complete = on_session_complete
|
||||
# 从活跃会话中移除(但保留历史记录可查询)
|
||||
if session_id in _active_sessions:
|
||||
# 标记为完成但不删除,允许用户查看结果
|
||||
pass
|
||||
|
||||
# Start session
|
||||
session_id = agent.start_session(
|
||||
device_id=request.device_id,
|
||||
target_count=request.target_count,
|
||||
category=request.category,
|
||||
watch_duration=request.watch_duration,
|
||||
max_steps=500,
|
||||
)
|
||||
# 释放设备锁
|
||||
device_manager.release_device(request.device_id)
|
||||
|
||||
# Store session
|
||||
_active_sessions[session_id] = agent
|
||||
agent.on_video_watched = on_video_watched
|
||||
agent.on_progress_update = on_progress_update
|
||||
agent.on_session_complete = on_session_complete
|
||||
|
||||
return {"session_id": session_id, "status": "created"}
|
||||
# Start session
|
||||
session_id = agent.start_session(
|
||||
device_id=request.device_id,
|
||||
target_count=request.target_count,
|
||||
category=request.category,
|
||||
watch_duration=request.watch_duration,
|
||||
max_steps=500,
|
||||
)
|
||||
|
||||
# Store session in both mappings
|
||||
_active_sessions[session_id] = agent
|
||||
if request.device_id not in _device_sessions:
|
||||
_device_sessions[request.device_id] = set()
|
||||
_device_sessions[request.device_id].add(session_id)
|
||||
|
||||
return {"session_id": session_id, "status": "created"}
|
||||
|
||||
except Exception as e:
|
||||
# 创建失败时释放设备锁
|
||||
device_manager.release_device(request.device_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create session: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/start", response_model=Dict[str, str])
|
||||
@@ -330,11 +353,26 @@ async def list_sessions() -> List[str]:
|
||||
|
||||
@router.delete("/sessions/{session_id}", response_model=Dict[str, str])
|
||||
async def delete_session(session_id: str) -> Dict[str, str]:
|
||||
"""Delete a session."""
|
||||
"""Delete a session and clean up device mapping."""
|
||||
if session_id not in _active_sessions:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# 获取设备ID以便清理映射
|
||||
agent = _active_sessions[session_id]
|
||||
device_id = None
|
||||
if agent.current_session:
|
||||
# 从 session 中获取 device_id
|
||||
device_id = agent._device_id
|
||||
|
||||
# 删除会话
|
||||
del _active_sessions[session_id]
|
||||
|
||||
# 清理设备-会话映射
|
||||
if device_id and device_id in _device_sessions:
|
||||
_device_sessions[device_id].discard(session_id)
|
||||
if not _device_sessions[device_id]:
|
||||
del _device_sessions[device_id]
|
||||
|
||||
return {"session_id": session_id, "status": "deleted"}
|
||||
|
||||
|
||||
@@ -361,3 +399,78 @@ async def analyze_session(session_id: str) -> Dict[str, Any]:
|
||||
"analyzed_count": analyzed_count,
|
||||
"status": "analysis_triggered"
|
||||
}
|
||||
|
||||
|
||||
# ==================== 设备级别的 API ====================
|
||||
|
||||
@router.get("/devices/{device_id}/sessions")
|
||||
async def get_device_sessions(device_id: str) -> List[Dict[str, Any]]:
|
||||
"""获取指定设备上的所有会话"""
|
||||
if device_id not in _device_sessions:
|
||||
return []
|
||||
|
||||
session_ids = _device_sessions[device_id]
|
||||
result = []
|
||||
for sid in session_ids:
|
||||
if sid in _active_sessions:
|
||||
agent = _active_sessions[sid]
|
||||
progress = agent.get_session_progress()
|
||||
result.append({
|
||||
"session_id": sid,
|
||||
"platform": progress.get("platform", ""),
|
||||
"target_count": progress.get("target_count", 0),
|
||||
"watched_count": progress.get("watched_count", 0),
|
||||
"progress_percent": progress.get("progress_percent", 0),
|
||||
"is_active": progress.get("is_active", False),
|
||||
"is_paused": progress.get("is_paused", False),
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/devices/{device_id}/stop-all", response_model=Dict[str, Any])
|
||||
async def stop_all_device_sessions(device_id: str) -> Dict[str, Any]:
|
||||
"""停止指定设备上的所有活跃会话"""
|
||||
if device_id not in _device_sessions:
|
||||
return {"device_id": device_id, "stopped": 0, "message": "No sessions on this device"}
|
||||
|
||||
count = 0
|
||||
session_ids = list(_device_sessions[device_id])
|
||||
for session_id in session_ids:
|
||||
if session_id in _active_sessions:
|
||||
agent = _active_sessions[session_id]
|
||||
# 只停止活跃的会话
|
||||
if agent.current_session and agent.current_session.is_active:
|
||||
agent.stop_session()
|
||||
count += 1
|
||||
|
||||
return {
|
||||
"device_id": device_id,
|
||||
"stopped": count,
|
||||
"message": f"Stopped {count} active session(s) on device {device_id}"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/devices/sessions", response_model=Dict[str, Any])
|
||||
async def get_all_device_sessions() -> Dict[str, Any]:
|
||||
"""获取所有设备上的会话概览"""
|
||||
result = {
|
||||
"devices": {},
|
||||
"total_devices": len(_device_sessions),
|
||||
"total_sessions": len(_active_sessions)
|
||||
}
|
||||
|
||||
for device_id, session_ids in _device_sessions.items():
|
||||
active_count = 0
|
||||
for sid in session_ids:
|
||||
if sid in _active_sessions:
|
||||
agent = _active_sessions[sid]
|
||||
if agent.current_session and agent.current_session.is_active:
|
||||
active_count += 1
|
||||
|
||||
result["devices"][device_id] = {
|
||||
"total_sessions": len(session_ids),
|
||||
"active_sessions": active_count,
|
||||
"session_ids": list(session_ids)
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user