Files
Open-AutoGLM/phone_agent/video_learning.py
let5sne.win10 b97d3f3a9f Improve Video Learning Agent with action-based detection and analysis toggle
- Change video detection from screenshot hash to action-based (Swipe detection)
- Add enable_analysis toggle to disable VLM screenshot analysis
- Improve task prompt to prevent VLM from stopping prematurely
- Add debug logging for action detection troubleshooting
- Fix ModelResponse attribute error (content -> raw_content)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-10 01:47:09 +08:00

694 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Video Learning Agent for AutoGLM
This agent learns from short video platforms (like Douyin/TikTok)
by watching videos and collecting information.
MVP Features:
- Automatic video scrolling
- Play/Pause control
- Screenshot capture for each video
- Basic data collection (likes, comments, etc.)
"""
import hashlib
import json
import os
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Callable, Dict, List, Optional, Any
from phone_agent import PhoneAgent, AgentConfig
from phone_agent.agent import StepResult
from phone_agent.model.client import ModelConfig
from phone_agent.device_factory import get_device_factory
@dataclass
class VideoRecord:
"""Record of a watched video."""
sequence_id: int
timestamp: str
screenshot_path: Optional[str] = None
watch_duration: float = 0.0 # seconds
# Basic info (extracted via OCR/analysis)
description: Optional[str] = None # Video caption/text
likes: Optional[int] = None
comments: Optional[int] = None
shares: Optional[int] = None
# Content analysis (for future expansion)
tags: List[str] = field(default_factory=list)
category: Optional[str] = None
elements: List[str] = field(default_factory=list)
# Metadata
position_in_session: int = 0
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"sequence_id": self.sequence_id,
"timestamp": self.timestamp,
"screenshot_path": self.screenshot_path,
"watch_duration": self.watch_duration,
"description": self.description,
"likes": self.likes,
"comments": self.comments,
"shares": self.shares,
"tags": self.tags,
"category": self.category,
"elements": self.elements,
"position_in_session": self.position_in_session,
}
@dataclass
class LearningSession:
"""A learning session with multiple videos."""
session_id: str
start_time: str
platform: str # "douyin", "tiktok", etc.
target_category: Optional[str] = None
target_count: int = 10
records: List[VideoRecord] = field(default_factory=list)
# Control flags
is_active: bool = True
is_paused: bool = False
# Statistics
total_videos: int = 0
total_duration: float = 0.0
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"session_id": self.session_id,
"start_time": self.start_time,
"platform": self.platform,
"target_category": self.target_category,
"target_count": self.target_count,
"is_active": self.is_active,
"is_paused": self.is_paused,
"total_videos": self.total_videos,
"total_duration": self.total_duration,
"records": [r.to_dict() for r in self.records],
}
class ScreenshotAnalyzer:
"""分析视频截图,提取内容信息"""
ANALYSIS_PROMPT = """分析这张短视频截图只返回JSON不要任何其他文字。
格式:{"description":"描述","likes":数字,"comments":数字,"tags":["标签"],"category":"类型","elements":["元素"]}
示例:{"description":"美食探店","likes":12000,"comments":500,"tags":["美食"],"category":"美食","elements":["食物"]}
注意:数字如"1.2万"转为12000无法识别则用null。只返回JSON"""
def __init__(self, model_config: ModelConfig):
"""初始化分析器"""
from phone_agent.model.client import ModelClient
self.model_client = ModelClient(model_config)
def analyze(self, screenshot_base64: str) -> Dict[str, Any]:
"""分析截图并返回提取的信息"""
from phone_agent.model.client import MessageBuilder
# 构建消息
messages = [
MessageBuilder.create_user_message(
text=self.ANALYSIS_PROMPT,
image_base64=screenshot_base64
)
]
try:
# 调用 VLM
response = self.model_client.request(messages)
# ModelResponse 使用 raw_content 而不是 content
result_text = response.raw_content.strip()
# 解析 JSON
return self._parse_result(result_text)
except Exception as e:
print(f"[ScreenshotAnalyzer] Error: {e}")
return {}
def _parse_result(self, text: str) -> Dict[str, Any]:
"""解析 VLM 返回的 JSON 结果"""
import re
# 调试日志
print(f"[ScreenshotAnalyzer] Raw response: {text[:200]}...")
# 尝试提取 JSON
json_match = re.search(r'\{[\s\S]*\}', text)
if not json_match:
print("[ScreenshotAnalyzer] No JSON found in response")
return {}
try:
result = json.loads(json_match.group())
print(f"[ScreenshotAnalyzer] Parsed: {result}")
# 确保数字字段是整数
for field in ['likes', 'comments', 'shares']:
if field in result and result[field] is not None:
try:
result[field] = int(result[field])
except (ValueError, TypeError):
result[field] = None
return result
except json.JSONDecodeError as e:
print(f"[ScreenshotAnalyzer] JSON parse error: {e}")
return {}
class VideoLearningAgent:
"""
Agent for learning from short video platforms.
MVP Capabilities:
- Navigate to video platform
- Watch videos automatically
- Capture screenshots
- Collect basic information
- Export learning data
"""
# Platform-specific configurations
PLATFORM_CONFIGS = {
"douyin": {
"package_name": "com.ss.android.ugc.aweme",
"activity_hint": "aweme",
"scroll_gesture": "up",
"like_position": {"x": 0.9, "y": 0.8}, # Relative coordinates
"comment_position": {"x": 0.9, "y": 0.7},
},
"kuaishou": {
"package_name": "com.smile.gifmaker",
"activity_hint": "gifmaker",
"scroll_gesture": "up",
"like_position": {"x": 0.9, "y": 0.8},
},
"tiktok": {
"package_name": "com.zhiliaoapp.musically",
"activity_hint": "musically",
"scroll_gesture": "up",
"like_position": {"x": 0.9, "y": 0.8},
},
}
def __init__(
self,
model_config: ModelConfig,
platform: str = "douyin",
output_dir: str = "./video_learning_data",
enable_analysis: bool = True,
):
"""
Initialize Video Learning Agent.
Args:
model_config: Model configuration for VLM
platform: Platform name (douyin, kuaishou, tiktok)
output_dir: Directory to save screenshots and data
enable_analysis: Whether to enable VLM screenshot analysis
"""
self.model_config = model_config
self.platform = platform
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.enable_analysis = enable_analysis # 画面分析开关
# Create screenshots subdirectory
self.screenshot_dir = self.output_dir / "screenshots"
self.screenshot_dir.mkdir(exist_ok=True)
# Current session
self.current_session: Optional[LearningSession] = None
self.video_counter = 0
# Agent will be created when starting a session
self.agent: Optional[PhoneAgent] = None
# Callbacks for external control
self.on_video_watched: Optional[Callable[[VideoRecord], None]] = None
self.on_session_complete: Optional[Callable[[LearningSession], None]] = None
self.on_progress_update: Optional[Callable[[int, int], None]] = None
# Video detection: 基于动作检测
self._first_video_recorded: bool = False # 是否已记录首视频
# Skip app startup screens
self._in_app_steps: int = 0
self._warmup_steps: int = 3 # Skip first 3 steps after entering app
# Screenshot analyzer for content extraction (only if enabled)
self._analyzer: Optional[ScreenshotAnalyzer] = None
if self.enable_analysis:
try:
self._analyzer = ScreenshotAnalyzer(model_config)
print("[VideoLearning] Screenshot analyzer initialized")
except Exception as e:
print(f"[VideoLearning] Analyzer init failed: {e}")
else:
print("[VideoLearning] Screenshot analysis disabled")
def start_session(
self,
device_id: str,
target_count: int = 10,
category: Optional[str] = None,
watch_duration: float = 3.0,
max_steps: int = 500,
) -> str:
"""
Start a learning session.
Args:
device_id: Target device ID
target_count: Number of videos to watch
category: Target category (e.g., "美食", "旅行")
watch_duration: How long to watch each video (seconds)
max_steps: Maximum execution steps
Returns:
Session ID
"""
# Create new session
session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.current_session = LearningSession(
session_id=session_id,
start_time=datetime.now().isoformat(),
platform=self.platform,
target_category=category,
target_count=target_count,
)
# Configure agent with callbacks
agent_config = AgentConfig(
device_id=device_id,
max_steps=max_steps,
lang="cn",
step_callback=self._on_step,
before_action_callback=self._before_action,
)
# Create phone agent
self.agent = PhoneAgent(
model_config=self.model_config,
agent_config=agent_config,
)
# Store parameters for the task
self._watch_duration = watch_duration
self._device_id = device_id
# Reset video detection tracking (simplified)
self._last_screenshot_hash = None
self.video_counter = 0
return session_id
def run_learning_task(self, task: str) -> bool:
"""
Run the learning task.
Args:
task: Natural language task description
Returns:
True if successful
"""
if not self.agent or not self.current_session:
raise RuntimeError("Session not started. Call start_session() first.")
try:
result = self.agent.run(task)
# Mark session as inactive after task completes
if self.current_session:
self.current_session.is_active = False
self._save_session()
print(f"[VideoLearning] Session completed. Recorded {self.video_counter} videos.")
return bool(result)
except Exception as e:
print(f"Error during learning: {e}")
if self.current_session:
self.current_session.is_active = False
return False
def stop_session(self):
"""Stop the current learning session."""
if self.current_session:
self.current_session.is_active = False
if self.agent:
# Agent will stop on next callback check
pass
def pause_session(self):
"""Pause the current session (can be resumed)."""
if self.current_session:
self.current_session.is_paused = True
def resume_session(self):
"""Resume a paused session."""
if self.current_session:
self.current_session.is_paused = False
def _on_step(self, result: StepResult) -> Optional[str]:
"""
Callback after each step.
基于动作检测的逻辑:
1. 检测是否在目标 APP 中
2. Warmup 阶段跳过
3. Warmup 结束后记录首视频
4. 检测滑动动作,滑动后记录新视频
Args:
result: Step execution result
Returns:
"stop" to end session, new task to switch, None to continue
"""
if not self.current_session:
return None
# Check if session should stop
if not self.current_session.is_active:
self._save_session()
if self.on_session_complete:
self.on_session_complete(self.current_session)
return "stop"
# Check if paused
if self.current_session.is_paused:
return None
# Check if we've watched enough videos
if self.video_counter >= self.current_session.target_count:
self.current_session.is_active = False
self._save_session()
if self.on_session_complete:
self.on_session_complete(self.current_session)
return "stop"
try:
# Use get_current_app() to detect if we're in target app
current_app = get_device_factory().get_current_app(self._device_id)
# Platform-specific package names
platform_packages = {
"douyin": ["aweme", "抖音", "douyin"],
"kuaishou": ["gifmaker", "快手", "kuaishou"],
"tiktok": ["musically", "tiktok"],
}
packages = platform_packages.get(self.platform, ["aweme"])
# Check if in target app
is_in_target = any(pkg.lower() in current_app.lower() for pkg in packages)
if not is_in_target:
# Reset counters when leaving app
self._in_app_steps = 0
self._first_video_recorded = False
print(f"[VideoLearning] Not in target app: {current_app}")
return None
# Warmup: skip first few steps to avoid startup screens
self._in_app_steps += 1
if self._in_app_steps <= self._warmup_steps:
print(f"[VideoLearning] Warmup step {self._in_app_steps}/{self._warmup_steps}, skipping...")
return None
# 获取截图用于记录
screenshot = get_device_factory().get_screenshot(self._device_id)
# 首视频记录warmup 结束后立即记录第一个视频
if not self._first_video_recorded:
self._first_video_recorded = True
self._record_video_from_screenshot(screenshot)
print(f"[VideoLearning] ✓ Recorded first video {self.video_counter}/{self.current_session.target_count}")
return self._check_target_reached()
# 基于动作检测:检测滑动动作
action = result.action
action_type = action.get("action") if action else None
# 调试日志:打印当前动作
if action_type:
print(f"[VideoLearning] Current action: {action_type}")
# 检查滑动动作(忽略大小写)
if action and action_type and action_type.lower() == "swipe":
# VLM 执行了滑动,记录新视频
print(f"[VideoLearning] Detected swipe action, recording new video...")
self._record_video_from_screenshot(screenshot)
print(f"[VideoLearning] ✓ Recorded video {self.video_counter}/{self.current_session.target_count}")
return self._check_target_reached()
# 如果不是滑动动作,打印提示
if action_type and action_type.lower() != "swipe":
print(f"[VideoLearning] Non-swipe action detected ({action_type}), waiting for swipe...")
except Exception as e:
print(f"[VideoLearning] Warning: {e}")
import traceback
traceback.print_exc()
return None
def _check_target_reached(self) -> Optional[str]:
"""检查是否达到目标数量"""
if self.video_counter >= self.current_session.target_count:
print(f"[VideoLearning] ✓ Target reached! Stopping...")
self.current_session.is_active = False
self._save_session()
return "stop"
return None
def _record_video_from_screenshot(self, screenshot):
"""Helper method to record video from screenshot with analysis."""
import base64
screenshot_bytes = base64.b64decode(screenshot.base64_data)
# 分析截图内容
analysis_result = {}
if self._analyzer:
try:
print(f"[VideoLearning] Analyzing screenshot...")
analysis_result = self._analyzer.analyze(screenshot.base64_data)
if analysis_result:
print(f"[VideoLearning] Analysis: {analysis_result.get('category', 'N/A')}")
except Exception as e:
print(f"[VideoLearning] Analysis failed: {e}")
# 记录视频
self.record_video(
screenshot=screenshot_bytes,
description=analysis_result.get('description', f"Video #{self.video_counter + 1}"),
likes=analysis_result.get('likes'),
comments=analysis_result.get('comments'),
shares=analysis_result.get('shares'),
tags=analysis_result.get('tags', []),
category=analysis_result.get('category'),
elements=analysis_result.get('elements', []),
)
def _before_action(self, action: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Callback before executing an action.
Args:
action: Action to execute
Returns:
Modified action or None
"""
# Could be used for action logging or modification
return None
def record_video(
self,
screenshot: Optional[bytes] = None,
description: Optional[str] = None,
likes: Optional[int] = None,
comments: Optional[int] = None,
shares: Optional[int] = None,
tags: Optional[List[str]] = None,
category: Optional[str] = None,
elements: Optional[List[str]] = None,
) -> VideoRecord:
"""
Record a watched video.
Args:
screenshot: Screenshot image data
description: Video description/caption
likes: Number of likes
comments: Number of comments
Returns:
VideoRecord object
"""
self.video_counter += 1
# Save screenshot if provided
screenshot_path = None
if screenshot:
screenshot_filename = f"{self.current_session.session_id}_video_{self.video_counter}.png"
screenshot_full_path = self.screenshot_dir / screenshot_filename
# Store relative path for web access: /video-learning-data/screenshots/filename.png
screenshot_path = f"/video-learning-data/screenshots/{screenshot_filename}"
with open(str(screenshot_full_path), "wb") as f:
f.write(screenshot)
# Create record
record = VideoRecord(
sequence_id=self.video_counter,
timestamp=datetime.now().isoformat(),
screenshot_path=screenshot_path,
watch_duration=self._watch_duration,
description=description,
likes=likes,
comments=comments,
shares=shares,
tags=tags or [],
category=category,
elements=elements or [],
position_in_session=self.video_counter,
)
# Add to session
if self.current_session:
self.current_session.records.append(record)
self.current_session.total_videos = self.video_counter
self.current_session.total_duration += self._watch_duration
# Notify callback
if self.on_video_watched:
self.on_video_watched(record)
# Notify progress
if self.on_progress_update:
self.on_progress_update(self.video_counter, self.current_session.target_count)
return record
def _save_session(self):
"""Save session data to JSON file."""
if not self.current_session:
return
session_file = self.output_dir / f"{self.current_session.session_id}.json"
with open(session_file, "w", encoding="utf-8") as f:
json.dump(self.current_session.to_dict(), f, ensure_ascii=False, indent=2)
print(f"Session saved to {session_file}")
def export_data(self, format: str = "json") -> str:
"""
Export session data.
Args:
format: Export format (json, csv)
Returns:
Path to exported file
"""
if not self.current_session:
raise RuntimeError("No session to export")
if format == "json":
return self._export_json()
elif format == "csv":
return self._export_csv()
else:
raise ValueError(f"Unsupported format: {format}")
def _export_json(self) -> str:
"""Export as JSON."""
output_file = self.output_dir / f"{self.current_session.session_id}_export.json"
with open(output_file, "w", encoding="utf-8") as f:
json.dump(self.current_session.to_dict(), f, ensure_ascii=False, indent=2)
return str(output_file)
def _export_csv(self) -> str:
"""Export as CSV."""
import csv
output_file = self.output_dir / f"{self.current_session.session_id}_export.csv"
with open(output_file, "w", encoding="utf-8", newline="") as f:
if not self.current_session.records:
return str(output_file)
writer = csv.DictWriter(f, fieldnames=self.current_session.records[0].to_dict().keys())
writer.writeheader()
for record in self.current_session.records:
writer.writerow(record.to_dict())
return str(output_file)
def get_session_progress(self) -> Dict[str, Any]:
"""Get current session progress."""
if not self.current_session:
return {"status": "no_session"}
return {
"session_id": self.current_session.session_id,
"platform": self.current_session.platform,
"target_count": self.current_session.target_count,
"watched_count": self.video_counter,
"progress_percent": (self.video_counter / self.current_session.target_count * 100)
if self.current_session.target_count > 0
else 0,
"is_active": self.current_session.is_active,
"is_paused": self.current_session.is_paused,
"total_duration": self.current_session.total_duration,
}
# Convenience function for standalone usage
def create_video_learning_agent(
base_url: str,
api_key: str,
model_name: str = "autoglm-phone-9b",
platform: str = "douyin",
output_dir: str = "./video_learning_data",
**model_kwargs,
) -> VideoLearningAgent:
"""
Create a Video Learning Agent with standard configuration.
Args:
base_url: Model API base URL
api_key: API key
model_name: Model name
platform: Platform name
output_dir: Output directory
**model_kwargs: Additional model parameters
Returns:
VideoLearningAgent instance
"""
model_config = ModelConfig(
base_url=base_url,
model_name=model_name,
api_key=api_key,
**model_kwargs,
)
return VideoLearningAgent(
model_config=model_config,
platform=platform,
output_dir=output_dir,
)