Add multi-device support for Video Learning Agent
- Add device-session bidirectional mapping (_device_sessions)
- Integrate DeviceManager lock mechanism (acquire/release)
- Add device-level API endpoints:
- GET /devices/{device_id}/sessions - List sessions on device
- POST /devices/{device_id}/stop-all - Stop all sessions on device
- GET /devices/sessions - Overview of all devices
- Update session cleanup to maintain mapping consistency
- Prevent concurrent sessions on same device
Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -64,7 +64,11 @@ class VideoInfo(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
# Global session storage (in production, use database)
|
# Global session storage (in production, use database)
|
||||||
|
# 会话存储:session_id -> VideoLearningAgent
|
||||||
_active_sessions: Dict[str, VideoLearningAgent] = {}
|
_active_sessions: Dict[str, VideoLearningAgent] = {}
|
||||||
|
# 设备-会话映射:device_id -> {session_ids}
|
||||||
|
# 支持多设备同时运行,每个设备可以有多个会话(但同一时间只能有一个活跃的)
|
||||||
|
_device_sessions: Dict[str, set[str]] = {}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions", response_model=Dict[str, str])
|
@router.post("/sessions", response_model=Dict[str, str])
|
||||||
@@ -81,67 +85,86 @@ async def create_session(
|
|||||||
if not device.is_connected:
|
if not device.is_connected:
|
||||||
raise HTTPException(status_code=400, detail="Device not connected")
|
raise HTTPException(status_code=400, detail="Device not connected")
|
||||||
|
|
||||||
if device.status == "busy":
|
# 使用设备锁机制确保同一设备不会同时运行多个会话
|
||||||
raise HTTPException(status_code=409, detail="Device is busy")
|
if not device_manager.acquire_device(request.device_id):
|
||||||
|
raise HTTPException(status_code=409, detail="Device is busy - another session is running on this device")
|
||||||
|
|
||||||
# Create model config from environment
|
try:
|
||||||
model_config = ModelConfig(
|
# Create model config from environment
|
||||||
base_url=config.MODEL_BASE_URL,
|
model_config = ModelConfig(
|
||||||
model_name=config.MODEL_NAME,
|
base_url=config.MODEL_BASE_URL,
|
||||||
api_key=config.MODEL_API_KEY,
|
model_name=config.MODEL_NAME,
|
||||||
max_tokens=config.MAX_TOKENS,
|
api_key=config.MODEL_API_KEY,
|
||||||
temperature=config.TEMPERATURE,
|
max_tokens=config.MAX_TOKENS,
|
||||||
top_p=config.TOP_P,
|
temperature=config.TEMPERATURE,
|
||||||
frequency_penalty=config.FREQUENCY_PENALTY,
|
top_p=config.TOP_P,
|
||||||
lang="cn",
|
frequency_penalty=config.FREQUENCY_PENALTY,
|
||||||
)
|
lang="cn",
|
||||||
|
)
|
||||||
|
|
||||||
# Create video learning agent
|
# Create video learning agent
|
||||||
agent = VideoLearningAgent(
|
agent = VideoLearningAgent(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
platform=request.platform,
|
platform=request.platform,
|
||||||
output_dir=config.VIDEO_LEARNING_OUTPUT_DIR,
|
output_dir=config.VIDEO_LEARNING_OUTPUT_DIR,
|
||||||
enable_analysis=request.enable_analysis,
|
enable_analysis=request.enable_analysis,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup callbacks for real-time updates
|
# Setup callbacks for real-time updates
|
||||||
session_id = None
|
session_id = None
|
||||||
|
|
||||||
def on_video_watched(record):
|
def on_video_watched(record):
|
||||||
"""Callback when a video is watched."""
|
"""Callback when a video is watched."""
|
||||||
# Broadcast via WebSocket
|
if session_id:
|
||||||
if session_id:
|
pass # WebSocket broadcast would go here
|
||||||
# This would be integrated with WebSocket manager
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_progress_update(current, total):
|
def on_progress_update(current, total):
|
||||||
"""Callback for progress updates."""
|
"""Callback for progress updates."""
|
||||||
if session_id:
|
if session_id:
|
||||||
# Broadcast progress
|
pass # Progress broadcast would go here
|
||||||
pass
|
|
||||||
|
|
||||||
def on_session_complete(session):
|
def on_session_complete(session):
|
||||||
"""Callback when session completes."""
|
"""Callback when session completes - release device lock and clean up."""
|
||||||
if session_id and session_id in _active_sessions:
|
if session_id:
|
||||||
del _active_sessions[session_id]
|
# 从设备-会话映射中移除
|
||||||
|
if request.device_id in _device_sessions:
|
||||||
|
_device_sessions[request.device_id].discard(session_id)
|
||||||
|
if not _device_sessions[request.device_id]:
|
||||||
|
del _device_sessions[request.device_id]
|
||||||
|
|
||||||
agent.on_video_watched = on_video_watched
|
# 从活跃会话中移除(但保留历史记录可查询)
|
||||||
agent.on_progress_update = on_progress_update
|
if session_id in _active_sessions:
|
||||||
agent.on_session_complete = on_session_complete
|
# 标记为完成但不删除,允许用户查看结果
|
||||||
|
pass
|
||||||
|
|
||||||
# Start session
|
# 释放设备锁
|
||||||
session_id = agent.start_session(
|
device_manager.release_device(request.device_id)
|
||||||
device_id=request.device_id,
|
|
||||||
target_count=request.target_count,
|
|
||||||
category=request.category,
|
|
||||||
watch_duration=request.watch_duration,
|
|
||||||
max_steps=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store session
|
agent.on_video_watched = on_video_watched
|
||||||
_active_sessions[session_id] = agent
|
agent.on_progress_update = on_progress_update
|
||||||
|
agent.on_session_complete = on_session_complete
|
||||||
|
|
||||||
return {"session_id": session_id, "status": "created"}
|
# Start session
|
||||||
|
session_id = agent.start_session(
|
||||||
|
device_id=request.device_id,
|
||||||
|
target_count=request.target_count,
|
||||||
|
category=request.category,
|
||||||
|
watch_duration=request.watch_duration,
|
||||||
|
max_steps=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store session in both mappings
|
||||||
|
_active_sessions[session_id] = agent
|
||||||
|
if request.device_id not in _device_sessions:
|
||||||
|
_device_sessions[request.device_id] = set()
|
||||||
|
_device_sessions[request.device_id].add(session_id)
|
||||||
|
|
||||||
|
return {"session_id": session_id, "status": "created"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 创建失败时释放设备锁
|
||||||
|
device_manager.release_device(request.device_id)
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to create session: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sessions/{session_id}/start", response_model=Dict[str, str])
|
@router.post("/sessions/{session_id}/start", response_model=Dict[str, str])
|
||||||
@@ -330,11 +353,26 @@ async def list_sessions() -> List[str]:
|
|||||||
|
|
||||||
@router.delete("/sessions/{session_id}", response_model=Dict[str, str])
|
@router.delete("/sessions/{session_id}", response_model=Dict[str, str])
|
||||||
async def delete_session(session_id: str) -> Dict[str, str]:
|
async def delete_session(session_id: str) -> Dict[str, str]:
|
||||||
"""Delete a session."""
|
"""Delete a session and clean up device mapping."""
|
||||||
if session_id not in _active_sessions:
|
if session_id not in _active_sessions:
|
||||||
raise HTTPException(status_code=404, detail="Session not found")
|
raise HTTPException(status_code=404, detail="Session not found")
|
||||||
|
|
||||||
|
# 获取设备ID以便清理映射
|
||||||
|
agent = _active_sessions[session_id]
|
||||||
|
device_id = None
|
||||||
|
if agent.current_session:
|
||||||
|
# 从 session 中获取 device_id
|
||||||
|
device_id = agent._device_id
|
||||||
|
|
||||||
|
# 删除会话
|
||||||
del _active_sessions[session_id]
|
del _active_sessions[session_id]
|
||||||
|
|
||||||
|
# 清理设备-会话映射
|
||||||
|
if device_id and device_id in _device_sessions:
|
||||||
|
_device_sessions[device_id].discard(session_id)
|
||||||
|
if not _device_sessions[device_id]:
|
||||||
|
del _device_sessions[device_id]
|
||||||
|
|
||||||
return {"session_id": session_id, "status": "deleted"}
|
return {"session_id": session_id, "status": "deleted"}
|
||||||
|
|
||||||
|
|
||||||
@@ -361,3 +399,78 @@ async def analyze_session(session_id: str) -> Dict[str, Any]:
|
|||||||
"analyzed_count": analyzed_count,
|
"analyzed_count": analyzed_count,
|
||||||
"status": "analysis_triggered"
|
"status": "analysis_triggered"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 设备级别的 API ====================
|
||||||
|
|
||||||
|
@router.get("/devices/{device_id}/sessions")
|
||||||
|
async def get_device_sessions(device_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""获取指定设备上的所有会话"""
|
||||||
|
if device_id not in _device_sessions:
|
||||||
|
return []
|
||||||
|
|
||||||
|
session_ids = _device_sessions[device_id]
|
||||||
|
result = []
|
||||||
|
for sid in session_ids:
|
||||||
|
if sid in _active_sessions:
|
||||||
|
agent = _active_sessions[sid]
|
||||||
|
progress = agent.get_session_progress()
|
||||||
|
result.append({
|
||||||
|
"session_id": sid,
|
||||||
|
"platform": progress.get("platform", ""),
|
||||||
|
"target_count": progress.get("target_count", 0),
|
||||||
|
"watched_count": progress.get("watched_count", 0),
|
||||||
|
"progress_percent": progress.get("progress_percent", 0),
|
||||||
|
"is_active": progress.get("is_active", False),
|
||||||
|
"is_paused": progress.get("is_paused", False),
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/devices/{device_id}/stop-all", response_model=Dict[str, Any])
|
||||||
|
async def stop_all_device_sessions(device_id: str) -> Dict[str, Any]:
|
||||||
|
"""停止指定设备上的所有活跃会话"""
|
||||||
|
if device_id not in _device_sessions:
|
||||||
|
return {"device_id": device_id, "stopped": 0, "message": "No sessions on this device"}
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
session_ids = list(_device_sessions[device_id])
|
||||||
|
for session_id in session_ids:
|
||||||
|
if session_id in _active_sessions:
|
||||||
|
agent = _active_sessions[session_id]
|
||||||
|
# 只停止活跃的会话
|
||||||
|
if agent.current_session and agent.current_session.is_active:
|
||||||
|
agent.stop_session()
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"device_id": device_id,
|
||||||
|
"stopped": count,
|
||||||
|
"message": f"Stopped {count} active session(s) on device {device_id}"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/devices/sessions", response_model=Dict[str, Any])
|
||||||
|
async def get_all_device_sessions() -> Dict[str, Any]:
|
||||||
|
"""获取所有设备上的会话概览"""
|
||||||
|
result = {
|
||||||
|
"devices": {},
|
||||||
|
"total_devices": len(_device_sessions),
|
||||||
|
"total_sessions": len(_active_sessions)
|
||||||
|
}
|
||||||
|
|
||||||
|
for device_id, session_ids in _device_sessions.items():
|
||||||
|
active_count = 0
|
||||||
|
for sid in session_ids:
|
||||||
|
if sid in _active_sessions:
|
||||||
|
agent = _active_sessions[sid]
|
||||||
|
if agent.current_session and agent.current_session.is_active:
|
||||||
|
active_count += 1
|
||||||
|
|
||||||
|
result["devices"][device_id] = {
|
||||||
|
"total_sessions": len(session_ids),
|
||||||
|
"active_sessions": active_count,
|
||||||
|
"session_ids": list(session_ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|||||||
Reference in New Issue
Block a user