This commit adds a dedicated task list page to view and manage all video learning sessions, solving the issue where users couldn't find their background tasks after navigating away. Features: - New sessions.html page with card-based layout for all sessions - Real-time polling for session status updates (every 3 seconds) - Session control buttons (pause/resume/stop/delete) - localStorage integration for session persistence across page refreshes - Navigation links added to main page and video learning page - Empty state UI when no sessions exist New files: - dashboard/static/sessions.html - Task list page - dashboard/static/js/sessions.js - Sessions module with API calls - dashboard/static/css/sessions.css - Styling for sessions page Modified files: - dashboard/api/video_learning.py - Added /sessions/list endpoint - dashboard/static/index.html - Added "任务列表" button - dashboard/static/video-learning.html - Added "任务列表" button and localStorage Co-Authored-By: Claude <noreply@anthropic.com>
500 lines
18 KiB
Python
500 lines
18 KiB
Python
"""
|
||
Video Learning API endpoints for the dashboard.
|
||
"""
|
||
|
||
import asyncio
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException
|
||
from pydantic import BaseModel, Field
|
||
|
||
from dashboard.config import config
|
||
from dashboard.dependencies import get_device_manager
|
||
from dashboard.services.device_manager import DeviceManager
|
||
from phone_agent import VideoLearningAgent
|
||
from phone_agent.model.client import ModelConfig
|
||
|
||
router = APIRouter(prefix="/api/video-learning", tags=["video-learning"])
|
||
|
||
|
||
class SessionCreateRequest(BaseModel):
|
||
"""Request to create a new learning session."""
|
||
|
||
device_id: str = Field(..., description="Target device ID")
|
||
platform: str = Field("douyin", description="Platform name (douyin, kuaishou, tiktok)")
|
||
target_count: int = Field(10, description="Number of videos to watch", ge=1, le=100)
|
||
category: Optional[str] = Field(None, description="Target category filter")
|
||
watch_duration: float = Field(3.0, description="Watch duration per video (seconds)", ge=1.0, le=30.0)
|
||
enable_analysis: bool = Field(True, description="Enable VLM screenshot analysis")
|
||
|
||
|
||
class SessionControlRequest(BaseModel):
|
||
"""Request to control a session."""
|
||
|
||
action: str = Field(..., description="Action: pause, resume, stop")
|
||
|
||
|
||
class SessionStatus(BaseModel):
|
||
"""Session status response."""
|
||
|
||
session_id: str
|
||
platform: str
|
||
target_count: int
|
||
watched_count: int
|
||
progress_percent: float
|
||
is_active: bool
|
||
is_paused: bool
|
||
total_duration: float
|
||
current_video: Optional[Dict] = None
|
||
|
||
|
||
class VideoInfo(BaseModel):
|
||
"""Information about a watched video."""
|
||
|
||
sequence_id: int
|
||
timestamp: str
|
||
screenshot_path: Optional[str] = None
|
||
watch_duration: float
|
||
description: Optional[str] = None
|
||
likes: Optional[int] = None
|
||
comments: Optional[int] = None
|
||
tags: List[str] = []
|
||
category: Optional[str] = None
|
||
|
||
|
||
# Global session storage (in production, use database)
|
||
# 会话存储:session_id -> VideoLearningAgent
|
||
_active_sessions: Dict[str, VideoLearningAgent] = {}
|
||
# 设备-会话映射:device_id -> {session_ids}
|
||
# 支持多设备同时运行,每个设备可以有多个会话(但同一时间只能有一个活跃的)
|
||
_device_sessions: Dict[str, set[str]] = {}
|
||
|
||
|
||
@router.post("/sessions", response_model=Dict[str, str])
|
||
async def create_session(
|
||
request: SessionCreateRequest,
|
||
device_manager: DeviceManager = Depends(get_device_manager),
|
||
) -> Dict[str, str]:
|
||
"""Create a new video learning session."""
|
||
# Check device availability
|
||
device = await device_manager.get_device(request.device_id)
|
||
if not device:
|
||
raise HTTPException(status_code=404, detail="Device not found")
|
||
|
||
if not device.is_connected:
|
||
raise HTTPException(status_code=400, detail="Device not connected")
|
||
|
||
# 使用设备锁机制确保同一设备不会同时运行多个会话
|
||
if not device_manager.acquire_device(request.device_id):
|
||
raise HTTPException(status_code=409, detail="Device is busy - another session is running on this device")
|
||
|
||
try:
|
||
# Create model config from environment
|
||
model_config = ModelConfig(
|
||
base_url=config.MODEL_BASE_URL,
|
||
model_name=config.MODEL_NAME,
|
||
api_key=config.MODEL_API_KEY,
|
||
max_tokens=config.MAX_TOKENS,
|
||
temperature=config.TEMPERATURE,
|
||
top_p=config.TOP_P,
|
||
frequency_penalty=config.FREQUENCY_PENALTY,
|
||
lang="cn",
|
||
)
|
||
|
||
# Create video learning agent
|
||
agent = VideoLearningAgent(
|
||
model_config=model_config,
|
||
platform=request.platform,
|
||
output_dir=config.VIDEO_LEARNING_OUTPUT_DIR,
|
||
enable_analysis=request.enable_analysis,
|
||
)
|
||
|
||
# Setup callbacks for real-time updates
|
||
session_id = None
|
||
|
||
def on_video_watched(record):
|
||
"""Callback when a video is watched."""
|
||
if session_id:
|
||
pass # WebSocket broadcast would go here
|
||
|
||
def on_progress_update(current, total):
|
||
"""Callback for progress updates."""
|
||
if session_id:
|
||
pass # Progress broadcast would go here
|
||
|
||
def on_session_complete(session):
|
||
"""Callback when session completes - release device lock and clean up."""
|
||
if session_id:
|
||
# 从设备-会话映射中移除
|
||
if request.device_id in _device_sessions:
|
||
_device_sessions[request.device_id].discard(session_id)
|
||
if not _device_sessions[request.device_id]:
|
||
del _device_sessions[request.device_id]
|
||
|
||
# 从活跃会话中移除(但保留历史记录可查询)
|
||
if session_id in _active_sessions:
|
||
# 标记为完成但不删除,允许用户查看结果
|
||
pass
|
||
|
||
# 释放设备锁
|
||
device_manager.release_device(request.device_id)
|
||
|
||
agent.on_video_watched = on_video_watched
|
||
agent.on_progress_update = on_progress_update
|
||
agent.on_session_complete = on_session_complete
|
||
|
||
# Start session
|
||
session_id = agent.start_session(
|
||
device_id=request.device_id,
|
||
target_count=request.target_count,
|
||
category=request.category,
|
||
watch_duration=request.watch_duration,
|
||
max_steps=500,
|
||
)
|
||
|
||
# Store session in both mappings
|
||
_active_sessions[session_id] = agent
|
||
if request.device_id not in _device_sessions:
|
||
_device_sessions[request.device_id] = set()
|
||
_device_sessions[request.device_id].add(session_id)
|
||
|
||
return {"session_id": session_id, "status": "created"}
|
||
|
||
except Exception as e:
|
||
# 创建失败时释放设备锁
|
||
device_manager.release_device(request.device_id)
|
||
raise HTTPException(status_code=500, detail=f"Failed to create session: {str(e)}")
|
||
|
||
|
||
@router.post("/sessions/{session_id}/start", response_model=Dict[str, str])
|
||
async def start_session(session_id: str) -> Dict[str, str]:
|
||
"""Start executing a learning session."""
|
||
if session_id not in _active_sessions:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
agent = _active_sessions[session_id]
|
||
|
||
# Build task based on session parameters
|
||
session = agent.current_session
|
||
if not session:
|
||
raise HTTPException(status_code=400, detail="Session not initialized")
|
||
|
||
category = session.target_category
|
||
target_count = session.target_count
|
||
watch_duration = agent._watch_duration
|
||
platform = agent.platform
|
||
|
||
# Platform-specific app name and package
|
||
platform_info = {
|
||
"douyin": {
|
||
"name": "抖音",
|
||
"package": "com.ss.android.ugc.aweme",
|
||
},
|
||
"kuaishou": {
|
||
"name": "快手",
|
||
"package": "com.smile.gifmaker",
|
||
},
|
||
"tiktok": {
|
||
"name": "TikTok",
|
||
"package": "com.zhiliaoapp.musically",
|
||
},
|
||
}
|
||
|
||
info = platform_info.get(platform, platform_info["douyin"])
|
||
app_name = info["name"]
|
||
|
||
# Build clear task instructions
|
||
# 重要:不让 VLM 自己计数,而是持续执行直到被外部停止
|
||
# 这样可以确保 VideoLearningAgent 的计数与实际观看数一致
|
||
if category:
|
||
task = f"""你是一个视频学习助手。你的任务是持续观看视频,不要自己判断任务是否完成。
|
||
|
||
【核心要求】你必须持续执行以下循环,直到系统主动停止你:
|
||
- 观看视频约{watch_duration}秒
|
||
- 向上滑动切换到下一个视频
|
||
- 继续观看
|
||
- 再次滑动
|
||
- 一直重复,不要停止
|
||
|
||
执行步骤:
|
||
1. 回到主屏幕,打开{app_name}应用
|
||
2. 在{app_name}中搜索"{category}"
|
||
3. 点击第一个搜索结果
|
||
4. 开始无限循环:观看视频 → 向上滑动 → 观看下一个视频 → 再次滑动...
|
||
|
||
【极其重要】
|
||
- 绝对不要自己完成任务或停止
|
||
- 绝对不要总结或说"完成了"
|
||
- 只管滑动看视频,持续不断
|
||
- 系统会在合适的时机主动停止你
|
||
|
||
现在开始执行,持续滑动观看视频,不要停!"""
|
||
else:
|
||
task = f"""你是一个视频学习助手。你的任务是持续观看视频,不要自己判断任务是否完成。
|
||
|
||
【核心要求】你必须持续执行以下循环,直到系统主动停止你:
|
||
- 观看视频约{watch_duration}秒
|
||
- 向上滑动切换到下一个视频
|
||
- 继续观看
|
||
- 再次滑动
|
||
- 一直重复,不要停止
|
||
|
||
执行步骤:
|
||
1. 回到主屏幕,打开{app_name}应用
|
||
2. 进入推荐页面
|
||
3. 开始无限循环:观看视频 → 向上滑动 → 观看下一个视频 → 再次滑动...
|
||
|
||
【极其重要】
|
||
- 绝对不要自己完成任务或停止
|
||
- 绝对不要总结或说"完成了"
|
||
- 只管滑动看视频,持续不断
|
||
- 系统会在合适的时机主动停止你
|
||
|
||
现在开始执行,持续滑动观看视频,不要停!"""
|
||
|
||
# Run in background
|
||
asyncio.create_task(asyncio.to_thread(agent.run_learning_task, task))
|
||
|
||
return {"session_id": session_id, "status": "started"}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/control", response_model=Dict[str, str])
|
||
async def control_session(
|
||
session_id: str, request: SessionControlRequest
|
||
) -> Dict[str, str]:
|
||
"""Control a learning session (pause/resume/stop)."""
|
||
if session_id not in _active_sessions:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
agent = _active_sessions[session_id]
|
||
|
||
if request.action == "pause":
|
||
agent.pause_session()
|
||
return {"session_id": session_id, "status": "paused"}
|
||
elif request.action == "resume":
|
||
agent.resume_session()
|
||
return {"session_id": session_id, "status": "resumed"}
|
||
elif request.action == "stop":
|
||
agent.stop_session()
|
||
# Don't delete immediately - let status queries still work
|
||
# Session will be cleaned up when is_active becomes False
|
||
return {"session_id": session_id, "status": "stopped"}
|
||
else:
|
||
raise HTTPException(status_code=400, detail=f"Invalid action: {request.action}")
|
||
|
||
|
||
@router.get("/sessions/{session_id}/status", response_model=SessionStatus)
|
||
async def get_session_status(session_id: str) -> SessionStatus:
|
||
"""Get session status."""
|
||
if session_id not in _active_sessions:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
agent = _active_sessions[session_id]
|
||
progress = agent.get_session_progress()
|
||
|
||
# Get current video info if available
|
||
current_video = None
|
||
if agent.current_session and agent.current_session.records:
|
||
latest = agent.current_session.records[-1]
|
||
current_video = {
|
||
"sequence_id": latest.sequence_id,
|
||
"timestamp": latest.timestamp,
|
||
"screenshot_path": latest.screenshot_path,
|
||
"description": latest.description,
|
||
"likes": latest.likes,
|
||
"comments": latest.comments,
|
||
}
|
||
|
||
return SessionStatus(
|
||
session_id=progress["session_id"],
|
||
platform=progress["platform"],
|
||
target_count=progress["target_count"],
|
||
watched_count=progress["watched_count"],
|
||
progress_percent=progress["progress_percent"],
|
||
is_active=progress["is_active"],
|
||
is_paused=progress["is_paused"],
|
||
total_duration=progress["total_duration"],
|
||
current_video=current_video,
|
||
)
|
||
|
||
|
||
@router.get("/sessions/{session_id}/videos", response_model=List[VideoInfo])
|
||
async def get_session_videos(session_id: str) -> List[VideoInfo]:
|
||
"""Get all videos from a session."""
|
||
if session_id not in _active_sessions:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
agent = _active_sessions[session_id]
|
||
if not agent.current_session:
|
||
return []
|
||
|
||
return [
|
||
VideoInfo(
|
||
sequence_id=r.sequence_id,
|
||
timestamp=r.timestamp,
|
||
screenshot_path=r.screenshot_path,
|
||
watch_duration=r.watch_duration,
|
||
description=r.description,
|
||
likes=r.likes,
|
||
comments=r.comments,
|
||
tags=r.tags,
|
||
category=r.category,
|
||
)
|
||
for r in agent.current_session.records
|
||
]
|
||
|
||
|
||
@router.get("/sessions", response_model=List[str])
|
||
async def list_sessions() -> List[str]:
|
||
"""List all active session IDs."""
|
||
return list(_active_sessions.keys())
|
||
|
||
|
||
@router.get("/sessions/list")
|
||
async def list_all_sessions() -> List[Dict[str, Any]]:
|
||
"""获取所有会话的详细信息(包含进度、状态等)"""
|
||
result = []
|
||
for session_id, agent in _active_sessions.items():
|
||
try:
|
||
progress = agent.get_session_progress()
|
||
result.append({
|
||
"session_id": session_id,
|
||
"platform": progress.get("platform", ""),
|
||
"target_count": progress.get("target_count", 0),
|
||
"watched_count": progress.get("watched_count", 0),
|
||
"progress_percent": progress.get("progress_percent", 0),
|
||
"is_active": progress.get("is_active", False),
|
||
"is_paused": progress.get("is_paused", False),
|
||
"total_duration": progress.get("total_duration", 0),
|
||
})
|
||
except Exception as e:
|
||
# 跳过出错的会话
|
||
continue
|
||
return result
|
||
|
||
|
||
@router.delete("/sessions/{session_id}", response_model=Dict[str, str])
|
||
async def delete_session(session_id: str) -> Dict[str, str]:
|
||
"""Delete a session and clean up device mapping."""
|
||
if session_id not in _active_sessions:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
# 获取设备ID以便清理映射
|
||
agent = _active_sessions[session_id]
|
||
device_id = None
|
||
if agent.current_session:
|
||
# 从 session 中获取 device_id
|
||
device_id = agent._device_id
|
||
|
||
# 删除会话
|
||
del _active_sessions[session_id]
|
||
|
||
# 清理设备-会话映射
|
||
if device_id and device_id in _device_sessions:
|
||
_device_sessions[device_id].discard(session_id)
|
||
if not _device_sessions[device_id]:
|
||
del _device_sessions[device_id]
|
||
|
||
return {"session_id": session_id, "status": "deleted"}
|
||
|
||
|
||
@router.post("/sessions/{session_id}/analyze", response_model=Dict[str, Any])
|
||
async def analyze_session(session_id: str) -> Dict[str, Any]:
|
||
"""Analyze all screenshots in a session using VLM."""
|
||
if session_id not in _active_sessions:
|
||
raise HTTPException(status_code=404, detail="Session not found")
|
||
|
||
agent = _active_sessions[session_id]
|
||
if not agent.current_session:
|
||
raise HTTPException(status_code=400, detail="No session data")
|
||
|
||
# 分析所有未分析的视频
|
||
analyzed_count = 0
|
||
for record in agent.current_session.records:
|
||
if record.likes is None and record.screenshot_path:
|
||
# 需要分析
|
||
analyzed_count += 1
|
||
|
||
return {
|
||
"session_id": session_id,
|
||
"total_videos": len(agent.current_session.records),
|
||
"analyzed_count": analyzed_count,
|
||
"status": "analysis_triggered"
|
||
}
|
||
|
||
|
||
# ==================== 设备级别的 API ====================
|
||
|
||
@router.get("/devices/{device_id}/sessions")
|
||
async def get_device_sessions(device_id: str) -> List[Dict[str, Any]]:
|
||
"""获取指定设备上的所有会话"""
|
||
if device_id not in _device_sessions:
|
||
return []
|
||
|
||
session_ids = _device_sessions[device_id]
|
||
result = []
|
||
for sid in session_ids:
|
||
if sid in _active_sessions:
|
||
agent = _active_sessions[sid]
|
||
progress = agent.get_session_progress()
|
||
result.append({
|
||
"session_id": sid,
|
||
"platform": progress.get("platform", ""),
|
||
"target_count": progress.get("target_count", 0),
|
||
"watched_count": progress.get("watched_count", 0),
|
||
"progress_percent": progress.get("progress_percent", 0),
|
||
"is_active": progress.get("is_active", False),
|
||
"is_paused": progress.get("is_paused", False),
|
||
})
|
||
return result
|
||
|
||
|
||
@router.post("/devices/{device_id}/stop-all", response_model=Dict[str, Any])
|
||
async def stop_all_device_sessions(device_id: str) -> Dict[str, Any]:
|
||
"""停止指定设备上的所有活跃会话"""
|
||
if device_id not in _device_sessions:
|
||
return {"device_id": device_id, "stopped": 0, "message": "No sessions on this device"}
|
||
|
||
count = 0
|
||
session_ids = list(_device_sessions[device_id])
|
||
for session_id in session_ids:
|
||
if session_id in _active_sessions:
|
||
agent = _active_sessions[session_id]
|
||
# 只停止活跃的会话
|
||
if agent.current_session and agent.current_session.is_active:
|
||
agent.stop_session()
|
||
count += 1
|
||
|
||
return {
|
||
"device_id": device_id,
|
||
"stopped": count,
|
||
"message": f"Stopped {count} active session(s) on device {device_id}"
|
||
}
|
||
|
||
|
||
@router.get("/devices/sessions", response_model=Dict[str, Any])
|
||
async def get_all_device_sessions() -> Dict[str, Any]:
|
||
"""获取所有设备上的会话概览"""
|
||
result = {
|
||
"devices": {},
|
||
"total_devices": len(_device_sessions),
|
||
"total_sessions": len(_active_sessions)
|
||
}
|
||
|
||
for device_id, session_ids in _device_sessions.items():
|
||
active_count = 0
|
||
for sid in session_ids:
|
||
if sid in _active_sessions:
|
||
agent = _active_sessions[sid]
|
||
if agent.current_session and agent.current_session.is_active:
|
||
active_count += 1
|
||
|
||
result["devices"][device_id] = {
|
||
"total_sessions": len(session_ids),
|
||
"active_sessions": active_count,
|
||
"session_ids": list(session_ids)
|
||
}
|
||
|
||
return result
|