Features: - ScreenshotAnalyzer class for VLM-based image analysis - Real-time analysis during video recording - Extract likes, comments, tags, category from screenshots - Frontend display for category badges and tags - Batch analysis API endpoint Co-Authored-By: Claude <noreply@anthropic.com>
658 lines
22 KiB
Python
658 lines
22 KiB
Python
"""
|
||
Video Learning Agent for AutoGLM
|
||
|
||
This agent learns from short video platforms (like Douyin/TikTok)
|
||
by watching videos and collecting information.
|
||
|
||
MVP Features:
|
||
- Automatic video scrolling
|
||
- Play/Pause control
|
||
- Screenshot capture for each video
|
||
- Basic data collection (likes, comments, etc.)
|
||
"""
|
||
|
||
import hashlib
|
||
import json
|
||
import os
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Callable, Dict, List, Optional, Any
|
||
|
||
from phone_agent import PhoneAgent, AgentConfig
|
||
from phone_agent.agent import StepResult
|
||
from phone_agent.model.client import ModelConfig
|
||
from phone_agent.device_factory import get_device_factory
|
||
|
||
|
||
@dataclass
|
||
class VideoRecord:
|
||
"""Record of a watched video."""
|
||
|
||
sequence_id: int
|
||
timestamp: str
|
||
screenshot_path: Optional[str] = None
|
||
watch_duration: float = 0.0 # seconds
|
||
|
||
# Basic info (extracted via OCR/analysis)
|
||
description: Optional[str] = None # Video caption/text
|
||
likes: Optional[int] = None
|
||
comments: Optional[int] = None
|
||
shares: Optional[int] = None
|
||
|
||
# Content analysis (for future expansion)
|
||
tags: List[str] = field(default_factory=list)
|
||
category: Optional[str] = None
|
||
elements: List[str] = field(default_factory=list)
|
||
|
||
# Metadata
|
||
position_in_session: int = 0
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""Convert to dictionary."""
|
||
return {
|
||
"sequence_id": self.sequence_id,
|
||
"timestamp": self.timestamp,
|
||
"screenshot_path": self.screenshot_path,
|
||
"watch_duration": self.watch_duration,
|
||
"description": self.description,
|
||
"likes": self.likes,
|
||
"comments": self.comments,
|
||
"shares": self.shares,
|
||
"tags": self.tags,
|
||
"category": self.category,
|
||
"elements": self.elements,
|
||
"position_in_session": self.position_in_session,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class LearningSession:
|
||
"""A learning session with multiple videos."""
|
||
|
||
session_id: str
|
||
start_time: str
|
||
platform: str # "douyin", "tiktok", etc.
|
||
target_category: Optional[str] = None
|
||
target_count: int = 10
|
||
records: List[VideoRecord] = field(default_factory=list)
|
||
|
||
# Control flags
|
||
is_active: bool = True
|
||
is_paused: bool = False
|
||
|
||
# Statistics
|
||
total_videos: int = 0
|
||
total_duration: float = 0.0
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""Convert to dictionary."""
|
||
return {
|
||
"session_id": self.session_id,
|
||
"start_time": self.start_time,
|
||
"platform": self.platform,
|
||
"target_category": self.target_category,
|
||
"target_count": self.target_count,
|
||
"is_active": self.is_active,
|
||
"is_paused": self.is_paused,
|
||
"total_videos": self.total_videos,
|
||
"total_duration": self.total_duration,
|
||
"records": [r.to_dict() for r in self.records],
|
||
}
|
||
|
||
|
||
class ScreenshotAnalyzer:
|
||
"""分析视频截图,提取内容信息"""
|
||
|
||
ANALYSIS_PROMPT = """分析这张短视频截图,提取以下信息并以JSON格式返回:
|
||
{
|
||
"description": "视频描述文案(屏幕上显示的文字,如果有的话)",
|
||
"likes": 点赞数(纯数字,如12000,没有则为null),
|
||
"comments": 评论数(纯数字,没有则为null),
|
||
"shares": 分享数(纯数字,没有则为null),
|
||
"tags": ["标签1", "标签2"],
|
||
"category": "视频类型(美食/旅行/搞笑/知识/生活/音乐/舞蹈/其他)",
|
||
"elements": ["画面中的主要元素,如:人物、食物、风景等"]
|
||
}
|
||
注意:
|
||
1. 只返回JSON,不要其他文字
|
||
2. 数字不要带单位,如"1.2万"应转为12000
|
||
3. 如果无法识别某项,设为null或空数组"""
|
||
|
||
def __init__(self, model_config: ModelConfig):
|
||
"""初始化分析器"""
|
||
from phone_agent.model.client import ModelClient
|
||
self.model_client = ModelClient(model_config)
|
||
|
||
def analyze(self, screenshot_base64: str) -> Dict[str, Any]:
|
||
"""分析截图并返回提取的信息"""
|
||
from phone_agent.model.client import MessageBuilder
|
||
|
||
# 构建消息
|
||
messages = [
|
||
MessageBuilder.create_user_message(
|
||
text=self.ANALYSIS_PROMPT,
|
||
image_base64=screenshot_base64
|
||
)
|
||
]
|
||
|
||
try:
|
||
# 调用 VLM
|
||
response = self.model_client.request(messages)
|
||
result_text = response.content.strip()
|
||
|
||
# 解析 JSON
|
||
return self._parse_result(result_text)
|
||
except Exception as e:
|
||
print(f"[ScreenshotAnalyzer] Error: {e}")
|
||
return {}
|
||
|
||
def _parse_result(self, text: str) -> Dict[str, Any]:
|
||
"""解析 VLM 返回的 JSON 结果"""
|
||
import re
|
||
|
||
# 尝试提取 JSON
|
||
json_match = re.search(r'\{[\s\S]*\}', text)
|
||
if not json_match:
|
||
return {}
|
||
|
||
try:
|
||
result = json.loads(json_match.group())
|
||
# 确保数字字段是整数
|
||
for field in ['likes', 'comments', 'shares']:
|
||
if field in result and result[field] is not None:
|
||
try:
|
||
result[field] = int(result[field])
|
||
except (ValueError, TypeError):
|
||
result[field] = None
|
||
return result
|
||
except json.JSONDecodeError:
|
||
return {}
|
||
|
||
|
||
class VideoLearningAgent:
|
||
"""
|
||
Agent for learning from short video platforms.
|
||
|
||
MVP Capabilities:
|
||
- Navigate to video platform
|
||
- Watch videos automatically
|
||
- Capture screenshots
|
||
- Collect basic information
|
||
- Export learning data
|
||
"""
|
||
|
||
# Platform-specific configurations
|
||
PLATFORM_CONFIGS = {
|
||
"douyin": {
|
||
"package_name": "com.ss.android.ugc.aweme",
|
||
"activity_hint": "aweme",
|
||
"scroll_gesture": "up",
|
||
"like_position": {"x": 0.9, "y": 0.8}, # Relative coordinates
|
||
"comment_position": {"x": 0.9, "y": 0.7},
|
||
},
|
||
"kuaishou": {
|
||
"package_name": "com.smile.gifmaker",
|
||
"activity_hint": "gifmaker",
|
||
"scroll_gesture": "up",
|
||
"like_position": {"x": 0.9, "y": 0.8},
|
||
},
|
||
"tiktok": {
|
||
"package_name": "com.zhiliaoapp.musically",
|
||
"activity_hint": "musically",
|
||
"scroll_gesture": "up",
|
||
"like_position": {"x": 0.9, "y": 0.8},
|
||
},
|
||
}
|
||
|
||
def __init__(
|
||
self,
|
||
model_config: ModelConfig,
|
||
platform: str = "douyin",
|
||
output_dir: str = "./video_learning_data",
|
||
):
|
||
"""
|
||
Initialize Video Learning Agent.
|
||
|
||
Args:
|
||
model_config: Model configuration for VLM
|
||
platform: Platform name (douyin, kuaishou, tiktok)
|
||
output_dir: Directory to save screenshots and data
|
||
"""
|
||
self.model_config = model_config
|
||
self.platform = platform
|
||
self.output_dir = Path(output_dir)
|
||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# Create screenshots subdirectory
|
||
self.screenshot_dir = self.output_dir / "screenshots"
|
||
self.screenshot_dir.mkdir(exist_ok=True)
|
||
|
||
# Current session
|
||
self.current_session: Optional[LearningSession] = None
|
||
self.video_counter = 0
|
||
|
||
# Agent will be created when starting a session
|
||
self.agent: Optional[PhoneAgent] = None
|
||
|
||
# Callbacks for external control
|
||
self.on_video_watched: Optional[Callable[[VideoRecord], None]] = None
|
||
self.on_session_complete: Optional[Callable[[LearningSession], None]] = None
|
||
self.on_progress_update: Optional[Callable[[int, int], None]] = None
|
||
|
||
# Video detection: track screenshot changes (simplified)
|
||
self._last_screenshot_hash: Optional[str] = None
|
||
|
||
# Screenshot analyzer for content extraction
|
||
self._analyzer: Optional[ScreenshotAnalyzer] = None
|
||
try:
|
||
self._analyzer = ScreenshotAnalyzer(model_config)
|
||
print("[VideoLearning] Screenshot analyzer initialized")
|
||
except Exception as e:
|
||
print(f"[VideoLearning] Analyzer init failed: {e}")
|
||
|
||
def start_session(
|
||
self,
|
||
device_id: str,
|
||
target_count: int = 10,
|
||
category: Optional[str] = None,
|
||
watch_duration: float = 3.0,
|
||
max_steps: int = 500,
|
||
) -> str:
|
||
"""
|
||
Start a learning session.
|
||
|
||
Args:
|
||
device_id: Target device ID
|
||
target_count: Number of videos to watch
|
||
category: Target category (e.g., "美食", "旅行")
|
||
watch_duration: How long to watch each video (seconds)
|
||
max_steps: Maximum execution steps
|
||
|
||
Returns:
|
||
Session ID
|
||
"""
|
||
# Create new session
|
||
session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||
self.current_session = LearningSession(
|
||
session_id=session_id,
|
||
start_time=datetime.now().isoformat(),
|
||
platform=self.platform,
|
||
target_category=category,
|
||
target_count=target_count,
|
||
)
|
||
|
||
# Configure agent with callbacks
|
||
agent_config = AgentConfig(
|
||
device_id=device_id,
|
||
max_steps=max_steps,
|
||
lang="cn",
|
||
step_callback=self._on_step,
|
||
before_action_callback=self._before_action,
|
||
)
|
||
|
||
# Create phone agent
|
||
self.agent = PhoneAgent(
|
||
model_config=self.model_config,
|
||
agent_config=agent_config,
|
||
)
|
||
|
||
# Store parameters for the task
|
||
self._watch_duration = watch_duration
|
||
self._device_id = device_id
|
||
|
||
# Reset video detection tracking (simplified)
|
||
self._last_screenshot_hash = None
|
||
self.video_counter = 0
|
||
|
||
return session_id
|
||
|
||
def run_learning_task(self, task: str) -> bool:
|
||
"""
|
||
Run the learning task.
|
||
|
||
Args:
|
||
task: Natural language task description
|
||
|
||
Returns:
|
||
True if successful
|
||
"""
|
||
if not self.agent or not self.current_session:
|
||
raise RuntimeError("Session not started. Call start_session() first.")
|
||
|
||
try:
|
||
result = self.agent.run(task)
|
||
# Mark session as inactive after task completes
|
||
if self.current_session:
|
||
self.current_session.is_active = False
|
||
self._save_session()
|
||
print(f"[VideoLearning] Session completed. Recorded {self.video_counter} videos.")
|
||
return bool(result)
|
||
except Exception as e:
|
||
print(f"Error during learning: {e}")
|
||
if self.current_session:
|
||
self.current_session.is_active = False
|
||
return False
|
||
|
||
def stop_session(self):
|
||
"""Stop the current learning session."""
|
||
if self.current_session:
|
||
self.current_session.is_active = False
|
||
|
||
if self.agent:
|
||
# Agent will stop on next callback check
|
||
pass
|
||
|
||
def pause_session(self):
|
||
"""Pause the current session (can be resumed)."""
|
||
if self.current_session:
|
||
self.current_session.is_paused = True
|
||
|
||
def resume_session(self):
|
||
"""Resume a paused session."""
|
||
if self.current_session:
|
||
self.current_session.is_paused = False
|
||
|
||
def _on_step(self, result: StepResult) -> Optional[str]:
|
||
"""
|
||
Callback after each step.
|
||
|
||
Simplified logic:
|
||
1. Check if we're in the target app using get_current_app()
|
||
2. Detect screenshot changes
|
||
3. Record video when screenshot changes
|
||
|
||
Args:
|
||
result: Step execution result
|
||
|
||
Returns:
|
||
"stop" to end session, new task to switch, None to continue
|
||
"""
|
||
if not self.current_session:
|
||
return None
|
||
|
||
# Check if session should stop
|
||
if not self.current_session.is_active:
|
||
self._save_session()
|
||
if self.on_session_complete:
|
||
self.on_session_complete(self.current_session)
|
||
return "stop"
|
||
|
||
# Check if paused
|
||
if self.current_session.is_paused:
|
||
return None
|
||
|
||
# Check if we've watched enough videos
|
||
if self.video_counter >= self.current_session.target_count:
|
||
self.current_session.is_active = False
|
||
self._save_session()
|
||
if self.on_session_complete:
|
||
self.on_session_complete(self.current_session)
|
||
return "stop"
|
||
|
||
try:
|
||
# Use get_current_app() to detect if we're in target app
|
||
current_app = get_device_factory().get_current_app(self._device_id)
|
||
|
||
# Platform-specific package names
|
||
platform_packages = {
|
||
"douyin": ["aweme", "抖音", "douyin"],
|
||
"kuaishou": ["gifmaker", "快手", "kuaishou"],
|
||
"tiktok": ["musically", "tiktok"],
|
||
}
|
||
packages = platform_packages.get(self.platform, ["aweme"])
|
||
|
||
# Check if in target app
|
||
is_in_target = any(pkg.lower() in current_app.lower() for pkg in packages)
|
||
|
||
if not is_in_target:
|
||
print(f"[VideoLearning] Not in target app: {current_app} (step {result.step_count})")
|
||
return None
|
||
|
||
# Get screenshot
|
||
screenshot = get_device_factory().get_screenshot(self._device_id)
|
||
|
||
# Use full base64 data for hash (more sensitive)
|
||
current_hash = hashlib.md5(screenshot.base64_data.encode()).hexdigest()
|
||
|
||
# Detect screenshot change and record video
|
||
if self._last_screenshot_hash is None:
|
||
# First screenshot in target app - record first video
|
||
self._last_screenshot_hash = current_hash
|
||
self._record_video_from_screenshot(screenshot)
|
||
print(f"[VideoLearning] ✓ Recorded video {self.video_counter}/{self.current_session.target_count}")
|
||
|
||
# Check if we've reached target after recording
|
||
if self.video_counter >= self.current_session.target_count:
|
||
print(f"[VideoLearning] ✓ Target reached! Stopping...")
|
||
self.current_session.is_active = False
|
||
self._save_session()
|
||
return "stop"
|
||
|
||
elif current_hash != self._last_screenshot_hash:
|
||
# Screenshot changed - record new video
|
||
self._last_screenshot_hash = current_hash
|
||
self._record_video_from_screenshot(screenshot)
|
||
print(f"[VideoLearning] ✓ Recorded video {self.video_counter}/{self.current_session.target_count}")
|
||
|
||
# Check if we've reached target after recording
|
||
if self.video_counter >= self.current_session.target_count:
|
||
print(f"[VideoLearning] ✓ Target reached! Stopping...")
|
||
self.current_session.is_active = False
|
||
self._save_session()
|
||
return "stop"
|
||
|
||
except Exception as e:
|
||
print(f"[VideoLearning] Warning: {e}")
|
||
|
||
return None
|
||
|
||
def _record_video_from_screenshot(self, screenshot):
|
||
"""Helper method to record video from screenshot with analysis."""
|
||
import base64
|
||
screenshot_bytes = base64.b64decode(screenshot.base64_data)
|
||
|
||
# 分析截图内容
|
||
analysis_result = {}
|
||
if self._analyzer:
|
||
try:
|
||
print(f"[VideoLearning] Analyzing screenshot...")
|
||
analysis_result = self._analyzer.analyze(screenshot.base64_data)
|
||
if analysis_result:
|
||
print(f"[VideoLearning] Analysis: {analysis_result.get('category', 'N/A')}")
|
||
except Exception as e:
|
||
print(f"[VideoLearning] Analysis failed: {e}")
|
||
|
||
# 记录视频
|
||
self.record_video(
|
||
screenshot=screenshot_bytes,
|
||
description=analysis_result.get('description', f"Video #{self.video_counter + 1}"),
|
||
likes=analysis_result.get('likes'),
|
||
comments=analysis_result.get('comments'),
|
||
shares=analysis_result.get('shares'),
|
||
tags=analysis_result.get('tags', []),
|
||
category=analysis_result.get('category'),
|
||
elements=analysis_result.get('elements', []),
|
||
)
|
||
|
||
def _before_action(self, action: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Callback before executing an action.
|
||
|
||
Args:
|
||
action: Action to execute
|
||
|
||
Returns:
|
||
Modified action or None
|
||
"""
|
||
# Could be used for action logging or modification
|
||
return None
|
||
|
||
def record_video(
|
||
self,
|
||
screenshot: Optional[bytes] = None,
|
||
description: Optional[str] = None,
|
||
likes: Optional[int] = None,
|
||
comments: Optional[int] = None,
|
||
) -> VideoRecord:
|
||
"""
|
||
Record a watched video.
|
||
|
||
Args:
|
||
screenshot: Screenshot image data
|
||
description: Video description/caption
|
||
likes: Number of likes
|
||
comments: Number of comments
|
||
|
||
Returns:
|
||
VideoRecord object
|
||
"""
|
||
self.video_counter += 1
|
||
|
||
# Save screenshot if provided
|
||
screenshot_path = None
|
||
if screenshot:
|
||
screenshot_filename = f"{self.current_session.session_id}_video_{self.video_counter}.png"
|
||
screenshot_full_path = self.screenshot_dir / screenshot_filename
|
||
# Store relative path for web access: /video-learning-data/screenshots/filename.png
|
||
screenshot_path = f"/video-learning-data/screenshots/{screenshot_filename}"
|
||
with open(str(screenshot_full_path), "wb") as f:
|
||
f.write(screenshot)
|
||
|
||
# Create record
|
||
record = VideoRecord(
|
||
sequence_id=self.video_counter,
|
||
timestamp=datetime.now().isoformat(),
|
||
screenshot_path=screenshot_path,
|
||
watch_duration=self._watch_duration,
|
||
description=description,
|
||
likes=likes,
|
||
comments=comments,
|
||
position_in_session=self.video_counter,
|
||
)
|
||
|
||
# Add to session
|
||
if self.current_session:
|
||
self.current_session.records.append(record)
|
||
self.current_session.total_videos = self.video_counter
|
||
self.current_session.total_duration += self._watch_duration
|
||
|
||
# Notify callback
|
||
if self.on_video_watched:
|
||
self.on_video_watched(record)
|
||
|
||
# Notify progress
|
||
if self.on_progress_update:
|
||
self.on_progress_update(self.video_counter, self.current_session.target_count)
|
||
|
||
return record
|
||
|
||
def _save_session(self):
|
||
"""Save session data to JSON file."""
|
||
if not self.current_session:
|
||
return
|
||
|
||
session_file = self.output_dir / f"{self.current_session.session_id}.json"
|
||
with open(session_file, "w", encoding="utf-8") as f:
|
||
json.dump(self.current_session.to_dict(), f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"Session saved to {session_file}")
|
||
|
||
def export_data(self, format: str = "json") -> str:
|
||
"""
|
||
Export session data.
|
||
|
||
Args:
|
||
format: Export format (json, csv)
|
||
|
||
Returns:
|
||
Path to exported file
|
||
"""
|
||
if not self.current_session:
|
||
raise RuntimeError("No session to export")
|
||
|
||
if format == "json":
|
||
return self._export_json()
|
||
elif format == "csv":
|
||
return self._export_csv()
|
||
else:
|
||
raise ValueError(f"Unsupported format: {format}")
|
||
|
||
def _export_json(self) -> str:
|
||
"""Export as JSON."""
|
||
output_file = self.output_dir / f"{self.current_session.session_id}_export.json"
|
||
with open(output_file, "w", encoding="utf-8") as f:
|
||
json.dump(self.current_session.to_dict(), f, ensure_ascii=False, indent=2)
|
||
return str(output_file)
|
||
|
||
def _export_csv(self) -> str:
|
||
"""Export as CSV."""
|
||
import csv
|
||
|
||
output_file = self.output_dir / f"{self.current_session.session_id}_export.csv"
|
||
with open(output_file, "w", encoding="utf-8", newline="") as f:
|
||
if not self.current_session.records:
|
||
return str(output_file)
|
||
|
||
writer = csv.DictWriter(f, fieldnames=self.current_session.records[0].to_dict().keys())
|
||
writer.writeheader()
|
||
for record in self.current_session.records:
|
||
writer.writerow(record.to_dict())
|
||
|
||
return str(output_file)
|
||
|
||
def get_session_progress(self) -> Dict[str, Any]:
|
||
"""Get current session progress."""
|
||
if not self.current_session:
|
||
return {"status": "no_session"}
|
||
|
||
return {
|
||
"session_id": self.current_session.session_id,
|
||
"platform": self.current_session.platform,
|
||
"target_count": self.current_session.target_count,
|
||
"watched_count": self.video_counter,
|
||
"progress_percent": (self.video_counter / self.current_session.target_count * 100)
|
||
if self.current_session.target_count > 0
|
||
else 0,
|
||
"is_active": self.current_session.is_active,
|
||
"is_paused": self.current_session.is_paused,
|
||
"total_duration": self.current_session.total_duration,
|
||
}
|
||
|
||
|
||
# Convenience function for standalone usage
|
||
def create_video_learning_agent(
|
||
base_url: str,
|
||
api_key: str,
|
||
model_name: str = "autoglm-phone-9b",
|
||
platform: str = "douyin",
|
||
output_dir: str = "./video_learning_data",
|
||
**model_kwargs,
|
||
) -> VideoLearningAgent:
|
||
"""
|
||
Create a Video Learning Agent with standard configuration.
|
||
|
||
Args:
|
||
base_url: Model API base URL
|
||
api_key: API key
|
||
model_name: Model name
|
||
platform: Platform name
|
||
output_dir: Output directory
|
||
**model_kwargs: Additional model parameters
|
||
|
||
Returns:
|
||
VideoLearningAgent instance
|
||
"""
|
||
model_config = ModelConfig(
|
||
base_url=base_url,
|
||
model_name=model_name,
|
||
api_key=api_key,
|
||
**model_kwargs,
|
||
)
|
||
|
||
return VideoLearningAgent(
|
||
model_config=model_config,
|
||
platform=platform,
|
||
output_dir=output_dir,
|
||
)
|