feat: Add comprehensive timeline editor with frame editing and regeneration capabilities
This commit is contained in:
@@ -73,6 +73,11 @@ class StoryboardFrame:
|
||||
duration: float = 0.0 # Frame duration (seconds, from audio or video)
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
# Quality tracking (added for quality assurance)
|
||||
quality_score: Optional[float] = None # Overall quality score (0.0-1.0)
|
||||
quality_issues: Optional[List[str]] = None # List of detected quality issues
|
||||
retry_count: int = 0 # Number of generation retries
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.now()
|
||||
|
||||
@@ -176,6 +176,16 @@ class StandardPipeline(LinearVideoPipeline):
|
||||
min_words = ctx.params.get("min_image_prompt_words", 30)
|
||||
max_words = ctx.params.get("max_image_prompt_words", 60)
|
||||
|
||||
# Auto-detect characters from narrations if CharacterMemory is enabled
|
||||
character_memory = ctx.params.get("character_memory")
|
||||
if character_memory and character_memory.config.auto_detect_characters:
|
||||
logger.info("🔍 Auto-detecting characters from narrations...")
|
||||
for i, narration in enumerate(ctx.narrations):
|
||||
await character_memory.detect_characters_from_narration(narration, frame_index=i)
|
||||
if character_memory.characters:
|
||||
detected_names = [c.name for c in character_memory.characters]
|
||||
logger.info(f"✅ Detected {len(character_memory.characters)} characters: {detected_names}")
|
||||
|
||||
# Override prompt_prefix if provided
|
||||
original_prefix = None
|
||||
if prompt_prefix is not None:
|
||||
@@ -205,6 +215,16 @@ class StandardPipeline(LinearVideoPipeline):
|
||||
progress_callback=image_prompt_progress
|
||||
)
|
||||
|
||||
# Apply character memory enhancement (if available)
|
||||
character_memory = ctx.params.get("character_memory")
|
||||
if character_memory and character_memory.characters:
|
||||
active_characters = [c for c in character_memory.characters if getattr(c, 'is_active', True)]
|
||||
if active_characters:
|
||||
logger.info(f"🎭 Applying {len(active_characters)} active characters to {len(base_image_prompts)} prompts")
|
||||
base_image_prompts = [
|
||||
character_memory.apply_to_prompt(p) for p in base_image_prompts
|
||||
]
|
||||
|
||||
# Apply prompt prefix
|
||||
image_config = self.core.config.get("comfyui", {}).get("image", {})
|
||||
prompt_prefix_to_use = prompt_prefix if prompt_prefix is not None else image_config.get("prompt_prefix", "")
|
||||
|
||||
@@ -27,19 +27,41 @@ from loguru import logger
|
||||
|
||||
from pixelle_video.models.progress import ProgressEvent
|
||||
from pixelle_video.models.storyboard import Storyboard, StoryboardFrame, StoryboardConfig
|
||||
from pixelle_video.services.quality import (
|
||||
QualityGate,
|
||||
QualityConfig,
|
||||
RetryManager,
|
||||
RetryConfig,
|
||||
QualityError,
|
||||
)
|
||||
|
||||
|
||||
class FrameProcessor:
|
||||
"""Frame processor"""
|
||||
|
||||
def __init__(self, pixelle_video_core):
|
||||
def __init__(
|
||||
self,
|
||||
pixelle_video_core,
|
||||
quality_config: Optional[QualityConfig] = None,
|
||||
retry_config: Optional[RetryConfig] = None,
|
||||
enable_quality_check: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize
|
||||
|
||||
Args:
|
||||
pixelle_video_core: PixelleVideoCore instance
|
||||
quality_config: Quality evaluation configuration
|
||||
retry_config: Retry behavior configuration
|
||||
enable_quality_check: Whether to enable quality checking
|
||||
"""
|
||||
self.core = pixelle_video_core
|
||||
self.enable_quality_check = enable_quality_check
|
||||
self.quality_gate = QualityGate(
|
||||
llm_service=pixelle_video_core.llm if hasattr(pixelle_video_core, 'llm') else None,
|
||||
config=quality_config or QualityConfig()
|
||||
)
|
||||
self.retry_manager = RetryManager(config=retry_config or RetryConfig())
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
@@ -199,11 +221,14 @@ class FrameProcessor:
|
||||
frame: StoryboardFrame,
|
||||
config: StoryboardConfig
|
||||
):
|
||||
"""Step 2: Generate media (image or video) using ComfyKit"""
|
||||
"""
|
||||
Step 2: Generate media (image or video) using ComfyKit
|
||||
|
||||
Enhanced with quality evaluation and retry logic.
|
||||
"""
|
||||
logger.debug(f" 2/4: Generating media for frame {frame.index}...")
|
||||
|
||||
# Determine media type based on workflow
|
||||
# video_ prefix in workflow name indicates video generation
|
||||
workflow_name = config.media_workflow or ""
|
||||
is_video_workflow = "video_" in workflow_name.lower()
|
||||
media_type = "video" if is_video_workflow else "image"
|
||||
@@ -213,57 +238,87 @@ class FrameProcessor:
|
||||
# Build media generation parameters
|
||||
media_params = {
|
||||
"prompt": frame.image_prompt,
|
||||
"workflow": config.media_workflow, # Pass workflow from config (None = use default)
|
||||
"workflow": config.media_workflow,
|
||||
"media_type": media_type,
|
||||
"width": config.media_width,
|
||||
"height": config.media_height,
|
||||
"index": frame.index + 1, # 1-based index for workflow
|
||||
"index": frame.index + 1,
|
||||
}
|
||||
|
||||
# For video workflows: pass audio duration as target video duration
|
||||
# This ensures video length matches audio length from the source
|
||||
if is_video_workflow and frame.duration:
|
||||
media_params["duration"] = frame.duration
|
||||
logger.info(f" → Generating video with target duration: {frame.duration:.2f}s (from TTS audio)")
|
||||
logger.info(f" → Generating video with target duration: {frame.duration:.2f}s")
|
||||
|
||||
# Call Media generation
|
||||
media_result = await self.core.media(**media_params)
|
||||
# Define generation operation
|
||||
async def generate_and_download():
|
||||
media_result = await self.core.media(**media_params)
|
||||
local_path = await self._download_media(
|
||||
media_result.url,
|
||||
frame.index,
|
||||
config.task_id,
|
||||
media_type=media_result.media_type
|
||||
)
|
||||
return (media_result, local_path)
|
||||
|
||||
# Store media type
|
||||
# Define quality evaluator
|
||||
async def evaluate_quality(result):
|
||||
media_result, local_path = result
|
||||
if media_result.is_video:
|
||||
return await self.quality_gate.evaluate_video(
|
||||
local_path, frame.image_prompt, frame.narration
|
||||
)
|
||||
else:
|
||||
return await self.quality_gate.evaluate_image(
|
||||
local_path, frame.image_prompt, frame.narration
|
||||
)
|
||||
|
||||
# Execute with retry and quality check
|
||||
if self.enable_quality_check:
|
||||
try:
|
||||
retry_result = await self.retry_manager.execute_with_retry(
|
||||
operation=generate_and_download,
|
||||
quality_evaluator=evaluate_quality,
|
||||
operation_name=f"frame_{frame.index}_media",
|
||||
)
|
||||
media_result, local_path = retry_result.result
|
||||
|
||||
# Store quality metrics on frame
|
||||
if retry_result.quality_score:
|
||||
frame.quality_score = retry_result.quality_score.overall_score
|
||||
frame.quality_issues = retry_result.quality_score.issues
|
||||
frame.retry_count = retry_result.attempts - 1 # first attempt is not a retry
|
||||
|
||||
except QualityError as e:
|
||||
logger.warning(f" ⚠ Quality check failed after retries: {e}")
|
||||
# Still try to use the last result if available
|
||||
media_result, local_path = await generate_and_download()
|
||||
frame.quality_issues = [str(e)]
|
||||
else:
|
||||
# Quality check disabled - just generate
|
||||
media_result, local_path = await generate_and_download()
|
||||
|
||||
# Store results on frame
|
||||
frame.media_type = media_result.media_type
|
||||
|
||||
if media_result.is_image:
|
||||
# Download image to local (pass task_id)
|
||||
local_path = await self._download_media(
|
||||
media_result.url,
|
||||
frame.index,
|
||||
config.task_id,
|
||||
media_type="image"
|
||||
)
|
||||
frame.image_path = local_path
|
||||
logger.debug(f" ✓ Image generated: {local_path}")
|
||||
|
||||
elif media_result.is_video:
|
||||
# Download video to local (pass task_id)
|
||||
local_path = await self._download_media(
|
||||
media_result.url,
|
||||
frame.index,
|
||||
config.task_id,
|
||||
media_type="video"
|
||||
)
|
||||
frame.video_path = local_path
|
||||
|
||||
# Update duration from video if available
|
||||
if media_result.duration:
|
||||
frame.duration = media_result.duration
|
||||
logger.debug(f" ✓ Video generated: {local_path} (duration: {frame.duration:.2f}s)")
|
||||
else:
|
||||
# Get video duration from file
|
||||
frame.duration = await self._get_video_duration(local_path)
|
||||
logger.debug(f" ✓ Video generated: {local_path} (duration: {frame.duration:.2f}s)")
|
||||
|
||||
logger.debug(f" ✓ Video generated: {local_path} (duration: {frame.duration:.2f}s)")
|
||||
else:
|
||||
raise ValueError(f"Unknown media type: {media_result.media_type}")
|
||||
|
||||
# Log quality result
|
||||
if frame.quality_score is not None:
|
||||
logger.info(
|
||||
f" 📊 Quality: {frame.quality_score:.2f} "
|
||||
f"(retries: {frame.retry_count}, issues: {len(frame.quality_issues or [])})"
|
||||
)
|
||||
|
||||
async def _step_compose_frame(
|
||||
self,
|
||||
|
||||
134
pixelle_video/services/publishing/__init__.py
Normal file
134
pixelle_video/services/publishing/__init__.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Publishing service for multi-platform video distribution.
|
||||
|
||||
Supports:
|
||||
- Format conversion + export (Douyin/Kuaishou)
|
||||
- API-based upload (Bilibili/YouTube)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class Platform(Enum):
|
||||
"""Supported publishing platforms"""
|
||||
EXPORT = "export" # Format conversion only
|
||||
DOUYIN = "douyin" # 抖音 (via export or CDP)
|
||||
KUAISHOU = "kuaishou" # 快手 (via export or CDP)
|
||||
BILIBILI = "bilibili" # B站 (API)
|
||||
YOUTUBE = "youtube" # YouTube (API)
|
||||
|
||||
|
||||
class PublishStatus(Enum):
|
||||
"""Publishing task status"""
|
||||
PENDING = "pending"
|
||||
CONVERTING = "converting"
|
||||
UPLOADING = "uploading"
|
||||
PROCESSING = "processing"
|
||||
PUBLISHED = "published"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoMetadata:
|
||||
"""Video metadata for publishing"""
|
||||
title: str
|
||||
description: str = ""
|
||||
tags: List[str] = field(default_factory=list)
|
||||
category: Optional[str] = None
|
||||
cover_path: Optional[str] = None
|
||||
privacy: str = "public" # public, private, unlisted
|
||||
|
||||
# Platform-specific options
|
||||
platform_options: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublishResult:
|
||||
"""Result of a publishing operation"""
|
||||
success: bool
|
||||
platform: Platform
|
||||
status: PublishStatus
|
||||
|
||||
# On success
|
||||
video_url: Optional[str] = None
|
||||
platform_video_id: Optional[str] = None
|
||||
|
||||
# On failure
|
||||
error_message: Optional[str] = None
|
||||
|
||||
# Export result
|
||||
export_path: Optional[str] = None
|
||||
|
||||
# Timestamps
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublishTask:
|
||||
"""A publishing task for background processing"""
|
||||
id: str
|
||||
video_path: str
|
||||
platform: Platform
|
||||
metadata: VideoMetadata
|
||||
status: PublishStatus = PublishStatus.PENDING
|
||||
result: Optional[PublishResult] = None
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class Publisher(ABC):
|
||||
"""Abstract base class for platform publishers"""
|
||||
|
||||
platform: Platform
|
||||
|
||||
@abstractmethod
|
||||
async def publish(
|
||||
self,
|
||||
video_path: str,
|
||||
metadata: VideoMetadata,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> PublishResult:
|
||||
"""
|
||||
Publish a video to the platform.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file
|
||||
metadata: Video metadata (title, description, tags, etc.)
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
PublishResult with success/failure details
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def validate_credentials(self) -> bool:
|
||||
"""Check if platform credentials are valid"""
|
||||
pass
|
||||
|
||||
def get_platform_requirements(self) -> Dict[str, Any]:
|
||||
"""Get platform-specific requirements (dimensions, file size, etc.)"""
|
||||
return {
|
||||
"max_file_size_mb": 128,
|
||||
"max_duration_seconds": 900, # 15 minutes
|
||||
"supported_formats": ["mp4", "webm"],
|
||||
"recommended_resolution": (1080, 1920), # Portrait 9:16
|
||||
"recommended_codec": "h264",
|
||||
}
|
||||
426
pixelle_video/services/publishing/bilibili_publisher.py
Normal file
426
pixelle_video/services/publishing/bilibili_publisher.py
Normal file
@@ -0,0 +1,426 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Bilibili Publisher - Upload videos to Bilibili using their Open Platform API.
|
||||
|
||||
Flow:
|
||||
1. Get preupload info (upos_uri, auth, chunk_size)
|
||||
2. Upload video chunks (8MB each)
|
||||
3. Merge chunks
|
||||
4. Submit video with metadata
|
||||
"""
|
||||
|
||||
import os
|
||||
import math
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from pixelle_video.services.publishing import (
|
||||
Publisher,
|
||||
Platform,
|
||||
PublishStatus,
|
||||
VideoMetadata,
|
||||
PublishResult,
|
||||
)
|
||||
|
||||
|
||||
# Bilibili API endpoints
|
||||
BILIBILI_PREUPLOAD_URL = "https://member.bilibili.com/preupload"
|
||||
BILIBILI_SUBMIT_URL = "https://member.bilibili.com/x/vu/web/add"
|
||||
|
||||
# Chunk size: 8MB (recommended by Bilibili)
|
||||
CHUNK_SIZE = 8 * 1024 * 1024
|
||||
|
||||
|
||||
class BilibiliPublisher(Publisher):
|
||||
"""
|
||||
Publisher for Bilibili video platform.
|
||||
|
||||
Requires:
|
||||
- access_token: OAuth access token
|
||||
- refresh_token: For token refresh (optional)
|
||||
"""
|
||||
|
||||
platform = Platform.BILIBILI
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
access_token: Optional[str] = None,
|
||||
refresh_token: Optional[str] = None,
|
||||
sessdata: Optional[str] = None, # Alternative: use cookies
|
||||
bili_jct: Optional[str] = None,
|
||||
):
|
||||
self.access_token = access_token or os.getenv("BILIBILI_ACCESS_TOKEN")
|
||||
self.refresh_token = refresh_token or os.getenv("BILIBILI_REFRESH_TOKEN")
|
||||
self.sessdata = sessdata or os.getenv("BILIBILI_SESSDATA")
|
||||
self.bili_jct = bili_jct or os.getenv("BILIBILI_BILI_JCT")
|
||||
|
||||
# Upload state
|
||||
self._upload_id = None
|
||||
self._upos_uri = None
|
||||
self._auth = None
|
||||
self._endpoint = None
|
||||
|
||||
async def publish(
|
||||
self,
|
||||
video_path: str,
|
||||
metadata: VideoMetadata,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> PublishResult:
|
||||
"""Upload and publish video to Bilibili."""
|
||||
started_at = datetime.now()
|
||||
|
||||
try:
|
||||
if not await self.validate_credentials():
|
||||
return PublishResult(
|
||||
success=False,
|
||||
platform=Platform.BILIBILI,
|
||||
status=PublishStatus.FAILED,
|
||||
error_message="B站凭证未配置或已过期",
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
video_file = Path(video_path)
|
||||
if not video_file.exists():
|
||||
return PublishResult(
|
||||
success=False,
|
||||
platform=Platform.BILIBILI,
|
||||
status=PublishStatus.FAILED,
|
||||
error_message=f"视频文件不存在: {video_path}",
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
file_size = video_file.stat().st_size
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.05, "获取上传信息...")
|
||||
|
||||
# Step 1: Get preupload info
|
||||
preupload_info = await self._preupload(video_file.name, file_size)
|
||||
if not preupload_info:
|
||||
return PublishResult(
|
||||
success=False,
|
||||
platform=Platform.BILIBILI,
|
||||
status=PublishStatus.FAILED,
|
||||
error_message="获取上传信息失败",
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.1, "上传视频分片...")
|
||||
|
||||
# Step 2: Upload chunks
|
||||
chunk_count = math.ceil(file_size / CHUNK_SIZE)
|
||||
uploaded_chunks = 0
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
with open(video_path, "rb") as f:
|
||||
for chunk_index in range(chunk_count):
|
||||
chunk_data = f.read(CHUNK_SIZE)
|
||||
chunk_start = chunk_index * CHUNK_SIZE
|
||||
chunk_end = min(chunk_start + len(chunk_data), file_size)
|
||||
|
||||
success = await self._upload_chunk(
|
||||
session,
|
||||
chunk_data,
|
||||
chunk_index,
|
||||
chunk_count,
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
file_size,
|
||||
)
|
||||
|
||||
if not success:
|
||||
return PublishResult(
|
||||
success=False,
|
||||
platform=Platform.BILIBILI,
|
||||
status=PublishStatus.FAILED,
|
||||
error_message=f"分片 {chunk_index + 1}/{chunk_count} 上传失败",
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
uploaded_chunks += 1
|
||||
progress = 0.1 + (0.7 * uploaded_chunks / chunk_count)
|
||||
if progress_callback:
|
||||
progress_callback(progress, f"上传分片 {uploaded_chunks}/{chunk_count}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.85, "合并视频...")
|
||||
|
||||
# Step 3: Merge chunks
|
||||
video_filename = await self._merge_chunks(chunk_count, file_size)
|
||||
if not video_filename:
|
||||
return PublishResult(
|
||||
success=False,
|
||||
platform=Platform.BILIBILI,
|
||||
status=PublishStatus.FAILED,
|
||||
error_message="视频合并失败",
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.9, "提交稿件...")
|
||||
|
||||
# Step 4: Submit video
|
||||
bvid = await self._submit_video(video_filename, metadata)
|
||||
if not bvid:
|
||||
return PublishResult(
|
||||
success=False,
|
||||
platform=Platform.BILIBILI,
|
||||
status=PublishStatus.FAILED,
|
||||
error_message="稿件提交失败",
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(1.0, "发布成功")
|
||||
|
||||
return PublishResult(
|
||||
success=True,
|
||||
platform=Platform.BILIBILI,
|
||||
status=PublishStatus.PUBLISHED,
|
||||
video_url=f"https://www.bilibili.com/video/{bvid}",
|
||||
platform_video_id=bvid,
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Bilibili publish failed: {e}")
|
||||
return PublishResult(
|
||||
success=False,
|
||||
platform=Platform.BILIBILI,
|
||||
status=PublishStatus.FAILED,
|
||||
error_message=str(e),
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
async def _preupload(self, filename: str, file_size: int) -> Optional[Dict]:
|
||||
"""Get preupload info from Bilibili."""
|
||||
params = {
|
||||
"name": filename,
|
||||
"size": file_size,
|
||||
"r": "upos",
|
||||
"profile": "ugcupos/bup",
|
||||
}
|
||||
|
||||
headers = self._get_headers()
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
BILIBILI_PREUPLOAD_URL,
|
||||
params=params,
|
||||
headers=headers
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(f"Preupload failed: {resp.status}")
|
||||
return None
|
||||
|
||||
data = await resp.json()
|
||||
|
||||
self._upos_uri = data.get("upos_uri")
|
||||
self._auth = data.get("auth")
|
||||
self._endpoint = data.get("endpoint")
|
||||
self._upload_id = data.get("upload_id")
|
||||
|
||||
logger.info(f"Preupload success: {self._upos_uri}")
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Preupload error: {e}")
|
||||
return None
|
||||
|
||||
async def _upload_chunk(
|
||||
self,
|
||||
session: aiohttp.ClientSession,
|
||||
chunk_data: bytes,
|
||||
chunk_index: int,
|
||||
chunk_count: int,
|
||||
chunk_start: int,
|
||||
chunk_end: int,
|
||||
total_size: int,
|
||||
) -> bool:
|
||||
"""Upload a single chunk."""
|
||||
if not self._upos_uri or not self._auth:
|
||||
return False
|
||||
|
||||
# Build upload URL
|
||||
upload_url = f"https:{self._endpoint}{self._upos_uri}"
|
||||
|
||||
params = {
|
||||
"uploadId": self._upload_id,
|
||||
"partNumber": chunk_index + 1,
|
||||
"chunk": chunk_index,
|
||||
"chunks": chunk_count,
|
||||
"size": len(chunk_data),
|
||||
"start": chunk_start,
|
||||
"end": chunk_end,
|
||||
"total": total_size,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"X-Upos-Auth": self._auth,
|
||||
"Content-Type": "application/octet-stream",
|
||||
}
|
||||
|
||||
try:
|
||||
async with session.put(
|
||||
upload_url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
data=chunk_data
|
||||
) as resp:
|
||||
if resp.status not in [200, 201, 204]:
|
||||
logger.error(f"Chunk upload failed: {resp.status}")
|
||||
return False
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Chunk upload error: {e}")
|
||||
return False
|
||||
|
||||
async def _merge_chunks(self, chunk_count: int, file_size: int) -> Optional[str]:
|
||||
"""Merge uploaded chunks."""
|
||||
if not self._upos_uri:
|
||||
return None
|
||||
|
||||
merge_url = f"https:{self._endpoint}{self._upos_uri}"
|
||||
|
||||
params = {
|
||||
"output": "json",
|
||||
"name": self._upos_uri.split("/")[-1],
|
||||
"profile": "ugcupos/bup",
|
||||
"uploadId": self._upload_id,
|
||||
"biz_id": "",
|
||||
}
|
||||
|
||||
# Build parts list
|
||||
parts = [{"partNumber": i + 1, "eTag": "etag"} for i in range(chunk_count)]
|
||||
|
||||
headers = {
|
||||
"X-Upos-Auth": self._auth,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
merge_url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
json={"parts": parts}
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(f"Merge failed: {resp.status}")
|
||||
return None
|
||||
|
||||
data = await resp.json()
|
||||
return self._upos_uri.split("/")[-1].split(".")[0]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Merge error: {e}")
|
||||
return None
|
||||
|
||||
async def _submit_video(
|
||||
self,
|
||||
video_filename: str,
|
||||
metadata: VideoMetadata
|
||||
) -> Optional[str]:
|
||||
"""Submit video with metadata."""
|
||||
|
||||
# Default to "生活" category (tid=160)
|
||||
tid = metadata.platform_options.get("tid", 160)
|
||||
|
||||
data = {
|
||||
"copyright": 1, # 1=原创
|
||||
"videos": [{
|
||||
"filename": video_filename,
|
||||
"title": metadata.title,
|
||||
"desc": metadata.description,
|
||||
}],
|
||||
"title": metadata.title,
|
||||
"desc": metadata.description,
|
||||
"tid": tid,
|
||||
"tag": ",".join(metadata.tags) if metadata.tags else "",
|
||||
"source": "",
|
||||
"cover": metadata.cover_path or "",
|
||||
"no_reprint": 1,
|
||||
"open_elec": 0,
|
||||
}
|
||||
|
||||
headers = self._get_headers()
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
BILIBILI_SUBMIT_URL,
|
||||
headers=headers,
|
||||
json=data
|
||||
) as resp:
|
||||
result = await resp.json()
|
||||
|
||||
if result.get("code") == 0:
|
||||
bvid = result.get("data", {}).get("bvid")
|
||||
logger.info(f"Video submitted: {bvid}")
|
||||
return bvid
|
||||
else:
|
||||
logger.error(f"Submit failed: {result}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Submit error: {e}")
|
||||
return None
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""Get common headers with authentication."""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36",
|
||||
"Referer": "https://www.bilibili.com/",
|
||||
}
|
||||
|
||||
if self.access_token:
|
||||
headers["Authorization"] = f"Bearer {self.access_token}"
|
||||
|
||||
if self.sessdata:
|
||||
headers["Cookie"] = f"SESSDATA={self.sessdata}"
|
||||
if self.bili_jct:
|
||||
headers["Cookie"] += f"; bili_jct={self.bili_jct}"
|
||||
|
||||
return headers
|
||||
|
||||
async def validate_credentials(self) -> bool:
|
||||
"""Check if Bilibili credentials are configured."""
|
||||
return bool(self.access_token or self.sessdata)
|
||||
|
||||
def get_platform_requirements(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"max_file_size_mb": 4096, # 4GB
|
||||
"max_duration_seconds": 14400, # 4 hours
|
||||
"supported_formats": ["mp4", "flv", "webm", "mov"],
|
||||
"recommended_resolution": (1920, 1080),
|
||||
"recommended_codec": "h264",
|
||||
"chunk_size_mb": 8,
|
||||
}
|
||||
182
pixelle_video/services/publishing/export_publisher.py
Normal file
182
pixelle_video/services/publishing/export_publisher.py
Normal file
@@ -0,0 +1,182 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Export Publisher - Format conversion and local export for platforms
|
||||
without API access (Douyin, Kuaishou).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
|
||||
from pixelle_video.services.publishing import (
|
||||
Publisher,
|
||||
Platform,
|
||||
PublishStatus,
|
||||
VideoMetadata,
|
||||
PublishResult,
|
||||
)
|
||||
|
||||
|
||||
class ExportPublisher(Publisher):
|
||||
"""
|
||||
Publisher that converts video to platform-optimized format
|
||||
and exports to local filesystem for manual upload.
|
||||
"""
|
||||
|
||||
platform = Platform.EXPORT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: str = "./output/exports",
|
||||
target_resolution: tuple = (1080, 1920), # Portrait 9:16
|
||||
target_codec: str = "h264",
|
||||
max_file_size_mb: int = 128,
|
||||
):
|
||||
self.output_dir = Path(output_dir)
|
||||
self.target_resolution = target_resolution
|
||||
self.target_codec = target_codec
|
||||
self.max_file_size_mb = max_file_size_mb
|
||||
|
||||
# Ensure output directory exists
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def publish(
|
||||
self,
|
||||
video_path: str,
|
||||
metadata: VideoMetadata,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> PublishResult:
|
||||
"""
|
||||
Convert video to optimized format and export.
|
||||
"""
|
||||
started_at = datetime.now()
|
||||
|
||||
try:
|
||||
if progress_callback:
|
||||
progress_callback(0.1, "分析视频...")
|
||||
|
||||
# Generate output filename
|
||||
safe_title = "".join(c if c.isalnum() or c in " -_" else "" for c in metadata.title)
|
||||
safe_title = safe_title[:50].strip() or "video"
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_filename = f"{safe_title}_{timestamp}.mp4"
|
||||
output_path = self.output_dir / output_filename
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.2, "转换格式...")
|
||||
|
||||
# Convert video
|
||||
success = await self._convert_video(
|
||||
video_path,
|
||||
str(output_path),
|
||||
progress_callback
|
||||
)
|
||||
|
||||
if not success:
|
||||
return PublishResult(
|
||||
success=False,
|
||||
platform=Platform.EXPORT,
|
||||
status=PublishStatus.FAILED,
|
||||
error_message="视频转换失败",
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(1.0, "导出完成")
|
||||
|
||||
# Verify file size
|
||||
file_size_mb = output_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
return PublishResult(
|
||||
success=True,
|
||||
platform=Platform.EXPORT,
|
||||
status=PublishStatus.PUBLISHED,
|
||||
export_path=str(output_path),
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Export failed: {e}")
|
||||
return PublishResult(
|
||||
success=False,
|
||||
platform=Platform.EXPORT,
|
||||
status=PublishStatus.FAILED,
|
||||
error_message=str(e),
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
async def _convert_video(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> bool:
|
||||
"""Convert video using FFmpeg."""
|
||||
|
||||
width, height = self.target_resolution
|
||||
|
||||
# FFmpeg command for H.264 conversion with size optimization
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", input_path,
|
||||
"-c:v", "libx264",
|
||||
"-preset", "medium",
|
||||
"-crf", "23",
|
||||
"-c:a", "aac",
|
||||
"-b:a", "128k",
|
||||
"-vf", f"scale={width}:{height}:force_original_aspect_ratio=decrease,pad={width}:{height}:(ow-iw)/2:(oh-ih)/2",
|
||||
"-movflags", "+faststart",
|
||||
output_path
|
||||
]
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
logger.error(f"FFmpeg error: {stderr.decode()}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error("FFmpeg not found. Please install FFmpeg.")
|
||||
# Fallback: just copy the file if FFmpeg is not available
|
||||
shutil.copy(input_path, output_path)
|
||||
return True
|
||||
|
||||
async def validate_credentials(self) -> bool:
|
||||
"""No credentials needed for export."""
|
||||
return True
|
||||
|
||||
def get_platform_requirements(self):
|
||||
return {
|
||||
"max_file_size_mb": self.max_file_size_mb,
|
||||
"recommended_resolution": self.target_resolution,
|
||||
"recommended_codec": self.target_codec,
|
||||
"output_format": "mp4",
|
||||
"platforms": ["douyin", "kuaishou"],
|
||||
}
|
||||
299
pixelle_video/services/publishing/task_manager.py
Normal file
299
pixelle_video/services/publishing/task_manager.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Publish Task Manager - Background task queue for video publishing.
|
||||
|
||||
Features:
|
||||
- Async task queue with configurable workers
|
||||
- Task persistence (in-memory, Redis optional)
|
||||
- Progress tracking and callbacks
|
||||
- Retry logic for failed tasks
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, List, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from loguru import logger
|
||||
|
||||
from pixelle_video.services.publishing import (
|
||||
Publisher,
|
||||
Platform,
|
||||
PublishStatus,
|
||||
VideoMetadata,
|
||||
PublishResult,
|
||||
PublishTask,
|
||||
)
|
||||
from pixelle_video.services.publishing.export_publisher import ExportPublisher
|
||||
from pixelle_video.services.publishing.bilibili_publisher import BilibiliPublisher
|
||||
from pixelle_video.services.publishing.youtube_publisher import YouTubePublisher
|
||||
|
||||
|
||||
class TaskPriority(Enum):
|
||||
LOW = 0
|
||||
NORMAL = 1
|
||||
HIGH = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueuedTask:
|
||||
"""Extended task with queue metadata"""
|
||||
task: PublishTask
|
||||
priority: TaskPriority = TaskPriority.NORMAL
|
||||
retries: int = 0
|
||||
max_retries: int = 3
|
||||
retry_delay: float = 5.0
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
progress: float = 0.0
|
||||
progress_message: str = ""
|
||||
|
||||
|
||||
class PublishTaskManager:
|
||||
"""
|
||||
Manages background publishing tasks with async queue.
|
||||
|
||||
Usage:
|
||||
manager = PublishTaskManager()
|
||||
await manager.start()
|
||||
|
||||
task_id = await manager.enqueue(
|
||||
video_path="/path/to/video.mp4",
|
||||
platform=Platform.BILIBILI,
|
||||
metadata=VideoMetadata(title="My Video")
|
||||
)
|
||||
|
||||
status = manager.get_task(task_id)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_workers: int = 3,
|
||||
max_queue_size: int = 100,
|
||||
):
|
||||
self.max_workers = max_workers
|
||||
self.max_queue_size = max_queue_size
|
||||
|
||||
# Task storage
|
||||
self._tasks: Dict[str, QueuedTask] = {}
|
||||
self._queue: asyncio.Queue = None
|
||||
self._workers: List[asyncio.Task] = []
|
||||
self._running = False
|
||||
|
||||
# Publishers
|
||||
self._publishers: Dict[Platform, Publisher] = {
|
||||
Platform.EXPORT: ExportPublisher(),
|
||||
Platform.BILIBILI: BilibiliPublisher(),
|
||||
Platform.YOUTUBE: YouTubePublisher(),
|
||||
}
|
||||
|
||||
# Callbacks
|
||||
self._on_complete: Optional[Callable] = None
|
||||
self._on_progress: Optional[Callable] = None
|
||||
|
||||
async def start(self):
|
||||
"""Start the task manager and workers."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._queue = asyncio.Queue(maxsize=self.max_queue_size)
|
||||
self._running = True
|
||||
|
||||
# Start worker tasks
|
||||
for i in range(self.max_workers):
|
||||
worker = asyncio.create_task(self._worker(i))
|
||||
self._workers.append(worker)
|
||||
|
||||
logger.info(f"✅ Publish task manager started with {self.max_workers} workers")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop all workers and clear queue."""
|
||||
self._running = False
|
||||
|
||||
# Cancel all workers
|
||||
for worker in self._workers:
|
||||
worker.cancel()
|
||||
|
||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||
self._workers.clear()
|
||||
|
||||
logger.info("✅ Publish task manager stopped")
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
video_path: str,
|
||||
platform: Platform,
|
||||
metadata: VideoMetadata,
|
||||
priority: TaskPriority = TaskPriority.NORMAL,
|
||||
) -> str:
|
||||
"""
|
||||
Add a publish task to the queue.
|
||||
|
||||
Returns:
|
||||
Task ID for tracking
|
||||
"""
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
|
||||
task = PublishTask(
|
||||
id=task_id,
|
||||
video_path=video_path,
|
||||
platform=platform,
|
||||
metadata=metadata,
|
||||
status=PublishStatus.PENDING,
|
||||
)
|
||||
|
||||
queued_task = QueuedTask(task=task, priority=priority)
|
||||
self._tasks[task_id] = queued_task
|
||||
|
||||
await self._queue.put(queued_task)
|
||||
|
||||
logger.info(f"📥 Queued task {task_id}: {metadata.title} → {platform.value}")
|
||||
|
||||
return task_id
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[QueuedTask]:
|
||||
"""Get task by ID."""
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
def get_all_tasks(self) -> List[QueuedTask]:
|
||||
"""Get all tasks."""
|
||||
return list(self._tasks.values())
|
||||
|
||||
def get_pending_tasks(self) -> List[QueuedTask]:
|
||||
"""Get pending tasks."""
|
||||
return [t for t in self._tasks.values() if t.task.status == PublishStatus.PENDING]
|
||||
|
||||
def get_active_tasks(self) -> List[QueuedTask]:
|
||||
"""Get currently processing tasks."""
|
||||
return [t for t in self._tasks.values() if t.task.status in [
|
||||
PublishStatus.CONVERTING,
|
||||
PublishStatus.UPLOADING,
|
||||
PublishStatus.PROCESSING,
|
||||
]]
|
||||
|
||||
def set_on_complete(self, callback: Callable):
|
||||
"""Set callback for task completion."""
|
||||
self._on_complete = callback
|
||||
|
||||
def set_on_progress(self, callback: Callable):
|
||||
"""Set callback for progress updates."""
|
||||
self._on_progress = callback
|
||||
|
||||
async def _worker(self, worker_id: int):
|
||||
"""Worker coroutine that processes tasks from queue."""
|
||||
logger.debug(f"Worker {worker_id} started")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# Get task from queue with timeout
|
||||
try:
|
||||
queued_task = await asyncio.wait_for(
|
||||
self._queue.get(),
|
||||
timeout=1.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
await self._process_task(queued_task, worker_id)
|
||||
self._queue.task_done()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Worker {worker_id} error: {e}")
|
||||
|
||||
async def _process_task(self, queued_task: QueuedTask, worker_id: int):
|
||||
"""Process a single publish task."""
|
||||
task = queued_task.task
|
||||
task_id = task.id
|
||||
|
||||
logger.info(f"🔄 Worker {worker_id} processing task {task_id}")
|
||||
|
||||
queued_task.started_at = datetime.now()
|
||||
task.status = PublishStatus.UPLOADING
|
||||
|
||||
# Get publisher
|
||||
publisher = self._publishers.get(task.platform)
|
||||
if not publisher:
|
||||
task.status = PublishStatus.FAILED
|
||||
task.result = PublishResult(
|
||||
success=False,
|
||||
platform=task.platform,
|
||||
status=PublishStatus.FAILED,
|
||||
error_message=f"No publisher for platform: {task.platform}",
|
||||
)
|
||||
return
|
||||
|
||||
# Progress callback
|
||||
def progress_callback(progress: float, message: str):
|
||||
queued_task.progress = progress
|
||||
queued_task.progress_message = message
|
||||
if self._on_progress:
|
||||
self._on_progress(task_id, progress, message)
|
||||
|
||||
try:
|
||||
# Execute publish
|
||||
result = await publisher.publish(
|
||||
task.video_path,
|
||||
task.metadata,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
task.result = result
|
||||
task.status = result.status
|
||||
|
||||
if result.success:
|
||||
logger.info(f"✅ Task {task_id} completed: {result.video_url or result.export_path}")
|
||||
else:
|
||||
logger.warning(f"❌ Task {task_id} failed: {result.error_message}")
|
||||
|
||||
# Retry if applicable
|
||||
if queued_task.retries < queued_task.max_retries:
|
||||
queued_task.retries += 1
|
||||
task.status = PublishStatus.PENDING
|
||||
logger.info(f"🔄 Retrying task {task_id} ({queued_task.retries}/{queued_task.max_retries})")
|
||||
await asyncio.sleep(queued_task.retry_delay)
|
||||
await self._queue.put(queued_task)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Task {task_id} exception: {e}")
|
||||
task.status = PublishStatus.FAILED
|
||||
task.result = PublishResult(
|
||||
success=False,
|
||||
platform=task.platform,
|
||||
status=PublishStatus.FAILED,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
queued_task.completed_at = datetime.now()
|
||||
task.updated_at = datetime.now()
|
||||
|
||||
# Call completion callback
|
||||
if self._on_complete:
|
||||
self._on_complete(task_id, task.result)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_publish_manager: Optional[PublishTaskManager] = None
|
||||
|
||||
|
||||
def get_publish_manager() -> PublishTaskManager:
|
||||
"""Get or create the global publish task manager."""
|
||||
global _publish_manager
|
||||
if _publish_manager is None:
|
||||
_publish_manager = PublishTaskManager()
|
||||
return _publish_manager
|
||||
310
pixelle_video/services/publishing/youtube_publisher.py
Normal file
310
pixelle_video/services/publishing/youtube_publisher.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
YouTube Publisher - Upload videos to YouTube using Data API v3.
|
||||
|
||||
Requires:
|
||||
- Google Cloud project with YouTube Data API v3 enabled
|
||||
- OAuth 2.0 credentials (client_secrets.json)
|
||||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from pixelle_video.services.publishing import (
|
||||
Publisher,
|
||||
Platform,
|
||||
PublishStatus,
|
||||
VideoMetadata,
|
||||
PublishResult,
|
||||
)
|
||||
|
||||
|
||||
# YouTube category IDs
|
||||
YOUTUBE_CATEGORIES = {
|
||||
"film": "1",
|
||||
"autos": "2",
|
||||
"music": "10",
|
||||
"pets": "15",
|
||||
"sports": "17",
|
||||
"travel": "19",
|
||||
"gaming": "20",
|
||||
"people": "22",
|
||||
"comedy": "23",
|
||||
"entertainment": "24",
|
||||
"news": "25",
|
||||
"howto": "26",
|
||||
"education": "27",
|
||||
"science": "28",
|
||||
"nonprofits": "29",
|
||||
}
|
||||
|
||||
|
||||
class YouTubePublisher(Publisher):
|
||||
"""
|
||||
Publisher for YouTube video platform.
|
||||
|
||||
Uses Google API Python Client for uploading videos.
|
||||
|
||||
Setup:
|
||||
1. Create project in Google Cloud Console
|
||||
2. Enable YouTube Data API v3
|
||||
3. Create OAuth 2.0 credentials
|
||||
4. Download client_secrets.json
|
||||
"""
|
||||
|
||||
platform = Platform.YOUTUBE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_secrets_path: Optional[str] = None,
|
||||
token_path: Optional[str] = None,
|
||||
):
|
||||
self.client_secrets_path = client_secrets_path or os.getenv(
|
||||
"YOUTUBE_CLIENT_SECRETS",
|
||||
"./config/youtube_client_secrets.json"
|
||||
)
|
||||
self.token_path = token_path or os.getenv(
|
||||
"YOUTUBE_TOKEN_PATH",
|
||||
"./config/youtube_token.pickle"
|
||||
)
|
||||
|
||||
self._youtube_service = None
|
||||
|
||||
async def publish(
|
||||
self,
|
||||
video_path: str,
|
||||
metadata: VideoMetadata,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> PublishResult:
|
||||
"""Upload and publish video to YouTube."""
|
||||
started_at = datetime.now()
|
||||
|
||||
try:
|
||||
if not await self.validate_credentials():
|
||||
return PublishResult(
|
||||
success=False,
|
||||
platform=Platform.YOUTUBE,
|
||||
status=PublishStatus.FAILED,
|
||||
error_message="YouTube 凭证未配置。请配置 client_secrets.json",
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
video_file = Path(video_path)
|
||||
if not video_file.exists():
|
||||
return PublishResult(
|
||||
success=False,
|
||||
platform=Platform.YOUTUBE,
|
||||
status=PublishStatus.FAILED,
|
||||
error_message=f"视频文件不存在: {video_path}",
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.1, "初始化 YouTube API...")
|
||||
|
||||
# Initialize YouTube service
|
||||
youtube = await self._get_youtube_service()
|
||||
if not youtube:
|
||||
return PublishResult(
|
||||
success=False,
|
||||
platform=Platform.YOUTUBE,
|
||||
status=PublishStatus.FAILED,
|
||||
error_message="无法初始化 YouTube API 服务",
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.2, "准备上传...")
|
||||
|
||||
# Prepare video metadata
|
||||
category_id = self._get_category_id(metadata.category)
|
||||
privacy_status = self._map_privacy(metadata.privacy)
|
||||
|
||||
body = {
|
||||
"snippet": {
|
||||
"title": metadata.title,
|
||||
"description": metadata.description,
|
||||
"tags": metadata.tags,
|
||||
"categoryId": category_id,
|
||||
},
|
||||
"status": {
|
||||
"privacyStatus": privacy_status,
|
||||
"selfDeclaredMadeForKids": False,
|
||||
}
|
||||
}
|
||||
|
||||
# Check for synthetic media flag
|
||||
if metadata.platform_options.get("contains_synthetic_media"):
|
||||
body["status"]["containsSyntheticMedia"] = True
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.3, "上传视频...")
|
||||
|
||||
# Upload using resumable upload
|
||||
video_id = await self._upload_video(youtube, video_path, body, progress_callback)
|
||||
|
||||
if not video_id:
|
||||
return PublishResult(
|
||||
success=False,
|
||||
platform=Platform.YOUTUBE,
|
||||
status=PublishStatus.FAILED,
|
||||
error_message="视频上传失败",
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(1.0, "发布成功")
|
||||
|
||||
return PublishResult(
|
||||
success=True,
|
||||
platform=Platform.YOUTUBE,
|
||||
status=PublishStatus.PUBLISHED,
|
||||
video_url=f"https://www.youtube.com/watch?v={video_id}",
|
||||
platform_video_id=video_id,
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"YouTube publish failed: {e}")
|
||||
return PublishResult(
|
||||
success=False,
|
||||
platform=Platform.YOUTUBE,
|
||||
status=PublishStatus.FAILED,
|
||||
error_message=str(e),
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
async def _get_youtube_service(self):
|
||||
"""Get authenticated YouTube service."""
|
||||
try:
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
SCOPES = ["https://www.googleapis.com/auth/youtube.upload"]
|
||||
|
||||
creds = None
|
||||
|
||||
# Load saved token
|
||||
if os.path.exists(self.token_path):
|
||||
with open(self.token_path, "rb") as token:
|
||||
creds = pickle.load(token)
|
||||
|
||||
# Refresh or get new credentials
|
||||
if not creds or not creds.valid:
|
||||
if creds and creds.expired and creds.refresh_token:
|
||||
from google.auth.transport.requests import Request
|
||||
creds.refresh(Request())
|
||||
else:
|
||||
if not os.path.exists(self.client_secrets_path):
|
||||
logger.error(f"Client secrets not found: {self.client_secrets_path}")
|
||||
return None
|
||||
|
||||
flow = InstalledAppFlow.from_client_secrets_file(
|
||||
self.client_secrets_path,
|
||||
SCOPES
|
||||
)
|
||||
creds = flow.run_local_server(port=0)
|
||||
|
||||
# Save token
|
||||
with open(self.token_path, "wb") as token:
|
||||
pickle.dump(creds, token)
|
||||
|
||||
return build("youtube", "v3", credentials=creds)
|
||||
|
||||
except ImportError:
|
||||
logger.error("Google API libraries not installed. Run: pip install google-api-python-client google-auth-oauthlib")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize YouTube service: {e}")
|
||||
return None
|
||||
|
||||
async def _upload_video(
|
||||
self,
|
||||
youtube,
|
||||
video_path: str,
|
||||
body: dict,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> Optional[str]:
|
||||
"""Upload video using resumable upload."""
|
||||
try:
|
||||
from googleapiclient.http import MediaFileUpload
|
||||
|
||||
media = MediaFileUpload(
|
||||
video_path,
|
||||
chunksize=1024 * 1024, # 1MB chunks
|
||||
resumable=True
|
||||
)
|
||||
|
||||
request = youtube.videos().insert(
|
||||
part=",".join(body.keys()),
|
||||
body=body,
|
||||
media_body=media
|
||||
)
|
||||
|
||||
response = None
|
||||
while response is None:
|
||||
status, response = request.next_chunk()
|
||||
if status:
|
||||
progress = 0.3 + (0.6 * status.progress())
|
||||
if progress_callback:
|
||||
progress_callback(progress, f"上传 {int(status.progress() * 100)}%")
|
||||
|
||||
video_id = response.get("id")
|
||||
logger.info(f"Video uploaded: {video_id}")
|
||||
return video_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Upload failed: {e}")
|
||||
return None
|
||||
|
||||
def _get_category_id(self, category: Optional[str]) -> str:
|
||||
"""Map category name to YouTube category ID."""
|
||||
if not category:
|
||||
return "22" # Default: People & Blogs
|
||||
|
||||
return YOUTUBE_CATEGORIES.get(category.lower(), "22")
|
||||
|
||||
def _map_privacy(self, privacy: str) -> str:
|
||||
"""Map privacy setting to YouTube format."""
|
||||
mapping = {
|
||||
"public": "public",
|
||||
"private": "private",
|
||||
"unlisted": "unlisted",
|
||||
}
|
||||
return mapping.get(privacy, "private")
|
||||
|
||||
async def validate_credentials(self) -> bool:
|
||||
"""Check if YouTube credentials are configured."""
|
||||
return os.path.exists(self.client_secrets_path) or os.path.exists(self.token_path)
|
||||
|
||||
def get_platform_requirements(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"max_file_size_mb": 256000, # 256GB
|
||||
"max_duration_seconds": 43200, # 12 hours
|
||||
"supported_formats": ["mp4", "mov", "avi", "webm", "mkv"],
|
||||
"recommended_resolution": (1920, 1080),
|
||||
"recommended_codec": "h264",
|
||||
"quota_cost_per_upload": 100,
|
||||
}
|
||||
84
pixelle_video/services/quality/__init__.py
Normal file
84
pixelle_video/services/quality/__init__.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Quality Assurance Services for Pixelle-Video
|
||||
|
||||
This module provides quality control mechanisms for video generation:
|
||||
- QualityGate: Evaluates generated content quality
|
||||
- RetryManager: Smart retry with quality-based decisions
|
||||
- OutputValidator: LLM output validation
|
||||
- StyleGuard: Visual style consistency
|
||||
- ContentFilter: Content moderation
|
||||
- CharacterMemory: Character consistency across frames
|
||||
"""
|
||||
|
||||
from pixelle_video.services.quality.models import (
|
||||
QualityScore,
|
||||
QualityConfig,
|
||||
RetryConfig,
|
||||
QualityError,
|
||||
)
|
||||
from pixelle_video.services.quality.quality_gate import QualityGate
|
||||
from pixelle_video.services.quality.retry_manager import RetryManager
|
||||
from pixelle_video.services.quality.output_validator import (
|
||||
OutputValidator,
|
||||
ValidationConfig,
|
||||
ValidationResult,
|
||||
)
|
||||
from pixelle_video.services.quality.style_guard import (
|
||||
StyleGuard,
|
||||
StyleGuardConfig,
|
||||
StyleAnchor,
|
||||
)
|
||||
from pixelle_video.services.quality.content_filter import (
|
||||
ContentFilter,
|
||||
ContentFilterConfig,
|
||||
FilterResult,
|
||||
FilterCategory,
|
||||
)
|
||||
from pixelle_video.services.quality.character_memory import (
|
||||
CharacterMemory,
|
||||
CharacterMemoryConfig,
|
||||
Character,
|
||||
CharacterType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Quality evaluation
|
||||
"QualityScore",
|
||||
"QualityConfig",
|
||||
"RetryConfig",
|
||||
"QualityError",
|
||||
"QualityGate",
|
||||
"RetryManager",
|
||||
# Output validation
|
||||
"OutputValidator",
|
||||
"ValidationConfig",
|
||||
"ValidationResult",
|
||||
# Style consistency
|
||||
"StyleGuard",
|
||||
"StyleGuardConfig",
|
||||
"StyleAnchor",
|
||||
# Content moderation
|
||||
"ContentFilter",
|
||||
"ContentFilterConfig",
|
||||
"FilterResult",
|
||||
"FilterCategory",
|
||||
# Character memory
|
||||
"CharacterMemory",
|
||||
"CharacterMemoryConfig",
|
||||
"Character",
|
||||
"CharacterType",
|
||||
]
|
||||
|
||||
|
||||
530
pixelle_video/services/quality/character_memory.py
Normal file
530
pixelle_video/services/quality/character_memory.py
Normal file
@@ -0,0 +1,530 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
CharacterMemory - Character consistency and memory system
|
||||
|
||||
Maintains consistent character appearance across video frames by:
|
||||
1. Detecting and registering characters from narrations
|
||||
2. Extracting visual descriptions from first appearances
|
||||
3. Injecting character consistency prompts into subsequent frames
|
||||
4. Supporting reference images for ComfyUI IP-Adapter/ControlNet
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Optional, Set
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class CharacterType(Enum):
|
||||
"""Type of character"""
|
||||
PERSON = "person" # Human character
|
||||
ANIMAL = "animal" # Animal character
|
||||
CREATURE = "creature" # Fantasy/fictional creature
|
||||
OBJECT = "object" # Personified object
|
||||
ABSTRACT = "abstract" # Abstract entity
|
||||
|
||||
|
||||
@dataclass
|
||||
class Character:
|
||||
"""
|
||||
Represents a character in the video narrative
|
||||
|
||||
Stores visual description, reference images, and appearance history
|
||||
to maintain consistency across frames.
|
||||
"""
|
||||
|
||||
# Identity
|
||||
id: str # Unique identifier
|
||||
name: str # Character name (e.g., "小明", "the hero")
|
||||
aliases: List[str] = field(default_factory=list) # Alternative names
|
||||
character_type: CharacterType = CharacterType.PERSON
|
||||
|
||||
# Visual description (for prompt injection)
|
||||
appearance_description: str = "" # Detailed visual description
|
||||
clothing_description: str = "" # Clothing/outfit description
|
||||
distinctive_features: List[str] = field(default_factory=list) # Unique features
|
||||
|
||||
# Reference images (for IP-Adapter/ControlNet)
|
||||
reference_images: List[str] = field(default_factory=list) # Paths to reference images
|
||||
primary_reference: Optional[str] = None # Primary reference image
|
||||
|
||||
# Prompt elements
|
||||
prompt_prefix: str = "" # Pre-built prompt prefix
|
||||
negative_prompt: str = "" # Negative prompt additions
|
||||
|
||||
# Metadata
|
||||
is_active: bool = True # Whether this character is active for logic
|
||||
first_appearance_frame: int = 0 # Frame index of first appearance
|
||||
appearance_frames: List[int] = field(default_factory=list) # All frames with this character
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.now()
|
||||
if not hasattr(self, 'is_active'):
|
||||
self.is_active = True
|
||||
if not self.prompt_prefix:
|
||||
self._build_prompt_prefix()
|
||||
|
||||
def _build_prompt_prefix(self):
|
||||
"""Build prompt prefix from visual descriptions"""
|
||||
elements = []
|
||||
|
||||
if self.appearance_description:
|
||||
elements.append(self.appearance_description)
|
||||
if self.clothing_description:
|
||||
elements.append(f"wearing {self.clothing_description}")
|
||||
if self.distinctive_features:
|
||||
elements.append(", ".join(self.distinctive_features))
|
||||
|
||||
self.prompt_prefix = ", ".join(elements) if elements else ""
|
||||
|
||||
def get_prompt_injection(self) -> str:
|
||||
"""Get the prompt text to inject for this character"""
|
||||
if self.prompt_prefix:
|
||||
return f"({self.name}: {self.prompt_prefix})"
|
||||
return f"({self.name})"
|
||||
|
||||
def add_reference_image(self, image_path: str, set_as_primary: bool = False):
|
||||
"""Add a reference image for this character"""
|
||||
if image_path not in self.reference_images:
|
||||
self.reference_images.append(image_path)
|
||||
if set_as_primary or self.primary_reference is None:
|
||||
self.primary_reference = image_path
|
||||
|
||||
def matches_name(self, name: str) -> bool:
|
||||
"""Check if a name matches this character"""
|
||||
name_lower = name.lower().strip()
|
||||
if self.name.lower() == name_lower:
|
||||
return True
|
||||
return any(alias.lower() == name_lower for alias in self.aliases)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"aliases": self.aliases,
|
||||
"type": self.character_type.value,
|
||||
"appearance_description": self.appearance_description,
|
||||
"clothing_description": self.clothing_description,
|
||||
"distinctive_features": self.distinctive_features,
|
||||
"reference_images": self.reference_images,
|
||||
"primary_reference": self.primary_reference,
|
||||
"prompt_prefix": self.prompt_prefix,
|
||||
"first_appearance_frame": self.first_appearance_frame,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CharacterMemoryConfig:
|
||||
"""Configuration for character memory system"""
|
||||
|
||||
# Detection settings
|
||||
auto_detect_characters: bool = True # Automatically detect characters from narrations
|
||||
use_llm_detection: bool = True # Use LLM to extract character info
|
||||
|
||||
# Consistency settings
|
||||
inject_character_prompts: bool = True # Inject character descriptions into prompts
|
||||
use_reference_images: bool = True # Use reference images for generation
|
||||
|
||||
# Reference image settings
|
||||
extract_reference_from_first: bool = True # Extract reference from first appearance
|
||||
max_reference_images: int = 3 # Max reference images per character
|
||||
|
||||
# Prompt injection settings
|
||||
prompt_injection_position: str = "start" # "start" or "end"
|
||||
include_clothing: bool = True # Include clothing in prompts
|
||||
include_features: bool = True # Include distinctive features
|
||||
|
||||
|
||||
class CharacterMemory:
|
||||
"""
|
||||
Character memory and consistency manager
|
||||
|
||||
Tracks characters across video frames and ensures visual consistency
|
||||
by injecting character descriptions and reference images into the
|
||||
generation pipeline.
|
||||
|
||||
Example:
|
||||
>>> memory = CharacterMemory(llm_service)
|
||||
>>>
|
||||
>>> # Register a character
|
||||
>>> char = memory.register_character(
|
||||
... name="小明",
|
||||
... appearance_description="young man with short black hair",
|
||||
... clothing_description="blue t-shirt"
|
||||
... )
|
||||
>>>
|
||||
>>> # Apply to prompt
|
||||
>>> enhanced_prompt = memory.apply_to_prompt(
|
||||
... prompt="A person walking in the park",
|
||||
... characters=["小明"]
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_service=None,
|
||||
config: Optional[CharacterMemoryConfig] = None
|
||||
):
|
||||
"""
|
||||
Initialize CharacterMemory
|
||||
|
||||
Args:
|
||||
llm_service: Optional LLM service for character detection
|
||||
config: Character memory configuration
|
||||
"""
|
||||
self.llm_service = llm_service
|
||||
self.config = config or CharacterMemoryConfig()
|
||||
self._characters: Dict[str, Character] = {}
|
||||
self._name_index: Dict[str, str] = {} # name -> character_id mapping
|
||||
|
||||
def register_character(
|
||||
self,
|
||||
name: str,
|
||||
appearance_description: str = "",
|
||||
clothing_description: str = "",
|
||||
distinctive_features: Optional[List[str]] = None,
|
||||
character_type: CharacterType = CharacterType.PERSON,
|
||||
first_frame: int = 0,
|
||||
) -> Character:
|
||||
"""
|
||||
Register a new character
|
||||
|
||||
Args:
|
||||
name: Character name
|
||||
appearance_description: Visual appearance description
|
||||
clothing_description: Clothing/outfit description
|
||||
distinctive_features: List of distinctive features
|
||||
character_type: Type of character
|
||||
first_frame: Frame index of first appearance
|
||||
|
||||
Returns:
|
||||
Created Character object
|
||||
"""
|
||||
# Generate unique ID
|
||||
char_id = f"char_{len(self._characters)}_{name.replace(' ', '_').lower()}"
|
||||
|
||||
character = Character(
|
||||
id=char_id,
|
||||
name=name,
|
||||
appearance_description=appearance_description,
|
||||
clothing_description=clothing_description,
|
||||
distinctive_features=distinctive_features or [],
|
||||
character_type=character_type,
|
||||
first_appearance_frame=first_frame,
|
||||
appearance_frames=[first_frame],
|
||||
)
|
||||
|
||||
self._characters[char_id] = character
|
||||
self._name_index[name.lower()] = char_id
|
||||
|
||||
logger.info(f"Registered character: {name} (id={char_id})")
|
||||
|
||||
return character
|
||||
|
||||
def get_character(self, name: str) -> Optional[Character]:
|
||||
"""Get a character by name"""
|
||||
name_lower = name.lower().strip()
|
||||
char_id = self._name_index.get(name_lower)
|
||||
if char_id:
|
||||
return self._characters.get(char_id)
|
||||
|
||||
# Search aliases
|
||||
for char in self._characters.values():
|
||||
if char.matches_name(name):
|
||||
return char
|
||||
|
||||
return None
|
||||
|
||||
def get_character_by_id(self, char_id: str) -> Optional[Character]:
|
||||
"""Get a character by ID"""
|
||||
return self._characters.get(char_id)
|
||||
|
||||
@property
|
||||
def characters(self) -> List[Character]:
|
||||
"""Get all registered characters"""
|
||||
return list(self._characters.values())
|
||||
|
||||
async def detect_characters_from_narration(
|
||||
self,
|
||||
narration: str,
|
||||
frame_index: int = 0,
|
||||
) -> List[Character]:
|
||||
"""
|
||||
Detect and register characters mentioned in narration
|
||||
|
||||
Args:
|
||||
narration: Narration text to analyze
|
||||
frame_index: Current frame index
|
||||
|
||||
Returns:
|
||||
List of detected/registered characters
|
||||
"""
|
||||
if not self.config.auto_detect_characters:
|
||||
return []
|
||||
|
||||
detected = []
|
||||
|
||||
if self.config.use_llm_detection and self.llm_service:
|
||||
detected = await self._detect_with_llm(narration, frame_index)
|
||||
else:
|
||||
detected = self._detect_basic(narration, frame_index)
|
||||
|
||||
return detected
|
||||
|
||||
async def _detect_with_llm(
|
||||
self,
|
||||
narration: str,
|
||||
frame_index: int,
|
||||
) -> List[Character]:
|
||||
"""Detect characters using LLM"""
|
||||
if not self.llm_service:
|
||||
return []
|
||||
|
||||
try:
|
||||
prompt = f"""分析以下文案,提取其中提到的角色/人物。
|
||||
|
||||
文案: {narration}
|
||||
|
||||
请用 JSON 格式返回角色列表,每个角色包含:
|
||||
- name: 角色名称或代称
|
||||
- type: person/animal/creature/object
|
||||
- appearance: 外貌描述(如有)
|
||||
- clothing: 服装描述(如有)
|
||||
|
||||
如果没有明确角色,返回空列表 []。
|
||||
|
||||
只返回 JSON,不要其他解释。"""
|
||||
|
||||
response = await self.llm_service(prompt, temperature=0.1)
|
||||
|
||||
# Parse response
|
||||
import json
|
||||
import re
|
||||
|
||||
# Extract JSON from response
|
||||
json_match = re.search(r'\[.*\]', response, re.DOTALL)
|
||||
if json_match:
|
||||
characters_data = json.loads(json_match.group())
|
||||
|
||||
result = []
|
||||
for char_data in characters_data:
|
||||
name = char_data.get("name", "").strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
# Check if already registered
|
||||
existing = self.get_character(name)
|
||||
if existing:
|
||||
existing.appearance_frames.append(frame_index)
|
||||
result.append(existing)
|
||||
else:
|
||||
# Register new character
|
||||
char_type = CharacterType.PERSON
|
||||
type_str = char_data.get("type", "person").lower()
|
||||
if type_str == "animal":
|
||||
char_type = CharacterType.ANIMAL
|
||||
elif type_str == "creature":
|
||||
char_type = CharacterType.CREATURE
|
||||
|
||||
char = self.register_character(
|
||||
name=name,
|
||||
appearance_description=char_data.get("appearance", ""),
|
||||
clothing_description=char_data.get("clothing", ""),
|
||||
character_type=char_type,
|
||||
first_frame=frame_index,
|
||||
)
|
||||
result.append(char)
|
||||
|
||||
return result
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM character detection failed: {e}")
|
||||
return self._detect_basic(narration, frame_index)
|
||||
|
||||
def _detect_basic(
|
||||
self,
|
||||
narration: str,
|
||||
frame_index: int,
|
||||
) -> List[Character]:
|
||||
"""Basic character detection without LLM"""
|
||||
# Simple pattern matching for common character references
|
||||
import re
|
||||
|
||||
patterns = [
|
||||
r'(?:他|她|它)们?', # Chinese pronouns
|
||||
r'(?:小\w{1,2})', # Names like 小明, 小红
|
||||
r'(?:老\w{1,2})', # Names like 老王, 老李
|
||||
]
|
||||
|
||||
detected = []
|
||||
for pattern in patterns:
|
||||
matches = re.findall(pattern, narration)
|
||||
for match in matches:
|
||||
existing = self.get_character(match)
|
||||
if existing:
|
||||
existing.appearance_frames.append(frame_index)
|
||||
if existing not in detected:
|
||||
detected.append(existing)
|
||||
|
||||
return detected
|
||||
|
||||
def apply_to_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
character_names: Optional[List[str]] = None,
|
||||
frame_index: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Apply character consistency to an image prompt
|
||||
|
||||
Args:
|
||||
prompt: Original image prompt
|
||||
character_names: Specific characters to include (None = auto-detect)
|
||||
frame_index: Current frame index for tracking
|
||||
|
||||
Returns:
|
||||
Enhanced prompt with character consistency
|
||||
"""
|
||||
if not self.config.inject_character_prompts:
|
||||
return prompt
|
||||
|
||||
characters_to_include = []
|
||||
|
||||
if character_names:
|
||||
for name in character_names:
|
||||
char = self.get_character(name)
|
||||
if char:
|
||||
characters_to_include.append(char)
|
||||
else:
|
||||
# Include all characters that have appeared
|
||||
characters_to_include = self.characters
|
||||
|
||||
if not characters_to_include:
|
||||
return prompt
|
||||
|
||||
# Build character injection
|
||||
injections = []
|
||||
for char in characters_to_include:
|
||||
injection = char.get_prompt_injection()
|
||||
if injection:
|
||||
injections.append(injection)
|
||||
|
||||
# Track appearance
|
||||
if frame_index is not None and frame_index not in char.appearance_frames:
|
||||
char.appearance_frames.append(frame_index)
|
||||
|
||||
if not injections:
|
||||
return prompt
|
||||
|
||||
character_prompt = ", ".join(injections)
|
||||
|
||||
if self.config.prompt_injection_position == "start":
|
||||
return f"{character_prompt}, {prompt}"
|
||||
else:
|
||||
return f"{prompt}, {character_prompt}"
|
||||
|
||||
def get_reference_images(
|
||||
self,
|
||||
character_names: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get reference images for specified characters
|
||||
|
||||
Args:
|
||||
character_names: Character names (None = all characters)
|
||||
|
||||
Returns:
|
||||
List of reference image paths
|
||||
"""
|
||||
if not self.config.use_reference_images:
|
||||
return []
|
||||
|
||||
images = []
|
||||
|
||||
if character_names:
|
||||
for name in character_names:
|
||||
char = self.get_character(name)
|
||||
if char and char.primary_reference:
|
||||
images.append(char.primary_reference)
|
||||
else:
|
||||
for char in self.characters:
|
||||
if char.primary_reference:
|
||||
images.append(char.primary_reference)
|
||||
|
||||
return images[:self.config.max_reference_images]
|
||||
|
||||
def set_reference_image(
|
||||
self,
|
||||
character_name: str,
|
||||
image_path: str,
|
||||
set_as_primary: bool = True,
|
||||
):
|
||||
"""
|
||||
Set a reference image for a character
|
||||
|
||||
Args:
|
||||
character_name: Character name
|
||||
image_path: Path to reference image
|
||||
set_as_primary: Whether to set as primary reference
|
||||
"""
|
||||
char = self.get_character(character_name)
|
||||
if char:
|
||||
char.add_reference_image(image_path, set_as_primary)
|
||||
logger.debug(f"Set reference image for {character_name}: {image_path}")
|
||||
else:
|
||||
logger.warning(f"Character not found: {character_name}")
|
||||
|
||||
def update_character_appearance(
|
||||
self,
|
||||
character_name: str,
|
||||
appearance_description: Optional[str] = None,
|
||||
clothing_description: Optional[str] = None,
|
||||
distinctive_features: Optional[List[str]] = None,
|
||||
):
|
||||
"""Update a character's visual description"""
|
||||
char = self.get_character(character_name)
|
||||
if char:
|
||||
if appearance_description:
|
||||
char.appearance_description = appearance_description
|
||||
if clothing_description:
|
||||
char.clothing_description = clothing_description
|
||||
if distinctive_features:
|
||||
char.distinctive_features = distinctive_features
|
||||
char._build_prompt_prefix()
|
||||
logger.debug(f"Updated appearance for {character_name}")
|
||||
|
||||
def get_consistency_summary(self) -> str:
|
||||
"""Get a summary of character consistency for logging"""
|
||||
if not self._characters:
|
||||
return "No characters registered"
|
||||
|
||||
lines = [f"Characters ({len(self._characters)}):"]
|
||||
for char in self.characters:
|
||||
lines.append(
|
||||
f" - {char.name}: {len(char.appearance_frames)} appearances, "
|
||||
f"ref_images={len(char.reference_images)}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
def reset(self):
|
||||
"""Clear all character memory"""
|
||||
self._characters.clear()
|
||||
self._name_index.clear()
|
||||
logger.info("Character memory cleared")
|
||||
316
pixelle_video/services/quality/content_filter.py
Normal file
316
pixelle_video/services/quality/content_filter.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
ContentFilter - Content moderation and safety filtering
|
||||
|
||||
Provides content safety checks for:
|
||||
- Text content (narrations, prompts)
|
||||
- Generated images
|
||||
- Generated videos
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Set
|
||||
from enum import Enum
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class FilterCategory(Enum):
|
||||
"""Content filter categories"""
|
||||
SAFE = "safe"
|
||||
SENSITIVE = "sensitive" # May require review
|
||||
BLOCKED = "blocked" # Should not proceed
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterResult:
|
||||
"""Result of content filtering"""
|
||||
category: FilterCategory
|
||||
passed: bool
|
||||
flagged_items: List[str] = field(default_factory=list)
|
||||
reason: Optional[str] = None
|
||||
confidence: float = 1.0
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"category": self.category.value,
|
||||
"passed": self.passed,
|
||||
"flagged_items": self.flagged_items,
|
||||
"reason": self.reason,
|
||||
"confidence": self.confidence,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentFilterConfig:
|
||||
"""Configuration for content filtering"""
|
||||
|
||||
# Text filtering
|
||||
enable_keyword_filter: bool = True
|
||||
enable_llm_filter: bool = False # Use LLM for semantic filtering
|
||||
|
||||
# Custom keywords to block (added to default list)
|
||||
custom_blocked_keywords: List[str] = field(default_factory=list)
|
||||
|
||||
# Sensitivity level: "strict", "moderate", "permissive"
|
||||
sensitivity_level: str = "moderate"
|
||||
|
||||
# Image filtering
|
||||
enable_image_filter: bool = False # Requires external service
|
||||
|
||||
# Action on detection
|
||||
block_on_sensitive: bool = False # Block content marked as sensitive
|
||||
log_filtered_content: bool = True
|
||||
|
||||
|
||||
class ContentFilter:
|
||||
"""
|
||||
Content moderation filter for generated content
|
||||
|
||||
Provides safety filtering for text and media content to prevent
|
||||
inappropriate or harmful content from being generated.
|
||||
|
||||
Example:
|
||||
>>> filter = ContentFilter()
|
||||
>>> result = await filter.check_text("Hello, world!")
|
||||
>>> if result.passed:
|
||||
... print("Content is safe")
|
||||
"""
|
||||
|
||||
# Default blocked keywords (minimal list for demonstration)
|
||||
DEFAULT_BLOCKED_PATTERNS = [
|
||||
r"\b(violence|gore|blood)\b",
|
||||
r"\b(nsfw|explicit|pornographic)\b",
|
||||
r"\b(illegal|drugs|weapons)\b",
|
||||
]
|
||||
|
||||
# Sensitive keywords (may require review)
|
||||
DEFAULT_SENSITIVE_PATTERNS = [
|
||||
r"\b(death|dying|kill)\b",
|
||||
r"\b(hate|racist|sexist)\b",
|
||||
r"\b(controversial|political)\b",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_service=None,
|
||||
config: Optional[ContentFilterConfig] = None
|
||||
):
|
||||
"""
|
||||
Initialize ContentFilter
|
||||
|
||||
Args:
|
||||
llm_service: Optional LLM service for semantic filtering
|
||||
config: Filter configuration
|
||||
"""
|
||||
self.llm_service = llm_service
|
||||
self.config = config or ContentFilterConfig()
|
||||
|
||||
# Compile patterns
|
||||
self._blocked_patterns = [
|
||||
re.compile(p, re.IGNORECASE)
|
||||
for p in self.DEFAULT_BLOCKED_PATTERNS
|
||||
]
|
||||
self._sensitive_patterns = [
|
||||
re.compile(p, re.IGNORECASE)
|
||||
for p in self.DEFAULT_SENSITIVE_PATTERNS
|
||||
]
|
||||
|
||||
# Add custom keywords
|
||||
if self.config.custom_blocked_keywords:
|
||||
for keyword in self.config.custom_blocked_keywords:
|
||||
pattern = re.compile(rf"\b{re.escape(keyword)}\b", re.IGNORECASE)
|
||||
self._blocked_patterns.append(pattern)
|
||||
|
||||
async def check_text(self, text: str) -> FilterResult:
|
||||
"""
|
||||
Check text content for safety
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
FilterResult with safety assessment
|
||||
"""
|
||||
if not text:
|
||||
return FilterResult(
|
||||
category=FilterCategory.SAFE,
|
||||
passed=True
|
||||
)
|
||||
|
||||
flagged_items = []
|
||||
category = FilterCategory.SAFE
|
||||
|
||||
# Keyword filtering
|
||||
if self.config.enable_keyword_filter:
|
||||
# Check blocked patterns
|
||||
for pattern in self._blocked_patterns:
|
||||
matches = pattern.findall(text)
|
||||
if matches:
|
||||
flagged_items.extend(matches)
|
||||
category = FilterCategory.BLOCKED
|
||||
|
||||
# Check sensitive patterns (if not already blocked)
|
||||
if category != FilterCategory.BLOCKED:
|
||||
for pattern in self._sensitive_patterns:
|
||||
matches = pattern.findall(text)
|
||||
if matches:
|
||||
flagged_items.extend(matches)
|
||||
category = FilterCategory.SENSITIVE
|
||||
|
||||
# LLM-based semantic filtering
|
||||
if self.config.enable_llm_filter and self.llm_service:
|
||||
semantic_result = await self._check_with_llm(text)
|
||||
if semantic_result.category.value > category.value:
|
||||
category = semantic_result.category
|
||||
flagged_items.extend(semantic_result.flagged_items)
|
||||
|
||||
# Determine if passed
|
||||
if category == FilterCategory.BLOCKED:
|
||||
passed = False
|
||||
reason = "Content contains blocked keywords or themes"
|
||||
elif category == FilterCategory.SENSITIVE and self.config.block_on_sensitive:
|
||||
passed = False
|
||||
reason = "Content contains sensitive themes (strict mode)"
|
||||
else:
|
||||
passed = True
|
||||
reason = None
|
||||
|
||||
# Log if configured
|
||||
if self.config.log_filtered_content and flagged_items:
|
||||
logger.warning(f"Content filter flagged: {flagged_items}")
|
||||
|
||||
return FilterResult(
|
||||
category=category,
|
||||
passed=passed,
|
||||
flagged_items=flagged_items,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
async def check_texts(self, texts: List[str]) -> FilterResult:
|
||||
"""
|
||||
Check multiple texts and return aggregate result
|
||||
|
||||
Args:
|
||||
texts: List of texts to check
|
||||
|
||||
Returns:
|
||||
Aggregate FilterResult
|
||||
"""
|
||||
all_flagged = []
|
||||
worst_category = FilterCategory.SAFE
|
||||
|
||||
for text in texts:
|
||||
result = await self.check_text(text)
|
||||
all_flagged.extend(result.flagged_items)
|
||||
|
||||
if result.category == FilterCategory.BLOCKED:
|
||||
worst_category = FilterCategory.BLOCKED
|
||||
elif result.category == FilterCategory.SENSITIVE and worst_category != FilterCategory.BLOCKED:
|
||||
worst_category = FilterCategory.SENSITIVE
|
||||
|
||||
passed = worst_category == FilterCategory.SAFE or (
|
||||
worst_category == FilterCategory.SENSITIVE and not self.config.block_on_sensitive
|
||||
)
|
||||
|
||||
return FilterResult(
|
||||
category=worst_category,
|
||||
passed=passed,
|
||||
flagged_items=list(set(all_flagged)), # Deduplicate
|
||||
reason=f"Found {len(all_flagged)} flagged items" if all_flagged else None,
|
||||
)
|
||||
|
||||
async def check_image(self, image_path: str) -> FilterResult:
|
||||
"""
|
||||
Check image content for safety
|
||||
|
||||
Note: This requires external NSFW detection service integration.
|
||||
Currently returns safe by default.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
|
||||
Returns:
|
||||
FilterResult with safety assessment
|
||||
"""
|
||||
if not self.config.enable_image_filter:
|
||||
return FilterResult(
|
||||
category=FilterCategory.SAFE,
|
||||
passed=True,
|
||||
reason="Image filtering disabled"
|
||||
)
|
||||
|
||||
# TODO: Integrate with external NSFW detection
|
||||
# Options: TensorFlow NSFW model, Azure Content Moderator, etc.
|
||||
|
||||
logger.debug(f"Image safety check: {image_path} (not implemented, assuming safe)")
|
||||
|
||||
return FilterResult(
|
||||
category=FilterCategory.SAFE,
|
||||
passed=True,
|
||||
reason="Image filtering not implemented"
|
||||
)
|
||||
|
||||
async def _check_with_llm(self, text: str) -> FilterResult:
|
||||
"""Check text using LLM for semantic understanding"""
|
||||
if not self.llm_service:
|
||||
return FilterResult(category=FilterCategory.SAFE, passed=True)
|
||||
|
||||
try:
|
||||
prompt = f"""Analyze the following text for content safety.
|
||||
Rate it as: SAFE, SENSITIVE, or BLOCKED.
|
||||
|
||||
SAFE: No concerning content
|
||||
SENSITIVE: Contains themes that may need review (violence, death, controversial topics)
|
||||
BLOCKED: Contains explicit, illegal, or harmful content
|
||||
|
||||
Text: {text[:500]}
|
||||
|
||||
Respond with only one word: SAFE, SENSITIVE, or BLOCKED."""
|
||||
|
||||
response = await self.llm_service(prompt, temperature=0.0, max_tokens=10)
|
||||
response = response.strip().upper()
|
||||
|
||||
if "BLOCKED" in response:
|
||||
return FilterResult(
|
||||
category=FilterCategory.BLOCKED,
|
||||
passed=False,
|
||||
reason="LLM detected blocked content"
|
||||
)
|
||||
elif "SENSITIVE" in response:
|
||||
return FilterResult(
|
||||
category=FilterCategory.SENSITIVE,
|
||||
passed=not self.config.block_on_sensitive,
|
||||
reason="LLM detected sensitive content"
|
||||
)
|
||||
else:
|
||||
return FilterResult(
|
||||
category=FilterCategory.SAFE,
|
||||
passed=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM content check failed: {e}")
|
||||
return FilterResult(category=FilterCategory.SAFE, passed=True)
|
||||
|
||||
def add_blocked_keyword(self, keyword: str):
|
||||
"""Add a keyword to the blocked list"""
|
||||
pattern = re.compile(rf"\b{re.escape(keyword)}\b", re.IGNORECASE)
|
||||
self._blocked_patterns.append(pattern)
|
||||
|
||||
def add_sensitive_keyword(self, keyword: str):
|
||||
"""Add a keyword to the sensitive list"""
|
||||
pattern = re.compile(rf"\b{re.escape(keyword)}\b", re.IGNORECASE)
|
||||
self._sensitive_patterns.append(pattern)
|
||||
140
pixelle_video/services/quality/models.py
Normal file
140
pixelle_video/services/quality/models.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Data models for quality assurance
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class QualityLevel(Enum):
|
||||
"""Quality level enumeration"""
|
||||
EXCELLENT = "excellent" # >= 0.8
|
||||
GOOD = "good" # >= 0.6
|
||||
ACCEPTABLE = "acceptable" # >= 0.4
|
||||
POOR = "poor" # < 0.4
|
||||
|
||||
|
||||
@dataclass
|
||||
class QualityScore:
|
||||
"""Quality evaluation result for generated content"""
|
||||
|
||||
# Individual scores (0.0 - 1.0)
|
||||
aesthetic_score: float = 0.0 # Visual appeal / beauty
|
||||
text_match_score: float = 0.0 # How well image matches the prompt
|
||||
technical_score: float = 0.0 # Technical quality (clarity, no artifacts)
|
||||
|
||||
# Overall
|
||||
overall_score: float = 0.0 # Weighted average
|
||||
passed: bool = False # Whether it meets threshold
|
||||
|
||||
# Diagnostics
|
||||
issues: List[str] = field(default_factory=list) # Detected problems
|
||||
evaluation_time_ms: float = 0.0 # Time taken for evaluation
|
||||
|
||||
@property
|
||||
def level(self) -> QualityLevel:
|
||||
"""Get quality level based on overall score"""
|
||||
if self.overall_score >= 0.8:
|
||||
return QualityLevel.EXCELLENT
|
||||
elif self.overall_score >= 0.6:
|
||||
return QualityLevel.GOOD
|
||||
elif self.overall_score >= 0.4:
|
||||
return QualityLevel.ACCEPTABLE
|
||||
else:
|
||||
return QualityLevel.POOR
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for serialization"""
|
||||
return {
|
||||
"aesthetic_score": self.aesthetic_score,
|
||||
"text_match_score": self.text_match_score,
|
||||
"technical_score": self.technical_score,
|
||||
"overall_score": self.overall_score,
|
||||
"passed": self.passed,
|
||||
"level": self.level.value,
|
||||
"issues": self.issues,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class QualityConfig:
|
||||
"""Configuration for quality evaluation"""
|
||||
|
||||
# Thresholds (0.0 - 1.0)
|
||||
overall_threshold: float = 0.6 # Minimum overall score to pass
|
||||
aesthetic_threshold: float = 0.5 # Minimum aesthetic score
|
||||
text_match_threshold: float = 0.6 # Minimum text-match score
|
||||
technical_threshold: float = 0.7 # Minimum technical score
|
||||
|
||||
# Weights for overall score calculation
|
||||
aesthetic_weight: float = 0.3
|
||||
text_match_weight: float = 0.4
|
||||
technical_weight: float = 0.3
|
||||
|
||||
# Evaluation settings
|
||||
use_vlm_evaluation: bool = True # Use VLM for evaluation (vs local models)
|
||||
vlm_model: Optional[str] = None # VLM model to use (None = use default LLM)
|
||||
skip_on_static_template: bool = True # Skip image quality check for static templates
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate weights sum to 1.0"""
|
||||
total = self.aesthetic_weight + self.text_match_weight + self.technical_weight
|
||||
if abs(total - 1.0) > 0.01:
|
||||
# Normalize weights
|
||||
self.aesthetic_weight /= total
|
||||
self.text_match_weight /= total
|
||||
self.technical_weight /= total
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""Configuration for retry behavior"""
|
||||
|
||||
max_retries: int = 3 # Maximum retry attempts
|
||||
backoff_factor: float = 1.5 # Exponential backoff multiplier
|
||||
initial_delay_ms: int = 500 # Initial delay before first retry
|
||||
max_delay_ms: int = 10000 # Maximum delay between retries
|
||||
|
||||
# Quality-based retry
|
||||
quality_threshold: float = 0.6 # Quality score threshold for pass
|
||||
|
||||
# Fallback behavior
|
||||
enable_fallback: bool = True # Enable fallback strategy on failure
|
||||
fallback_prompt_simplify: bool = True # Simplify prompt on retry
|
||||
|
||||
# Retry conditions
|
||||
retry_on_quality_fail: bool = True # Retry when quality below threshold
|
||||
retry_on_error: bool = True # Retry on generation errors
|
||||
|
||||
|
||||
class QualityError(Exception):
|
||||
"""Exception raised when quality standards are not met after all retries"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
attempts: int = 0,
|
||||
last_score: Optional[QualityScore] = None
|
||||
):
|
||||
super().__init__(message)
|
||||
self.attempts = attempts
|
||||
self.last_score = last_score
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.last_score:
|
||||
return f"{base} (attempts={self.attempts}, last_score={self.last_score.overall_score:.2f})"
|
||||
return f"{base} (attempts={self.attempts})"
|
||||
336
pixelle_video/services/quality/output_validator.py
Normal file
336
pixelle_video/services/quality/output_validator.py
Normal file
@@ -0,0 +1,336 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OutputValidator - LLM output validation and quality control
|
||||
|
||||
Validates LLM-generated content for:
|
||||
- Narrations: length, relevance, coherence
|
||||
- Image prompts: format, language, prompt-narration alignment
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationConfig:
|
||||
"""Configuration for output validation"""
|
||||
|
||||
# Narration validation
|
||||
min_narration_words: int = 5
|
||||
max_narration_words: int = 50
|
||||
relevance_threshold: float = 0.6
|
||||
coherence_threshold: float = 0.6
|
||||
|
||||
# Image prompt validation
|
||||
min_prompt_words: int = 10
|
||||
max_prompt_words: int = 100
|
||||
require_english_prompts: bool = True
|
||||
prompt_match_threshold: float = 0.5
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of validation"""
|
||||
passed: bool
|
||||
issues: List[str] = field(default_factory=list)
|
||||
score: float = 1.0 # 1.0 = perfect, 0.0 = failed
|
||||
suggestions: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"passed": self.passed,
|
||||
"score": self.score,
|
||||
"issues": self.issues,
|
||||
"suggestions": self.suggestions,
|
||||
}
|
||||
|
||||
|
||||
class OutputValidator:
|
||||
"""
|
||||
Validator for LLM-generated outputs
|
||||
|
||||
Validates narrations and image prompts to ensure they meet quality standards
|
||||
before proceeding with media generation.
|
||||
|
||||
Example:
|
||||
>>> validator = OutputValidator(llm_service)
|
||||
>>> result = await validator.validate_narrations(
|
||||
... narrations=["旁白1", "旁白2"],
|
||||
... topic="人生哲理",
|
||||
... config=ValidationConfig()
|
||||
... )
|
||||
>>> if not result.passed:
|
||||
... print(f"Validation failed: {result.issues}")
|
||||
"""
|
||||
|
||||
def __init__(self, llm_service=None):
|
||||
"""
|
||||
Initialize OutputValidator
|
||||
|
||||
Args:
|
||||
llm_service: Optional LLM service for semantic validation
|
||||
"""
|
||||
self.llm_service = llm_service
|
||||
|
||||
async def validate_narrations(
|
||||
self,
|
||||
narrations: List[str],
|
||||
topic: str,
|
||||
config: Optional[ValidationConfig] = None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate generated narrations
|
||||
|
||||
Checks:
|
||||
1. Length constraints
|
||||
2. Non-empty content
|
||||
3. Topic relevance (if LLM available)
|
||||
4. Coherence between narrations (if LLM available)
|
||||
|
||||
Args:
|
||||
narrations: List of narration texts
|
||||
topic: Original topic/theme
|
||||
config: Validation configuration
|
||||
|
||||
Returns:
|
||||
ValidationResult with pass/fail and issues
|
||||
"""
|
||||
cfg = config or ValidationConfig()
|
||||
issues = []
|
||||
suggestions = []
|
||||
scores = []
|
||||
|
||||
if not narrations:
|
||||
return ValidationResult(
|
||||
passed=False,
|
||||
issues=["No narrations provided"],
|
||||
score=0.0
|
||||
)
|
||||
|
||||
# 1. Length validation
|
||||
for i, narration in enumerate(narrations, 1):
|
||||
word_count = len(narration)
|
||||
|
||||
if not narration.strip():
|
||||
issues.append(f"分镜{i}: 内容为空")
|
||||
scores.append(0.0)
|
||||
continue
|
||||
|
||||
if word_count < cfg.min_narration_words:
|
||||
issues.append(f"分镜{i}: 内容过短 ({word_count}字,最少{cfg.min_narration_words}字)")
|
||||
scores.append(0.5)
|
||||
elif word_count > cfg.max_narration_words:
|
||||
issues.append(f"分镜{i}: 内容过长 ({word_count}字,最多{cfg.max_narration_words}字)")
|
||||
suggestions.append(f"考虑将分镜{i}拆分为多个短句")
|
||||
scores.append(0.7)
|
||||
else:
|
||||
scores.append(1.0)
|
||||
|
||||
# 2. Semantic validation (if LLM available)
|
||||
if self.llm_service:
|
||||
try:
|
||||
relevance = await self._check_relevance(narrations, topic)
|
||||
if relevance < cfg.relevance_threshold:
|
||||
issues.append(f'内容与主题"{topic}"相关性不足 ({relevance:.0%})')
|
||||
suggestions.append("建议重新生成,确保内容紧扣主题")
|
||||
scores.append(relevance)
|
||||
|
||||
coherence = await self._check_coherence(narrations)
|
||||
if coherence < cfg.coherence_threshold:
|
||||
issues.append(f"内容连贯性不足 ({coherence:.0%})")
|
||||
suggestions.append("建议检查段落之间的逻辑衔接")
|
||||
scores.append(coherence)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Semantic validation failed: {e}")
|
||||
# Don't fail on semantic check errors
|
||||
|
||||
# Calculate overall score
|
||||
overall_score = sum(scores) / len(scores) if scores else 0.0
|
||||
passed = len(issues) == 0 or overall_score >= 0.7
|
||||
|
||||
logger.debug(f"Narration validation: score={overall_score:.2f}, issues={len(issues)}")
|
||||
|
||||
return ValidationResult(
|
||||
passed=passed,
|
||||
issues=issues,
|
||||
score=overall_score,
|
||||
suggestions=suggestions,
|
||||
)
|
||||
|
||||
async def validate_image_prompts(
|
||||
self,
|
||||
prompts: List[str],
|
||||
narrations: List[str],
|
||||
config: Optional[ValidationConfig] = None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate generated image prompts
|
||||
|
||||
Checks:
|
||||
1. Length constraints
|
||||
2. Language (should be English)
|
||||
3. Prompt-narration alignment
|
||||
|
||||
Args:
|
||||
prompts: List of image prompts
|
||||
narrations: Corresponding narrations
|
||||
config: Validation configuration
|
||||
|
||||
Returns:
|
||||
ValidationResult with pass/fail and issues
|
||||
"""
|
||||
cfg = config or ValidationConfig()
|
||||
issues = []
|
||||
suggestions = []
|
||||
scores = []
|
||||
|
||||
if not prompts:
|
||||
return ValidationResult(
|
||||
passed=False,
|
||||
issues=["No image prompts provided"],
|
||||
score=0.0
|
||||
)
|
||||
|
||||
if len(prompts) != len(narrations):
|
||||
issues.append(f"提示词数量({len(prompts)})与旁白数量({len(narrations)})不匹配")
|
||||
|
||||
for i, prompt in enumerate(prompts, 1):
|
||||
if not prompt.strip():
|
||||
issues.append(f"提示词{i}: 内容为空")
|
||||
scores.append(0.0)
|
||||
continue
|
||||
|
||||
word_count = len(prompt.split())
|
||||
|
||||
# Length check
|
||||
if word_count < cfg.min_prompt_words:
|
||||
issues.append(f"提示词{i}: 过短 ({word_count}词,最少{cfg.min_prompt_words}词)")
|
||||
scores.append(0.5)
|
||||
elif word_count > cfg.max_prompt_words:
|
||||
issues.append(f"提示词{i}: 过长 ({word_count}词,最多{cfg.max_prompt_words}词)")
|
||||
scores.append(0.8)
|
||||
else:
|
||||
scores.append(1.0)
|
||||
|
||||
# English check
|
||||
if cfg.require_english_prompts:
|
||||
chinese_ratio = self._get_chinese_ratio(prompt)
|
||||
if chinese_ratio > 0.3: # More than 30% Chinese characters
|
||||
issues.append(f"提示词{i}: 应使用英文 (当前含{chinese_ratio:.0%}中文)")
|
||||
suggestions.append(f"将提示词{i}翻译为英文以获得更好的生成效果")
|
||||
scores[-1] *= 0.5
|
||||
|
||||
overall_score = sum(scores) / len(scores) if scores else 0.0
|
||||
passed = len(issues) == 0 or overall_score >= 0.7
|
||||
|
||||
logger.debug(f"Image prompt validation: score={overall_score:.2f}, issues={len(issues)}")
|
||||
|
||||
return ValidationResult(
|
||||
passed=passed,
|
||||
issues=issues,
|
||||
score=overall_score,
|
||||
suggestions=suggestions,
|
||||
)
|
||||
|
||||
async def _check_relevance(
|
||||
self,
|
||||
narrations: List[str],
|
||||
topic: str,
|
||||
) -> float:
|
||||
"""
|
||||
Check relevance of narrations to topic using LLM
|
||||
|
||||
Returns:
|
||||
Relevance score 0.0-1.0
|
||||
"""
|
||||
if not self.llm_service:
|
||||
return 0.8 # Default score when LLM not available
|
||||
|
||||
try:
|
||||
combined_text = "\n".join(narrations[:3]) # Check first 3 for efficiency
|
||||
|
||||
prompt = f"""评估以下内容与主题"{topic}"的相关性。
|
||||
|
||||
内容:
|
||||
{combined_text}
|
||||
|
||||
请用0-100的分数评估相关性,只输出数字。
|
||||
相关性越高,分数越高。"""
|
||||
|
||||
response = await self.llm_service(prompt, temperature=0.1, max_tokens=10)
|
||||
|
||||
# Parse score from response
|
||||
score_match = re.search(r'\d+', response)
|
||||
if score_match:
|
||||
score = int(score_match.group()) / 100
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
return 0.7 # Default if parsing fails
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Relevance check failed: {e}")
|
||||
return 0.7
|
||||
|
||||
async def _check_coherence(
|
||||
self,
|
||||
narrations: List[str],
|
||||
) -> float:
|
||||
"""
|
||||
Check coherence between narrations using LLM
|
||||
|
||||
Returns:
|
||||
Coherence score 0.0-1.0
|
||||
"""
|
||||
if not self.llm_service or len(narrations) < 2:
|
||||
return 0.8 # Default score
|
||||
|
||||
try:
|
||||
numbered = "\n".join(f"{i+1}. {n}" for i, n in enumerate(narrations[:5]))
|
||||
|
||||
prompt = f"""评估以下段落之间的逻辑连贯性。
|
||||
|
||||
{numbered}
|
||||
|
||||
请用0-100的分数评估连贯性,只输出数字。
|
||||
段落之间逻辑顺畅、衔接自然则分数高。"""
|
||||
|
||||
response = await self.llm_service(prompt, temperature=0.1, max_tokens=10)
|
||||
|
||||
score_match = re.search(r'\d+', response)
|
||||
if score_match:
|
||||
score = int(score_match.group()) / 100
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
return 0.7
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Coherence check failed: {e}")
|
||||
return 0.7
|
||||
|
||||
def _get_chinese_ratio(self, text: str) -> float:
|
||||
"""Calculate ratio of Chinese characters in text"""
|
||||
if not text:
|
||||
return 0.0
|
||||
|
||||
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||||
total_chars = len(text.replace(' ', ''))
|
||||
|
||||
if total_chars == 0:
|
||||
return 0.0
|
||||
|
||||
return chinese_chars / total_chars
|
||||
363
pixelle_video/services/quality/quality_gate.py
Normal file
363
pixelle_video/services/quality/quality_gate.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
QualityGate - Quality evaluation system for generated content
|
||||
|
||||
Evaluates images and videos based on:
|
||||
- Aesthetic quality (visual appeal)
|
||||
- Text-to-image matching (semantic alignment)
|
||||
- Technical quality (clarity, no artifacts)
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pixelle_video.services.quality.models import QualityScore, QualityConfig
|
||||
|
||||
|
||||
class QualityGate:
|
||||
"""
|
||||
Quality evaluation gate for AI-generated content
|
||||
|
||||
Uses VLM (Vision Language Model) or local models to evaluate:
|
||||
1. Aesthetic quality - Is the image visually appealing?
|
||||
2. Text matching - Does the image match the prompt/narration?
|
||||
3. Technical quality - Is the image clear and free of artifacts?
|
||||
|
||||
Example:
|
||||
>>> gate = QualityGate(llm_service, config)
|
||||
>>> score = await gate.evaluate_image(
|
||||
... image_path="output/frame_001.png",
|
||||
... prompt="A sunset over mountains",
|
||||
... narration="夕阳西下,余晖洒满山间"
|
||||
... )
|
||||
>>> if score.passed:
|
||||
... print("Image quality approved!")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_service=None,
|
||||
config: Optional[QualityConfig] = None
|
||||
):
|
||||
"""
|
||||
Initialize QualityGate
|
||||
|
||||
Args:
|
||||
llm_service: LLM service for VLM-based evaluation
|
||||
config: Quality configuration
|
||||
"""
|
||||
self.llm_service = llm_service
|
||||
self.config = config or QualityConfig()
|
||||
|
||||
async def evaluate_image(
|
||||
self,
|
||||
image_path: str,
|
||||
prompt: str,
|
||||
narration: Optional[str] = None,
|
||||
) -> QualityScore:
|
||||
"""
|
||||
Evaluate the quality of a generated image
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file
|
||||
prompt: The prompt used to generate the image
|
||||
narration: Optional narration text for context
|
||||
|
||||
Returns:
|
||||
QualityScore with evaluation results
|
||||
"""
|
||||
start_time = time.time()
|
||||
issues = []
|
||||
|
||||
# Validate image exists
|
||||
if not Path(image_path).exists():
|
||||
return QualityScore(
|
||||
passed=False,
|
||||
issues=["Image file not found"],
|
||||
evaluation_time_ms=(time.time() - start_time) * 1000
|
||||
)
|
||||
|
||||
# Evaluate using VLM or fallback to basic checks
|
||||
if self.config.use_vlm_evaluation and self.llm_service:
|
||||
score = await self._evaluate_with_vlm(image_path, prompt, narration)
|
||||
else:
|
||||
score = await self._evaluate_basic(image_path, prompt)
|
||||
|
||||
# Set evaluation time
|
||||
score.evaluation_time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Determine if passed
|
||||
score.passed = score.overall_score >= self.config.overall_threshold
|
||||
|
||||
logger.debug(
|
||||
f"Quality evaluation: overall={score.overall_score:.2f}, "
|
||||
f"passed={score.passed}, time={score.evaluation_time_ms:.0f}ms"
|
||||
)
|
||||
|
||||
return score
|
||||
|
||||
async def evaluate_video(
|
||||
self,
|
||||
video_path: str,
|
||||
prompt: str,
|
||||
narration: Optional[str] = None,
|
||||
) -> QualityScore:
|
||||
"""
|
||||
Evaluate the quality of a generated video
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file
|
||||
prompt: The prompt used to generate the video
|
||||
narration: Optional narration text for context
|
||||
|
||||
Returns:
|
||||
QualityScore with evaluation results
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Validate video exists
|
||||
if not Path(video_path).exists():
|
||||
return QualityScore(
|
||||
passed=False,
|
||||
issues=["Video file not found"],
|
||||
evaluation_time_ms=(time.time() - start_time) * 1000
|
||||
)
|
||||
|
||||
# For video, we can extract key frames and evaluate
|
||||
# For now, use VLM with video input or sample frames
|
||||
if self.config.use_vlm_evaluation and self.llm_service:
|
||||
score = await self._evaluate_video_with_vlm(video_path, prompt, narration)
|
||||
else:
|
||||
score = await self._evaluate_video_basic(video_path)
|
||||
|
||||
score.evaluation_time_ms = (time.time() - start_time) * 1000
|
||||
score.passed = score.overall_score >= self.config.overall_threshold
|
||||
|
||||
return score
|
||||
|
||||
async def _evaluate_with_vlm(
|
||||
self,
|
||||
image_path: str,
|
||||
prompt: str,
|
||||
narration: Optional[str] = None,
|
||||
) -> QualityScore:
|
||||
"""
|
||||
Evaluate image quality using Vision Language Model
|
||||
|
||||
Uses the LLM with vision capability to assess:
|
||||
- Visual quality and aesthetics
|
||||
- Prompt-image alignment
|
||||
- Technical defects
|
||||
"""
|
||||
evaluation_prompt = self._build_evaluation_prompt(prompt, narration)
|
||||
|
||||
try:
|
||||
# Call LLM with image (requires VLM-capable model like GPT-4o, Qwen-VL)
|
||||
# Note: This requires the LLM service to support vision input
|
||||
# For now, we'll use a basic score if VLM is not available
|
||||
|
||||
# TODO: Implement actual VLM call when integrating with vision-capable LLM
|
||||
# response = await self.llm_service(
|
||||
# prompt=evaluation_prompt,
|
||||
# images=[image_path],
|
||||
# response_type=ImageQualityResponse
|
||||
# )
|
||||
|
||||
# Fallback to basic evaluation for now
|
||||
logger.debug("VLM evaluation: using basic fallback (VLM integration pending)")
|
||||
return await self._evaluate_basic(image_path, prompt)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"VLM evaluation failed: {e}, falling back to basic")
|
||||
return await self._evaluate_basic(image_path, prompt)
|
||||
|
||||
async def _evaluate_basic(
|
||||
self,
|
||||
image_path: str,
|
||||
prompt: str,
|
||||
) -> QualityScore:
|
||||
"""
|
||||
Basic image quality evaluation without VLM
|
||||
|
||||
Performs simple checks:
|
||||
- File size and dimensions
|
||||
- Image format validation
|
||||
"""
|
||||
issues = []
|
||||
|
||||
try:
|
||||
# Import PIL for basic checks
|
||||
from PIL import Image
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
width, height = img.size
|
||||
|
||||
# Check minimum dimensions
|
||||
if width < 256 or height < 256:
|
||||
issues.append(f"Image too small: {width}x{height}")
|
||||
|
||||
# Check aspect ratio (not too extreme)
|
||||
aspect = max(width, height) / min(width, height)
|
||||
if aspect > 4:
|
||||
issues.append(f"Extreme aspect ratio: {aspect:.1f}")
|
||||
|
||||
# Basic scores (generous defaults when VLM not available)
|
||||
aesthetic_score = 0.7 if not issues else 0.4
|
||||
text_match_score = 0.7 # Can't properly evaluate without VLM
|
||||
technical_score = 0.8 if not issues else 0.5
|
||||
|
||||
# Calculate overall
|
||||
overall = (
|
||||
aesthetic_score * self.config.aesthetic_weight +
|
||||
text_match_score * self.config.text_match_weight +
|
||||
technical_score * self.config.technical_weight
|
||||
)
|
||||
|
||||
return QualityScore(
|
||||
aesthetic_score=aesthetic_score,
|
||||
text_match_score=text_match_score,
|
||||
technical_score=technical_score,
|
||||
overall_score=overall,
|
||||
issues=issues,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Basic evaluation failed: {e}")
|
||||
return QualityScore(
|
||||
overall_score=0.0,
|
||||
passed=False,
|
||||
issues=[f"Evaluation error: {str(e)}"]
|
||||
)
|
||||
|
||||
async def _evaluate_video_with_vlm(
|
||||
self,
|
||||
video_path: str,
|
||||
prompt: str,
|
||||
narration: Optional[str] = None,
|
||||
) -> QualityScore:
|
||||
"""Evaluate video using VLM (placeholder for future implementation)"""
|
||||
# TODO: Implement video frame sampling and VLM evaluation
|
||||
return await self._evaluate_video_basic(video_path)
|
||||
|
||||
async def _evaluate_video_basic(
|
||||
self,
|
||||
video_path: str,
|
||||
) -> QualityScore:
|
||||
"""Basic video quality evaluation"""
|
||||
issues = []
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
# Use ffprobe to get video info
|
||||
cmd = [
|
||||
"ffprobe", "-v", "quiet", "-print_format", "json",
|
||||
"-show_format", "-show_streams", video_path
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
issues.append("Failed to read video metadata")
|
||||
return QualityScore(overall_score=0.5, issues=issues)
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
|
||||
# Check for video stream
|
||||
video_stream = None
|
||||
for stream in info.get("streams", []):
|
||||
if stream.get("codec_type") == "video":
|
||||
video_stream = stream
|
||||
break
|
||||
|
||||
if not video_stream:
|
||||
issues.append("No video stream found")
|
||||
return QualityScore(overall_score=0.0, passed=False, issues=issues)
|
||||
|
||||
# Check dimensions
|
||||
width = video_stream.get("width", 0)
|
||||
height = video_stream.get("height", 0)
|
||||
if width < 256 or height < 256:
|
||||
issues.append(f"Video too small: {width}x{height}")
|
||||
|
||||
# Check duration
|
||||
duration = float(info.get("format", {}).get("duration", 0))
|
||||
if duration < 0.5:
|
||||
issues.append(f"Video too short: {duration:.1f}s")
|
||||
|
||||
# Calculate scores
|
||||
aesthetic_score = 0.7
|
||||
text_match_score = 0.7
|
||||
technical_score = 0.8 if not issues else 0.5
|
||||
|
||||
overall = (
|
||||
aesthetic_score * self.config.aesthetic_weight +
|
||||
text_match_score * self.config.text_match_weight +
|
||||
technical_score * self.config.technical_weight
|
||||
)
|
||||
|
||||
return QualityScore(
|
||||
aesthetic_score=aesthetic_score,
|
||||
text_match_score=text_match_score,
|
||||
technical_score=technical_score,
|
||||
overall_score=overall,
|
||||
issues=issues,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Video evaluation failed: {e}")
|
||||
return QualityScore(
|
||||
overall_score=0.5,
|
||||
issues=[f"Evaluation error: {str(e)}"]
|
||||
)
|
||||
|
||||
def _build_evaluation_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
narration: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Build the evaluation prompt for VLM"""
|
||||
context = f"Narration: {narration}\n" if narration else ""
|
||||
|
||||
return f"""Evaluate this AI-generated image on the following criteria.
|
||||
Rate each from 0.0 to 1.0.
|
||||
|
||||
Image Generation Prompt: {prompt}
|
||||
{context}
|
||||
Evaluation Criteria:
|
||||
|
||||
1. Aesthetic Quality (0.0-1.0):
|
||||
- Is the image visually appealing?
|
||||
- Good composition, colors, and style?
|
||||
|
||||
2. Prompt Matching (0.0-1.0):
|
||||
- Does the image accurately represent the prompt?
|
||||
- Are key elements from the prompt visible?
|
||||
|
||||
3. Technical Quality (0.0-1.0):
|
||||
- Is the image clear and well-defined?
|
||||
- Free of artifacts, distortions, or blurriness?
|
||||
- Natural looking (no AI artifacts like extra fingers)?
|
||||
|
||||
Respond in JSON format:
|
||||
{{
|
||||
"aesthetic_score": 0.0,
|
||||
"text_match_score": 0.0,
|
||||
"technical_score": 0.0,
|
||||
"issues": ["list of any problems found"]
|
||||
}}
|
||||
"""
|
||||
296
pixelle_video/services/quality/retry_manager.py
Normal file
296
pixelle_video/services/quality/retry_manager.py
Normal file
@@ -0,0 +1,296 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
RetryManager - Smart retry logic with quality-based decisions
|
||||
|
||||
Provides unified retry management for:
|
||||
- Media generation (images, videos)
|
||||
- LLM calls
|
||||
- Any async operations with quality evaluation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, TypeVar, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pixelle_video.services.quality.models import (
|
||||
QualityScore,
|
||||
QualityConfig,
|
||||
RetryConfig,
|
||||
QualityError,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryResult:
|
||||
"""Result of a retry operation"""
|
||||
success: bool
|
||||
result: Any
|
||||
attempts: int
|
||||
quality_score: Optional[QualityScore] = None
|
||||
error: Optional[Exception] = None
|
||||
|
||||
|
||||
class RetryManager:
|
||||
"""
|
||||
Smart retry manager with quality-based decisions
|
||||
|
||||
Features:
|
||||
- Exponential backoff with configurable delays
|
||||
- Quality-aware retry (retries when quality below threshold)
|
||||
- Fallback strategies
|
||||
- Detailed logging and metrics
|
||||
|
||||
Example:
|
||||
>>> retry_manager = RetryManager()
|
||||
>>>
|
||||
>>> async def generate():
|
||||
... return await media_service.generate_image(prompt)
|
||||
>>>
|
||||
>>> async def evaluate(image_path):
|
||||
... return await quality_gate.evaluate_image(image_path, prompt)
|
||||
>>>
|
||||
>>> result = await retry_manager.execute_with_retry(
|
||||
... operation=generate,
|
||||
... quality_evaluator=evaluate,
|
||||
... config=RetryConfig(max_retries=3)
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[RetryConfig] = None):
|
||||
"""
|
||||
Initialize RetryManager
|
||||
|
||||
Args:
|
||||
config: Default retry configuration
|
||||
"""
|
||||
self.default_config = config or RetryConfig()
|
||||
|
||||
async def execute_with_retry(
|
||||
self,
|
||||
operation: Callable[[], Any],
|
||||
quality_evaluator: Optional[Callable[[Any], QualityScore]] = None,
|
||||
config: Optional[RetryConfig] = None,
|
||||
fallback_operation: Optional[Callable[[], Any]] = None,
|
||||
operation_name: str = "operation",
|
||||
) -> RetryResult:
|
||||
"""
|
||||
Execute operation with automatic retry and quality evaluation
|
||||
|
||||
Args:
|
||||
operation: Async callable to execute
|
||||
quality_evaluator: Optional async callable to evaluate result quality
|
||||
config: Retry configuration (uses default if not provided)
|
||||
fallback_operation: Optional fallback to try when all retries fail
|
||||
operation_name: Name for logging purposes
|
||||
|
||||
Returns:
|
||||
RetryResult with success status, result, and quality score
|
||||
|
||||
Raises:
|
||||
QualityError: When all retries fail and no fallback available
|
||||
"""
|
||||
cfg = config or self.default_config
|
||||
last_error: Optional[Exception] = None
|
||||
last_score: Optional[QualityScore] = None
|
||||
|
||||
for attempt in range(1, cfg.max_retries + 1):
|
||||
try:
|
||||
# Execute the operation
|
||||
logger.debug(f"{operation_name}: attempt {attempt}/{cfg.max_retries}")
|
||||
result = await operation()
|
||||
|
||||
# If no quality evaluator, accept the result
|
||||
if quality_evaluator is None:
|
||||
logger.debug(f"{operation_name}: completed (no quality check)")
|
||||
return RetryResult(
|
||||
success=True,
|
||||
result=result,
|
||||
attempts=attempt,
|
||||
)
|
||||
|
||||
# Evaluate quality
|
||||
score = await quality_evaluator(result)
|
||||
last_score = score
|
||||
|
||||
if score.passed:
|
||||
logger.info(
|
||||
f"{operation_name}: passed quality check "
|
||||
f"(score={score.overall_score:.2f}, attempt={attempt})"
|
||||
)
|
||||
return RetryResult(
|
||||
success=True,
|
||||
result=result,
|
||||
attempts=attempt,
|
||||
quality_score=score,
|
||||
)
|
||||
|
||||
# Quality check failed
|
||||
logger.warning(
|
||||
f"{operation_name}: quality check failed "
|
||||
f"(score={score.overall_score:.2f}, threshold={cfg.quality_threshold}, "
|
||||
f"issues={score.issues})"
|
||||
)
|
||||
|
||||
if not cfg.retry_on_quality_fail:
|
||||
# Don't retry on quality failure
|
||||
return RetryResult(
|
||||
success=False,
|
||||
result=result,
|
||||
attempts=attempt,
|
||||
quality_score=score,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(f"{operation_name}: attempt {attempt} failed with error: {e}")
|
||||
|
||||
if not cfg.retry_on_error:
|
||||
raise
|
||||
|
||||
# Calculate backoff delay
|
||||
if attempt < cfg.max_retries:
|
||||
delay_ms = min(
|
||||
cfg.initial_delay_ms * (cfg.backoff_factor ** (attempt - 1)),
|
||||
cfg.max_delay_ms
|
||||
)
|
||||
logger.debug(f"{operation_name}: waiting {delay_ms:.0f}ms before retry")
|
||||
await asyncio.sleep(delay_ms / 1000)
|
||||
|
||||
# All retries exhausted
|
||||
logger.warning(f"{operation_name}: all {cfg.max_retries} attempts failed")
|
||||
|
||||
# Try fallback if available
|
||||
if cfg.enable_fallback and fallback_operation:
|
||||
try:
|
||||
logger.info(f"{operation_name}: trying fallback strategy")
|
||||
result = await fallback_operation()
|
||||
|
||||
# Evaluate fallback result if evaluator available
|
||||
if quality_evaluator:
|
||||
score = await quality_evaluator(result)
|
||||
return RetryResult(
|
||||
success=score.passed if score else True,
|
||||
result=result,
|
||||
attempts=cfg.max_retries + 1,
|
||||
quality_score=score,
|
||||
)
|
||||
|
||||
return RetryResult(
|
||||
success=True,
|
||||
result=result,
|
||||
attempts=cfg.max_retries + 1,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{operation_name}: fallback also failed: {e}")
|
||||
last_error = e
|
||||
|
||||
# Nothing worked
|
||||
if last_error:
|
||||
raise QualityError(
|
||||
f"{operation_name} failed after {cfg.max_retries} attempts: {last_error}",
|
||||
attempts=cfg.max_retries,
|
||||
last_score=last_score,
|
||||
)
|
||||
|
||||
raise QualityError(
|
||||
f"{operation_name} failed to meet quality threshold after {cfg.max_retries} attempts",
|
||||
attempts=cfg.max_retries,
|
||||
last_score=last_score,
|
||||
)
|
||||
|
||||
async def execute_simple(
|
||||
self,
|
||||
operation: Callable[[], Any],
|
||||
config: Optional[RetryConfig] = None,
|
||||
operation_name: str = "operation",
|
||||
) -> Any:
|
||||
"""
|
||||
Execute operation with simple retry (no quality evaluation)
|
||||
|
||||
Args:
|
||||
operation: Async callable to execute
|
||||
config: Retry configuration
|
||||
operation_name: Name for logging
|
||||
|
||||
Returns:
|
||||
Operation result
|
||||
|
||||
Raises:
|
||||
Exception: Last exception if all retries fail
|
||||
"""
|
||||
cfg = config or self.default_config
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for attempt in range(1, cfg.max_retries + 1):
|
||||
try:
|
||||
logger.debug(f"{operation_name}: attempt {attempt}/{cfg.max_retries}")
|
||||
result = await operation()
|
||||
logger.debug(f"{operation_name}: completed on attempt {attempt}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(f"{operation_name}: attempt {attempt} failed: {e}")
|
||||
|
||||
if attempt < cfg.max_retries:
|
||||
delay_ms = min(
|
||||
cfg.initial_delay_ms * (cfg.backoff_factor ** (attempt - 1)),
|
||||
cfg.max_delay_ms
|
||||
)
|
||||
await asyncio.sleep(delay_ms / 1000)
|
||||
|
||||
raise last_error
|
||||
|
||||
@staticmethod
|
||||
def create_prompt_simplifier(original_prompt: str) -> Callable[[], str]:
|
||||
"""
|
||||
Create a fallback that simplifies the prompt
|
||||
|
||||
Args:
|
||||
original_prompt: Original complex prompt
|
||||
|
||||
Returns:
|
||||
Callable that returns a simplified prompt
|
||||
"""
|
||||
def simplify() -> str:
|
||||
# Simple prompt simplification strategies
|
||||
simplified = original_prompt
|
||||
|
||||
# Remove complex modifiers
|
||||
remove_phrases = [
|
||||
"highly detailed",
|
||||
"ultra realistic",
|
||||
"8k resolution",
|
||||
"masterpiece",
|
||||
"best quality",
|
||||
]
|
||||
for phrase in remove_phrases:
|
||||
simplified = simplified.replace(phrase, "").replace(phrase.lower(), "")
|
||||
|
||||
# Truncate if too long
|
||||
if len(simplified) > 150:
|
||||
simplified = simplified[:150].rsplit(" ", 1)[0]
|
||||
|
||||
# Clean up
|
||||
simplified = " ".join(simplified.split())
|
||||
|
||||
return simplified
|
||||
|
||||
return simplify
|
||||
276
pixelle_video/services/quality/style_guard.py
Normal file
276
pixelle_video/services/quality/style_guard.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
StyleGuard - Visual style consistency engine
|
||||
|
||||
Ensures consistent visual style across all frames in a video by:
|
||||
1. Extracting style anchor from the first generated frame
|
||||
2. Injecting style constraints into subsequent frame prompts
|
||||
3. (Optional) Using style reference techniques like IP-Adapter
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class StyleAnchor:
|
||||
"""Style anchor extracted from reference frame"""
|
||||
|
||||
# Core style elements
|
||||
color_palette: str = "" # e.g., "warm earth tones", "cool blues"
|
||||
art_style: str = "" # e.g., "minimalist", "realistic", "anime"
|
||||
composition_style: str = "" # e.g., "centered", "rule of thirds"
|
||||
texture: str = "" # e.g., "smooth", "grainy", "watercolor"
|
||||
lighting: str = "" # e.g., "soft ambient", "dramatic shadows"
|
||||
|
||||
# Combined style prefix for prompts
|
||||
style_prefix: str = ""
|
||||
|
||||
# Reference image path (for IP-Adapter style techniques)
|
||||
reference_image: Optional[str] = None
|
||||
|
||||
def to_prompt_prefix(self) -> str:
|
||||
"""Generate a style prefix for image prompts"""
|
||||
if self.style_prefix:
|
||||
return self.style_prefix
|
||||
|
||||
elements = []
|
||||
if self.art_style:
|
||||
elements.append(f"{self.art_style} style")
|
||||
if self.color_palette:
|
||||
elements.append(f"{self.color_palette}")
|
||||
if self.lighting:
|
||||
elements.append(f"{self.lighting} lighting")
|
||||
if self.texture:
|
||||
elements.append(f"{self.texture} texture")
|
||||
|
||||
return ", ".join(elements) if elements else ""
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"color_palette": self.color_palette,
|
||||
"art_style": self.art_style,
|
||||
"composition_style": self.composition_style,
|
||||
"texture": self.texture,
|
||||
"lighting": self.lighting,
|
||||
"style_prefix": self.style_prefix,
|
||||
"reference_image": self.reference_image,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class StyleGuardConfig:
|
||||
"""Configuration for StyleGuard"""
|
||||
|
||||
# Extraction settings
|
||||
extract_from_first_frame: bool = True
|
||||
use_vlm_extraction: bool = True
|
||||
|
||||
# Application settings
|
||||
apply_to_all_frames: bool = True
|
||||
prefix_position: str = "start" # "start" or "end"
|
||||
|
||||
# Optional external style reference
|
||||
external_style_image: Optional[str] = None
|
||||
custom_style_prefix: Optional[str] = None
|
||||
|
||||
|
||||
class StyleGuard:
|
||||
"""
|
||||
Style consistency guardian for video generation
|
||||
|
||||
Ensures all frames in a video maintain visual consistency by:
|
||||
1. Analyzing the first frame (or reference image) to extract style
|
||||
2. Applying style constraints to all subsequent frame prompts
|
||||
|
||||
Example:
|
||||
>>> style_guard = StyleGuard(llm_service)
|
||||
>>>
|
||||
>>> # Extract style from first frame
|
||||
>>> anchor = await style_guard.extract_style_anchor(
|
||||
... image_path="output/frame_001.png"
|
||||
... )
|
||||
>>>
|
||||
>>> # Apply to subsequent prompts
|
||||
>>> styled_prompt = style_guard.apply_style(
|
||||
... prompt="A cat sitting on a windowsill",
|
||||
... style_anchor=anchor
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_service=None,
|
||||
config: Optional[StyleGuardConfig] = None
|
||||
):
|
||||
"""
|
||||
Initialize StyleGuard
|
||||
|
||||
Args:
|
||||
llm_service: LLM service for VLM-based style extraction
|
||||
config: StyleGuard configuration
|
||||
"""
|
||||
self.llm_service = llm_service
|
||||
self.config = config or StyleGuardConfig()
|
||||
self._current_anchor: Optional[StyleAnchor] = None
|
||||
|
||||
async def extract_style_anchor(
|
||||
self,
|
||||
image_path: str,
|
||||
) -> StyleAnchor:
|
||||
"""
|
||||
Extract style anchor from reference image
|
||||
|
||||
Args:
|
||||
image_path: Path to reference image
|
||||
|
||||
Returns:
|
||||
StyleAnchor with extracted style elements
|
||||
"""
|
||||
logger.info(f"Extracting style anchor from: {image_path}")
|
||||
|
||||
if self.config.custom_style_prefix:
|
||||
# Use custom style prefix if provided
|
||||
anchor = StyleAnchor(
|
||||
style_prefix=self.config.custom_style_prefix,
|
||||
reference_image=image_path
|
||||
)
|
||||
self._current_anchor = anchor
|
||||
return anchor
|
||||
|
||||
if self.config.use_vlm_extraction and self.llm_service:
|
||||
anchor = await self._extract_with_vlm(image_path)
|
||||
else:
|
||||
anchor = self._extract_basic(image_path)
|
||||
|
||||
self._current_anchor = anchor
|
||||
logger.info(f"Style anchor extracted: {anchor.to_prompt_prefix()}")
|
||||
|
||||
return anchor
|
||||
|
||||
async def _extract_with_vlm(self, image_path: str) -> StyleAnchor:
|
||||
"""Extract style using Vision Language Model"""
|
||||
try:
|
||||
# TODO: Implement VLM call when vision-capable LLM is integrated
|
||||
# For now, return a placeholder
|
||||
logger.debug("VLM style extraction: using placeholder (VLM not yet integrated)")
|
||||
|
||||
# Placeholder extraction based on common styles
|
||||
return StyleAnchor(
|
||||
art_style="consistent artistic",
|
||||
color_palette="harmonious colors",
|
||||
lighting="balanced",
|
||||
style_prefix="maintaining visual consistency, same artistic style as previous frames",
|
||||
reference_image=image_path,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"VLM style extraction failed: {e}")
|
||||
return self._extract_basic(image_path)
|
||||
|
||||
def _extract_basic(self, image_path: str) -> StyleAnchor:
|
||||
"""Basic style extraction without VLM"""
|
||||
# Return generic style anchor
|
||||
return StyleAnchor(
|
||||
style_prefix="consistent visual style",
|
||||
reference_image=image_path,
|
||||
)
|
||||
|
||||
def apply_style(
|
||||
self,
|
||||
prompt: str,
|
||||
style_anchor: Optional[StyleAnchor] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Apply style constraints to an image prompt
|
||||
|
||||
Args:
|
||||
prompt: Original image prompt
|
||||
style_anchor: Style anchor to apply (uses current if not provided)
|
||||
|
||||
Returns:
|
||||
Modified prompt with style constraints
|
||||
"""
|
||||
anchor = style_anchor or self._current_anchor
|
||||
|
||||
if not anchor:
|
||||
return prompt
|
||||
|
||||
style_prefix = anchor.to_prompt_prefix()
|
||||
|
||||
if not style_prefix:
|
||||
return prompt
|
||||
|
||||
if self.config.prefix_position == "start":
|
||||
return f"{style_prefix}, {prompt}"
|
||||
else:
|
||||
return f"{prompt}, {style_prefix}"
|
||||
|
||||
def apply_style_to_batch(
|
||||
self,
|
||||
prompts: List[str],
|
||||
style_anchor: Optional[StyleAnchor] = None,
|
||||
skip_first: bool = True,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Apply style constraints to a batch of prompts
|
||||
|
||||
Args:
|
||||
prompts: List of image prompts
|
||||
style_anchor: Style anchor to apply
|
||||
skip_first: Skip first prompt (used as reference)
|
||||
|
||||
Returns:
|
||||
List of styled prompts
|
||||
"""
|
||||
if not prompts:
|
||||
return prompts
|
||||
|
||||
anchor = style_anchor or self._current_anchor
|
||||
|
||||
if not anchor:
|
||||
return prompts
|
||||
|
||||
result = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
if skip_first and i == 0:
|
||||
result.append(prompt)
|
||||
else:
|
||||
result.append(self.apply_style(prompt, anchor))
|
||||
|
||||
return result
|
||||
|
||||
def get_consistency_prompt_suffix(self) -> str:
|
||||
"""
|
||||
Get a consistency prompt suffix for LLM prompt generation
|
||||
|
||||
This can be added to the LLM prompt when generating image prompts
|
||||
to encourage consistent style descriptions.
|
||||
"""
|
||||
return (
|
||||
"Ensure all image prompts maintain consistent visual style, "
|
||||
"including similar color palette, art style, lighting, and composition. "
|
||||
"Each image should feel like it belongs to the same visual narrative."
|
||||
)
|
||||
|
||||
@property
|
||||
def current_anchor(self) -> Optional[StyleAnchor]:
|
||||
"""Get the current style anchor"""
|
||||
return self._current_anchor
|
||||
|
||||
def reset(self):
|
||||
"""Reset the current style anchor"""
|
||||
self._current_anchor = None
|
||||
Reference in New Issue
Block a user