feat: Add comprehensive timeline editor with frame editing and regeneration capabilities

This commit is contained in:
empty
2026-01-05 14:48:43 +08:00
parent 7d78dcd078
commit ca018a9b1f
68 changed files with 14904 additions and 57 deletions

View File

@@ -73,6 +73,11 @@ class StoryboardFrame:
duration: float = 0.0 # Frame duration (seconds, from audio or video)
created_at: Optional[datetime] = None
# Quality tracking (added for quality assurance)
quality_score: Optional[float] = None # Overall quality score (0.0-1.0)
quality_issues: Optional[List[str]] = None # List of detected quality issues
retry_count: int = 0 # Number of generation retries
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now()

View File

@@ -176,6 +176,16 @@ class StandardPipeline(LinearVideoPipeline):
min_words = ctx.params.get("min_image_prompt_words", 30)
max_words = ctx.params.get("max_image_prompt_words", 60)
# Auto-detect characters from narrations if CharacterMemory is enabled
character_memory = ctx.params.get("character_memory")
if character_memory and character_memory.config.auto_detect_characters:
logger.info("🔍 Auto-detecting characters from narrations...")
for i, narration in enumerate(ctx.narrations):
await character_memory.detect_characters_from_narration(narration, frame_index=i)
if character_memory.characters:
detected_names = [c.name for c in character_memory.characters]
logger.info(f"✅ Detected {len(character_memory.characters)} characters: {detected_names}")
# Override prompt_prefix if provided
original_prefix = None
if prompt_prefix is not None:
@@ -205,6 +215,16 @@ class StandardPipeline(LinearVideoPipeline):
progress_callback=image_prompt_progress
)
# Apply character memory enhancement (if available)
character_memory = ctx.params.get("character_memory")
if character_memory and character_memory.characters:
active_characters = [c for c in character_memory.characters if getattr(c, 'is_active', True)]
if active_characters:
logger.info(f"🎭 Applying {len(active_characters)} active characters to {len(base_image_prompts)} prompts")
base_image_prompts = [
character_memory.apply_to_prompt(p) for p in base_image_prompts
]
# Apply prompt prefix
image_config = self.core.config.get("comfyui", {}).get("image", {})
prompt_prefix_to_use = prompt_prefix if prompt_prefix is not None else image_config.get("prompt_prefix", "")

View File

