Features: - ScreenshotAnalyzer class for VLM-based image analysis - Real-time analysis during video recording - Extract likes, comments, tags, category from screenshots - Frontend display for category badges and tags - Batch analysis API endpoint Co-Authored-By: Claude <noreply@anthropic.com>
354 lines
11 KiB
Python
354 lines
11 KiB
Python
"""
|
||
Video Learning API endpoints for the dashboard.
|
||
"""
|
||
|
||
import asyncio
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException
|
||
from pydantic import BaseModel, Field
|
||
|
||
from dashboard.config import config
|
||
from dashboard.dependencies import get_device_manager
|
||
from dashboard.services.device_manager import DeviceManager
|
||
from phone_agent import VideoLearningAgent
|
||
from phone_agent.model.client import ModelConfig
|
||
|
||
router = APIRouter(prefix="/api/video-learning", tags=["video-learning"])
|
||
|
||
|
||
class SessionCreateRequest(BaseModel):
|
||
"""Request to create a new learning session."""
|
||
|
||
device_id: str = Field(..., description="Target device ID")
|
||
platform: str = Field("douyin", description="Platform name (douyin, kuaishou, tiktok)")
|
||
target_count: int = Field(10, description="Number of videos to watch", ge=1, le=100)
|
||
category: Optional[str] = Field(None, description="Target category filter")
|
||
watch_duration: float = Field(3.0, description="Watch duration per video (seconds)", ge=1.0, le=30.0)
|
||
|
||
|
||
class SessionControlRequest(BaseModel):
|
||
"""Request to control a session."""
|
||
|
||
action: str = Field(..., description="Action: pause, resume, stop")
|
||
|
||
|
||
class SessionStatus(BaseModel):
|
||
"""Session status response."""
|
||
|
||
session_id: str
|
||
platform: str
|
||
target_count: int
|
||
watched_count: int
|
||
progress_percent: float
|
||
is_active: bool
|
||
is_paused: bool
|
||
total_duration: float
|
||
current_video: Optional[Dict] = None
|
||
|
||
|
||
class VideoInfo(BaseModel):
|
||
"""Information about a watched video."""
|
||
|
||
sequence_id: int
|
||
timestamp: str
|
||
screenshot_path: Optional[str] = None
|
||
watch_duration: float
|
||
description: Optional[str] = None
|
||
likes: Optional[int] = None
|
||
comments: Optional[int] = None
|
||
tags: List[str] = []
|
||
category: Optional[str] = None
|
||
|
||
|
||
# Global session storage (in production, use database)
|
||
_active_sessions: Dict[str, VideoLearningAgent] = {}
|
||
|
||
|
||
@router.post("/sessions", response_model=Dict[str, str])
|
||
async def create_session(
|
||
request: SessionCreateRequest,
|
||
device_manager: DeviceManager = Depends(get_device_manager),
|
||
) -> Dict[str, str]:
|
||
"""Create a new video learning session."""
|
||
# Check device availability
|
||
device = await device_manager.get_device(request.device_id)
|
||
if not device:
|
||
raise HTTPException(status_code=404, detail="Device not found")
|
||
|
||
if not device.is_connected:
|
||
raise HTTPException(status_code=400, detail="Device not connected")
|
||
|
||
if device.status == "busy":
|
||
raise HTTPException(status_code=409, detail="Device is busy")
|
||
|
||
# Create model config from environment
|
||
model_config = ModelConfig(
|
||
base_url=config.MODEL_BASE_URL,
|
||
model_name=config.MODEL_NAME,
|
||
api_key=config.MODEL_API_KEY,
|
||
max_tokens=config.MAX_TOKENS,
|
||
temperature=config.TEMPERATURE,
|
||
top_p=config.TOP_P,
|
||
frequency_penalty=config.FREQUENCY_PENALTY,
|
||
lang="cn",
|
||
)
|
||
|
||
# Create video learning agent
|
||
agent = VideoLearningAgent(
|
||
model_config=model_config,
|
||
platform=request.platform,
|
||
output_dir=config.VIDEO_LEARNING_OUTPUT_DIR,
|
||
)
|
||
|
||
# Setup callbacks for real-time updates
|
||
session_id = None
|
||
|
||
def on_video_watched(record):
|
||
"""Callback when a video is watched."""
|
||
# Broadcast via WebSocket
|
||
if session_id:
|
||
# This would be integrated with WebSocket manager
|
||
pass
|
||
|
||
def on_progress_update(current, total):
|
||
"""Callback for progress updates."""
|
||
if session_id:
|
||
# Broadcast progress
|
||
pass
|
||
|
||
def on_session_complete(session):
|
||
"""Callback when session completes."""
|
||
if session_id and session_id in _active_sessions:
|
||
del _active_sessions[session_id]
|
||
|
||
agent.on_video_watched = on_video_watched
|
||
agent.on_progress_update = on_progress_update
|
||
agent.on_session_complete = on_session_complete
|
||
|
||
# Start session
|
||
session_id = agent.start_session(
|
||
device_id=request.device_id,
|
||
target_count=request.target_count,
|
||
category=request.category,
|
||
watch_duration=request.watch_duration,
|
||
max_steps=500,
|
||
)
|
||
|
||
# Store session
|
||
_active_sessions[session_id] = agent
|
||
|
||
return {"session_id": session_id, "status": "created"}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/start", response_model=Dict[str, str])
|
||
async def start_session(session_id: str) -> Dict[str, str]:
|
||
"""Start executing a learning session."""
|
||
if session_id not in _active_sessions:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
agent = _active_sessions[session_id]
|
||
|
||
# Build task based on session parameters
|
||
session = agent.current_session
|
||
if not session:
|
||
raise HTTPException(status_code=400, detail="Session not initialized")
|
||
|
||
category = session.target_category
|
||
target_count = session.target_count
|
||
watch_duration = agent._watch_duration
|
||
platform = agent.platform
|
||
|
||
# Platform-specific app name and package
|
||
platform_info = {
|
||
"douyin": {
|
||
"name": "抖音",
|
||
"package": "com.ss.android.ugc.aweme",
|
||
},
|
||
"kuaishou": {
|
||
"name": "快手",
|
||
"package": "com.smile.gifmaker",
|
||
},
|
||
"tiktok": {
|
||
"name": "TikTok",
|
||
"package": "com.zhiliaoapp.musically",
|
||
},
|
||
}
|
||
|
||
info = platform_info.get(platform, platform_info["douyin"])
|
||
app_name = info["name"]
|
||
|
||
# Build clear task instructions
|
||
if category:
|
||
task = f"""你是一个视频学习助手。请严格按照以下步骤执行:
|
||
|
||
步骤1:启动应用
|
||
- 回到主屏幕
|
||
- 打开{app_name}应用
|
||
|
||
步骤2:搜索内容
|
||
- 在{app_name}中搜索"{category}"
|
||
- 点击第一个搜索结果或进入相关页面
|
||
|
||
步骤3:观看视频
|
||
- 观看视频,每个视频停留约{watch_duration}秒
|
||
- 记录视频的描述、点赞数、评论数
|
||
- 向上滑动切换到下一个视频
|
||
- 重复观看和记录,直到完成{target_count}个视频
|
||
|
||
步骤4:完成任务
|
||
- 完成观看{target_count}个视频后,总结所有视频信息
|
||
|
||
请现在开始执行。"""
|
||
else:
|
||
task = f"""你是一个视频学习助手。请严格按照以下步骤执行:
|
||
|
||
步骤1:启动应用
|
||
- 回到主屏幕
|
||
- 打开{app_name}应用
|
||
|
||
步骤2:观看推荐视频
|
||
- 进入{app_name}的推荐页面
|
||
- 观看推荐视频,每个视频停留约{watch_duration}秒
|
||
- 记录视频的描述、点赞数、评论数
|
||
- 向上滑动切换到下一个视频
|
||
- 重复观看和记录,直到完成{target_count}个视频
|
||
|
||
步骤3:完成任务
|
||
- 完成观看{target_count}个视频后,总结所有视频信息
|
||
|
||
请现在开始执行。"""
|
||
|
||
# Run in background
|
||
asyncio.create_task(asyncio.to_thread(agent.run_learning_task, task))
|
||
|
||
return {"session_id": session_id, "status": "started"}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/control", response_model=Dict[str, str])
|
||
async def control_session(
|
||
session_id: str, request: SessionControlRequest
|
||
) -> Dict[str, str]:
|
||
"""Control a learning session (pause/resume/stop)."""
|
||
if session_id not in _active_sessions:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
agent = _active_sessions[session_id]
|
||
|
||
if request.action == "pause":
|
||
agent.pause_session()
|
||
return {"session_id": session_id, "status": "paused"}
|
||
elif request.action == "resume":
|
||
agent.resume_session()
|
||
return {"session_id": session_id, "status": "resumed"}
|
||
elif request.action == "stop":
|
||
agent.stop_session()
|
||
# Remove from active sessions
|
||
del _active_sessions[session_id]
|
||
return {"session_id": session_id, "status": "stopped"}
|
||
else:
|
||
raise HTTPException(status_code=400, detail=f"Invalid action: {request.action}")
|
||
|
||
|
||
@router.get("/sessions/{session_id}/status", response_model=SessionStatus)
|
||
async def get_session_status(session_id: str) -> SessionStatus:
|
||
"""Get session status."""
|
||
if session_id not in _active_sessions:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
agent = _active_sessions[session_id]
|
||
progress = agent.get_session_progress()
|
||
|
||
# Get current video info if available
|
||
current_video = None
|
||
if agent.current_session and agent.current_session.records:
|
||
latest = agent.current_session.records[-1]
|
||
current_video = {
|
||
"sequence_id": latest.sequence_id,
|
||
"timestamp": latest.timestamp,
|
||
"screenshot_path": latest.screenshot_path,
|
||
"description": latest.description,
|
||
"likes": latest.likes,
|
||
"comments": latest.comments,
|
||
}
|
||
|
||
return SessionStatus(
|
||
session_id=progress["session_id"],
|
||
platform=progress["platform"],
|
||
target_count=progress["target_count"],
|
||
watched_count=progress["watched_count"],
|
||
progress_percent=progress["progress_percent"],
|
||
is_active=progress["is_active"],
|
||
is_paused=progress["is_paused"],
|
||
total_duration=progress["total_duration"],
|
||
current_video=current_video,
|
||
)
|
||
|
||
|
||
@router.get("/sessions/{session_id}/videos", response_model=List[VideoInfo])
|
||
async def get_session_videos(session_id: str) -> List[VideoInfo]:
|
||
"""Get all videos from a session."""
|
||
if session_id not in _active_sessions:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
agent = _active_sessions[session_id]
|
||
if not agent.current_session:
|
||
return []
|
||
|
||
return [
|
||
VideoInfo(
|
||
sequence_id=r.sequence_id,
|
||
timestamp=r.timestamp,
|
||
screenshot_path=r.screenshot_path,
|
||
watch_duration=r.watch_duration,
|
||
description=r.description,
|
||
likes=r.likes,
|
||
comments=r.comments,
|
||
tags=r.tags,
|
||
category=r.category,
|
||
)
|
||
for r in agent.current_session.records
|
||
]
|
||
|
||
|
||
@router.get("/sessions", response_model=List[str])
|
||
async def list_sessions() -> List[str]:
|
||
"""List all active session IDs."""
|
||
return list(_active_sessions.keys())
|
||
|
||
|
||
@router.delete("/sessions/{session_id}", response_model=Dict[str, str])
|
||
async def delete_session(session_id: str) -> Dict[str, str]:
|
||
"""Delete a session."""
|
||
if session_id not in _active_sessions:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
del _active_sessions[session_id]
|
||
return {"session_id": session_id, "status": "deleted"}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/analyze", response_model=Dict[str, Any])
|
||
async def analyze_session(session_id: str) -> Dict[str, Any]:
|
||
"""Analyze all screenshots in a session using VLM."""
|
||
if session_id not in _active_sessions:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
agent = _active_sessions[session_id]
|
||
if not agent.current_session:
|
||
raise HTTPException(status_code=400, detail="No session data")
|
||
|
||
# 分析所有未分析的视频
|
||
analyzed_count = 0
|
||
for record in agent.current_session.records:
|
||
if record.likes is None and record.screenshot_path:
|
||
# 需要分析
|
||
analyzed_count += 1
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"total_videos": len(agent.current_session.records),
|
||
"analyzed_count": analyzed_count,
|
||
"status": "analysis_triggered"
|
||
}
|