- Change video detection from screenshot hash to action-based (Swipe detection) - Add enable_analysis toggle to disable VLM screenshot analysis - Improve task prompt to prevent VLM from stopping prematurely - Add debug logging for action detection troubleshooting - Fix ModelResponse attribute error (content -> raw_content) Co-Authored-By: Claude <noreply@anthropic.com>
364 lines
12 KiB
Python
364 lines
12 KiB
Python
"""
|
|
Video Learning API endpoints for the dashboard.
|
|
"""
|
|
|
|
import asyncio
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
|
|
from dashboard.config import config
|
|
from dashboard.dependencies import get_device_manager
|
|
from dashboard.services.device_manager import DeviceManager
|
|
from phone_agent import VideoLearningAgent
|
|
from phone_agent.model.client import ModelConfig
|
|
|
|
router = APIRouter(prefix="/api/video-learning", tags=["video-learning"])
|
|
|
|
|
|
class SessionCreateRequest(BaseModel):
|
|
"""Request to create a new learning session."""
|
|
|
|
device_id: str = Field(..., description="Target device ID")
|
|
platform: str = Field("douyin", description="Platform name (douyin, kuaishou, tiktok)")
|
|
target_count: int = Field(10, description="Number of videos to watch", ge=1, le=100)
|
|
category: Optional[str] = Field(None, description="Target category filter")
|
|
watch_duration: float = Field(3.0, description="Watch duration per video (seconds)", ge=1.0, le=30.0)
|
|
enable_analysis: bool = Field(True, description="Enable VLM screenshot analysis")
|
|
|
|
|
|
class SessionControlRequest(BaseModel):
|
|
"""Request to control a session."""
|
|
|
|
action: str = Field(..., description="Action: pause, resume, stop")
|
|
|
|
|
|
class SessionStatus(BaseModel):
|
|
"""Session status response."""
|
|
|
|
session_id: str
|
|
platform: str
|
|
target_count: int
|
|
watched_count: int
|
|
progress_percent: float
|
|
is_active: bool
|
|
is_paused: bool
|
|
total_duration: float
|
|
current_video: Optional[Dict] = None
|
|
|
|
|
|
class VideoInfo(BaseModel):
|
|
"""Information about a watched video."""
|
|
|
|
sequence_id: int
|
|
timestamp: str
|
|
screenshot_path: Optional[str] = None
|
|
watch_duration: float
|
|
description: Optional[str] = None
|
|
likes: Optional[int] = None
|
|
comments: Optional[int] = None
|
|
tags: List[str] = []
|
|
category: Optional[str] = None
|
|
|
|
|
|
# Global session storage (in production, use database)
|
|
_active_sessions: Dict[str, VideoLearningAgent] = {}
|
|
|
|
|
|
@router.post("/sessions", response_model=Dict[str, str])
|
|
async def create_session(
|
|
request: SessionCreateRequest,
|
|
device_manager: DeviceManager = Depends(get_device_manager),
|
|
) -> Dict[str, str]:
|
|
"""Create a new video learning session."""
|
|
# Check device availability
|
|
device = await device_manager.get_device(request.device_id)
|
|
if not device:
|
|
raise HTTPException(status_code=404, detail="Device not found")
|
|
|
|
if not device.is_connected:
|
|
raise HTTPException(status_code=400, detail="Device not connected")
|
|
|
|
if device.status == "busy":
|
|
raise HTTPException(status_code=409, detail="Device is busy")
|
|
|
|
# Create model config from environment
|
|
model_config = ModelConfig(
|
|
base_url=config.MODEL_BASE_URL,
|
|
model_name=config.MODEL_NAME,
|
|
api_key=config.MODEL_API_KEY,
|
|
max_tokens=config.MAX_TOKENS,
|
|
temperature=config.TEMPERATURE,
|
|
top_p=config.TOP_P,
|
|
frequency_penalty=config.FREQUENCY_PENALTY,
|
|
lang="cn",
|
|
)
|
|
|
|
# Create video learning agent
|
|
agent = VideoLearningAgent(
|
|
model_config=model_config,
|
|
platform=request.platform,
|
|
output_dir=config.VIDEO_LEARNING_OUTPUT_DIR,
|
|
enable_analysis=request.enable_analysis,
|
|
)
|
|
|
|
# Setup callbacks for real-time updates
|
|
session_id = None
|
|
|
|
def on_video_watched(record):
|
|
"""Callback when a video is watched."""
|
|
# Broadcast via WebSocket
|
|
if session_id:
|
|
# This would be integrated with WebSocket manager
|
|
pass
|
|
|
|
def on_progress_update(current, total):
|
|
"""Callback for progress updates."""
|
|
if session_id:
|
|
# Broadcast progress
|
|
pass
|
|
|
|
def on_session_complete(session):
|
|
"""Callback when session completes."""
|
|
if session_id and session_id in _active_sessions:
|
|
del _active_sessions[session_id]
|
|
|
|
agent.on_video_watched = on_video_watched
|
|
agent.on_progress_update = on_progress_update
|
|
agent.on_session_complete = on_session_complete
|
|
|
|
# Start session
|
|
session_id = agent.start_session(
|
|
device_id=request.device_id,
|
|
target_count=request.target_count,
|
|
category=request.category,
|
|
watch_duration=request.watch_duration,
|
|
max_steps=500,
|
|
)
|
|
|
|
# Store session
|
|
_active_sessions[session_id] = agent
|
|
|
|
return {"session_id": session_id, "status": "created"}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/start", response_model=Dict[str, str])
|
|
async def start_session(session_id: str) -> Dict[str, str]:
|
|
"""Start executing a learning session."""
|
|
if session_id not in _active_sessions:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
agent = _active_sessions[session_id]
|
|
|
|
# Build task based on session parameters
|
|
session = agent.current_session
|
|
if not session:
|
|
raise HTTPException(status_code=400, detail="Session not initialized")
|
|
|
|
category = session.target_category
|
|
target_count = session.target_count
|
|
watch_duration = agent._watch_duration
|
|
platform = agent.platform
|
|
|
|
# Platform-specific app name and package
|
|
platform_info = {
|
|
"douyin": {
|
|
"name": "抖音",
|
|
"package": "com.ss.android.ugc.aweme",
|
|
},
|
|
"kuaishou": {
|
|
"name": "快手",
|
|
"package": "com.smile.gifmaker",
|
|
},
|
|
"tiktok": {
|
|
"name": "TikTok",
|
|
"package": "com.zhiliaoapp.musically",
|
|
},
|
|
}
|
|
|
|
info = platform_info.get(platform, platform_info["douyin"])
|
|
app_name = info["name"]
|
|
|
|
# Build clear task instructions
|
|
# 重要:不让 VLM 自己计数,而是持续执行直到被外部停止
|
|
# 这样可以确保 VideoLearningAgent 的计数与实际观看数一致
|
|
if category:
|
|
task = f"""你是一个视频学习助手。你的任务是持续观看视频,不要自己判断任务是否完成。
|
|
|
|
【核心要求】你必须持续执行以下循环,直到系统主动停止你:
|
|
- 观看视频约{watch_duration}秒
|
|
- 向上滑动切换到下一个视频
|
|
- 继续观看
|
|
- 再次滑动
|
|
- 一直重复,不要停止
|
|
|
|
执行步骤:
|
|
1. 回到主屏幕,打开{app_name}应用
|
|
2. 在{app_name}中搜索"{category}"
|
|
3. 点击第一个搜索结果
|
|
4. 开始无限循环:观看视频 → 向上滑动 → 观看下一个视频 → 再次滑动...
|
|
|
|
【极其重要】
|
|
- 绝对不要自己完成任务或停止
|
|
- 绝对不要总结或说"完成了"
|
|
- 只管滑动看视频,持续不断
|
|
- 系统会在合适的时机主动停止你
|
|
|
|
现在开始执行,持续滑动观看视频,不要停!"""
|
|
else:
|
|
task = f"""你是一个视频学习助手。你的任务是持续观看视频,不要自己判断任务是否完成。
|
|
|
|
【核心要求】你必须持续执行以下循环,直到系统主动停止你:
|
|
- 观看视频约{watch_duration}秒
|
|
- 向上滑动切换到下一个视频
|
|
- 继续观看
|
|
- 再次滑动
|
|
- 一直重复,不要停止
|
|
|
|
执行步骤:
|
|
1. 回到主屏幕,打开{app_name}应用
|
|
2. 进入推荐页面
|
|
3. 开始无限循环:观看视频 → 向上滑动 → 观看下一个视频 → 再次滑动...
|
|
|
|
【极其重要】
|
|
- 绝对不要自己完成任务或停止
|
|
- 绝对不要总结或说"完成了"
|
|
- 只管滑动看视频,持续不断
|
|
- 系统会在合适的时机主动停止你
|
|
|
|
现在开始执行,持续滑动观看视频,不要停!"""
|
|
|
|
# Run in background
|
|
asyncio.create_task(asyncio.to_thread(agent.run_learning_task, task))
|
|
|
|
return {"session_id": session_id, "status": "started"}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/control", response_model=Dict[str, str])
|
|
async def control_session(
|
|
session_id: str, request: SessionControlRequest
|
|
) -> Dict[str, str]:
|
|
"""Control a learning session (pause/resume/stop)."""
|
|
if session_id not in _active_sessions:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
agent = _active_sessions[session_id]
|
|
|
|
if request.action == "pause":
|
|
agent.pause_session()
|
|
return {"session_id": session_id, "status": "paused"}
|
|
elif request.action == "resume":
|
|
agent.resume_session()
|
|
return {"session_id": session_id, "status": "resumed"}
|
|
elif request.action == "stop":
|
|
agent.stop_session()
|
|
# Don't delete immediately - let status queries still work
|
|
# Session will be cleaned up when is_active becomes False
|
|
return {"session_id": session_id, "status": "stopped"}
|
|
else:
|
|
raise HTTPException(status_code=400, detail=f"Invalid action: {request.action}")
|
|
|
|
|
|
@router.get("/sessions/{session_id}/status", response_model=SessionStatus)
|
|
async def get_session_status(session_id: str) -> SessionStatus:
|
|
"""Get session status."""
|
|
if session_id not in _active_sessions:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
agent = _active_sessions[session_id]
|
|
progress = agent.get_session_progress()
|
|
|
|
# Get current video info if available
|
|
current_video = None
|
|
if agent.current_session and agent.current_session.records:
|
|
latest = agent.current_session.records[-1]
|
|
current_video = {
|
|
"sequence_id": latest.sequence_id,
|
|
"timestamp": latest.timestamp,
|
|
"screenshot_path": latest.screenshot_path,
|
|
"description": latest.description,
|
|
"likes": latest.likes,
|
|
"comments": latest.comments,
|
|
}
|
|
|
|
return SessionStatus(
|
|
session_id=progress["session_id"],
|
|
platform=progress["platform"],
|
|
target_count=progress["target_count"],
|
|
watched_count=progress["watched_count"],
|
|
progress_percent=progress["progress_percent"],
|
|
is_active=progress["is_active"],
|
|
is_paused=progress["is_paused"],
|
|
total_duration=progress["total_duration"],
|
|
current_video=current_video,
|
|
)
|
|
|
|
|
|
@router.get("/sessions/{session_id}/videos", response_model=List[VideoInfo])
|
|
async def get_session_videos(session_id: str) -> List[VideoInfo]:
|
|
"""Get all videos from a session."""
|
|
if session_id not in _active_sessions:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
agent = _active_sessions[session_id]
|
|
if not agent.current_session:
|
|
return []
|
|
|
|
return [
|
|
VideoInfo(
|
|
sequence_id=r.sequence_id,
|
|
timestamp=r.timestamp,
|
|
screenshot_path=r.screenshot_path,
|
|
watch_duration=r.watch_duration,
|
|
description=r.description,
|
|
likes=r.likes,
|
|
comments=r.comments,
|
|
tags=r.tags,
|
|
category=r.category,
|
|
)
|
|
for r in agent.current_session.records
|
|
]
|
|
|
|
|
|
@router.get("/sessions", response_model=List[str])
|
|
async def list_sessions() -> List[str]:
|
|
"""List all active session IDs."""
|
|
return list(_active_sessions.keys())
|
|
|
|
|
|
@router.delete("/sessions/{session_id}", response_model=Dict[str, str])
|
|
async def delete_session(session_id: str) -> Dict[str, str]:
|
|
"""Delete a session."""
|
|
if session_id not in _active_sessions:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
del _active_sessions[session_id]
|
|
return {"session_id": session_id, "status": "deleted"}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/analyze", response_model=Dict[str, Any])
|
|
async def analyze_session(session_id: str) -> Dict[str, Any]:
|
|
"""Analyze all screenshots in a session using VLM."""
|
|
if session_id not in _active_sessions:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
agent = _active_sessions[session_id]
|
|
if not agent.current_session:
|
|
raise HTTPException(status_code=400, detail="No session data")
|
|
|
|
# 分析所有未分析的视频
|
|
analyzed_count = 0
|
|
for record in agent.current_session.records:
|
|
if record.likes is None and record.screenshot_path:
|
|
# 需要分析
|
|
analyzed_count += 1
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"total_videos": len(agent.current_session.records),
|
|
"analyzed_count": analyzed_count,
|
|
"status": "analysis_triggered"
|
|
}
|