""" Video Learning API endpoints for the dashboard. """ import asyncio from datetime import datetime from typing import Any, Dict, List, Optional from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, Field from dashboard.config import config from dashboard.dependencies import get_device_manager from dashboard.services.device_manager import DeviceManager from phone_agent import VideoLearningAgent from phone_agent.model.client import ModelConfig router = APIRouter(prefix="/api/video-learning", tags=["video-learning"]) class SessionCreateRequest(BaseModel): """Request to create a new learning session.""" device_id: str = Field(..., description="Target device ID") platform: str = Field("douyin", description="Platform name (douyin, kuaishou, tiktok)") target_count: int = Field(10, description="Number of videos to watch", ge=1, le=100) category: Optional[str] = Field(None, description="Target category filter") watch_duration: float = Field(3.0, description="Watch duration per video (seconds)", ge=1.0, le=30.0) enable_analysis: bool = Field(True, description="Enable VLM screenshot analysis") class SessionControlRequest(BaseModel): """Request to control a session.""" action: str = Field(..., description="Action: pause, resume, stop") class SessionStatus(BaseModel): """Session status response.""" session_id: str platform: str target_count: int watched_count: int progress_percent: float is_active: bool is_paused: bool total_duration: float current_video: Optional[Dict] = None class VideoInfo(BaseModel): """Information about a watched video.""" sequence_id: int timestamp: str screenshot_path: Optional[str] = None watch_duration: float description: Optional[str] = None likes: Optional[int] = None comments: Optional[int] = None tags: List[str] = [] category: Optional[str] = None # Global session storage (in production, use database) # 会话存储:session_id -> VideoLearningAgent _active_sessions: Dict[str, VideoLearningAgent] = {} # 设备-会话映射:device_id -> {session_ids} # 支持多设备同时运行,每个设备可以有多个会话(但同一时间只能有一个活跃的) _device_sessions: Dict[str, set[str]] = {} @router.post("/sessions", response_model=Dict[str, str]) async def create_session( request: SessionCreateRequest, device_manager: DeviceManager = Depends(get_device_manager), ) -> Dict[str, str]: """Create a new video learning session.""" # Check device availability device = await device_manager.get_device(request.device_id) if not device: raise HTTPException(status_code=404, detail="Device not found") if not device.is_connected: raise HTTPException(status_code=400, detail="Device not connected") # 使用设备锁机制确保同一设备不会同时运行多个会话 if not device_manager.acquire_device(request.device_id): raise HTTPException(status_code=409, detail="Device is busy - another session is running on this device") try: # Create model config from environment model_config = ModelConfig( base_url=config.MODEL_BASE_URL, model_name=config.MODEL_NAME, api_key=config.MODEL_API_KEY, max_tokens=config.MAX_TOKENS, temperature=config.TEMPERATURE, top_p=config.TOP_P, frequency_penalty=config.FREQUENCY_PENALTY, lang="cn", ) # Create video learning agent agent = VideoLearningAgent( model_config=model_config, platform=request.platform, output_dir=config.VIDEO_LEARNING_OUTPUT_DIR, enable_analysis=request.enable_analysis, ) # Setup callbacks for real-time updates session_id = None def on_video_watched(record): """Callback when a video is watched.""" if session_id: pass # WebSocket broadcast would go here def on_progress_update(current, total): """Callback for progress updates.""" if session_id: pass # Progress broadcast would go here def on_session_complete(session): """Callback when session completes - release device lock and clean up.""" if session_id: # 从设备-会话映射中移除 if request.device_id in _device_sessions: _device_sessions[request.device_id].discard(session_id) if not _device_sessions[request.device_id]: del _device_sessions[request.device_id] # 从活跃会话中移除(但保留历史记录可查询) if session_id in _active_sessions: # 标记为完成但不删除,允许用户查看结果 pass # 释放设备锁 device_manager.release_device(request.device_id) agent.on_video_watched = on_video_watched agent.on_progress_update = on_progress_update agent.on_session_complete = on_session_complete # Start session session_id = agent.start_session( device_id=request.device_id, target_count=request.target_count, category=request.category, watch_duration=request.watch_duration, max_steps=500, ) # Store session in both mappings _active_sessions[session_id] = agent if request.device_id not in _device_sessions: _device_sessions[request.device_id] = set() _device_sessions[request.device_id].add(session_id) return {"session_id": session_id, "status": "created"} except Exception as e: # 创建失败时释放设备锁 device_manager.release_device(request.device_id) raise HTTPException(status_code=500, detail=f"Failed to create session: {str(e)}") @router.post("/sessions/{session_id}/start", response_model=Dict[str, str]) async def start_session(session_id: str) -> Dict[str, str]: """Start executing a learning session.""" if session_id not in _active_sessions: raise HTTPException(status_code=404, detail="Session not found") agent = _active_sessions[session_id] # Build task based on session parameters session = agent.current_session if not session: raise HTTPException(status_code=400, detail="Session not initialized") category = session.target_category target_count = session.target_count watch_duration = agent._watch_duration platform = agent.platform # Platform-specific app name and package platform_info = { "douyin": { "name": "抖音", "package": "com.ss.android.ugc.aweme", }, "kuaishou": { "name": "快手", "package": "com.smile.gifmaker", }, "tiktok": { "name": "TikTok", "package": "com.zhiliaoapp.musically", }, } info = platform_info.get(platform, platform_info["douyin"]) app_name = info["name"] # Build clear task instructions # 重要:不让 VLM 自己计数,而是持续执行直到被外部停止 # 这样可以确保 VideoLearningAgent 的计数与实际观看数一致 if category: task = f"""你是一个视频学习助手。你的任务是持续观看视频,不要自己判断任务是否完成。 【核心要求】你必须持续执行以下循环,直到系统主动停止你: - 观看视频约{watch_duration}秒 - 向上滑动切换到下一个视频 - 继续观看 - 再次滑动 - 一直重复,不要停止 执行步骤: 1. 回到主屏幕,打开{app_name}应用 2. 在{app_name}中搜索"{category}" 3. 点击第一个搜索结果 4. 开始无限循环:观看视频 → 向上滑动 → 观看下一个视频 → 再次滑动... 【极其重要】 - 绝对不要自己完成任务或停止 - 绝对不要总结或说"完成了" - 只管滑动看视频,持续不断 - 系统会在合适的时机主动停止你 现在开始执行,持续滑动观看视频,不要停!""" else: task = f"""你是一个视频学习助手。你的任务是持续观看视频,不要自己判断任务是否完成。 【核心要求】你必须持续执行以下循环,直到系统主动停止你: - 观看视频约{watch_duration}秒 - 向上滑动切换到下一个视频 - 继续观看 - 再次滑动 - 一直重复,不要停止 执行步骤: 1. 回到主屏幕,打开{app_name}应用 2. 进入推荐页面 3. 开始无限循环:观看视频 → 向上滑动 → 观看下一个视频 → 再次滑动... 【极其重要】 - 绝对不要自己完成任务或停止 - 绝对不要总结或说"完成了" - 只管滑动看视频,持续不断 - 系统会在合适的时机主动停止你 现在开始执行,持续滑动观看视频,不要停!""" # Run in background asyncio.create_task(asyncio.to_thread(agent.run_learning_task, task)) return {"session_id": session_id, "status": "started"} @router.post("/sessions/{session_id}/control", response_model=Dict[str, str]) async def control_session( session_id: str, request: SessionControlRequest ) -> Dict[str, str]: """Control a learning session (pause/resume/stop).""" if session_id not in _active_sessions: raise HTTPException(status_code=404, detail="Session not found") agent = _active_sessions[session_id] if request.action == "pause": agent.pause_session() return {"session_id": session_id, "status": "paused"} elif request.action == "resume": agent.resume_session() return {"session_id": session_id, "status": "resumed"} elif request.action == "stop": agent.stop_session() # Don't delete immediately - let status queries still work # Session will be cleaned up when is_active becomes False return {"session_id": session_id, "status": "stopped"} else: raise HTTPException(status_code=400, detail=f"Invalid action: {request.action}") @router.get("/sessions/{session_id}/status", response_model=SessionStatus) async def get_session_status(session_id: str) -> SessionStatus: """Get session status.""" if session_id not in _active_sessions: raise HTTPException(status_code=404, detail="Session not found") agent = _active_sessions[session_id] progress = agent.get_session_progress() # Get current video info if available current_video = None if agent.current_session and agent.current_session.records: latest = agent.current_session.records[-1] current_video = { "sequence_id": latest.sequence_id, "timestamp": latest.timestamp, "screenshot_path": latest.screenshot_path, "description": latest.description, "likes": latest.likes, "comments": latest.comments, } return SessionStatus( session_id=progress["session_id"], platform=progress["platform"], target_count=progress["target_count"], watched_count=progress["watched_count"], progress_percent=progress["progress_percent"], is_active=progress["is_active"], is_paused=progress["is_paused"], total_duration=progress["total_duration"], current_video=current_video, ) @router.get("/sessions/{session_id}/videos", response_model=List[VideoInfo]) async def get_session_videos(session_id: str) -> List[VideoInfo]: """Get all videos from a session.""" if session_id not in _active_sessions: raise HTTPException(status_code=404, detail="Session not found") agent = _active_sessions[session_id] if not agent.current_session: return [] return [ VideoInfo( sequence_id=r.sequence_id, timestamp=r.timestamp, screenshot_path=r.screenshot_path, watch_duration=r.watch_duration, description=r.description, likes=r.likes, comments=r.comments, tags=r.tags, category=r.category, ) for r in agent.current_session.records ] @router.get("/sessions", response_model=List[str]) async def list_sessions() -> List[str]: """List all active session IDs.""" return list(_active_sessions.keys()) @router.get("/sessions/list") async def list_all_sessions() -> List[Dict[str, Any]]: """获取所有会话的详细信息(包含进度、状态等)""" result = [] for session_id, agent in _active_sessions.items(): try: progress = agent.get_session_progress() result.append({ "session_id": session_id, "platform": progress.get("platform", ""), "target_count": progress.get("target_count", 0), "watched_count": progress.get("watched_count", 0), "progress_percent": progress.get("progress_percent", 0), "is_active": progress.get("is_active", False), "is_paused": progress.get("is_paused", False), "total_duration": progress.get("total_duration", 0), }) except Exception as e: # 跳过出错的会话 continue return result @router.delete("/sessions/{session_id}", response_model=Dict[str, str]) async def delete_session(session_id: str) -> Dict[str, str]: """Delete a session and clean up device mapping.""" if session_id not in _active_sessions: raise HTTPException(status_code=404, detail="Session not found") # 获取设备ID以便清理映射 agent = _active_sessions[session_id] device_id = None if agent.current_session: # 从 session 中获取 device_id device_id = agent._device_id # 删除会话 del _active_sessions[session_id] # 清理设备-会话映射 if device_id and device_id in _device_sessions: _device_sessions[device_id].discard(session_id) if not _device_sessions[device_id]: del _device_sessions[device_id] return {"session_id": session_id, "status": "deleted"} @router.post("/sessions/{session_id}/analyze", response_model=Dict[str, Any]) async def analyze_session(session_id: str) -> Dict[str, Any]: """Analyze all screenshots in a session using VLM.""" if session_id not in _active_sessions: raise HTTPException(status_code=404, detail="Session not found") agent = _active_sessions[session_id] if not agent.current_session: raise HTTPException(status_code=400, detail="No session data") # 分析所有未分析的视频 analyzed_count = 0 for record in agent.current_session.records: if record.likes is None and record.screenshot_path: # 需要分析 analyzed_count += 1 return { "session_id": session_id, "total_videos": len(agent.current_session.records), "analyzed_count": analyzed_count, "status": "analysis_triggered" } # ==================== 设备级别的 API ==================== @router.get("/devices/{device_id}/sessions") async def get_device_sessions(device_id: str) -> List[Dict[str, Any]]: """获取指定设备上的所有会话""" if device_id not in _device_sessions: return [] session_ids = _device_sessions[device_id] result = [] for sid in session_ids: if sid in _active_sessions: agent = _active_sessions[sid] progress = agent.get_session_progress() result.append({ "session_id": sid, "platform": progress.get("platform", ""), "target_count": progress.get("target_count", 0), "watched_count": progress.get("watched_count", 0), "progress_percent": progress.get("progress_percent", 0), "is_active": progress.get("is_active", False), "is_paused": progress.get("is_paused", False), }) return result @router.post("/devices/{device_id}/stop-all", response_model=Dict[str, Any]) async def stop_all_device_sessions(device_id: str) -> Dict[str, Any]: """停止指定设备上的所有活跃会话""" if device_id not in _device_sessions: return {"device_id": device_id, "stopped": 0, "message": "No sessions on this device"} count = 0 session_ids = list(_device_sessions[device_id]) for session_id in session_ids: if session_id in _active_sessions: agent = _active_sessions[session_id] # 只停止活跃的会话 if agent.current_session and agent.current_session.is_active: agent.stop_session() count += 1 return { "device_id": device_id, "stopped": count, "message": f"Stopped {count} active session(s) on device {device_id}" } @router.get("/devices/sessions", response_model=Dict[str, Any]) async def get_all_device_sessions() -> Dict[str, Any]: """获取所有设备上的会话概览""" result = { "devices": {}, "total_devices": len(_device_sessions), "total_sessions": len(_active_sessions) } for device_id, session_ids in _device_sessions.items(): active_count = 0 for sid in session_ids: if sid in _active_sessions: agent = _active_sessions[sid] if agent.current_session and agent.current_session.is_active: active_count += 1 result["devices"][device_id] = { "total_sessions": len(session_ids), "active_sessions": active_count, "session_ids": list(session_ids) } return result