Add Video Learning Agent for short video platforms
Features: - VideoLearningAgent for automated video watching on Douyin/Kuaishou/TikTok - Web dashboard UI for video learning sessions - Real-time progress tracking with screenshot capture - App detection using get_current_app() for accurate recording - Session management with pause/resume/stop controls Technical improvements: - Simplified video detection logic using direct app detection - Full base64 hash for sensitive screenshot change detection - Immediate stop when target video count is reached - Fixed circular import issues with ModelConfig Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -7,6 +7,7 @@ using AI models for visual understanding and decision making.
|
||||
|
||||
from phone_agent.agent import AgentConfig, PhoneAgent, StepResult
|
||||
from phone_agent.agent_ios import IOSAgentConfig, IOSPhoneAgent
|
||||
from phone_agent.video_learning import VideoLearningAgent, VideoRecord, LearningSession
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__all__ = [
|
||||
@@ -15,4 +16,7 @@ __all__ = [
|
||||
"AgentConfig",
|
||||
"IOSAgentConfig",
|
||||
"StepResult",
|
||||
"VideoLearningAgent",
|
||||
"VideoRecord",
|
||||
"LearningSession",
|
||||
]
|
||||
|
||||
561
phone_agent/video_learning.py
Normal file
561
phone_agent/video_learning.py
Normal file
@@ -0,0 +1,561 @@
|
||||
"""
|
||||
Video Learning Agent for AutoGLM
|
||||
|
||||
This agent learns from short video platforms (like Douyin/TikTok)
|
||||
by watching videos and collecting information.
|
||||
|
||||
MVP Features:
|
||||
- Automatic video scrolling
|
||||
- Play/Pause control
|
||||
- Screenshot capture for each video
|
||||
- Basic data collection (likes, comments, etc.)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Any
|
||||
|
||||
from phone_agent import PhoneAgent, AgentConfig
|
||||
from phone_agent.agent import StepResult
|
||||
from phone_agent.model.client import ModelConfig
|
||||
from phone_agent.device_factory import get_device_factory
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoRecord:
|
||||
"""Record of a watched video."""
|
||||
|
||||
sequence_id: int
|
||||
timestamp: str
|
||||
screenshot_path: Optional[str] = None
|
||||
watch_duration: float = 0.0 # seconds
|
||||
|
||||
# Basic info (extracted via OCR/analysis)
|
||||
description: Optional[str] = None # Video caption/text
|
||||
likes: Optional[int] = None
|
||||
comments: Optional[int] = None
|
||||
shares: Optional[int] = None
|
||||
|
||||
# Content analysis (for future expansion)
|
||||
tags: List[str] = field(default_factory=list)
|
||||
category: Optional[str] = None
|
||||
elements: List[str] = field(default_factory=list)
|
||||
|
||||
# Metadata
|
||||
position_in_session: int = 0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"sequence_id": self.sequence_id,
|
||||
"timestamp": self.timestamp,
|
||||
"screenshot_path": self.screenshot_path,
|
||||
"watch_duration": self.watch_duration,
|
||||
"description": self.description,
|
||||
"likes": self.likes,
|
||||
"comments": self.comments,
|
||||
"shares": self.shares,
|
||||
"tags": self.tags,
|
||||
"category": self.category,
|
||||
"elements": self.elements,
|
||||
"position_in_session": self.position_in_session,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LearningSession:
|
||||
"""A learning session with multiple videos."""
|
||||
|
||||
session_id: str
|
||||
start_time: str
|
||||
platform: str # "douyin", "tiktok", etc.
|
||||
target_category: Optional[str] = None
|
||||
target_count: int = 10
|
||||
records: List[VideoRecord] = field(default_factory=list)
|
||||
|
||||
# Control flags
|
||||
is_active: bool = True
|
||||
is_paused: bool = False
|
||||
|
||||
# Statistics
|
||||
total_videos: int = 0
|
||||
total_duration: float = 0.0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"start_time": self.start_time,
|
||||
"platform": self.platform,
|
||||
"target_category": self.target_category,
|
||||
"target_count": self.target_count,
|
||||
"is_active": self.is_active,
|
||||
"is_paused": self.is_paused,
|
||||
"total_videos": self.total_videos,
|
||||
"total_duration": self.total_duration,
|
||||
"records": [r.to_dict() for r in self.records],
|
||||
}
|
||||
|
||||
|
||||
class VideoLearningAgent:
|
||||
"""
|
||||
Agent for learning from short video platforms.
|
||||
|
||||
MVP Capabilities:
|
||||
- Navigate to video platform
|
||||
- Watch videos automatically
|
||||
- Capture screenshots
|
||||
- Collect basic information
|
||||
- Export learning data
|
||||
"""
|
||||
|
||||
# Platform-specific configurations
|
||||
PLATFORM_CONFIGS = {
|
||||
"douyin": {
|
||||
"package_name": "com.ss.android.ugc.aweme",
|
||||
"activity_hint": "aweme",
|
||||
"scroll_gesture": "up",
|
||||
"like_position": {"x": 0.9, "y": 0.8}, # Relative coordinates
|
||||
"comment_position": {"x": 0.9, "y": 0.7},
|
||||
},
|
||||
"kuaishou": {
|
||||
"package_name": "com.smile.gifmaker",
|
||||
"activity_hint": "gifmaker",
|
||||
"scroll_gesture": "up",
|
||||
"like_position": {"x": 0.9, "y": 0.8},
|
||||
},
|
||||
"tiktok": {
|
||||
"package_name": "com.zhiliaoapp.musically",
|
||||
"activity_hint": "musically",
|
||||
"scroll_gesture": "up",
|
||||
"like_position": {"x": 0.9, "y": 0.8},
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
platform: str = "douyin",
|
||||
output_dir: str = "./video_learning_data",
|
||||
):
|
||||
"""
|
||||
Initialize Video Learning Agent.
|
||||
|
||||
Args:
|
||||
model_config: Model configuration for VLM
|
||||
platform: Platform name (douyin, kuaishou, tiktok)
|
||||
output_dir: Directory to save screenshots and data
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.platform = platform
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create screenshots subdirectory
|
||||
self.screenshot_dir = self.output_dir / "screenshots"
|
||||
self.screenshot_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Current session
|
||||
self.current_session: Optional[LearningSession] = None
|
||||
self.video_counter = 0
|
||||
|
||||
# Agent will be created when starting a session
|
||||
self.agent: Optional[PhoneAgent] = None
|
||||
|
||||
# Callbacks for external control
|
||||
self.on_video_watched: Optional[Callable[[VideoRecord], None]] = None
|
||||
self.on_session_complete: Optional[Callable[[LearningSession], None]] = None
|
||||
self.on_progress_update: Optional[Callable[[int, int], None]] = None
|
||||
|
||||
# Video detection: track screenshot changes (simplified)
|
||||
self._last_screenshot_hash: Optional[str] = None
|
||||
|
||||
def start_session(
|
||||
self,
|
||||
device_id: str,
|
||||
target_count: int = 10,
|
||||
category: Optional[str] = None,
|
||||
watch_duration: float = 3.0,
|
||||
max_steps: int = 500,
|
||||
) -> str:
|
||||
"""
|
||||
Start a learning session.
|
||||
|
||||
Args:
|
||||
device_id: Target device ID
|
||||
target_count: Number of videos to watch
|
||||
category: Target category (e.g., "美食", "旅行")
|
||||
watch_duration: How long to watch each video (seconds)
|
||||
max_steps: Maximum execution steps
|
||||
|
||||
Returns:
|
||||
Session ID
|
||||
"""
|
||||
# Create new session
|
||||
session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
self.current_session = LearningSession(
|
||||
session_id=session_id,
|
||||
start_time=datetime.now().isoformat(),
|
||||
platform=self.platform,
|
||||
target_category=category,
|
||||
target_count=target_count,
|
||||
)
|
||||
|
||||
# Configure agent with callbacks
|
||||
agent_config = AgentConfig(
|
||||
device_id=device_id,
|
||||
max_steps=max_steps,
|
||||
lang="cn",
|
||||
step_callback=self._on_step,
|
||||
before_action_callback=self._before_action,
|
||||
)
|
||||
|
||||
# Create phone agent
|
||||
self.agent = PhoneAgent(
|
||||
model_config=self.model_config,
|
||||
agent_config=agent_config,
|
||||
)
|
||||
|
||||
# Store parameters for the task
|
||||
self._watch_duration = watch_duration
|
||||
self._device_id = device_id
|
||||
|
||||
# Reset video detection tracking (simplified)
|
||||
self._last_screenshot_hash = None
|
||||
self.video_counter = 0
|
||||
|
||||
return session_id
|
||||
|
||||
def run_learning_task(self, task: str) -> bool:
|
||||
"""
|
||||
Run the learning task.
|
||||
|
||||
Args:
|
||||
task: Natural language task description
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
if not self.agent or not self.current_session:
|
||||
raise RuntimeError("Session not started. Call start_session() first.")
|
||||
|
||||
try:
|
||||
result = self.agent.run(task)
|
||||
# Mark session as inactive after task completes
|
||||
if self.current_session:
|
||||
self.current_session.is_active = False
|
||||
self._save_session()
|
||||
print(f"[VideoLearning] Session completed. Recorded {self.video_counter} videos.")
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
print(f"Error during learning: {e}")
|
||||
if self.current_session:
|
||||
self.current_session.is_active = False
|
||||
return False
|
||||
|
||||
def stop_session(self):
|
||||
"""Stop the current learning session."""
|
||||
if self.current_session:
|
||||
self.current_session.is_active = False
|
||||
|
||||
if self.agent:
|
||||
# Agent will stop on next callback check
|
||||
pass
|
||||
|
||||
def pause_session(self):
|
||||
"""Pause the current session (can be resumed)."""
|
||||
if self.current_session:
|
||||
self.current_session.is_paused = True
|
||||
|
||||
def resume_session(self):
|
||||
"""Resume a paused session."""
|
||||
if self.current_session:
|
||||
self.current_session.is_paused = False
|
||||
|
||||
def _on_step(self, result: StepResult) -> Optional[str]:
|
||||
"""
|
||||
Callback after each step.
|
||||
|
||||
Simplified logic:
|
||||
1. Check if we're in the target app using get_current_app()
|
||||
2. Detect screenshot changes
|
||||
3. Record video when screenshot changes
|
||||
|
||||
Args:
|
||||
result: Step execution result
|
||||
|
||||
Returns:
|
||||
"stop" to end session, new task to switch, None to continue
|
||||
"""
|
||||
if not self.current_session:
|
||||
return None
|
||||
|
||||
# Check if session should stop
|
||||
if not self.current_session.is_active:
|
||||
self._save_session()
|
||||
if self.on_session_complete:
|
||||
self.on_session_complete(self.current_session)
|
||||
return "stop"
|
||||
|
||||
# Check if paused
|
||||
if self.current_session.is_paused:
|
||||
return None
|
||||
|
||||
# Check if we've watched enough videos
|
||||
if self.video_counter >= self.current_session.target_count:
|
||||
self.current_session.is_active = False
|
||||
self._save_session()
|
||||
if self.on_session_complete:
|
||||
self.on_session_complete(self.current_session)
|
||||
return "stop"
|
||||
|
||||
try:
|
||||
# Use get_current_app() to detect if we're in target app
|
||||
current_app = get_device_factory().get_current_app(self._device_id)
|
||||
|
||||
# Platform-specific package names
|
||||
platform_packages = {
|
||||
"douyin": ["aweme", "抖音", "douyin"],
|
||||
"kuaishou": ["gifmaker", "快手", "kuaishou"],
|
||||
"tiktok": ["musically", "tiktok"],
|
||||
}
|
||||
packages = platform_packages.get(self.platform, ["aweme"])
|
||||
|
||||
# Check if in target app
|
||||
is_in_target = any(pkg.lower() in current_app.lower() for pkg in packages)
|
||||
|
||||
if not is_in_target:
|
||||
print(f"[VideoLearning] Not in target app: {current_app} (step {result.step_count})")
|
||||
return None
|
||||
|
||||
# Get screenshot
|
||||
screenshot = get_device_factory().get_screenshot(self._device_id)
|
||||
|
||||
# Use full base64 data for hash (more sensitive)
|
||||
current_hash = hashlib.md5(screenshot.base64_data.encode()).hexdigest()
|
||||
|
||||
# Detect screenshot change and record video
|
||||
if self._last_screenshot_hash is None:
|
||||
# First screenshot in target app - record first video
|
||||
self._last_screenshot_hash = current_hash
|
||||
self._record_video_from_screenshot(screenshot)
|
||||
print(f"[VideoLearning] ✓ Recorded video {self.video_counter}/{self.current_session.target_count}")
|
||||
|
||||
# Check if we've reached target after recording
|
||||
if self.video_counter >= self.current_session.target_count:
|
||||
print(f"[VideoLearning] ✓ Target reached! Stopping...")
|
||||
self.current_session.is_active = False
|
||||
self._save_session()
|
||||
return "stop"
|
||||
|
||||
elif current_hash != self._last_screenshot_hash:
|
||||
# Screenshot changed - record new video
|
||||
self._last_screenshot_hash = current_hash
|
||||
self._record_video_from_screenshot(screenshot)
|
||||
print(f"[VideoLearning] ✓ Recorded video {self.video_counter}/{self.current_session.target_count}")
|
||||
|
||||
# Check if we've reached target after recording
|
||||
if self.video_counter >= self.current_session.target_count:
|
||||
print(f"[VideoLearning] ✓ Target reached! Stopping...")
|
||||
self.current_session.is_active = False
|
||||
self._save_session()
|
||||
return "stop"
|
||||
|
||||
except Exception as e:
|
||||
print(f"[VideoLearning] Warning: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _record_video_from_screenshot(self, screenshot):
|
||||
"""Helper method to record video from screenshot."""
|
||||
import base64
|
||||
screenshot_bytes = base64.b64decode(screenshot.base64_data)
|
||||
self.record_video(
|
||||
screenshot=screenshot_bytes,
|
||||
description=f"Video #{self.video_counter + 1}",
|
||||
)
|
||||
|
||||
def _before_action(self, action: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Callback before executing an action.
|
||||
|
||||
Args:
|
||||
action: Action to execute
|
||||
|
||||
Returns:
|
||||
Modified action or None
|
||||
"""
|
||||
# Could be used for action logging or modification
|
||||
return None
|
||||
|
||||
def record_video(
|
||||
self,
|
||||
screenshot: Optional[bytes] = None,
|
||||
description: Optional[str] = None,
|
||||
likes: Optional[int] = None,
|
||||
comments: Optional[int] = None,
|
||||
) -> VideoRecord:
|
||||
"""
|
||||
Record a watched video.
|
||||
|
||||
Args:
|
||||
screenshot: Screenshot image data
|
||||
description: Video description/caption
|
||||
likes: Number of likes
|
||||
comments: Number of comments
|
||||
|
||||
Returns:
|
||||
VideoRecord object
|
||||
"""
|
||||
self.video_counter += 1
|
||||
|
||||
# Save screenshot if provided
|
||||
screenshot_path = None
|
||||
if screenshot:
|
||||
screenshot_filename = f"{self.current_session.session_id}_video_{self.video_counter}.png"
|
||||
screenshot_full_path = self.screenshot_dir / screenshot_filename
|
||||
# Store relative path for web access: /video-learning-data/screenshots/filename.png
|
||||
screenshot_path = f"/video-learning-data/screenshots/{screenshot_filename}"
|
||||
with open(str(screenshot_full_path), "wb") as f:
|
||||
f.write(screenshot)
|
||||
|
||||
# Create record
|
||||
record = VideoRecord(
|
||||
sequence_id=self.video_counter,
|
||||
timestamp=datetime.now().isoformat(),
|
||||
screenshot_path=screenshot_path,
|
||||
watch_duration=self._watch_duration,
|
||||
description=description,
|
||||
likes=likes,
|
||||
comments=comments,
|
||||
position_in_session=self.video_counter,
|
||||
)
|
||||
|
||||
# Add to session
|
||||
if self.current_session:
|
||||
self.current_session.records.append(record)
|
||||
self.current_session.total_videos = self.video_counter
|
||||
self.current_session.total_duration += self._watch_duration
|
||||
|
||||
# Notify callback
|
||||
if self.on_video_watched:
|
||||
self.on_video_watched(record)
|
||||
|
||||
# Notify progress
|
||||
if self.on_progress_update:
|
||||
self.on_progress_update(self.video_counter, self.current_session.target_count)
|
||||
|
||||
return record
|
||||
|
||||
def _save_session(self):
|
||||
"""Save session data to JSON file."""
|
||||
if not self.current_session:
|
||||
return
|
||||
|
||||
session_file = self.output_dir / f"{self.current_session.session_id}.json"
|
||||
with open(session_file, "w", encoding="utf-8") as f:
|
||||
json.dump(self.current_session.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"Session saved to {session_file}")
|
||||
|
||||
def export_data(self, format: str = "json") -> str:
|
||||
"""
|
||||
Export session data.
|
||||
|
||||
Args:
|
||||
format: Export format (json, csv)
|
||||
|
||||
Returns:
|
||||
Path to exported file
|
||||
"""
|
||||
if not self.current_session:
|
||||
raise RuntimeError("No session to export")
|
||||
|
||||
if format == "json":
|
||||
return self._export_json()
|
||||
elif format == "csv":
|
||||
return self._export_csv()
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
def _export_json(self) -> str:
|
||||
"""Export as JSON."""
|
||||
output_file = self.output_dir / f"{self.current_session.session_id}_export.json"
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(self.current_session.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
return str(output_file)
|
||||
|
||||
def _export_csv(self) -> str:
|
||||
"""Export as CSV."""
|
||||
import csv
|
||||
|
||||
output_file = self.output_dir / f"{self.current_session.session_id}_export.csv"
|
||||
with open(output_file, "w", encoding="utf-8", newline="") as f:
|
||||
if not self.current_session.records:
|
||||
return str(output_file)
|
||||
|
||||
writer = csv.DictWriter(f, fieldnames=self.current_session.records[0].to_dict().keys())
|
||||
writer.writeheader()
|
||||
for record in self.current_session.records:
|
||||
writer.writerow(record.to_dict())
|
||||
|
||||
return str(output_file)
|
||||
|
||||
def get_session_progress(self) -> Dict[str, Any]:
|
||||
"""Get current session progress."""
|
||||
if not self.current_session:
|
||||
return {"status": "no_session"}
|
||||
|
||||
return {
|
||||
"session_id": self.current_session.session_id,
|
||||
"platform": self.current_session.platform,
|
||||
"target_count": self.current_session.target_count,
|
||||
"watched_count": self.video_counter,
|
||||
"progress_percent": (self.video_counter / self.current_session.target_count * 100)
|
||||
if self.current_session.target_count > 0
|
||||
else 0,
|
||||
"is_active": self.current_session.is_active,
|
||||
"is_paused": self.current_session.is_paused,
|
||||
"total_duration": self.current_session.total_duration,
|
||||
}
|
||||
|
||||
|
||||
# Convenience function for standalone usage
|
||||
def create_video_learning_agent(
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
model_name: str = "autoglm-phone-9b",
|
||||
platform: str = "douyin",
|
||||
output_dir: str = "./video_learning_data",
|
||||
**model_kwargs,
|
||||
) -> VideoLearningAgent:
|
||||
"""
|
||||
Create a Video Learning Agent with standard configuration.
|
||||
|
||||
Args:
|
||||
base_url: Model API base URL
|
||||
api_key: API key
|
||||
model_name: Model name
|
||||
platform: Platform name
|
||||
output_dir: Output directory
|
||||
**model_kwargs: Additional model parameters
|
||||
|
||||
Returns:
|
||||
VideoLearningAgent instance
|
||||
"""
|
||||
model_config = ModelConfig(
|
||||
base_url=base_url,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return VideoLearningAgent(
|
||||
model_config=model_config,
|
||||
platform=platform,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
Reference in New Issue
Block a user