@@ -27,19 +27,41 @@ from loguru import logger
from pixelle_video.models.progress import ProgressEvent
from pixelle_video.models.storyboard import Storyboard, StoryboardFrame, StoryboardConfig
from pixelle_video.services.quality import (
QualityGate,
QualityConfig,
RetryManager,
RetryConfig,
QualityError,
)
class FrameProcessor:
"""Frame processor"""
def __init__(self, pixelle_video_core):
def __init__(
self,
pixelle_video_core,
quality_config: Optional[QualityConfig] = None,
retry_config: Optional[RetryConfig] = None,
enable_quality_check: bool = True,
):
"""
Initialize
Args:
pixelle_video_core: PixelleVideoCore instance
quality_config: Quality evaluation configuration
retry_config: Retry behavior configuration
enable_quality_check: Whether to enable quality checking
"""
self.core = pixelle_video_core
self.enable_quality_check = enable_quality_check
self.quality_gate = QualityGate(
llm_service=pixelle_video_core.llm if hasattr(pixelle_video_core, 'llm') else None,
config=quality_config or QualityConfig()
)
self.retry_manager = RetryManager(config=retry_config or RetryConfig())
async def __call__(
self,
@@ -199,11 +221,14 @@ class FrameProcessor:
frame: StoryboardFrame,
config: StoryboardConfig
):
"""Step 2: Generate media (image or video) using ComfyKit"""
"""
Step 2: Generate media (image or video) using ComfyKit
Enhanced with quality evaluation and retry logic.
"""
logger.debug(f" 2/4: Generating media for frame {frame.index}...")
# Determine media type based on workflow
# video_ prefix in workflow name indicates video generation
workflow_name = config.media_workflow or ""
is_video_workflow = "video_" in workflow_name.lower()
media_type = "video" if is_video_workflow else "image"
@@ -213,57 +238,87 @@ class FrameProcessor:
# Build media generation parameters
media_params = {
"prompt": frame.image_prompt,
"workflow": config.media_workflow, # Pass workflow from config (None = use default)
"workflow": config.media_workflow,
"media_type": media_type,
"width": config.media_width,
"height": config.media_height,
"index": frame.index + 1, # 1-based index for workflow
"index": frame.index + 1,
}
# For video workflows: pass audio duration as target video duration
# This ensures video length matches audio length from the source
if is_video_workflow and frame.duration:
media_params["duration"] = frame.duration
logger.info(f" → Generating video with target duration: {frame.duration:.2f}s (from TTS audio)")
logger.info(f" → Generating video with target duration: {frame.duration:.2f}s")
# Call Media generation
media_result = await self.core.media(**media_params)
# Define generation operation
async def generate_and_download():
media_result = await self.core.media(**media_params)
local_path = await self._download_media(
media_result.url,
frame.index,
config.task_id,
media_type=media_result.media_type
)
return (media_result, local_path)
# Store media type
# Define quality evaluator
async def evaluate_quality(result):
media_result, local_path = result
if media_result.is_video:
return await self.quality_gate.evaluate_video(
local_path, frame.image_prompt, frame.narration
)
else:
return await self.quality_gate.evaluate_image(
local_path, frame.image_prompt, frame.narration
)
# Execute with retry and quality check
if self.enable_quality_check:
try:
retry_result = await self.retry_manager.execute_with_retry(
operation=generate_and_download,
quality_evaluator=evaluate_quality,
operation_name=f"frame_{frame.index}_media",
)
media_result, local_path = retry_result.result
# Store quality metrics on frame
if retry_result.quality_score:
frame.quality_score = retry_result.quality_score.overall_score
frame.quality_issues = retry_result.quality_score.issues
frame.retry_count = retry_result.attempts - 1 # first attempt is not a retry
except QualityError as e:
logger.warning(f" ⚠ Quality check failed after retries: {e}")
# Still try to use the last result if available
media_result, local_path = await generate_and_download()
frame.quality_issues = [str(e)]
else:
# Quality check disabled - just generate
media_result, local_path = await generate_and_download()
# Store results on frame
frame.media_type = media_result.media_type
if media_result.is_image:
# Download image to local (pass task_id)
local_path = await self._download_media(
media_result.url,
frame.index,
config.task_id,
media_type="image"
)
frame.image_path = local_path
logger.debug(f" ✓ Image generated: {local_path}")
elif media_result.is_video:
# Download video to local (pass task_id)
local_path = await self._download_media(
media_result.url,
frame.index,
config.task_id,
media_type="video"
)
frame.video_path = local_path
# Update duration from video if available
if media_result.duration:
frame.duration = media_result.duration
logger.debug(f" ✓ Video generated: {local_path} (duration: {frame.duration:.2f}s)")
else:
# Get video duration from file
frame.duration = await self._get_video_duration(local_path)
logger.debug(f" ✓ Video generated: {local_path} (duration: {frame.duration:.2f}s)")
logger.debug(f" ✓ Video generated: {local_path} (duration: {frame.duration:.2f}s)")
else:
raise ValueError(f"Unknown media type: {media_result.media_type}")
# Log quality result
if frame.quality_score is not None:
logger.info(
f" 📊 Quality: {frame.quality_score:.2f} "
f"(retries: {frame.retry_count}, issues: {len(frame.quality_issues or [])})"
)
async def _step_compose_frame(
self,

View File

@@ -0,0 +1,134 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Publishing service for multi-platform video distribution.
Supports:
- Format conversion + export (Douyin/Kuaishou)
- API-based upload (Bilibili/YouTube)
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, List, Dict, Any
from datetime import datetime
class Platform(Enum):
"""Supported publishing platforms"""
EXPORT = "export" # Format conversion only
DOUYIN = "douyin" # 抖音 (via export or CDP)
KUAISHOU = "kuaishou" # 快手 (via export or CDP)
BILIBILI = "bilibili" # B站 (API)
YOUTUBE = "youtube" # YouTube (API)
class PublishStatus(Enum):
"""Publishing task status"""
PENDING = "pending"
CONVERTING = "converting"
UPLOADING = "uploading"
PROCESSING = "processing"
PUBLISHED = "published"
FAILED = "failed"
@dataclass
class VideoMetadata:
"""Video metadata for publishing"""
title: str
description: str = ""
tags: List[str] = field(default_factory=list)
category: Optional[str] = None
cover_path: Optional[str] = None
privacy: str = "public" # public, private, unlisted
# Platform-specific options
platform_options: Dict[str, Any] = field(default_factory=dict)
@dataclass
class PublishResult:
"""Result of a publishing operation"""
success: bool
platform: Platform
status: PublishStatus
# On success
video_url: Optional[str] = None
platform_video_id: Optional[str] = None
# On failure
error_message: Optional[str] = None
# Export result
export_path: Optional[str] = None
# Timestamps
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
@dataclass
class PublishTask:
"""A publishing task for background processing"""
id: str
video_path: str
platform: Platform
metadata: VideoMetadata
status: PublishStatus = PublishStatus.PENDING
result: Optional[PublishResult] = None
created_at: datetime = field(default_factory=datetime.now)
updated_at: Optional[datetime] = None
class Publisher(ABC):
"""Abstract base class for platform publishers"""
platform: Platform
@abstractmethod
async def publish(
self,
video_path: str,
metadata: VideoMetadata,
progress_callback: Optional[callable] = None
) -> PublishResult:
"""
Publish a video to the platform.
Args:
video_path: Path to the video file
metadata: Video metadata (title, description, tags, etc.)
progress_callback: Optional callback for progress updates
Returns:
PublishResult with success/failure details
"""
pass
@abstractmethod
async def validate_credentials(self) -> bool:
"""Check if platform credentials are valid"""
pass
def get_platform_requirements(self) -> Dict[str, Any]:
"""Get platform-specific requirements (dimensions, file size, etc.)"""
return {
"max_file_size_mb": 128,
"max_duration_seconds": 900, # 15 minutes
"supported_formats": ["mp4", "webm"],
"recommended_resolution": (1080, 1920), # Portrait 9:16
"recommended_codec": "h264",
}

View File

@@ -0,0 +1,426 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Bilibili Publisher - Upload videos to Bilibili using their Open Platform API.
Flow:
1. Get preupload info (upos_uri, auth, chunk_size)
2. Upload video chunks (8MB each)
3. Merge chunks
4. Submit video with metadata
"""
import os
import math
import aiohttp
import asyncio
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict, Any
from loguru import logger
from pixelle_video.services.publishing import (
Publisher,
Platform,
PublishStatus,
VideoMetadata,
PublishResult,
)
# Bilibili API endpoints
BILIBILI_PREUPLOAD_URL = "https://member.bilibili.com/preupload"
BILIBILI_SUBMIT_URL = "https://member.bilibili.com/x/vu/web/add"
# Chunk size: 8MB (recommended by Bilibili)
CHUNK_SIZE = 8 * 1024 * 1024
class BilibiliPublisher(Publisher):
"""
Publisher for Bilibili video platform.
Requires:
- access_token: OAuth access token
- refresh_token: For token refresh (optional)
"""
platform = Platform.BILIBILI
def __init__(
self,
access_token: Optional[str] = None,
refresh_token: Optional[str] = None,
sessdata: Optional[str] = None, # Alternative: use cookies
bili_jct: Optional[str] = None,
):
self.access_token = access_token or os.getenv("BILIBILI_ACCESS_TOKEN")
self.refresh_token = refresh_token or os.getenv("BILIBILI_REFRESH_TOKEN")
self.sessdata = sessdata or os.getenv("BILIBILI_SESSDATA")
self.bili_jct = bili_jct or os.getenv("BILIBILI_BILI_JCT")
# Upload state
self._upload_id = None
self._upos_uri = None
self._auth = None
self._endpoint = None
async def publish(
self,
video_path: str,
metadata: VideoMetadata,
progress_callback: Optional[callable] = None
) -> PublishResult:
"""Upload and publish video to Bilibili."""
started_at = datetime.now()
try:
if not await self.validate_credentials():
return PublishResult(
success=False,
platform=Platform.BILIBILI,
status=PublishStatus.FAILED,
error_message="B站凭证未配置或已过期",
started_at=started_at,
completed_at=datetime.now(),
)
video_file = Path(video_path)
if not video_file.exists():
return PublishResult(
success=False,
platform=Platform.BILIBILI,
status=PublishStatus.FAILED,
error_message=f"视频文件不存在: {video_path}",
started_at=started_at,
completed_at=datetime.now(),
)
file_size = video_file.stat().st_size
if progress_callback:
progress_callback(0.05, "获取上传信息...")
# Step 1: Get preupload info
preupload_info = await self._preupload(video_file.name, file_size)
if not preupload_info:
return PublishResult(
success=False,
platform=Platform.BILIBILI,
status=PublishStatus.FAILED,
error_message="获取上传信息失败",
started_at=started_at,
completed_at=datetime.now(),
)
if progress_callback:
progress_callback(0.1, "上传视频分片...")
# Step 2: Upload chunks
chunk_count = math.ceil(file_size / CHUNK_SIZE)
uploaded_chunks = 0
async with aiohttp.ClientSession() as session:
with open(video_path, "rb") as f:
for chunk_index in range(chunk_count):
chunk_data = f.read(CHUNK_SIZE)
chunk_start = chunk_index * CHUNK_SIZE
chunk_end = min(chunk_start + len(chunk_data), file_size)
success = await self._upload_chunk(
session,
chunk_data,
chunk_index,
chunk_count,
chunk_start,
chunk_end,
file_size,
)
if not success:
return PublishResult(
success=False,
platform=Platform.BILIBILI,
status=PublishStatus.FAILED,
error_message=f"分片 {chunk_index + 1}/{chunk_count} 上传失败",
started_at=started_at,
completed_at=datetime.now(),
)
uploaded_chunks += 1
progress = 0.1 + (0.7 * uploaded_chunks / chunk_count)
if progress_callback:
progress_callback(progress, f"上传分片 {uploaded_chunks}/{chunk_count}")
if progress_callback:
progress_callback(0.85, "合并视频...")
# Step 3: Merge chunks
video_filename = await self._merge_chunks(chunk_count, file_size)
if not video_filename:
return PublishResult(
success=False,
platform=Platform.BILIBILI,
status=PublishStatus.FAILED,
error_message="视频合并失败",
started_at=started_at,
completed_at=datetime.now(),
)
if progress_callback:
progress_callback(0.9, "提交稿件...")
# Step 4: Submit video
bvid = await self._submit_video(video_filename, metadata)
if not bvid:
return PublishResult(
success=False,
platform=Platform.BILIBILI,
status=PublishStatus.FAILED,
error_message="稿件提交失败",
started_at=started_at,
completed_at=datetime.now(),
)
if progress_callback:
progress_callback(1.0, "发布成功")
return PublishResult(
success=True,
platform=Platform.BILIBILI,
status=PublishStatus.PUBLISHED,
video_url=f"https://www.bilibili.com/video/{bvid}",
platform_video_id=bvid,
started_at=started_at,
completed_at=datetime.now(),
)
except Exception as e:
logger.error(f"Bilibili publish failed: {e}")
return PublishResult(
success=False,
platform=Platform.BILIBILI,
status=PublishStatus.FAILED,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(),
)
async def _preupload(self, filename: str, file_size: int) -> Optional[Dict]:
"""Get preupload info from Bilibili."""
params = {
"name": filename,
"size": file_size,
"r": "upos",
"profile": "ugcupos/bup",
}
headers = self._get_headers()
try:
async with aiohttp.ClientSession() as session:
async with session.get(
BILIBILI_PREUPLOAD_URL,
params=params,
headers=headers
) as resp:
if resp.status != 200:
logger.error(f"Preupload failed: {resp.status}")
return None
data = await resp.json()
self._upos_uri = data.get("upos_uri")
self._auth = data.get("auth")
self._endpoint = data.get("endpoint")
self._upload_id = data.get("upload_id")
logger.info(f"Preupload success: {self._upos_uri}")
return data
except Exception as e:
logger.error(f"Preupload error: {e}")
return None
async def _upload_chunk(
self,
session: aiohttp.ClientSession,
chunk_data: bytes,
chunk_index: int,
chunk_count: int,
chunk_start: int,
chunk_end: int,
total_size: int,
) -> bool:
"""Upload a single chunk."""
if not self._upos_uri or not self._auth:
return False
# Build upload URL
upload_url = f"https:{self._endpoint}{self._upos_uri}"
params = {
"uploadId": self._upload_id,
"partNumber": chunk_index + 1,
"chunk": chunk_index,
"chunks": chunk_count,
"size": len(chunk_data),
"start": chunk_start,
"end": chunk_end,
"total": total_size,
}
headers = {
"X-Upos-Auth": self._auth,
"Content-Type": "application/octet-stream",
}
try:
async with session.put(
upload_url,
params=params,
headers=headers,
data=chunk_data
) as resp:
if resp.status not in [200, 201, 204]:
logger.error(f"Chunk upload failed: {resp.status}")
return False
return True
except Exception as e:
logger.error(f"Chunk upload error: {e}")
return False
async def _merge_chunks(self, chunk_count: int, file_size: int) -> Optional[str]:
"""Merge uploaded chunks."""
if not self._upos_uri:
return None
merge_url = f"https:{self._endpoint}{self._upos_uri}"
params = {
"output": "json",
"name": self._upos_uri.split("/")[-1],
"profile": "ugcupos/bup",
"uploadId": self._upload_id,
"biz_id": "",
}
# Build parts list
parts = [{"partNumber": i + 1, "eTag": "etag"} for i in range(chunk_count)]
headers = {
"X-Upos-Auth": self._auth,
"Content-Type": "application/json",
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
merge_url,
params=params,
headers=headers,
json={"parts": parts}
) as resp:
if resp.status != 200:
logger.error(f"Merge failed: {resp.status}")
return None
data = await resp.json()
return self._upos_uri.split("/")[-1].split(".")[0]
except Exception as e:
logger.error(f"Merge error: {e}")
return None
async def _submit_video(
self,
video_filename: str,
metadata: VideoMetadata
) -> Optional[str]:
"""Submit video with metadata."""
# Default to "生活" category (tid=160)
tid = metadata.platform_options.get("tid", 160)
data = {
"copyright": 1, # 1=原创
"videos": [{
"filename": video_filename,
"title": metadata.title,
"desc": metadata.description,
}],
"title": metadata.title,
"desc": metadata.description,
"tid": tid,
"tag": ",".join(metadata.tags) if metadata.tags else "",
"source": "",
"cover": metadata.cover_path or "",
"no_reprint": 1,
"open_elec": 0,
}
headers = self._get_headers()
headers["Content-Type"] = "application/json"
try:
async with aiohttp.ClientSession() as session:
async with session.post(
BILIBILI_SUBMIT_URL,
headers=headers,
json=data
) as resp:
result = await resp.json()
if result.get("code") == 0:
bvid = result.get("data", {}).get("bvid")
logger.info(f"Video submitted: {bvid}")
return bvid
else:
logger.error(f"Submit failed: {result}")
return None
except Exception as e:
logger.error(f"Submit error: {e}")
return None
def _get_headers(self) -> Dict[str, str]:
"""Get common headers with authentication."""
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36",
"Referer": "https://www.bilibili.com/",
}
if self.access_token:
headers["Authorization"] = f"Bearer {self.access_token}"
if self.sessdata:
headers["Cookie"] = f"SESSDATA={self.sessdata}"
if self.bili_jct:
headers["Cookie"] += f"; bili_jct={self.bili_jct}"
return headers
async def validate_credentials(self) -> bool:
"""Check if Bilibili credentials are configured."""
return bool(self.access_token or self.sessdata)
def get_platform_requirements(self) -> Dict[str, Any]:
return {
"max_file_size_mb": 4096, # 4GB
"max_duration_seconds": 14400, # 4 hours
"supported_formats": ["mp4", "flv", "webm", "mov"],
"recommended_resolution": (1920, 1080),
"recommended_codec": "h264",
"chunk_size_mb": 8,
}

View File

@@ -0,0 +1,182 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Export Publisher - Format conversion and local export for platforms
without API access (Douyin, Kuaishou).
"""
import asyncio
import subprocess
import shutil
from pathlib import Path
from datetime import datetime
from typing import Optional
from loguru import logger
from pixelle_video.services.publishing import (
Publisher,
Platform,
PublishStatus,
VideoMetadata,
PublishResult,
)
class ExportPublisher(Publisher):
"""
Publisher that converts video to platform-optimized format
and exports to local filesystem for manual upload.
"""
platform = Platform.EXPORT
def __init__(
self,
output_dir: str = "./output/exports",
target_resolution: tuple = (1080, 1920), # Portrait 9:16
target_codec: str = "h264",
max_file_size_mb: int = 128,
):
self.output_dir = Path(output_dir)
self.target_resolution = target_resolution
self.target_codec = target_codec
self.max_file_size_mb = max_file_size_mb
# Ensure output directory exists
self.output_dir.mkdir(parents=True, exist_ok=True)
async def publish(
self,
video_path: str,
metadata: VideoMetadata,
progress_callback: Optional[callable] = None
) -> PublishResult:
"""
Convert video to optimized format and export.
"""
started_at = datetime.now()
try:
if progress_callback:
progress_callback(0.1, "分析视频...")
# Generate output filename
safe_title = "".join(c if c.isalnum() or c in " -_" else "" for c in metadata.title)
safe_title = safe_title[:50].strip() or "video"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_filename = f"{safe_title}_{timestamp}.mp4"
output_path = self.output_dir / output_filename
if progress_callback:
progress_callback(0.2, "转换格式...")
# Convert video
success = await self._convert_video(
video_path,
str(output_path),
progress_callback
)
if not success:
return PublishResult(
success=False,
platform=Platform.EXPORT,
status=PublishStatus.FAILED,
error_message="视频转换失败",
started_at=started_at,
completed_at=datetime.now(),
)
if progress_callback:
progress_callback(1.0, "导出完成")
# Verify file size
file_size_mb = output_path.stat().st_size / (1024 * 1024)
return PublishResult(
success=True,
platform=Platform.EXPORT,
status=PublishStatus.PUBLISHED,
export_path=str(output_path),
started_at=started_at,
completed_at=datetime.now(),
)
except Exception as e:
logger.error(f"Export failed: {e}")
return PublishResult(
success=False,
platform=Platform.EXPORT,
status=PublishStatus.FAILED,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(),
)
async def _convert_video(
self,
input_path: str,
output_path: str,
progress_callback: Optional[callable] = None
) -> bool:
"""Convert video using FFmpeg."""
width, height = self.target_resolution
# FFmpeg command for H.264 conversion with size optimization
cmd = [
"ffmpeg", "-y",
"-i", input_path,
"-c:v", "libx264",
"-preset", "medium",
"-crf", "23",
"-c:a", "aac",
"-b:a", "128k",
"-vf", f"scale={width}:{height}:force_original_aspect_ratio=decrease,pad={width}:{height}:(ow-iw)/2:(oh-ih)/2",
"-movflags", "+faststart",
output_path
]
try:
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
logger.error(f"FFmpeg error: {stderr.decode()}")
return False
return True
except FileNotFoundError:
logger.error("FFmpeg not found. Please install FFmpeg.")
# Fallback: just copy the file if FFmpeg is not available
shutil.copy(input_path, output_path)
return True
async def validate_credentials(self) -> bool:
"""No credentials needed for export."""
return True
def get_platform_requirements(self):
return {
"max_file_size_mb": self.max_file_size_mb,
"recommended_resolution": self.target_resolution,
"recommended_codec": self.target_codec,
"output_format": "mp4",
"platforms": ["douyin", "kuaishou"],
}

View File

@@ -0,0 +1,299 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Publish Task Manager - Background task queue for video publishing.
Features:
- Async task queue with configurable workers
- Task persistence (in-memory, Redis optional)
- Progress tracking and callbacks
- Retry logic for failed tasks
"""
import asyncio
import uuid
from datetime import datetime
from typing import Optional, Dict, List, Callable
from dataclasses import dataclass, field
from enum import Enum
from loguru import logger
from pixelle_video.services.publishing import (
Publisher,
Platform,
PublishStatus,
VideoMetadata,
PublishResult,
PublishTask,
)
from pixelle_video.services.publishing.export_publisher import ExportPublisher
from pixelle_video.services.publishing.bilibili_publisher import BilibiliPublisher
from pixelle_video.services.publishing.youtube_publisher import YouTubePublisher
class TaskPriority(Enum):
LOW = 0
NORMAL = 1
HIGH = 2
@dataclass
class QueuedTask:
"""Extended task with queue metadata"""
task: PublishTask
priority: TaskPriority = TaskPriority.NORMAL
retries: int = 0
max_retries: int = 3
retry_delay: float = 5.0
created_at: datetime = field(default_factory=datetime.now)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
progress: float = 0.0
progress_message: str = ""
class PublishTaskManager:
"""
Manages background publishing tasks with async queue.
Usage:
manager = PublishTaskManager()
await manager.start()
task_id = await manager.enqueue(
video_path="/path/to/video.mp4",
platform=Platform.BILIBILI,
metadata=VideoMetadata(title="My Video")
)
status = manager.get_task(task_id)
"""
def __init__(
self,
max_workers: int = 3,
max_queue_size: int = 100,
):
self.max_workers = max_workers
self.max_queue_size = max_queue_size
# Task storage
self._tasks: Dict[str, QueuedTask] = {}
self._queue: asyncio.Queue = None
self._workers: List[asyncio.Task] = []
self._running = False
# Publishers
self._publishers: Dict[Platform, Publisher] = {
Platform.EXPORT: ExportPublisher(),
Platform.BILIBILI: BilibiliPublisher(),
Platform.YOUTUBE: YouTubePublisher(),
}
# Callbacks
self._on_complete: Optional[Callable] = None
self._on_progress: Optional[Callable] = None
async def start(self):
"""Start the task manager and workers."""
if self._running:
return
self._queue = asyncio.Queue(maxsize=self.max_queue_size)
self._running = True
# Start worker tasks
for i in range(self.max_workers):
worker = asyncio.create_task(self._worker(i))
self._workers.append(worker)
logger.info(f"✅ Publish task manager started with {self.max_workers} workers")
async def stop(self):
"""Stop all workers and clear queue."""
self._running = False
# Cancel all workers
for worker in self._workers:
worker.cancel()
await asyncio.gather(*self._workers, return_exceptions=True)
self._workers.clear()
logger.info("✅ Publish task manager stopped")
async def enqueue(
self,
video_path: str,
platform: Platform,
metadata: VideoMetadata,
priority: TaskPriority = TaskPriority.NORMAL,
) -> str:
"""
Add a publish task to the queue.
Returns:
Task ID for tracking
"""
task_id = str(uuid.uuid4())[:8]
task = PublishTask(
id=task_id,
video_path=video_path,
platform=platform,
metadata=metadata,
status=PublishStatus.PENDING,
)
queued_task = QueuedTask(task=task, priority=priority)
self._tasks[task_id] = queued_task
await self._queue.put(queued_task)
logger.info(f"📥 Queued task {task_id}: {metadata.title}{platform.value}")
return task_id
def get_task(self, task_id: str) -> Optional[QueuedTask]:
"""Get task by ID."""
return self._tasks.get(task_id)
def get_all_tasks(self) -> List[QueuedTask]:
"""Get all tasks."""
return list(self._tasks.values())
def get_pending_tasks(self) -> List[QueuedTask]:
"""Get pending tasks."""
return [t for t in self._tasks.values() if t.task.status == PublishStatus.PENDING]
def get_active_tasks(self) -> List[QueuedTask]:
"""Get currently processing tasks."""
return [t for t in self._tasks.values() if t.task.status in [
PublishStatus.CONVERTING,
PublishStatus.UPLOADING,
PublishStatus.PROCESSING,
]]
def set_on_complete(self, callback: Callable):
"""Set callback for task completion."""
self._on_complete = callback
def set_on_progress(self, callback: Callable):
"""Set callback for progress updates."""
self._on_progress = callback
async def _worker(self, worker_id: int):
"""Worker coroutine that processes tasks from queue."""
logger.debug(f"Worker {worker_id} started")
while self._running:
try:
# Get task from queue with timeout
try:
queued_task = await asyncio.wait_for(
self._queue.get(),
timeout=1.0
)
except asyncio.TimeoutError:
continue
await self._process_task(queued_task, worker_id)
self._queue.task_done()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Worker {worker_id} error: {e}")
async def _process_task(self, queued_task: QueuedTask, worker_id: int):
"""Process a single publish task."""
task = queued_task.task
task_id = task.id
logger.info(f"🔄 Worker {worker_id} processing task {task_id}")
queued_task.started_at = datetime.now()
task.status = PublishStatus.UPLOADING
# Get publisher
publisher = self._publishers.get(task.platform)
if not publisher:
task.status = PublishStatus.FAILED
task.result = PublishResult(
success=False,
platform=task.platform,
status=PublishStatus.FAILED,
error_message=f"No publisher for platform: {task.platform}",
)
return
# Progress callback
def progress_callback(progress: float, message: str):
queued_task.progress = progress
queued_task.progress_message = message
if self._on_progress:
self._on_progress(task_id, progress, message)
try:
# Execute publish
result = await publisher.publish(
task.video_path,
task.metadata,
progress_callback=progress_callback
)
task.result = result
task.status = result.status
if result.success:
logger.info(f"✅ Task {task_id} completed: {result.video_url or result.export_path}")
else:
logger.warning(f"❌ Task {task_id} failed: {result.error_message}")
# Retry if applicable
if queued_task.retries < queued_task.max_retries:
queued_task.retries += 1
task.status = PublishStatus.PENDING
logger.info(f"🔄 Retrying task {task_id} ({queued_task.retries}/{queued_task.max_retries})")
await asyncio.sleep(queued_task.retry_delay)
await self._queue.put(queued_task)
return
except Exception as e:
logger.error(f"Task {task_id} exception: {e}")
task.status = PublishStatus.FAILED
task.result = PublishResult(
success=False,
platform=task.platform,
status=PublishStatus.FAILED,
error_message=str(e),
)
queued_task.completed_at = datetime.now()
task.updated_at = datetime.now()
# Call completion callback
if self._on_complete:
self._on_complete(task_id, task.result)
# Singleton instance
_publish_manager: Optional[PublishTaskManager] = None
def get_publish_manager() -> PublishTaskManager:
"""Get or create the global publish task manager."""
global _publish_manager
if _publish_manager is None:
_publish_manager = PublishTaskManager()
return _publish_manager

View File

@@ -0,0 +1,310 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
YouTube Publisher - Upload videos to YouTube using Data API v3.
Requires:
- Google Cloud project with YouTube Data API v3 enabled
- OAuth 2.0 credentials (client_secrets.json)
"""
import os
import pickle
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict, Any
from loguru import logger
from pixelle_video.services.publishing import (
Publisher,
Platform,
PublishStatus,
VideoMetadata,
PublishResult,
)
# YouTube category IDs
YOUTUBE_CATEGORIES = {
"film": "1",
"autos": "2",
"music": "10",
"pets": "15",
"sports": "17",
"travel": "19",
"gaming": "20",
"people": "22",
"comedy": "23",
"entertainment": "24",
"news": "25",
"howto": "26",
"education": "27",
"science": "28",
"nonprofits": "29",
}
class YouTubePublisher(Publisher):
"""
Publisher for YouTube video platform.
Uses Google API Python Client for uploading videos.
Setup:
1. Create project in Google Cloud Console
2. Enable YouTube Data API v3
3. Create OAuth 2.0 credentials
4. Download client_secrets.json
"""
platform = Platform.YOUTUBE
def __init__(
self,
client_secrets_path: Optional[str] = None,
token_path: Optional[str] = None,
):
self.client_secrets_path = client_secrets_path or os.getenv(
"YOUTUBE_CLIENT_SECRETS",
"./config/youtube_client_secrets.json"
)
self.token_path = token_path or os.getenv(
"YOUTUBE_TOKEN_PATH",
"./config/youtube_token.pickle"
)
self._youtube_service = None
async def publish(
self,
video_path: str,
metadata: VideoMetadata,
progress_callback: Optional[callable] = None
) -> PublishResult:
"""Upload and publish video to YouTube."""
started_at = datetime.now()
try:
if not await self.validate_credentials():
return PublishResult(
success=False,
platform=Platform.YOUTUBE,
status=PublishStatus.FAILED,
error_message="YouTube 凭证未配置。请配置 client_secrets.json",
started_at=started_at,
completed_at=datetime.now(),
)
video_file = Path(video_path)
if not video_file.exists():
return PublishResult(
success=False,
platform=Platform.YOUTUBE,
status=PublishStatus.FAILED,
error_message=f"视频文件不存在: {video_path}",
started_at=started_at,
completed_at=datetime.now(),
)
if progress_callback:
progress_callback(0.1, "初始化 YouTube API...")
# Initialize YouTube service
youtube = await self._get_youtube_service()
if not youtube:
return PublishResult(
success=False,
platform=Platform.YOUTUBE,
status=PublishStatus.FAILED,
error_message="无法初始化 YouTube API 服务",
started_at=started_at,
completed_at=datetime.now(),
)
if progress_callback:
progress_callback(0.2, "准备上传...")
# Prepare video metadata
category_id = self._get_category_id(metadata.category)
privacy_status = self._map_privacy(metadata.privacy)
body = {
"snippet": {
"title": metadata.title,
"description": metadata.description,
"tags": metadata.tags,
"categoryId": category_id,
},
"status": {
"privacyStatus": privacy_status,
"selfDeclaredMadeForKids": False,
}
}
# Check for synthetic media flag
if metadata.platform_options.get("contains_synthetic_media"):
body["status"]["containsSyntheticMedia"] = True
if progress_callback:
progress_callback(0.3, "上传视频...")
# Upload using resumable upload
video_id = await self._upload_video(youtube, video_path, body, progress_callback)
if not video_id:
return PublishResult(
success=False,
platform=Platform.YOUTUBE,
status=PublishStatus.FAILED,
error_message="视频上传失败",
started_at=started_at,
completed_at=datetime.now(),
)
if progress_callback:
progress_callback(1.0, "发布成功")
return PublishResult(
success=True,
platform=Platform.YOUTUBE,
status=PublishStatus.PUBLISHED,
video_url=f"https://www.youtube.com/watch?v={video_id}",
platform_video_id=video_id,
started_at=started_at,
completed_at=datetime.now(),
)
except Exception as e:
logger.error(f"YouTube publish failed: {e}")
return PublishResult(
success=False,
platform=Platform.YOUTUBE,
status=PublishStatus.FAILED,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(),
)
async def _get_youtube_service(self):
"""Get authenticated YouTube service."""
try:
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
from googleapiclient.discovery import build
SCOPES = ["https://www.googleapis.com/auth/youtube.upload"]
creds = None
# Load saved token
if os.path.exists(self.token_path):
with open(self.token_path, "rb") as token:
creds = pickle.load(token)
# Refresh or get new credentials
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
from google.auth.transport.requests import Request
creds.refresh(Request())
else:
if not os.path.exists(self.client_secrets_path):
logger.error(f"Client secrets not found: {self.client_secrets_path}")
return None
flow = InstalledAppFlow.from_client_secrets_file(
self.client_secrets_path,
SCOPES
)
creds = flow.run_local_server(port=0)
# Save token
with open(self.token_path, "wb") as token:
pickle.dump(creds, token)
return build("youtube", "v3", credentials=creds)
except ImportError:
logger.error("Google API libraries not installed. Run: pip install google-api-python-client google-auth-oauthlib")
return None
except Exception as e:
logger.error(f"Failed to initialize YouTube service: {e}")
return None
async def _upload_video(
self,
youtube,
video_path: str,
body: dict,
progress_callback: Optional[callable] = None
) -> Optional[str]:
"""Upload video using resumable upload."""
try:
from googleapiclient.http import MediaFileUpload
media = MediaFileUpload(
video_path,
chunksize=1024 * 1024, # 1MB chunks
resumable=True
)
request = youtube.videos().insert(
part=",".join(body.keys()),
body=body,
media_body=media
)
response = None
while response is None:
status, response = request.next_chunk()
if status:
progress = 0.3 + (0.6 * status.progress())
if progress_callback:
progress_callback(progress, f"上传 {int(status.progress() * 100)}%")
video_id = response.get("id")
logger.info(f"Video uploaded: {video_id}")
return video_id
except Exception as e:
logger.error(f"Upload failed: {e}")
return None
def _get_category_id(self, category: Optional[str]) -> str:
"""Map category name to YouTube category ID."""
if not category:
return "22" # Default: People & Blogs
return YOUTUBE_CATEGORIES.get(category.lower(), "22")
def _map_privacy(self, privacy: str) -> str:
"""Map privacy setting to YouTube format."""
mapping = {
"public": "public",
"private": "private",
"unlisted": "unlisted",
}
return mapping.get(privacy, "private")
async def validate_credentials(self) -> bool:
"""Check if YouTube credentials are configured."""
return os.path.exists(self.client_secrets_path) or os.path.exists(self.token_path)
def get_platform_requirements(self) -> Dict[str, Any]:
return {
"max_file_size_mb": 256000, # 256GB
"max_duration_seconds": 43200, # 12 hours
"supported_formats": ["mp4", "mov", "avi", "webm", "mkv"],
"recommended_resolution": (1920, 1080),
"recommended_codec": "h264",
"quota_cost_per_upload": 100,
}

View File

@@ -0,0 +1,84 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Quality Assurance Services for Pixelle-Video
This module provides quality control mechanisms for video generation:
- QualityGate: Evaluates generated content quality
- RetryManager: Smart retry with quality-based decisions
- OutputValidator: LLM output validation
- StyleGuard: Visual style consistency
- ContentFilter: Content moderation
- CharacterMemory: Character consistency across frames
"""
from pixelle_video.services.quality.models import (
QualityScore,
QualityConfig,
RetryConfig,
QualityError,
)
from pixelle_video.services.quality.quality_gate import QualityGate
from pixelle_video.services.quality.retry_manager import RetryManager
from pixelle_video.services.quality.output_validator import (
OutputValidator,
ValidationConfig,
ValidationResult,
)
from pixelle_video.services.quality.style_guard import (
StyleGuard,
StyleGuardConfig,
StyleAnchor,
)
from pixelle_video.services.quality.content_filter import (
ContentFilter,
ContentFilterConfig,
FilterResult,
FilterCategory,
)
from pixelle_video.services.quality.character_memory import (
CharacterMemory,
CharacterMemoryConfig,
Character,
CharacterType,
)
__all__ = [
# Quality evaluation
"QualityScore",
"QualityConfig",
"RetryConfig",
"QualityError",
"QualityGate",
"RetryManager",
# Output validation
"OutputValidator",
"ValidationConfig",
"ValidationResult",
# Style consistency
"StyleGuard",
"StyleGuardConfig",
"StyleAnchor",
# Content moderation
"ContentFilter",
"ContentFilterConfig",
"FilterResult",
"FilterCategory",
# Character memory
"CharacterMemory",
"CharacterMemoryConfig",
"Character",
"CharacterType",
]

View File

@@ -0,0 +1,530 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
CharacterMemory - Character consistency and memory system
Maintains consistent character appearance across video frames by:
1. Detecting and registering characters from narrations
2. Extracting visual descriptions from first appearances
3. Injecting character consistency prompts into subsequent frames
4. Supporting reference images for ComfyUI IP-Adapter/ControlNet
"""
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Set
from datetime import datetime
from enum import Enum
from loguru import logger
class CharacterType(Enum):
"""Type of character"""
PERSON = "person" # Human character
ANIMAL = "animal" # Animal character
CREATURE = "creature" # Fantasy/fictional creature
OBJECT = "object" # Personified object
ABSTRACT = "abstract" # Abstract entity
@dataclass
class Character:
"""
Represents a character in the video narrative
Stores visual description, reference images, and appearance history
to maintain consistency across frames.
"""
# Identity
id: str # Unique identifier
name: str # Character name (e.g., "小明", "the hero")
aliases: List[str] = field(default_factory=list) # Alternative names
character_type: CharacterType = CharacterType.PERSON
# Visual description (for prompt injection)
appearance_description: str = "" # Detailed visual description
clothing_description: str = "" # Clothing/outfit description
distinctive_features: List[str] = field(default_factory=list) # Unique features
# Reference images (for IP-Adapter/ControlNet)
reference_images: List[str] = field(default_factory=list) # Paths to reference images
primary_reference: Optional[str] = None # Primary reference image
# Prompt elements
prompt_prefix: str = "" # Pre-built prompt prefix
negative_prompt: str = "" # Negative prompt additions
# Metadata
is_active: bool = True # Whether this character is active for logic
first_appearance_frame: int = 0 # Frame index of first appearance
appearance_frames: List[int] = field(default_factory=list) # All frames with this character
created_at: Optional[datetime] = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now()
if not hasattr(self, 'is_active'):
self.is_active = True
if not self.prompt_prefix:
self._build_prompt_prefix()
def _build_prompt_prefix(self):
"""Build prompt prefix from visual descriptions"""
elements = []
if self.appearance_description:
elements.append(self.appearance_description)
if self.clothing_description:
elements.append(f"wearing {self.clothing_description}")
if self.distinctive_features:
elements.append(", ".join(self.distinctive_features))
self.prompt_prefix = ", ".join(elements) if elements else ""
def get_prompt_injection(self) -> str:
"""Get the prompt text to inject for this character"""
if self.prompt_prefix:
return f"({self.name}: {self.prompt_prefix})"
return f"({self.name})"
def add_reference_image(self, image_path: str, set_as_primary: bool = False):
"""Add a reference image for this character"""
if image_path not in self.reference_images:
self.reference_images.append(image_path)
if set_as_primary or self.primary_reference is None:
self.primary_reference = image_path
def matches_name(self, name: str) -> bool:
"""Check if a name matches this character"""
name_lower = name.lower().strip()
if self.name.lower() == name_lower:
return True
return any(alias.lower() == name_lower for alias in self.aliases)
def to_dict(self) -> dict:
return {
"id": self.id,
"name": self.name,
"aliases": self.aliases,
"type": self.character_type.value,
"appearance_description": self.appearance_description,
"clothing_description": self.clothing_description,
"distinctive_features": self.distinctive_features,
"reference_images": self.reference_images,
"primary_reference": self.primary_reference,
"prompt_prefix": self.prompt_prefix,
"first_appearance_frame": self.first_appearance_frame,
}
@dataclass
class CharacterMemoryConfig:
"""Configuration for character memory system"""
# Detection settings
auto_detect_characters: bool = True # Automatically detect characters from narrations
use_llm_detection: bool = True # Use LLM to extract character info
# Consistency settings
inject_character_prompts: bool = True # Inject character descriptions into prompts
use_reference_images: bool = True # Use reference images for generation
# Reference image settings
extract_reference_from_first: bool = True # Extract reference from first appearance
max_reference_images: int = 3 # Max reference images per character
# Prompt injection settings
prompt_injection_position: str = "start" # "start" or "end"
include_clothing: bool = True # Include clothing in prompts
include_features: bool = True # Include distinctive features
class CharacterMemory:
"""
Character memory and consistency manager
Tracks characters across video frames and ensures visual consistency
by injecting character descriptions and reference images into the
generation pipeline.
Example:
>>> memory = CharacterMemory(llm_service)
>>>
>>> # Register a character
>>> char = memory.register_character(
... name="小明",
... appearance_description="young man with short black hair",
... clothing_description="blue t-shirt"
... )
>>>
>>> # Apply to prompt
>>> enhanced_prompt = memory.apply_to_prompt(
... prompt="A person walking in the park",
... characters=["小明"]
... )
"""
def __init__(
self,
llm_service=None,
config: Optional[CharacterMemoryConfig] = None
):
"""
Initialize CharacterMemory
Args:
llm_service: Optional LLM service for character detection
config: Character memory configuration
"""
self.llm_service = llm_service
self.config = config or CharacterMemoryConfig()
self._characters: Dict[str, Character] = {}
self._name_index: Dict[str, str] = {} # name -> character_id mapping
def register_character(
self,
name: str,
appearance_description: str = "",
clothing_description: str = "",
distinctive_features: Optional[List[str]] = None,
character_type: CharacterType = CharacterType.PERSON,
first_frame: int = 0,
) -> Character:
"""
Register a new character
Args:
name: Character name
appearance_description: Visual appearance description
clothing_description: Clothing/outfit description
distinctive_features: List of distinctive features
character_type: Type of character
first_frame: Frame index of first appearance
Returns:
Created Character object
"""
# Generate unique ID
char_id = f"char_{len(self._characters)}_{name.replace(' ', '_').lower()}"
character = Character(
id=char_id,
name=name,
appearance_description=appearance_description,
clothing_description=clothing_description,
distinctive_features=distinctive_features or [],
character_type=character_type,
first_appearance_frame=first_frame,
appearance_frames=[first_frame],
)
self._characters[char_id] = character
self._name_index[name.lower()] = char_id
logger.info(f"Registered character: {name} (id={char_id})")
return character
def get_character(self, name: str) -> Optional[Character]:
"""Get a character by name"""
name_lower = name.lower().strip()
char_id = self._name_index.get(name_lower)
if char_id:
return self._characters.get(char_id)
# Search aliases
for char in self._characters.values():
if char.matches_name(name):
return char
return None
def get_character_by_id(self, char_id: str) -> Optional[Character]:
"""Get a character by ID"""
return self._characters.get(char_id)
@property
def characters(self) -> List[Character]:
"""Get all registered characters"""
return list(self._characters.values())
async def detect_characters_from_narration(
self,
narration: str,
frame_index: int = 0,
) -> List[Character]:
"""
Detect and register characters mentioned in narration
Args:
narration: Narration text to analyze
frame_index: Current frame index
Returns:
List of detected/registered characters
"""
if not self.config.auto_detect_characters:
return []
detected = []
if self.config.use_llm_detection and self.llm_service:
detected = await self._detect_with_llm(narration, frame_index)
else:
detected = self._detect_basic(narration, frame_index)
return detected
async def _detect_with_llm(
self,
narration: str,
frame_index: int,
) -> List[Character]:
"""Detect characters using LLM"""
if not self.llm_service:
return []
try:
prompt = f"""分析以下文案,提取其中提到的角色/人物。
文案: {narration}
请用 JSON 格式返回角色列表,每个角色包含:
- name: 角色名称或代称
- type: person/animal/creature/object
- appearance: 外貌描述(如有)
- clothing: 服装描述(如有)
如果没有明确角色,返回空列表 []。
只返回 JSON不要其他解释。"""
response = await self.llm_service(prompt, temperature=0.1)
# Parse response
import json
import re
# Extract JSON from response
json_match = re.search(r'\[.*\]', response, re.DOTALL)
if json_match:
characters_data = json.loads(json_match.group())
result = []
for char_data in characters_data:
name = char_data.get("name", "").strip()
if not name:
continue
# Check if already registered
existing = self.get_character(name)
if existing:
existing.appearance_frames.append(frame_index)
result.append(existing)
else:
# Register new character
char_type = CharacterType.PERSON
type_str = char_data.get("type", "person").lower()
if type_str == "animal":
char_type = CharacterType.ANIMAL
elif type_str == "creature":
char_type = CharacterType.CREATURE
char = self.register_character(
name=name,
appearance_description=char_data.get("appearance", ""),
clothing_description=char_data.get("clothing", ""),
character_type=char_type,
first_frame=frame_index,
)
result.append(char)
return result
return []
except Exception as e:
logger.warning(f"LLM character detection failed: {e}")
return self._detect_basic(narration, frame_index)
def _detect_basic(
self,
narration: str,
frame_index: int,
) -> List[Character]:
"""Basic character detection without LLM"""
# Simple pattern matching for common character references
import re
patterns = [
r'(?:他|她|它)们?', # Chinese pronouns
r'(?:小\w{1,2})', # Names like 小明, 小红
r'(?:老\w{1,2})', # Names like 老王, 老李
]
detected = []
for pattern in patterns:
matches = re.findall(pattern, narration)
for match in matches:
existing = self.get_character(match)
if existing:
existing.appearance_frames.append(frame_index)
if existing not in detected:
detected.append(existing)
return detected
def apply_to_prompt(
self,
prompt: str,
character_names: Optional[List[str]] = None,
frame_index: Optional[int] = None,
) -> str:
"""
Apply character consistency to an image prompt
Args:
prompt: Original image prompt
character_names: Specific characters to include (None = auto-detect)
frame_index: Current frame index for tracking
Returns:
Enhanced prompt with character consistency
"""
if not self.config.inject_character_prompts:
return prompt
characters_to_include = []
if character_names:
for name in character_names:
char = self.get_character(name)
if char:
characters_to_include.append(char)
else:
# Include all characters that have appeared
characters_to_include = self.characters
if not characters_to_include:
return prompt
# Build character injection
injections = []
for char in characters_to_include:
injection = char.get_prompt_injection()
if injection:
injections.append(injection)
# Track appearance
if frame_index is not None and frame_index not in char.appearance_frames:
char.appearance_frames.append(frame_index)
if not injections:
return prompt
character_prompt = ", ".join(injections)
if self.config.prompt_injection_position == "start":
return f"{character_prompt}, {prompt}"
else:
return f"{prompt}, {character_prompt}"
def get_reference_images(
self,
character_names: Optional[List[str]] = None,
) -> List[str]:
"""
Get reference images for specified characters
Args:
character_names: Character names (None = all characters)
Returns:
List of reference image paths
"""
if not self.config.use_reference_images:
return []
images = []
if character_names:
for name in character_names:
char = self.get_character(name)
if char and char.primary_reference:
images.append(char.primary_reference)
else:
for char in self.characters:
if char.primary_reference:
images.append(char.primary_reference)
return images[:self.config.max_reference_images]
def set_reference_image(
self,
character_name: str,
image_path: str,
set_as_primary: bool = True,
):
"""
Set a reference image for a character
Args:
character_name: Character name
image_path: Path to reference image
set_as_primary: Whether to set as primary reference
"""
char = self.get_character(character_name)
if char:
char.add_reference_image(image_path, set_as_primary)
logger.debug(f"Set reference image for {character_name}: {image_path}")
else:
logger.warning(f"Character not found: {character_name}")
def update_character_appearance(
self,
character_name: str,
appearance_description: Optional[str] = None,
clothing_description: Optional[str] = None,
distinctive_features: Optional[List[str]] = None,
):
"""Update a character's visual description"""
char = self.get_character(character_name)
if char:
if appearance_description:
char.appearance_description = appearance_description
if clothing_description:
char.clothing_description = clothing_description
if distinctive_features:
char.distinctive_features = distinctive_features
char._build_prompt_prefix()
logger.debug(f"Updated appearance for {character_name}")
def get_consistency_summary(self) -> str:
"""Get a summary of character consistency for logging"""
if not self._characters:
return "No characters registered"
lines = [f"Characters ({len(self._characters)}):"]
for char in self.characters:
lines.append(
f" - {char.name}: {len(char.appearance_frames)} appearances, "
f"ref_images={len(char.reference_images)}"
)
return "\n".join(lines)
def reset(self):
"""Clear all character memory"""
self._characters.clear()
self._name_index.clear()
logger.info("Character memory cleared")

View File

@@ -0,0 +1,316 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ContentFilter - Content moderation and safety filtering
Provides content safety checks for:
- Text content (narrations, prompts)
- Generated images
- Generated videos
"""
import re
from dataclasses import dataclass, field
from typing import List, Optional, Set
from enum import Enum
from loguru import logger
class FilterCategory(Enum):
"""Content filter categories"""
SAFE = "safe"
SENSITIVE = "sensitive" # May require review
BLOCKED = "blocked" # Should not proceed
@dataclass
class FilterResult:
"""Result of content filtering"""
category: FilterCategory
passed: bool
flagged_items: List[str] = field(default_factory=list)
reason: Optional[str] = None
confidence: float = 1.0
def to_dict(self) -> dict:
return {
"category": self.category.value,
"passed": self.passed,
"flagged_items": self.flagged_items,
"reason": self.reason,
"confidence": self.confidence,
}
@dataclass
class ContentFilterConfig:
"""Configuration for content filtering"""
# Text filtering
enable_keyword_filter: bool = True
enable_llm_filter: bool = False # Use LLM for semantic filtering
# Custom keywords to block (added to default list)
custom_blocked_keywords: List[str] = field(default_factory=list)
# Sensitivity level: "strict", "moderate", "permissive"
sensitivity_level: str = "moderate"
# Image filtering
enable_image_filter: bool = False # Requires external service
# Action on detection
block_on_sensitive: bool = False # Block content marked as sensitive
log_filtered_content: bool = True
class ContentFilter:
"""
Content moderation filter for generated content
Provides safety filtering for text and media content to prevent
inappropriate or harmful content from being generated.
Example:
>>> filter = ContentFilter()
>>> result = await filter.check_text("Hello, world!")
>>> if result.passed:
... print("Content is safe")
"""
# Default blocked keywords (minimal list for demonstration)
DEFAULT_BLOCKED_PATTERNS = [
r"\b(violence|gore|blood)\b",
r"\b(nsfw|explicit|pornographic)\b",
r"\b(illegal|drugs|weapons)\b",
]
# Sensitive keywords (may require review)
DEFAULT_SENSITIVE_PATTERNS = [
r"\b(death|dying|kill)\b",
r"\b(hate|racist|sexist)\b",
r"\b(controversial|political)\b",
]
def __init__(
self,
llm_service=None,
config: Optional[ContentFilterConfig] = None
):
"""
Initialize ContentFilter
Args:
llm_service: Optional LLM service for semantic filtering
config: Filter configuration
"""
self.llm_service = llm_service
self.config = config or ContentFilterConfig()
# Compile patterns
self._blocked_patterns = [
re.compile(p, re.IGNORECASE)
for p in self.DEFAULT_BLOCKED_PATTERNS
]
self._sensitive_patterns = [
re.compile(p, re.IGNORECASE)
for p in self.DEFAULT_SENSITIVE_PATTERNS
]
# Add custom keywords
if self.config.custom_blocked_keywords:
for keyword in self.config.custom_blocked_keywords:
pattern = re.compile(rf"\b{re.escape(keyword)}\b", re.IGNORECASE)
self._blocked_patterns.append(pattern)
async def check_text(self, text: str) -> FilterResult:
"""
Check text content for safety
Args:
text: Text to check
Returns:
FilterResult with safety assessment
"""
if not text:
return FilterResult(
category=FilterCategory.SAFE,
passed=True
)
flagged_items = []
category = FilterCategory.SAFE
# Keyword filtering
if self.config.enable_keyword_filter:
# Check blocked patterns
for pattern in self._blocked_patterns:
matches = pattern.findall(text)
if matches:
flagged_items.extend(matches)
category = FilterCategory.BLOCKED
# Check sensitive patterns (if not already blocked)
if category != FilterCategory.BLOCKED:
for pattern in self._sensitive_patterns:
matches = pattern.findall(text)
if matches:
flagged_items.extend(matches)
category = FilterCategory.SENSITIVE
# LLM-based semantic filtering
if self.config.enable_llm_filter and self.llm_service:
semantic_result = await self._check_with_llm(text)
if semantic_result.category.value > category.value:
category = semantic_result.category
flagged_items.extend(semantic_result.flagged_items)
# Determine if passed
if category == FilterCategory.BLOCKED:
passed = False
reason = "Content contains blocked keywords or themes"
elif category == FilterCategory.SENSITIVE and self.config.block_on_sensitive:
passed = False
reason = "Content contains sensitive themes (strict mode)"
else:
passed = True
reason = None
# Log if configured
if self.config.log_filtered_content and flagged_items:
logger.warning(f"Content filter flagged: {flagged_items}")
return FilterResult(
category=category,
passed=passed,
flagged_items=flagged_items,
reason=reason,
)
async def check_texts(self, texts: List[str]) -> FilterResult:
"""
Check multiple texts and return aggregate result
Args:
texts: List of texts to check
Returns:
Aggregate FilterResult
"""
all_flagged = []
worst_category = FilterCategory.SAFE
for text in texts:
result = await self.check_text(text)
all_flagged.extend(result.flagged_items)
if result.category == FilterCategory.BLOCKED:
worst_category = FilterCategory.BLOCKED
elif result.category == FilterCategory.SENSITIVE and worst_category != FilterCategory.BLOCKED:
worst_category = FilterCategory.SENSITIVE
passed = worst_category == FilterCategory.SAFE or (
worst_category == FilterCategory.SENSITIVE and not self.config.block_on_sensitive
)
return FilterResult(
category=worst_category,
passed=passed,
flagged_items=list(set(all_flagged)), # Deduplicate
reason=f"Found {len(all_flagged)} flagged items" if all_flagged else None,
)
async def check_image(self, image_path: str) -> FilterResult:
"""
Check image content for safety
Note: This requires external NSFW detection service integration.
Currently returns safe by default.
Args:
image_path: Path to image file
Returns:
FilterResult with safety assessment
"""
if not self.config.enable_image_filter:
return FilterResult(
category=FilterCategory.SAFE,
passed=True,
reason="Image filtering disabled"
)
# TODO: Integrate with external NSFW detection
# Options: TensorFlow NSFW model, Azure Content Moderator, etc.
logger.debug(f"Image safety check: {image_path} (not implemented, assuming safe)")
return FilterResult(
category=FilterCategory.SAFE,
passed=True,
reason="Image filtering not implemented"
)
async def _check_with_llm(self, text: str) -> FilterResult:
"""Check text using LLM for semantic understanding"""
if not self.llm_service:
return FilterResult(category=FilterCategory.SAFE, passed=True)
try:
prompt = f"""Analyze the following text for content safety.
Rate it as: SAFE, SENSITIVE, or BLOCKED.
SAFE: No concerning content
SENSITIVE: Contains themes that may need review (violence, death, controversial topics)
BLOCKED: Contains explicit, illegal, or harmful content
Text: {text[:500]}
Respond with only one word: SAFE, SENSITIVE, or BLOCKED."""
response = await self.llm_service(prompt, temperature=0.0, max_tokens=10)
response = response.strip().upper()
if "BLOCKED" in response:
return FilterResult(
category=FilterCategory.BLOCKED,
passed=False,
reason="LLM detected blocked content"
)
elif "SENSITIVE" in response:
return FilterResult(
category=FilterCategory.SENSITIVE,
passed=not self.config.block_on_sensitive,
reason="LLM detected sensitive content"
)
else:
return FilterResult(
category=FilterCategory.SAFE,
passed=True
)
except Exception as e:
logger.warning(f"LLM content check failed: {e}")
return FilterResult(category=FilterCategory.SAFE, passed=True)
def add_blocked_keyword(self, keyword: str):
"""Add a keyword to the blocked list"""
pattern = re.compile(rf"\b{re.escape(keyword)}\b", re.IGNORECASE)
self._blocked_patterns.append(pattern)
def add_sensitive_keyword(self, keyword: str):
"""Add a keyword to the sensitive list"""
pattern = re.compile(rf"\b{re.escape(keyword)}\b", re.IGNORECASE)
self._sensitive_patterns.append(pattern)

View File

@@ -0,0 +1,140 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data models for quality assurance
"""
from dataclasses import dataclass, field
from typing import List, Optional
from enum import Enum
class QualityLevel(Enum):
"""Quality level enumeration"""
EXCELLENT = "excellent" # >= 0.8
GOOD = "good" # >= 0.6
ACCEPTABLE = "acceptable" # >= 0.4
POOR = "poor" # < 0.4
@dataclass
class QualityScore:
"""Quality evaluation result for generated content"""
# Individual scores (0.0 - 1.0)
aesthetic_score: float = 0.0 # Visual appeal / beauty
text_match_score: float = 0.0 # How well image matches the prompt
technical_score: float = 0.0 # Technical quality (clarity, no artifacts)
# Overall
overall_score: float = 0.0 # Weighted average
passed: bool = False # Whether it meets threshold
# Diagnostics
issues: List[str] = field(default_factory=list) # Detected problems
evaluation_time_ms: float = 0.0 # Time taken for evaluation
@property
def level(self) -> QualityLevel:
"""Get quality level based on overall score"""
if self.overall_score >= 0.8:
return QualityLevel.EXCELLENT
elif self.overall_score >= 0.6:
return QualityLevel.GOOD
elif self.overall_score >= 0.4:
return QualityLevel.ACCEPTABLE
else:
return QualityLevel.POOR
def to_dict(self) -> dict:
"""Convert to dictionary for serialization"""
return {
"aesthetic_score": self.aesthetic_score,
"text_match_score": self.text_match_score,
"technical_score": self.technical_score,
"overall_score": self.overall_score,
"passed": self.passed,
"level": self.level.value,
"issues": self.issues,
}
@dataclass
class QualityConfig:
"""Configuration for quality evaluation"""
# Thresholds (0.0 - 1.0)
overall_threshold: float = 0.6 # Minimum overall score to pass
aesthetic_threshold: float = 0.5 # Minimum aesthetic score
text_match_threshold: float = 0.6 # Minimum text-match score
technical_threshold: float = 0.7 # Minimum technical score
# Weights for overall score calculation
aesthetic_weight: float = 0.3
text_match_weight: float = 0.4
technical_weight: float = 0.3
# Evaluation settings
use_vlm_evaluation: bool = True # Use VLM for evaluation (vs local models)
vlm_model: Optional[str] = None # VLM model to use (None = use default LLM)
skip_on_static_template: bool = True # Skip image quality check for static templates
def __post_init__(self):
"""Validate weights sum to 1.0"""
total = self.aesthetic_weight + self.text_match_weight + self.technical_weight
if abs(total - 1.0) > 0.01:
# Normalize weights
self.aesthetic_weight /= total
self.text_match_weight /= total
self.technical_weight /= total
@dataclass
class RetryConfig:
"""Configuration for retry behavior"""
max_retries: int = 3 # Maximum retry attempts
backoff_factor: float = 1.5 # Exponential backoff multiplier
initial_delay_ms: int = 500 # Initial delay before first retry
max_delay_ms: int = 10000 # Maximum delay between retries
# Quality-based retry
quality_threshold: float = 0.6 # Quality score threshold for pass
# Fallback behavior
enable_fallback: bool = True # Enable fallback strategy on failure
fallback_prompt_simplify: bool = True # Simplify prompt on retry
# Retry conditions
retry_on_quality_fail: bool = True # Retry when quality below threshold
retry_on_error: bool = True # Retry on generation errors
class QualityError(Exception):
"""Exception raised when quality standards are not met after all retries"""
def __init__(
self,
message: str,
attempts: int = 0,
last_score: Optional[QualityScore] = None
):
super().__init__(message)
self.attempts = attempts
self.last_score = last_score
def __str__(self) -> str:
base = super().__str__()
if self.last_score:
return f"{base} (attempts={self.attempts}, last_score={self.last_score.overall_score:.2f})"
return f"{base} (attempts={self.attempts})"

View File

@@ -0,0 +1,336 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OutputValidator - LLM output validation and quality control
Validates LLM-generated content for:
- Narrations: length, relevance, coherence
- Image prompts: format, language, prompt-narration alignment
"""
import re
from dataclasses import dataclass, field
from typing import List, Optional
from loguru import logger
@dataclass
class ValidationConfig:
"""Configuration for output validation"""
# Narration validation
min_narration_words: int = 5
max_narration_words: int = 50
relevance_threshold: float = 0.6
coherence_threshold: float = 0.6
# Image prompt validation
min_prompt_words: int = 10
max_prompt_words: int = 100
require_english_prompts: bool = True
prompt_match_threshold: float = 0.5
@dataclass
class ValidationResult:
"""Result of validation"""
passed: bool
issues: List[str] = field(default_factory=list)
score: float = 1.0 # 1.0 = perfect, 0.0 = failed
suggestions: List[str] = field(default_factory=list)
def to_dict(self) -> dict:
return {
"passed": self.passed,
"score": self.score,
"issues": self.issues,
"suggestions": self.suggestions,
}
class OutputValidator:
"""
Validator for LLM-generated outputs
Validates narrations and image prompts to ensure they meet quality standards
before proceeding with media generation.
Example:
>>> validator = OutputValidator(llm_service)
>>> result = await validator.validate_narrations(
... narrations=["旁白1", "旁白2"],
... topic="人生哲理",
... config=ValidationConfig()
... )
>>> if not result.passed:
... print(f"Validation failed: {result.issues}")
"""
def __init__(self, llm_service=None):
"""
Initialize OutputValidator
Args:
llm_service: Optional LLM service for semantic validation
"""
self.llm_service = llm_service
async def validate_narrations(
self,
narrations: List[str],
topic: str,
config: Optional[ValidationConfig] = None,
) -> ValidationResult:
"""
Validate generated narrations
Checks:
1. Length constraints
2. Non-empty content
3. Topic relevance (if LLM available)
4. Coherence between narrations (if LLM available)
Args:
narrations: List of narration texts
topic: Original topic/theme
config: Validation configuration
Returns:
ValidationResult with pass/fail and issues
"""
cfg = config or ValidationConfig()
issues = []
suggestions = []
scores = []
if not narrations:
return ValidationResult(
passed=False,
issues=["No narrations provided"],
score=0.0
)
# 1. Length validation
for i, narration in enumerate(narrations, 1):
word_count = len(narration)
if not narration.strip():
issues.append(f"分镜{i}: 内容为空")
scores.append(0.0)
continue
if word_count < cfg.min_narration_words:
issues.append(f"分镜{i}: 内容过短 ({word_count}字,最少{cfg.min_narration_words}字)")
scores.append(0.5)
elif word_count > cfg.max_narration_words:
issues.append(f"分镜{i}: 内容过长 ({word_count}字,最多{cfg.max_narration_words}字)")
suggestions.append(f"考虑将分镜{i}拆分为多个短句")
scores.append(0.7)
else:
scores.append(1.0)
# 2. Semantic validation (if LLM available)
if self.llm_service:
try:
relevance = await self._check_relevance(narrations, topic)
if relevance < cfg.relevance_threshold:
issues.append(f'内容与主题"{topic}"相关性不足 ({relevance:.0%})')
suggestions.append("建议重新生成,确保内容紧扣主题")
scores.append(relevance)
coherence = await self._check_coherence(narrations)
if coherence < cfg.coherence_threshold:
issues.append(f"内容连贯性不足 ({coherence:.0%})")
suggestions.append("建议检查段落之间的逻辑衔接")
scores.append(coherence)
except Exception as e:
logger.warning(f"Semantic validation failed: {e}")
# Don't fail on semantic check errors
# Calculate overall score
overall_score = sum(scores) / len(scores) if scores else 0.0
passed = len(issues) == 0 or overall_score >= 0.7
logger.debug(f"Narration validation: score={overall_score:.2f}, issues={len(issues)}")
return ValidationResult(
passed=passed,
issues=issues,
score=overall_score,
suggestions=suggestions,
)
async def validate_image_prompts(
self,
prompts: List[str],
narrations: List[str],
config: Optional[ValidationConfig] = None,
) -> ValidationResult:
"""
Validate generated image prompts
Checks:
1. Length constraints
2. Language (should be English)
3. Prompt-narration alignment
Args:
prompts: List of image prompts
narrations: Corresponding narrations
config: Validation configuration
Returns:
ValidationResult with pass/fail and issues
"""
cfg = config or ValidationConfig()
issues = []
suggestions = []
scores = []
if not prompts:
return ValidationResult(
passed=False,
issues=["No image prompts provided"],
score=0.0
)
if len(prompts) != len(narrations):
issues.append(f"提示词数量({len(prompts)})与旁白数量({len(narrations)})不匹配")
for i, prompt in enumerate(prompts, 1):
if not prompt.strip():
issues.append(f"提示词{i}: 内容为空")
scores.append(0.0)
continue
word_count = len(prompt.split())
# Length check
if word_count < cfg.min_prompt_words:
issues.append(f"提示词{i}: 过短 ({word_count}词,最少{cfg.min_prompt_words}词)")
scores.append(0.5)
elif word_count > cfg.max_prompt_words:
issues.append(f"提示词{i}: 过长 ({word_count}词,最多{cfg.max_prompt_words}词)")
scores.append(0.8)
else:
scores.append(1.0)
# English check
if cfg.require_english_prompts:
chinese_ratio = self._get_chinese_ratio(prompt)
if chinese_ratio > 0.3: # More than 30% Chinese characters
issues.append(f"提示词{i}: 应使用英文 (当前含{chinese_ratio:.0%}中文)")
suggestions.append(f"将提示词{i}翻译为英文以获得更好的生成效果")
scores[-1] *= 0.5
overall_score = sum(scores) / len(scores) if scores else 0.0
passed = len(issues) == 0 or overall_score >= 0.7
logger.debug(f"Image prompt validation: score={overall_score:.2f}, issues={len(issues)}")
return ValidationResult(
passed=passed,
issues=issues,
score=overall_score,
suggestions=suggestions,
)
async def _check_relevance(
self,
narrations: List[str],
topic: str,
) -> float:
"""
Check relevance of narrations to topic using LLM
Returns:
Relevance score 0.0-1.0
"""
if not self.llm_service:
return 0.8 # Default score when LLM not available
try:
combined_text = "\n".join(narrations[:3]) # Check first 3 for efficiency
prompt = f"""评估以下内容与主题"{topic}"的相关性。
内容:
{combined_text}
请用0-100的分数评估相关性只输出数字。
相关性越高,分数越高。"""
response = await self.llm_service(prompt, temperature=0.1, max_tokens=10)
# Parse score from response
score_match = re.search(r'\d+', response)
if score_match:
score = int(score_match.group()) / 100
return min(1.0, max(0.0, score))
return 0.7 # Default if parsing fails
except Exception as e:
logger.warning(f"Relevance check failed: {e}")
return 0.7
async def _check_coherence(
self,
narrations: List[str],
) -> float:
"""
Check coherence between narrations using LLM
Returns:
Coherence score 0.0-1.0
"""
if not self.llm_service or len(narrations) < 2:
return 0.8 # Default score
try:
numbered = "\n".join(f"{i+1}. {n}" for i, n in enumerate(narrations[:5]))
prompt = f"""评估以下段落之间的逻辑连贯性。
{numbered}
请用0-100的分数评估连贯性只输出数字。
段落之间逻辑顺畅、衔接自然则分数高。"""
response = await self.llm_service(prompt, temperature=0.1, max_tokens=10)
score_match = re.search(r'\d+', response)
if score_match:
score = int(score_match.group()) / 100
return min(1.0, max(0.0, score))
return 0.7
except Exception as e:
logger.warning(f"Coherence check failed: {e}")
return 0.7
def _get_chinese_ratio(self, text: str) -> float:
"""Calculate ratio of Chinese characters in text"""
if not text:
return 0.0
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
total_chars = len(text.replace(' ', ''))
if total_chars == 0:
return 0.0
return chinese_chars / total_chars

View File

@@ -0,0 +1,363 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
QualityGate - Quality evaluation system for generated content
Evaluates images and videos based on:
- Aesthetic quality (visual appeal)
- Text-to-image matching (semantic alignment)
- Technical quality (clarity, no artifacts)
"""
import time
from typing import Optional
from pathlib import Path
from loguru import logger
from pixelle_video.services.quality.models import QualityScore, QualityConfig
class QualityGate:
"""
Quality evaluation gate for AI-generated content
Uses VLM (Vision Language Model) or local models to evaluate:
1. Aesthetic quality - Is the image visually appealing?
2. Text matching - Does the image match the prompt/narration?
3. Technical quality - Is the image clear and free of artifacts?
Example:
>>> gate = QualityGate(llm_service, config)
>>> score = await gate.evaluate_image(
... image_path="output/frame_001.png",
... prompt="A sunset over mountains",
... narration="夕阳西下,余晖洒满山间"
... )
>>> if score.passed:
... print("Image quality approved!")
"""
def __init__(
self,
llm_service=None,
config: Optional[QualityConfig] = None
):
"""
Initialize QualityGate
Args:
llm_service: LLM service for VLM-based evaluation
config: Quality configuration
"""
self.llm_service = llm_service
self.config = config or QualityConfig()
async def evaluate_image(
self,
image_path: str,
prompt: str,
narration: Optional[str] = None,
) -> QualityScore:
"""
Evaluate the quality of a generated image
Args:
image_path: Path to the image file
prompt: The prompt used to generate the image
narration: Optional narration text for context
Returns:
QualityScore with evaluation results
"""
start_time = time.time()
issues = []
# Validate image exists
if not Path(image_path).exists():
return QualityScore(
passed=False,
issues=["Image file not found"],
evaluation_time_ms=(time.time() - start_time) * 1000
)
# Evaluate using VLM or fallback to basic checks
if self.config.use_vlm_evaluation and self.llm_service:
score = await self._evaluate_with_vlm(image_path, prompt, narration)
else:
score = await self._evaluate_basic(image_path, prompt)
# Set evaluation time
score.evaluation_time_ms = (time.time() - start_time) * 1000
# Determine if passed
score.passed = score.overall_score >= self.config.overall_threshold
logger.debug(
f"Quality evaluation: overall={score.overall_score:.2f}, "
f"passed={score.passed}, time={score.evaluation_time_ms:.0f}ms"
)
return score
async def evaluate_video(
self,
video_path: str,
prompt: str,
narration: Optional[str] = None,
) -> QualityScore:
"""
Evaluate the quality of a generated video
Args:
video_path: Path to the video file
prompt: The prompt used to generate the video
narration: Optional narration text for context
Returns:
QualityScore with evaluation results
"""
start_time = time.time()
# Validate video exists
if not Path(video_path).exists():
return QualityScore(
passed=False,
issues=["Video file not found"],
evaluation_time_ms=(time.time() - start_time) * 1000
)
# For video, we can extract key frames and evaluate
# For now, use VLM with video input or sample frames
if self.config.use_vlm_evaluation and self.llm_service:
score = await self._evaluate_video_with_vlm(video_path, prompt, narration)
else:
score = await self._evaluate_video_basic(video_path)
score.evaluation_time_ms = (time.time() - start_time) * 1000
score.passed = score.overall_score >= self.config.overall_threshold
return score
async def _evaluate_with_vlm(
self,
image_path: str,
prompt: str,
narration: Optional[str] = None,
) -> QualityScore:
"""
Evaluate image quality using Vision Language Model
Uses the LLM with vision capability to assess:
- Visual quality and aesthetics
- Prompt-image alignment
- Technical defects
"""
evaluation_prompt = self._build_evaluation_prompt(prompt, narration)
try:
# Call LLM with image (requires VLM-capable model like GPT-4o, Qwen-VL)
# Note: This requires the LLM service to support vision input
# For now, we'll use a basic score if VLM is not available
# TODO: Implement actual VLM call when integrating with vision-capable LLM
# response = await self.llm_service(
# prompt=evaluation_prompt,
# images=[image_path],
# response_type=ImageQualityResponse
# )
# Fallback to basic evaluation for now
logger.debug("VLM evaluation: using basic fallback (VLM integration pending)")
return await self._evaluate_basic(image_path, prompt)
except Exception as e:
logger.warning(f"VLM evaluation failed: {e}, falling back to basic")
return await self._evaluate_basic(image_path, prompt)
async def _evaluate_basic(
self,
image_path: str,
prompt: str,
) -> QualityScore:
"""
Basic image quality evaluation without VLM
Performs simple checks:
- File size and dimensions
- Image format validation
"""
issues = []
try:
# Import PIL for basic checks
from PIL import Image
with Image.open(image_path) as img:
width, height = img.size
# Check minimum dimensions
if width < 256 or height < 256:
issues.append(f"Image too small: {width}x{height}")
# Check aspect ratio (not too extreme)
aspect = max(width, height) / min(width, height)
if aspect > 4:
issues.append(f"Extreme aspect ratio: {aspect:.1f}")
# Basic scores (generous defaults when VLM not available)
aesthetic_score = 0.7 if not issues else 0.4
text_match_score = 0.7 # Can't properly evaluate without VLM
technical_score = 0.8 if not issues else 0.5
# Calculate overall
overall = (
aesthetic_score * self.config.aesthetic_weight +
text_match_score * self.config.text_match_weight +
technical_score * self.config.technical_weight
)
return QualityScore(
aesthetic_score=aesthetic_score,
text_match_score=text_match_score,
technical_score=technical_score,
overall_score=overall,
issues=issues,
)
except Exception as e:
logger.error(f"Basic evaluation failed: {e}")
return QualityScore(
overall_score=0.0,
passed=False,
issues=[f"Evaluation error: {str(e)}"]
)
async def _evaluate_video_with_vlm(
self,
video_path: str,
prompt: str,
narration: Optional[str] = None,
) -> QualityScore:
"""Evaluate video using VLM (placeholder for future implementation)"""
# TODO: Implement video frame sampling and VLM evaluation
return await self._evaluate_video_basic(video_path)
async def _evaluate_video_basic(
self,
video_path: str,
) -> QualityScore:
"""Basic video quality evaluation"""
issues = []
try:
import subprocess
import json
# Use ffprobe to get video info
cmd = [
"ffprobe", "-v", "quiet", "-print_format", "json",
"-show_format", "-show_streams", video_path
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
issues.append("Failed to read video metadata")
return QualityScore(overall_score=0.5, issues=issues)
info = json.loads(result.stdout)
# Check for video stream
video_stream = None
for stream in info.get("streams", []):
if stream.get("codec_type") == "video":
video_stream = stream
break
if not video_stream:
issues.append("No video stream found")
return QualityScore(overall_score=0.0, passed=False, issues=issues)
# Check dimensions
width = video_stream.get("width", 0)
height = video_stream.get("height", 0)
if width < 256 or height < 256:
issues.append(f"Video too small: {width}x{height}")
# Check duration
duration = float(info.get("format", {}).get("duration", 0))
if duration < 0.5:
issues.append(f"Video too short: {duration:.1f}s")
# Calculate scores
aesthetic_score = 0.7
text_match_score = 0.7
technical_score = 0.8 if not issues else 0.5
overall = (
aesthetic_score * self.config.aesthetic_weight +
text_match_score * self.config.text_match_weight +
technical_score * self.config.technical_weight
)
return QualityScore(
aesthetic_score=aesthetic_score,
text_match_score=text_match_score,
technical_score=technical_score,
overall_score=overall,
issues=issues,
)
except Exception as e:
logger.error(f"Video evaluation failed: {e}")
return QualityScore(
overall_score=0.5,
issues=[f"Evaluation error: {str(e)}"]
)
def _build_evaluation_prompt(
self,
prompt: str,
narration: Optional[str] = None,
) -> str:
"""Build the evaluation prompt for VLM"""
context = f"Narration: {narration}\n" if narration else ""
return f"""Evaluate this AI-generated image on the following criteria.
Rate each from 0.0 to 1.0.
Image Generation Prompt: {prompt}
{context}
Evaluation Criteria:
1. Aesthetic Quality (0.0-1.0):
- Is the image visually appealing?
- Good composition, colors, and style?
2. Prompt Matching (0.0-1.0):
- Does the image accurately represent the prompt?
- Are key elements from the prompt visible?
3. Technical Quality (0.0-1.0):
- Is the image clear and well-defined?
- Free of artifacts, distortions, or blurriness?
- Natural looking (no AI artifacts like extra fingers)?
Respond in JSON format:
{{
"aesthetic_score": 0.0,
"text_match_score": 0.0,
"technical_score": 0.0,
"issues": ["list of any problems found"]
}}
"""

View File

@@ -0,0 +1,296 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
RetryManager - Smart retry logic with quality-based decisions
Provides unified retry management for:
- Media generation (images, videos)
- LLM calls
- Any async operations with quality evaluation
"""
import asyncio
from typing import Callable, TypeVar, Optional, Tuple, Any
from dataclasses import dataclass
from loguru import logger
from pixelle_video.services.quality.models import (
QualityScore,
QualityConfig,
RetryConfig,
QualityError,
)
T = TypeVar("T")
@dataclass
class RetryResult:
"""Result of a retry operation"""
success: bool
result: Any
attempts: int
quality_score: Optional[QualityScore] = None
error: Optional[Exception] = None
class RetryManager:
"""
Smart retry manager with quality-based decisions
Features:
- Exponential backoff with configurable delays
- Quality-aware retry (retries when quality below threshold)
- Fallback strategies
- Detailed logging and metrics
Example:
>>> retry_manager = RetryManager()
>>>
>>> async def generate():
... return await media_service.generate_image(prompt)
>>>
>>> async def evaluate(image_path):
... return await quality_gate.evaluate_image(image_path, prompt)
>>>
>>> result = await retry_manager.execute_with_retry(
... operation=generate,
... quality_evaluator=evaluate,
... config=RetryConfig(max_retries=3)
... )
"""
def __init__(self, config: Optional[RetryConfig] = None):
"""
Initialize RetryManager
Args:
config: Default retry configuration
"""
self.default_config = config or RetryConfig()
async def execute_with_retry(
self,
operation: Callable[[], Any],
quality_evaluator: Optional[Callable[[Any], QualityScore]] = None,
config: Optional[RetryConfig] = None,
fallback_operation: Optional[Callable[[], Any]] = None,
operation_name: str = "operation",
) -> RetryResult:
"""
Execute operation with automatic retry and quality evaluation
Args:
operation: Async callable to execute
quality_evaluator: Optional async callable to evaluate result quality
config: Retry configuration (uses default if not provided)
fallback_operation: Optional fallback to try when all retries fail
operation_name: Name for logging purposes
Returns:
RetryResult with success status, result, and quality score
Raises:
QualityError: When all retries fail and no fallback available
"""
cfg = config or self.default_config
last_error: Optional[Exception] = None
last_score: Optional[QualityScore] = None
for attempt in range(1, cfg.max_retries + 1):
try:
# Execute the operation
logger.debug(f"{operation_name}: attempt {attempt}/{cfg.max_retries}")
result = await operation()
# If no quality evaluator, accept the result
if quality_evaluator is None:
logger.debug(f"{operation_name}: completed (no quality check)")
return RetryResult(
success=True,
result=result,
attempts=attempt,
)
# Evaluate quality
score = await quality_evaluator(result)
last_score = score
if score.passed:
logger.info(
f"{operation_name}: passed quality check "
f"(score={score.overall_score:.2f}, attempt={attempt})"
)
return RetryResult(
success=True,
result=result,
attempts=attempt,
quality_score=score,
)
# Quality check failed
logger.warning(
f"{operation_name}: quality check failed "
f"(score={score.overall_score:.2f}, threshold={cfg.quality_threshold}, "
f"issues={score.issues})"
)
if not cfg.retry_on_quality_fail:
# Don't retry on quality failure
return RetryResult(
success=False,
result=result,
attempts=attempt,
quality_score=score,
)
except Exception as e:
last_error = e
logger.warning(f"{operation_name}: attempt {attempt} failed with error: {e}")
if not cfg.retry_on_error:
raise
# Calculate backoff delay
if attempt < cfg.max_retries:
delay_ms = min(
cfg.initial_delay_ms * (cfg.backoff_factor ** (attempt - 1)),
cfg.max_delay_ms
)
logger.debug(f"{operation_name}: waiting {delay_ms:.0f}ms before retry")
await asyncio.sleep(delay_ms / 1000)
# All retries exhausted
logger.warning(f"{operation_name}: all {cfg.max_retries} attempts failed")
# Try fallback if available
if cfg.enable_fallback and fallback_operation:
try:
logger.info(f"{operation_name}: trying fallback strategy")
result = await fallback_operation()
# Evaluate fallback result if evaluator available
if quality_evaluator:
score = await quality_evaluator(result)
return RetryResult(
success=score.passed if score else True,
result=result,
attempts=cfg.max_retries + 1,
quality_score=score,
)
return RetryResult(
success=True,
result=result,
attempts=cfg.max_retries + 1,
)
except Exception as e:
logger.error(f"{operation_name}: fallback also failed: {e}")
last_error = e
# Nothing worked
if last_error:
raise QualityError(
f"{operation_name} failed after {cfg.max_retries} attempts: {last_error}",
attempts=cfg.max_retries,
last_score=last_score,
)
raise QualityError(
f"{operation_name} failed to meet quality threshold after {cfg.max_retries} attempts",
attempts=cfg.max_retries,
last_score=last_score,
)
async def execute_simple(
self,
operation: Callable[[], Any],
config: Optional[RetryConfig] = None,
operation_name: str = "operation",
) -> Any:
"""
Execute operation with simple retry (no quality evaluation)
Args:
operation: Async callable to execute
config: Retry configuration
operation_name: Name for logging
Returns:
Operation result
Raises:
Exception: Last exception if all retries fail
"""
cfg = config or self.default_config
last_error: Optional[Exception] = None
for attempt in range(1, cfg.max_retries + 1):
try:
logger.debug(f"{operation_name}: attempt {attempt}/{cfg.max_retries}")
result = await operation()
logger.debug(f"{operation_name}: completed on attempt {attempt}")
return result
except Exception as e:
last_error = e
logger.warning(f"{operation_name}: attempt {attempt} failed: {e}")
if attempt < cfg.max_retries:
delay_ms = min(
cfg.initial_delay_ms * (cfg.backoff_factor ** (attempt - 1)),
cfg.max_delay_ms
)
await asyncio.sleep(delay_ms / 1000)
raise last_error
@staticmethod
def create_prompt_simplifier(original_prompt: str) -> Callable[[], str]:
"""
Create a fallback that simplifies the prompt
Args:
original_prompt: Original complex prompt
Returns:
Callable that returns a simplified prompt
"""
def simplify() -> str:
# Simple prompt simplification strategies
simplified = original_prompt
# Remove complex modifiers
remove_phrases = [
"highly detailed",
"ultra realistic",
"8k resolution",
"masterpiece",
"best quality",
]
for phrase in remove_phrases:
simplified = simplified.replace(phrase, "").replace(phrase.lower(), "")
# Truncate if too long
if len(simplified) > 150:
simplified = simplified[:150].rsplit(" ", 1)[0]
# Clean up
simplified = " ".join(simplified.split())
return simplified
return simplify

View File

@@ -0,0 +1,276 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
StyleGuard - Visual style consistency engine
Ensures consistent visual style across all frames in a video by:
1. Extracting style anchor from the first generated frame
2. Injecting style constraints into subsequent frame prompts
3. (Optional) Using style reference techniques like IP-Adapter
"""
from dataclasses import dataclass, field
from typing import List, Optional
from loguru import logger
@dataclass
class StyleAnchor:
"""Style anchor extracted from reference frame"""
# Core style elements
color_palette: str = "" # e.g., "warm earth tones", "cool blues"
art_style: str = "" # e.g., "minimalist", "realistic", "anime"
composition_style: str = "" # e.g., "centered", "rule of thirds"
texture: str = "" # e.g., "smooth", "grainy", "watercolor"
lighting: str = "" # e.g., "soft ambient", "dramatic shadows"
# Combined style prefix for prompts
style_prefix: str = ""
# Reference image path (for IP-Adapter style techniques)
reference_image: Optional[str] = None
def to_prompt_prefix(self) -> str:
"""Generate a style prefix for image prompts"""
if self.style_prefix:
return self.style_prefix
elements = []
if self.art_style:
elements.append(f"{self.art_style} style")
if self.color_palette:
elements.append(f"{self.color_palette}")
if self.lighting:
elements.append(f"{self.lighting} lighting")
if self.texture:
elements.append(f"{self.texture} texture")
return ", ".join(elements) if elements else ""
def to_dict(self) -> dict:
return {
"color_palette": self.color_palette,
"art_style": self.art_style,
"composition_style": self.composition_style,
"texture": self.texture,
"lighting": self.lighting,
"style_prefix": self.style_prefix,
"reference_image": self.reference_image,
}
@dataclass
class StyleGuardConfig:
"""Configuration for StyleGuard"""
# Extraction settings
extract_from_first_frame: bool = True
use_vlm_extraction: bool = True
# Application settings
apply_to_all_frames: bool = True
prefix_position: str = "start" # "start" or "end"
# Optional external style reference
external_style_image: Optional[str] = None
custom_style_prefix: Optional[str] = None
class StyleGuard:
"""
Style consistency guardian for video generation
Ensures all frames in a video maintain visual consistency by:
1. Analyzing the first frame (or reference image) to extract style
2. Applying style constraints to all subsequent frame prompts
Example:
>>> style_guard = StyleGuard(llm_service)
>>>
>>> # Extract style from first frame
>>> anchor = await style_guard.extract_style_anchor(
... image_path="output/frame_001.png"
... )
>>>
>>> # Apply to subsequent prompts
>>> styled_prompt = style_guard.apply_style(
... prompt="A cat sitting on a windowsill",
... style_anchor=anchor
... )
"""
def __init__(
self,
llm_service=None,
config: Optional[StyleGuardConfig] = None
):
"""
Initialize StyleGuard
Args:
llm_service: LLM service for VLM-based style extraction
config: StyleGuard configuration
"""
self.llm_service = llm_service
self.config = config or StyleGuardConfig()
self._current_anchor: Optional[StyleAnchor] = None
async def extract_style_anchor(
self,
image_path: str,
) -> StyleAnchor:
"""
Extract style anchor from reference image
Args:
image_path: Path to reference image
Returns:
StyleAnchor with extracted style elements
"""
logger.info(f"Extracting style anchor from: {image_path}")
if self.config.custom_style_prefix:
# Use custom style prefix if provided
anchor = StyleAnchor(
style_prefix=self.config.custom_style_prefix,
reference_image=image_path
)
self._current_anchor = anchor
return anchor
if self.config.use_vlm_extraction and self.llm_service:
anchor = await self._extract_with_vlm(image_path)
else:
anchor = self._extract_basic(image_path)
self._current_anchor = anchor
logger.info(f"Style anchor extracted: {anchor.to_prompt_prefix()}")
return anchor
async def _extract_with_vlm(self, image_path: str) -> StyleAnchor:
"""Extract style using Vision Language Model"""
try:
# TODO: Implement VLM call when vision-capable LLM is integrated
# For now, return a placeholder
logger.debug("VLM style extraction: using placeholder (VLM not yet integrated)")
# Placeholder extraction based on common styles
return StyleAnchor(
art_style="consistent artistic",
color_palette="harmonious colors",
lighting="balanced",
style_prefix="maintaining visual consistency, same artistic style as previous frames",
reference_image=image_path,
)
except Exception as e:
logger.warning(f"VLM style extraction failed: {e}")
return self._extract_basic(image_path)
def _extract_basic(self, image_path: str) -> StyleAnchor:
"""Basic style extraction without VLM"""
# Return generic style anchor
return StyleAnchor(
style_prefix="consistent visual style",
reference_image=image_path,
)
def apply_style(
self,
prompt: str,
style_anchor: Optional[StyleAnchor] = None,
) -> str:
"""
Apply style constraints to an image prompt
Args:
prompt: Original image prompt
style_anchor: Style anchor to apply (uses current if not provided)
Returns:
Modified prompt with style constraints
"""
anchor = style_anchor or self._current_anchor
if not anchor:
return prompt
style_prefix = anchor.to_prompt_prefix()
if not style_prefix:
return prompt
if self.config.prefix_position == "start":
return f"{style_prefix}, {prompt}"
else:
return f"{prompt}, {style_prefix}"
def apply_style_to_batch(
self,
prompts: List[str],
style_anchor: Optional[StyleAnchor] = None,
skip_first: bool = True,
) -> List[str]:
"""
Apply style constraints to a batch of prompts
Args:
prompts: List of image prompts
style_anchor: Style anchor to apply
skip_first: Skip first prompt (used as reference)
Returns:
List of styled prompts
"""
if not prompts:
return prompts
anchor = style_anchor or self._current_anchor
if not anchor:
return prompts
result = []
for i, prompt in enumerate(prompts):
if skip_first and i == 0:
result.append(prompt)
else:
result.append(self.apply_style(prompt, anchor))
return result
def get_consistency_prompt_suffix(self) -> str:
"""
Get a consistency prompt suffix for LLM prompt generation
This can be added to the LLM prompt when generating image prompts
to encourage consistent style descriptions.
"""
return (
"Ensure all image prompts maintain consistent visual style, "
"including similar color palette, art style, lighting, and composition. "
"Each image should feel like it belongs to the same visual narrative."
)
@property
def current_anchor(self) -> Optional[StyleAnchor]:
"""Get the current style anchor"""
return self._current_anchor
def reset(self):
"""Reset the current style anchor"""
self._current_anchor = None