feat: Add comprehensive timeline editor with frame editing and regeneration capabilities

This commit is contained in:
empty
2026-01-05 14:48:43 +08:00
parent 7d78dcd078
commit ca018a9b1f
68 changed files with 14904 additions and 57 deletions

View File

@@ -0,0 +1,84 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Quality Assurance Services for Pixelle-Video
This module provides quality control mechanisms for video generation:
- QualityGate: Evaluates generated content quality
- RetryManager: Smart retry with quality-based decisions
- OutputValidator: LLM output validation
- StyleGuard: Visual style consistency
- ContentFilter: Content moderation
- CharacterMemory: Character consistency across frames
"""
from pixelle_video.services.quality.models import (
QualityScore,
QualityConfig,
RetryConfig,
QualityError,
)
from pixelle_video.services.quality.quality_gate import QualityGate
from pixelle_video.services.quality.retry_manager import RetryManager
from pixelle_video.services.quality.output_validator import (
OutputValidator,
ValidationConfig,
ValidationResult,
)
from pixelle_video.services.quality.style_guard import (
StyleGuard,
StyleGuardConfig,
StyleAnchor,
)
from pixelle_video.services.quality.content_filter import (
ContentFilter,
ContentFilterConfig,
FilterResult,
FilterCategory,
)
from pixelle_video.services.quality.character_memory import (
CharacterMemory,
CharacterMemoryConfig,
Character,
CharacterType,
)
__all__ = [
# Quality evaluation
"QualityScore",
"QualityConfig",
"RetryConfig",
"QualityError",
"QualityGate",
"RetryManager",
# Output validation
"OutputValidator",
"ValidationConfig",
"ValidationResult",
# Style consistency
"StyleGuard",
"StyleGuardConfig",
"StyleAnchor",
# Content moderation
"ContentFilter",
"ContentFilterConfig",
"FilterResult",
"FilterCategory",
# Character memory
"CharacterMemory",
"CharacterMemoryConfig",
"Character",
"CharacterType",
]

View File

@@ -0,0 +1,530 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
CharacterMemory - Character consistency and memory system
Maintains consistent character appearance across video frames by:
1. Detecting and registering characters from narrations
2. Extracting visual descriptions from first appearances
3. Injecting character consistency prompts into subsequent frames
4. Supporting reference images for ComfyUI IP-Adapter/ControlNet
"""
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Set
from datetime import datetime
from enum import Enum
from loguru import logger
class CharacterType(Enum):
"""Type of character"""
PERSON = "person" # Human character
ANIMAL = "animal" # Animal character
CREATURE = "creature" # Fantasy/fictional creature
OBJECT = "object" # Personified object
ABSTRACT = "abstract" # Abstract entity
@dataclass
class Character:
"""
Represents a character in the video narrative
Stores visual description, reference images, and appearance history
to maintain consistency across frames.
"""
# Identity
id: str # Unique identifier
name: str # Character name (e.g., "小明", "the hero")
aliases: List[str] = field(default_factory=list) # Alternative names
character_type: CharacterType = CharacterType.PERSON
# Visual description (for prompt injection)
appearance_description: str = "" # Detailed visual description
clothing_description: str = "" # Clothing/outfit description
distinctive_features: List[str] = field(default_factory=list) # Unique features
# Reference images (for IP-Adapter/ControlNet)
reference_images: List[str] = field(default_factory=list) # Paths to reference images
primary_reference: Optional[str] = None # Primary reference image
# Prompt elements
prompt_prefix: str = "" # Pre-built prompt prefix
negative_prompt: str = "" # Negative prompt additions
# Metadata
is_active: bool = True # Whether this character is active for logic
first_appearance_frame: int = 0 # Frame index of first appearance
appearance_frames: List[int] = field(default_factory=list) # All frames with this character
created_at: Optional[datetime] = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now()
if not hasattr(self, 'is_active'):
self.is_active = True
if not self.prompt_prefix:
self._build_prompt_prefix()
def _build_prompt_prefix(self):
"""Build prompt prefix from visual descriptions"""
elements = []
if self.appearance_description:
elements.append(self.appearance_description)
if self.clothing_description:
elements.append(f"wearing {self.clothing_description}")
if self.distinctive_features:
elements.append(", ".join(self.distinctive_features))
self.prompt_prefix = ", ".join(elements) if elements else ""
def get_prompt_injection(self) -> str:
"""Get the prompt text to inject for this character"""
if self.prompt_prefix:
return f"({self.name}: {self.prompt_prefix})"
return f"({self.name})"
def add_reference_image(self, image_path: str, set_as_primary: bool = False):
"""Add a reference image for this character"""
if image_path not in self.reference_images:
self.reference_images.append(image_path)
if set_as_primary or self.primary_reference is None:
self.primary_reference = image_path
def matches_name(self, name: str) -> bool:
"""Check if a name matches this character"""
name_lower = name.lower().strip()
if self.name.lower() == name_lower:
return True
return any(alias.lower() == name_lower for alias in self.aliases)
def to_dict(self) -> dict:
return {
"id": self.id,
"name": self.name,
"aliases": self.aliases,
"type": self.character_type.value,
"appearance_description": self.appearance_description,
"clothing_description": self.clothing_description,
"distinctive_features": self.distinctive_features,
"reference_images": self.reference_images,
"primary_reference": self.primary_reference,
"prompt_prefix": self.prompt_prefix,
"first_appearance_frame": self.first_appearance_frame,
}
@dataclass
class CharacterMemoryConfig:
"""Configuration for character memory system"""
# Detection settings
auto_detect_characters: bool = True # Automatically detect characters from narrations
use_llm_detection: bool = True # Use LLM to extract character info
# Consistency settings
inject_character_prompts: bool = True # Inject character descriptions into prompts
use_reference_images: bool = True # Use reference images for generation
# Reference image settings
extract_reference_from_first: bool = True # Extract reference from first appearance
max_reference_images: int = 3 # Max reference images per character
# Prompt injection settings
prompt_injection_position: str = "start" # "start" or "end"
include_clothing: bool = True # Include clothing in prompts
include_features: bool = True # Include distinctive features
class CharacterMemory:
"""
Character memory and consistency manager
Tracks characters across video frames and ensures visual consistency
by injecting character descriptions and reference images into the
generation pipeline.
Example:
>>> memory = CharacterMemory(llm_service)
>>>
>>> # Register a character
>>> char = memory.register_character(
... name="小明",
... appearance_description="young man with short black hair",
... clothing_description="blue t-shirt"
... )
>>>
>>> # Apply to prompt
>>> enhanced_prompt = memory.apply_to_prompt(
... prompt="A person walking in the park",
... characters=["小明"]
... )
"""
def __init__(
self,
llm_service=None,
config: Optional[CharacterMemoryConfig] = None
):
"""
Initialize CharacterMemory
Args:
llm_service: Optional LLM service for character detection
config: Character memory configuration
"""
self.llm_service = llm_service
self.config = config or CharacterMemoryConfig()
self._characters: Dict[str, Character] = {}
self._name_index: Dict[str, str] = {} # name -> character_id mapping
def register_character(
self,
name: str,
appearance_description: str = "",
clothing_description: str = "",
distinctive_features: Optional[List[str]] = None,
character_type: CharacterType = CharacterType.PERSON,
first_frame: int = 0,
) -> Character:
"""
Register a new character
Args:
name: Character name
appearance_description: Visual appearance description
clothing_description: Clothing/outfit description
distinctive_features: List of distinctive features
character_type: Type of character
first_frame: Frame index of first appearance
Returns:
Created Character object
"""
# Generate unique ID
char_id = f"char_{len(self._characters)}_{name.replace(' ', '_').lower()}"
character = Character(
id=char_id,
name=name,
appearance_description=appearance_description,
clothing_description=clothing_description,
distinctive_features=distinctive_features or [],
character_type=character_type,
first_appearance_frame=first_frame,
appearance_frames=[first_frame],
)
self._characters[char_id] = character
self._name_index[name.lower()] = char_id
logger.info(f"Registered character: {name} (id={char_id})")
return character
def get_character(self, name: str) -> Optional[Character]:
"""Get a character by name"""
name_lower = name.lower().strip()
char_id = self._name_index.get(name_lower)
if char_id:
return self._characters.get(char_id)
# Search aliases
for char in self._characters.values():
if char.matches_name(name):
return char
return None
def get_character_by_id(self, char_id: str) -> Optional[Character]:
"""Get a character by ID"""
return self._characters.get(char_id)
@property
def characters(self) -> List[Character]:
"""Get all registered characters"""
return list(self._characters.values())
async def detect_characters_from_narration(
self,
narration: str,
frame_index: int = 0,
) -> List[Character]:
"""
Detect and register characters mentioned in narration
Args:
narration: Narration text to analyze
frame_index: Current frame index
Returns:
List of detected/registered characters
"""
if not self.config.auto_detect_characters:
return []
detected = []
if self.config.use_llm_detection and self.llm_service:
detected = await self._detect_with_llm(narration, frame_index)
else:
detected = self._detect_basic(narration, frame_index)
return detected
async def _detect_with_llm(
self,
narration: str,
frame_index: int,
) -> List[Character]:
"""Detect characters using LLM"""
if not self.llm_service:
return []
try:
prompt = f"""分析以下文案,提取其中提到的角色/人物。
文案: {narration}
请用 JSON 格式返回角色列表,每个角色包含:
- name: 角色名称或代称
- type: person/animal/creature/object
- appearance: 外貌描述(如有)
- clothing: 服装描述(如有)
如果没有明确角色,返回空列表 []。
只返回 JSON不要其他解释。"""
response = await self.llm_service(prompt, temperature=0.1)
# Parse response
import json
import re
# Extract JSON from response
json_match = re.search(r'\[.*\]', response, re.DOTALL)
if json_match:
characters_data = json.loads(json_match.group())
result = []
for char_data in characters_data:
name = char_data.get("name", "").strip()
if not name:
continue
# Check if already registered
existing = self.get_character(name)
if existing:
existing.appearance_frames.append(frame_index)
result.append(existing)
else:
# Register new character
char_type = CharacterType.PERSON
type_str = char_data.get("type", "person").lower()
if type_str == "animal":
char_type = CharacterType.ANIMAL
elif type_str == "creature":
char_type = CharacterType.CREATURE
char = self.register_character(
name=name,
appearance_description=char_data.get("appearance", ""),
clothing_description=char_data.get("clothing", ""),
character_type=char_type,
first_frame=frame_index,
)
result.append(char)
return result
return []
except Exception as e:
logger.warning(f"LLM character detection failed: {e}")
return self._detect_basic(narration, frame_index)
def _detect_basic(
self,
narration: str,
frame_index: int,
) -> List[Character]:
"""Basic character detection without LLM"""
# Simple pattern matching for common character references
import re
patterns = [
r'(?:他|她|它)们?', # Chinese pronouns
r'(?:小\w{1,2})', # Names like 小明, 小红
r'(?:老\w{1,2})', # Names like 老王, 老李
]
detected = []
for pattern in patterns:
matches = re.findall(pattern, narration)
for match in matches:
existing = self.get_character(match)
if existing:
existing.appearance_frames.append(frame_index)
if existing not in detected:
detected.append(existing)
return detected
def apply_to_prompt(
self,
prompt: str,
character_names: Optional[List[str]] = None,
frame_index: Optional[int] = None,
) -> str:
"""
Apply character consistency to an image prompt
Args:
prompt: Original image prompt
character_names: Specific characters to include (None = auto-detect)
frame_index: Current frame index for tracking
Returns:
Enhanced prompt with character consistency
"""
if not self.config.inject_character_prompts:
return prompt
characters_to_include = []
if character_names:
for name in character_names:
char = self.get_character(name)
if char:
characters_to_include.append(char)
else:
# Include all characters that have appeared
characters_to_include = self.characters
if not characters_to_include:
return prompt
# Build character injection
injections = []
for char in characters_to_include:
injection = char.get_prompt_injection()
if injection:
injections.append(injection)
# Track appearance
if frame_index is not None and frame_index not in char.appearance_frames:
char.appearance_frames.append(frame_index)
if not injections:
return prompt
character_prompt = ", ".join(injections)
if self.config.prompt_injection_position == "start":
return f"{character_prompt}, {prompt}"
else:
return f"{prompt}, {character_prompt}"
def get_reference_images(
self,
character_names: Optional[List[str]] = None,
) -> List[str]:
"""
Get reference images for specified characters
Args:
character_names: Character names (None = all characters)
Returns:
List of reference image paths
"""
if not self.config.use_reference_images:
return []
images = []
if character_names:
for name in character_names:
char = self.get_character(name)
if char and char.primary_reference:
images.append(char.primary_reference)
else:
for char in self.characters:
if char.primary_reference:
images.append(char.primary_reference)
return images[:self.config.max_reference_images]
def set_reference_image(
self,
character_name: str,
image_path: str,
set_as_primary: bool = True,
):
"""
Set a reference image for a character
Args:
character_name: Character name
image_path: Path to reference image
set_as_primary: Whether to set as primary reference
"""
char = self.get_character(character_name)
if char:
char.add_reference_image(image_path, set_as_primary)
logger.debug(f"Set reference image for {character_name}: {image_path}")
else:
logger.warning(f"Character not found: {character_name}")
def update_character_appearance(
self,
character_name: str,
appearance_description: Optional[str] = None,
clothing_description: Optional[str] = None,
distinctive_features: Optional[List[str]] = None,
):
"""Update a character's visual description"""
char = self.get_character(character_name)
if char:
if appearance_description:
char.appearance_description = appearance_description
if clothing_description:
char.clothing_description = clothing_description
if distinctive_features:
char.distinctive_features = distinctive_features
char._build_prompt_prefix()
logger.debug(f"Updated appearance for {character_name}")
def get_consistency_summary(self) -> str:
"""Get a summary of character consistency for logging"""
if not self._characters:
return "No characters registered"
lines = [f"Characters ({len(self._characters)}):"]
for char in self.characters:
lines.append(
f" - {char.name}: {len(char.appearance_frames)} appearances, "
f"ref_images={len(char.reference_images)}"
)
return "\n".join(lines)
def reset(self):
"""Clear all character memory"""
self._characters.clear()
self._name_index.clear()
logger.info("Character memory cleared")

View File

@@ -0,0 +1,316 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ContentFilter - Content moderation and safety filtering
Provides content safety checks for:
- Text content (narrations, prompts)
- Generated images
- Generated videos
"""
import re
from dataclasses import dataclass, field
from typing import List, Optional, Set
from enum import Enum
from loguru import logger
class FilterCategory(Enum):
"""Content filter categories"""
SAFE = "safe"
SENSITIVE = "sensitive" # May require review
BLOCKED = "blocked" # Should not proceed
@dataclass
class FilterResult:
"""Result of content filtering"""
category: FilterCategory
passed: bool
flagged_items: List[str] = field(default_factory=list)
reason: Optional[str] = None
confidence: float = 1.0
def to_dict(self) -> dict:
return {
"category": self.category.value,
"passed": self.passed,
"flagged_items": self.flagged_items,
"reason": self.reason,
"confidence": self.confidence,
}
@dataclass
class ContentFilterConfig:
"""Configuration for content filtering"""
# Text filtering
enable_keyword_filter: bool = True
enable_llm_filter: bool = False # Use LLM for semantic filtering
# Custom keywords to block (added to default list)
custom_blocked_keywords: List[str] = field(default_factory=list)
# Sensitivity level: "strict", "moderate", "permissive"
sensitivity_level: str = "moderate"
# Image filtering
enable_image_filter: bool = False # Requires external service
# Action on detection
block_on_sensitive: bool = False # Block content marked as sensitive
log_filtered_content: bool = True
class ContentFilter:
"""
Content moderation filter for generated content
Provides safety filtering for text and media content to prevent
inappropriate or harmful content from being generated.
Example:
>>> filter = ContentFilter()
>>> result = await filter.check_text("Hello, world!")
>>> if result.passed:
... print("Content is safe")
"""
# Default blocked keywords (minimal list for demonstration)
DEFAULT_BLOCKED_PATTERNS = [
r"\b(violence|gore|blood)\b",
r"\b(nsfw|explicit|pornographic)\b",
r"\b(illegal|drugs|weapons)\b",
]
# Sensitive keywords (may require review)
DEFAULT_SENSITIVE_PATTERNS = [
r"\b(death|dying|kill)\b",
r"\b(hate|racist|sexist)\b",
r"\b(controversial|political)\b",
]
def __init__(
self,
llm_service=None,
config: Optional[ContentFilterConfig] = None
):
"""
Initialize ContentFilter
Args:
llm_service: Optional LLM service for semantic filtering
config: Filter configuration
"""
self.llm_service = llm_service
self.config = config or ContentFilterConfig()
# Compile patterns
self._blocked_patterns = [
re.compile(p, re.IGNORECASE)
for p in self.DEFAULT_BLOCKED_PATTERNS
]
self._sensitive_patterns = [
re.compile(p, re.IGNORECASE)
for p in self.DEFAULT_SENSITIVE_PATTERNS
]
# Add custom keywords
if self.config.custom_blocked_keywords:
for keyword in self.config.custom_blocked_keywords:
pattern = re.compile(rf"\b{re.escape(keyword)}\b", re.IGNORECASE)
self._blocked_patterns.append(pattern)
async def check_text(self, text: str) -> FilterResult:
"""
Check text content for safety
Args:
text: Text to check
Returns:
FilterResult with safety assessment
"""
if not text:
return FilterResult(
category=FilterCategory.SAFE,
passed=True
)
flagged_items = []
category = FilterCategory.SAFE
# Keyword filtering
if self.config.enable_keyword_filter:
# Check blocked patterns
for pattern in self._blocked_patterns:
matches = pattern.findall(text)
if matches:
flagged_items.extend(matches)
category = FilterCategory.BLOCKED
# Check sensitive patterns (if not already blocked)
if category != FilterCategory.BLOCKED:
for pattern in self._sensitive_patterns:
matches = pattern.findall(text)
if matches:
flagged_items.extend(matches)
category = FilterCategory.SENSITIVE
# LLM-based semantic filtering
if self.config.enable_llm_filter and self.llm_service:
semantic_result = await self._check_with_llm(text)
if semantic_result.category.value > category.value:
category = semantic_result.category
flagged_items.extend(semantic_result.flagged_items)
# Determine if passed
if category == FilterCategory.BLOCKED:
passed = False
reason = "Content contains blocked keywords or themes"
elif category == FilterCategory.SENSITIVE and self.config.block_on_sensitive:
passed = False
reason = "Content contains sensitive themes (strict mode)"
else:
passed = True
reason = None
# Log if configured
if self.config.log_filtered_content and flagged_items:
logger.warning(f"Content filter flagged: {flagged_items}")
return FilterResult(
category=category,
passed=passed,
flagged_items=flagged_items,
reason=reason,
)
async def check_texts(self, texts: List[str]) -> FilterResult:
"""
Check multiple texts and return aggregate result
Args:
texts: List of texts to check
Returns:
Aggregate FilterResult
"""
all_flagged = []
worst_category = FilterCategory.SAFE
for text in texts:
result = await self.check_text(text)
all_flagged.extend(result.flagged_items)
if result.category == FilterCategory.BLOCKED:
worst_category = FilterCategory.BLOCKED
elif result.category == FilterCategory.SENSITIVE and worst_category != FilterCategory.BLOCKED:
worst_category = FilterCategory.SENSITIVE
passed = worst_category == FilterCategory.SAFE or (
worst_category == FilterCategory.SENSITIVE and not self.config.block_on_sensitive
)
return FilterResult(
category=worst_category,
passed=passed,
flagged_items=list(set(all_flagged)), # Deduplicate
reason=f"Found {len(all_flagged)} flagged items" if all_flagged else None,
)
async def check_image(self, image_path: str) -> FilterResult:
"""
Check image content for safety
Note: This requires external NSFW detection service integration.
Currently returns safe by default.
Args:
image_path: Path to image file
Returns:
FilterResult with safety assessment
"""
if not self.config.enable_image_filter:
return FilterResult(
category=FilterCategory.SAFE,
passed=True,
reason="Image filtering disabled"
)
# TODO: Integrate with external NSFW detection
# Options: TensorFlow NSFW model, Azure Content Moderator, etc.
logger.debug(f"Image safety check: {image_path} (not implemented, assuming safe)")
return FilterResult(
category=FilterCategory.SAFE,
passed=True,
reason="Image filtering not implemented"
)
async def _check_with_llm(self, text: str) -> FilterResult:
"""Check text using LLM for semantic understanding"""
if not self.llm_service:
return FilterResult(category=FilterCategory.SAFE, passed=True)
try:
prompt = f"""Analyze the following text for content safety.
Rate it as: SAFE, SENSITIVE, or BLOCKED.
SAFE: No concerning content
SENSITIVE: Contains themes that may need review (violence, death, controversial topics)
BLOCKED: Contains explicit, illegal, or harmful content
Text: {text[:500]}
Respond with only one word: SAFE, SENSITIVE, or BLOCKED."""
response = await self.llm_service(prompt, temperature=0.0, max_tokens=10)
response = response.strip().upper()
if "BLOCKED" in response:
return FilterResult(
category=FilterCategory.BLOCKED,
passed=False,
reason="LLM detected blocked content"
)
elif "SENSITIVE" in response:
return FilterResult(
category=FilterCategory.SENSITIVE,
passed=not self.config.block_on_sensitive,
reason="LLM detected sensitive content"
)
else:
return FilterResult(
category=FilterCategory.SAFE,
passed=True
)
except Exception as e:
logger.warning(f"LLM content check failed: {e}")
return FilterResult(category=FilterCategory.SAFE, passed=True)
def add_blocked_keyword(self, keyword: str):
"""Add a keyword to the blocked list"""
pattern = re.compile(rf"\b{re.escape(keyword)}\b", re.IGNORECASE)
self._blocked_patterns.append(pattern)
def add_sensitive_keyword(self, keyword: str):
"""Add a keyword to the sensitive list"""
pattern = re.compile(rf"\b{re.escape(keyword)}\b", re.IGNORECASE)
self._sensitive_patterns.append(pattern)

View File

@@ -0,0 +1,140 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data models for quality assurance
"""
from dataclasses import dataclass, field
from typing import List, Optional
from enum import Enum
class QualityLevel(Enum):
"""Quality level enumeration"""
EXCELLENT = "excellent" # >= 0.8
GOOD = "good" # >= 0.6
ACCEPTABLE = "acceptable" # >= 0.4
POOR = "poor" # < 0.4
@dataclass
class QualityScore:
"""Quality evaluation result for generated content"""
# Individual scores (0.0 - 1.0)
aesthetic_score: float = 0.0 # Visual appeal / beauty
text_match_score: float = 0.0 # How well image matches the prompt
technical_score: float = 0.0 # Technical quality (clarity, no artifacts)
# Overall
overall_score: float = 0.0 # Weighted average
passed: bool = False # Whether it meets threshold
# Diagnostics
issues: List[str] = field(default_factory=list) # Detected problems
evaluation_time_ms: float = 0.0 # Time taken for evaluation
@property
def level(self) -> QualityLevel:
"""Get quality level based on overall score"""
if self.overall_score >= 0.8:
return QualityLevel.EXCELLENT
elif self.overall_score >= 0.6:
return QualityLevel.GOOD
elif self.overall_score >= 0.4:
return QualityLevel.ACCEPTABLE
else:
return QualityLevel.POOR
def to_dict(self) -> dict:
"""Convert to dictionary for serialization"""
return {
"aesthetic_score": self.aesthetic_score,
"text_match_score": self.text_match_score,
"technical_score": self.technical_score,
"overall_score": self.overall_score,
"passed": self.passed,
"level": self.level.value,
"issues": self.issues,
}
@dataclass
class QualityConfig:
"""Configuration for quality evaluation"""
# Thresholds (0.0 - 1.0)
overall_threshold: float = 0.6 # Minimum overall score to pass
aesthetic_threshold: float = 0.5 # Minimum aesthetic score
text_match_threshold: float = 0.6 # Minimum text-match score
technical_threshold: float = 0.7 # Minimum technical score
# Weights for overall score calculation
aesthetic_weight: float = 0.3
text_match_weight: float = 0.4
technical_weight: float = 0.3
# Evaluation settings
use_vlm_evaluation: bool = True # Use VLM for evaluation (vs local models)
vlm_model: Optional[str] = None # VLM model to use (None = use default LLM)
skip_on_static_template: bool = True # Skip image quality check for static templates
def __post_init__(self):
"""Validate weights sum to 1.0"""
total = self.aesthetic_weight + self.text_match_weight + self.technical_weight
if abs(total - 1.0) > 0.01:
# Normalize weights
self.aesthetic_weight /= total
self.text_match_weight /= total
self.technical_weight /= total
@dataclass
class RetryConfig:
"""Configuration for retry behavior"""
max_retries: int = 3 # Maximum retry attempts
backoff_factor: float = 1.5 # Exponential backoff multiplier
initial_delay_ms: int = 500 # Initial delay before first retry
max_delay_ms: int = 10000 # Maximum delay between retries
# Quality-based retry
quality_threshold: float = 0.6 # Quality score threshold for pass
# Fallback behavior
enable_fallback: bool = True # Enable fallback strategy on failure
fallback_prompt_simplify: bool = True # Simplify prompt on retry
# Retry conditions
retry_on_quality_fail: bool = True # Retry when quality below threshold
retry_on_error: bool = True # Retry on generation errors
class QualityError(Exception):
"""Exception raised when quality standards are not met after all retries"""
def __init__(
self,
message: str,
attempts: int = 0,
last_score: Optional[QualityScore] = None
):
super().__init__(message)
self.attempts = attempts
self.last_score = last_score
def __str__(self) -> str:
base = super().__str__()
if self.last_score:
return f"{base} (attempts={self.attempts}, last_score={self.last_score.overall_score:.2f})"
return f"{base} (attempts={self.attempts})"

View File

@@ -0,0 +1,336 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OutputValidator - LLM output validation and quality control
Validates LLM-generated content for:
- Narrations: length, relevance, coherence
- Image prompts: format, language, prompt-narration alignment
"""
import re
from dataclasses import dataclass, field
from typing import List, Optional
from loguru import logger
@dataclass
class ValidationConfig:
"""Configuration for output validation"""
# Narration validation
min_narration_words: int = 5
max_narration_words: int = 50
relevance_threshold: float = 0.6
coherence_threshold: float = 0.6
# Image prompt validation
min_prompt_words: int = 10
max_prompt_words: int = 100
require_english_prompts: bool = True
prompt_match_threshold: float = 0.5
@dataclass
class ValidationResult:
"""Result of validation"""
passed: bool
issues: List[str] = field(default_factory=list)
score: float = 1.0 # 1.0 = perfect, 0.0 = failed
suggestions: List[str] = field(default_factory=list)
def to_dict(self) -> dict:
return {
"passed": self.passed,
"score": self.score,
"issues": self.issues,
"suggestions": self.suggestions,
}
class OutputValidator:
"""
Validator for LLM-generated outputs
Validates narrations and image prompts to ensure they meet quality standards
before proceeding with media generation.
Example:
>>> validator = OutputValidator(llm_service)
>>> result = await validator.validate_narrations(
... narrations=["旁白1", "旁白2"],
... topic="人生哲理",
... config=ValidationConfig()
... )
>>> if not result.passed:
... print(f"Validation failed: {result.issues}")
"""
def __init__(self, llm_service=None):
"""
Initialize OutputValidator
Args:
llm_service: Optional LLM service for semantic validation
"""
self.llm_service = llm_service
async def validate_narrations(
self,
narrations: List[str],
topic: str,
config: Optional[ValidationConfig] = None,
) -> ValidationResult:
"""
Validate generated narrations
Checks:
1. Length constraints
2. Non-empty content
3. Topic relevance (if LLM available)
4. Coherence between narrations (if LLM available)
Args:
narrations: List of narration texts
topic: Original topic/theme
config: Validation configuration
Returns:
ValidationResult with pass/fail and issues
"""
cfg = config or ValidationConfig()
issues = []
suggestions = []
scores = []
if not narrations:
return ValidationResult(
passed=False,
issues=["No narrations provided"],
score=0.0
)
# 1. Length validation
for i, narration in enumerate(narrations, 1):
word_count = len(narration)
if not narration.strip():
issues.append(f"分镜{i}: 内容为空")
scores.append(0.0)
continue
if word_count < cfg.min_narration_words:
issues.append(f"分镜{i}: 内容过短 ({word_count}字,最少{cfg.min_narration_words}字)")
scores.append(0.5)
elif word_count > cfg.max_narration_words:
issues.append(f"分镜{i}: 内容过长 ({word_count}字,最多{cfg.max_narration_words}字)")
suggestions.append(f"考虑将分镜{i}拆分为多个短句")
scores.append(0.7)
else:
scores.append(1.0)
# 2. Semantic validation (if LLM available)
if self.llm_service:
try:
relevance = await self._check_relevance(narrations, topic)
if relevance < cfg.relevance_threshold:
issues.append(f'内容与主题"{topic}"相关性不足 ({relevance:.0%})')
suggestions.append("建议重新生成,确保内容紧扣主题")
scores.append(relevance)
coherence = await self._check_coherence(narrations)
if coherence < cfg.coherence_threshold:
issues.append(f"内容连贯性不足 ({coherence:.0%})")
suggestions.append("建议检查段落之间的逻辑衔接")
scores.append(coherence)
except Exception as e:
logger.warning(f"Semantic validation failed: {e}")
# Don't fail on semantic check errors
# Calculate overall score
overall_score = sum(scores) / len(scores) if scores else 0.0
passed = len(issues) == 0 or overall_score >= 0.7
logger.debug(f"Narration validation: score={overall_score:.2f}, issues={len(issues)}")
return ValidationResult(
passed=passed,
issues=issues,
score=overall_score,
suggestions=suggestions,
)
async def validate_image_prompts(
self,
prompts: List[str],
narrations: List[str],
config: Optional[ValidationConfig] = None,
) -> ValidationResult:
"""
Validate generated image prompts
Checks:
1. Length constraints
2. Language (should be English)
3. Prompt-narration alignment
Args:
prompts: List of image prompts
narrations: Corresponding narrations
config: Validation configuration
Returns:
ValidationResult with pass/fail and issues
"""
cfg = config or ValidationConfig()
issues = []
suggestions = []
scores = []
if not prompts:
return ValidationResult(
passed=False,
issues=["No image prompts provided"],
score=0.0
)
if len(prompts) != len(narrations):
issues.append(f"提示词数量({len(prompts)})与旁白数量({len(narrations)})不匹配")
for i, prompt in enumerate(prompts, 1):
if not prompt.strip():
issues.append(f"提示词{i}: 内容为空")
scores.append(0.0)
continue
word_count = len(prompt.split())
# Length check
if word_count < cfg.min_prompt_words:
issues.append(f"提示词{i}: 过短 ({word_count}词,最少{cfg.min_prompt_words}词)")
scores.append(0.5)
elif word_count > cfg.max_prompt_words:
issues.append(f"提示词{i}: 过长 ({word_count}词,最多{cfg.max_prompt_words}词)")
scores.append(0.8)
else:
scores.append(1.0)
# English check
if cfg.require_english_prompts:
chinese_ratio = self._get_chinese_ratio(prompt)
if chinese_ratio > 0.3: # More than 30% Chinese characters
issues.append(f"提示词{i}: 应使用英文 (当前含{chinese_ratio:.0%}中文)")
suggestions.append(f"将提示词{i}翻译为英文以获得更好的生成效果")
scores[-1] *= 0.5
overall_score = sum(scores) / len(scores) if scores else 0.0
passed = len(issues) == 0 or overall_score >= 0.7
logger.debug(f"Image prompt validation: score={overall_score:.2f}, issues={len(issues)}")
return ValidationResult(
passed=passed,
issues=issues,
score=overall_score,
suggestions=suggestions,
)
async def _check_relevance(
self,
narrations: List[str],
topic: str,
) -> float:
"""
Check relevance of narrations to topic using LLM
Returns:
Relevance score 0.0-1.0
"""
if not self.llm_service:
return 0.8 # Default score when LLM not available
try:
combined_text = "\n".join(narrations[:3]) # Check first 3 for efficiency
prompt = f"""评估以下内容与主题"{topic}"的相关性。
内容:
{combined_text}
请用0-100的分数评估相关性只输出数字。
相关性越高,分数越高。"""
response = await self.llm_service(prompt, temperature=0.1, max_tokens=10)
# Parse score from response
score_match = re.search(r'\d+', response)
if score_match:
score = int(score_match.group()) / 100
return min(1.0, max(0.0, score))
return 0.7 # Default if parsing fails
except Exception as e:
logger.warning(f"Relevance check failed: {e}")
return 0.7
async def _check_coherence(
self,
narrations: List[str],
) -> float:
"""
Check coherence between narrations using LLM
Returns:
Coherence score 0.0-1.0
"""
if not self.llm_service or len(narrations) < 2:
return 0.8 # Default score
try:
numbered = "\n".join(f"{i+1}. {n}" for i, n in enumerate(narrations[:5]))
prompt = f"""评估以下段落之间的逻辑连贯性。
{numbered}
请用0-100的分数评估连贯性只输出数字。
段落之间逻辑顺畅、衔接自然则分数高。"""
response = await self.llm_service(prompt, temperature=0.1, max_tokens=10)
score_match = re.search(r'\d+', response)
if score_match:
score = int(score_match.group()) / 100
return min(1.0, max(0.0, score))
return 0.7
except Exception as e:
logger.warning(f"Coherence check failed: {e}")
return 0.7
def _get_chinese_ratio(self, text: str) -> float:
"""Calculate ratio of Chinese characters in text"""
if not text:
return 0.0
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
total_chars = len(text.replace(' ', ''))
if total_chars == 0:
return 0.0
return chinese_chars / total_chars

View File

@@ -0,0 +1,363 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
QualityGate - Quality evaluation system for generated content
Evaluates images and videos based on:
- Aesthetic quality (visual appeal)
- Text-to-image matching (semantic alignment)
- Technical quality (clarity, no artifacts)
"""
import time
from typing import Optional
from pathlib import Path
from loguru import logger
from pixelle_video.services.quality.models import QualityScore, QualityConfig
class QualityGate:
"""
Quality evaluation gate for AI-generated content
Uses VLM (Vision Language Model) or local models to evaluate:
1. Aesthetic quality - Is the image visually appealing?
2. Text matching - Does the image match the prompt/narration?
3. Technical quality - Is the image clear and free of artifacts?
Example:
>>> gate = QualityGate(llm_service, config)
>>> score = await gate.evaluate_image(
... image_path="output/frame_001.png",
... prompt="A sunset over mountains",
... narration="夕阳西下,余晖洒满山间"
... )
>>> if score.passed:
... print("Image quality approved!")
"""
def __init__(
self,
llm_service=None,
config: Optional[QualityConfig] = None
):
"""
Initialize QualityGate
Args:
llm_service: LLM service for VLM-based evaluation
config: Quality configuration
"""
self.llm_service = llm_service
self.config = config or QualityConfig()
async def evaluate_image(
self,
image_path: str,
prompt: str,
narration: Optional[str] = None,
) -> QualityScore:
"""
Evaluate the quality of a generated image
Args:
image_path: Path to the image file
prompt: The prompt used to generate the image
narration: Optional narration text for context
Returns:
QualityScore with evaluation results
"""
start_time = time.time()
issues = []
# Validate image exists
if not Path(image_path).exists():
return QualityScore(
passed=False,
issues=["Image file not found"],
evaluation_time_ms=(time.time() - start_time) * 1000
)
# Evaluate using VLM or fallback to basic checks
if self.config.use_vlm_evaluation and self.llm_service:
score = await self._evaluate_with_vlm(image_path, prompt, narration)
else:
score = await self._evaluate_basic(image_path, prompt)
# Set evaluation time
score.evaluation_time_ms = (time.time() - start_time) * 1000
# Determine if passed
score.passed = score.overall_score >= self.config.overall_threshold
logger.debug(
f"Quality evaluation: overall={score.overall_score:.2f}, "
f"passed={score.passed}, time={score.evaluation_time_ms:.0f}ms"
)
return score
async def evaluate_video(
self,
video_path: str,
prompt: str,
narration: Optional[str] = None,
) -> QualityScore:
"""
Evaluate the quality of a generated video
Args:
video_path: Path to the video file
prompt: The prompt used to generate the video
narration: Optional narration text for context
Returns:
QualityScore with evaluation results
"""
start_time = time.time()
# Validate video exists
if not Path(video_path).exists():
return QualityScore(
passed=False,
issues=["Video file not found"],
evaluation_time_ms=(time.time() - start_time) * 1000
)
# For video, we can extract key frames and evaluate
# For now, use VLM with video input or sample frames
if self.config.use_vlm_evaluation and self.llm_service:
score = await self._evaluate_video_with_vlm(video_path, prompt, narration)
else:
score = await self._evaluate_video_basic(video_path)
score.evaluation_time_ms = (time.time() - start_time) * 1000
score.passed = score.overall_score >= self.config.overall_threshold
return score
async def _evaluate_with_vlm(
self,
image_path: str,
prompt: str,
narration: Optional[str] = None,
) -> QualityScore:
"""
Evaluate image quality using Vision Language Model
Uses the LLM with vision capability to assess:
- Visual quality and aesthetics
- Prompt-image alignment
- Technical defects
"""
evaluation_prompt = self._build_evaluation_prompt(prompt, narration)
try:
# Call LLM with image (requires VLM-capable model like GPT-4o, Qwen-VL)
# Note: This requires the LLM service to support vision input
# For now, we'll use a basic score if VLM is not available
# TODO: Implement actual VLM call when integrating with vision-capable LLM
# response = await self.llm_service(
# prompt=evaluation_prompt,
# images=[image_path],
# response_type=ImageQualityResponse
# )
# Fallback to basic evaluation for now
logger.debug("VLM evaluation: using basic fallback (VLM integration pending)")
return await self._evaluate_basic(image_path, prompt)
except Exception as e:
logger.warning(f"VLM evaluation failed: {e}, falling back to basic")
return await self._evaluate_basic(image_path, prompt)
async def _evaluate_basic(
self,
image_path: str,
prompt: str,
) -> QualityScore:
"""
Basic image quality evaluation without VLM
Performs simple checks:
- File size and dimensions
- Image format validation
"""
issues = []
try:
# Import PIL for basic checks
from PIL import Image
with Image.open(image_path) as img:
width, height = img.size
# Check minimum dimensions
if width < 256 or height < 256:
issues.append(f"Image too small: {width}x{height}")
# Check aspect ratio (not too extreme)
aspect = max(width, height) / min(width, height)
if aspect > 4:
issues.append(f"Extreme aspect ratio: {aspect:.1f}")
# Basic scores (generous defaults when VLM not available)
aesthetic_score = 0.7 if not issues else 0.4
text_match_score = 0.7 # Can't properly evaluate without VLM
technical_score = 0.8 if not issues else 0.5
# Calculate overall
overall = (
aesthetic_score * self.config.aesthetic_weight +
text_match_score * self.config.text_match_weight +
technical_score * self.config.technical_weight
)
return QualityScore(
aesthetic_score=aesthetic_score,
text_match_score=text_match_score,
technical_score=technical_score,
overall_score=overall,
issues=issues,
)
except Exception as e:
logger.error(f"Basic evaluation failed: {e}")
return QualityScore(
overall_score=0.0,
passed=False,
issues=[f"Evaluation error: {str(e)}"]
)
async def _evaluate_video_with_vlm(
self,
video_path: str,
prompt: str,
narration: Optional[str] = None,
) -> QualityScore:
"""Evaluate video using VLM (placeholder for future implementation)"""
# TODO: Implement video frame sampling and VLM evaluation
return await self._evaluate_video_basic(video_path)
async def _evaluate_video_basic(
self,
video_path: str,
) -> QualityScore:
"""Basic video quality evaluation"""
issues = []
try:
import subprocess
import json
# Use ffprobe to get video info
cmd = [
"ffprobe", "-v", "quiet", "-print_format", "json",
"-show_format", "-show_streams", video_path
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
issues.append("Failed to read video metadata")
return QualityScore(overall_score=0.5, issues=issues)
info = json.loads(result.stdout)
# Check for video stream
video_stream = None
for stream in info.get("streams", []):
if stream.get("codec_type") == "video":
video_stream = stream
break
if not video_stream:
issues.append("No video stream found")
return QualityScore(overall_score=0.0, passed=False, issues=issues)
# Check dimensions
width = video_stream.get("width", 0)
height = video_stream.get("height", 0)
if width < 256 or height < 256:
issues.append(f"Video too small: {width}x{height}")
# Check duration
duration = float(info.get("format", {}).get("duration", 0))
if duration < 0.5:
issues.append(f"Video too short: {duration:.1f}s")
# Calculate scores
aesthetic_score = 0.7
text_match_score = 0.7
technical_score = 0.8 if not issues else 0.5
overall = (
aesthetic_score * self.config.aesthetic_weight +
text_match_score * self.config.text_match_weight +
technical_score * self.config.technical_weight
)
return QualityScore(
aesthetic_score=aesthetic_score,
text_match_score=text_match_score,
technical_score=technical_score,
overall_score=overall,
issues=issues,
)
except Exception as e:
logger.error(f"Video evaluation failed: {e}")
return QualityScore(
overall_score=0.5,
issues=[f"Evaluation error: {str(e)}"]
)
def _build_evaluation_prompt(
self,
prompt: str,
narration: Optional[str] = None,
) -> str:
"""Build the evaluation prompt for VLM"""
context = f"Narration: {narration}\n" if narration else ""
return f"""Evaluate this AI-generated image on the following criteria.
Rate each from 0.0 to 1.0.
Image Generation Prompt: {prompt}
{context}
Evaluation Criteria:
1. Aesthetic Quality (0.0-1.0):
- Is the image visually appealing?
- Good composition, colors, and style?
2. Prompt Matching (0.0-1.0):
- Does the image accurately represent the prompt?
- Are key elements from the prompt visible?
3. Technical Quality (0.0-1.0):
- Is the image clear and well-defined?
- Free of artifacts, distortions, or blurriness?
- Natural looking (no AI artifacts like extra fingers)?
Respond in JSON format:
{{
"aesthetic_score": 0.0,
"text_match_score": 0.0,
"technical_score": 0.0,
"issues": ["list of any problems found"]
}}
"""

View File

@@ -0,0 +1,296 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
RetryManager - Smart retry logic with quality-based decisions
Provides unified retry management for:
- Media generation (images, videos)
- LLM calls
- Any async operations with quality evaluation
"""
import asyncio
from typing import Callable, TypeVar, Optional, Tuple, Any
from dataclasses import dataclass
from loguru import logger
from pixelle_video.services.quality.models import (
QualityScore,
QualityConfig,
RetryConfig,
QualityError,
)
T = TypeVar("T")
@dataclass
class RetryResult:
"""Result of a retry operation"""
success: bool
result: Any
attempts: int
quality_score: Optional[QualityScore] = None
error: Optional[Exception] = None
class RetryManager:
"""
Smart retry manager with quality-based decisions
Features:
- Exponential backoff with configurable delays
- Quality-aware retry (retries when quality below threshold)
- Fallback strategies
- Detailed logging and metrics
Example:
>>> retry_manager = RetryManager()
>>>
>>> async def generate():
... return await media_service.generate_image(prompt)
>>>
>>> async def evaluate(image_path):
... return await quality_gate.evaluate_image(image_path, prompt)
>>>
>>> result = await retry_manager.execute_with_retry(
... operation=generate,
... quality_evaluator=evaluate,
... config=RetryConfig(max_retries=3)
... )
"""
def __init__(self, config: Optional[RetryConfig] = None):
"""
Initialize RetryManager
Args:
config: Default retry configuration
"""
self.default_config = config or RetryConfig()
async def execute_with_retry(
self,
operation: Callable[[], Any],
quality_evaluator: Optional[Callable[[Any], QualityScore]] = None,
config: Optional[RetryConfig] = None,
fallback_operation: Optional[Callable[[], Any]] = None,
operation_name: str = "operation",
) -> RetryResult:
"""
Execute operation with automatic retry and quality evaluation
Args:
operation: Async callable to execute
quality_evaluator: Optional async callable to evaluate result quality
config: Retry configuration (uses default if not provided)
fallback_operation: Optional fallback to try when all retries fail
operation_name: Name for logging purposes
Returns:
RetryResult with success status, result, and quality score
Raises:
QualityError: When all retries fail and no fallback available
"""
cfg = config or self.default_config
last_error: Optional[Exception] = None
last_score: Optional[QualityScore] = None
for attempt in range(1, cfg.max_retries + 1):
try:
# Execute the operation
logger.debug(f"{operation_name}: attempt {attempt}/{cfg.max_retries}")
result = await operation()
# If no quality evaluator, accept the result
if quality_evaluator is None:
logger.debug(f"{operation_name}: completed (no quality check)")
return RetryResult(
success=True,
result=result,
attempts=attempt,
)
# Evaluate quality
score = await quality_evaluator(result)
last_score = score
if score.passed:
logger.info(
f"{operation_name}: passed quality check "
f"(score={score.overall_score:.2f}, attempt={attempt})"
)
return RetryResult(
success=True,
result=result,
attempts=attempt,
quality_score=score,
)
# Quality check failed
logger.warning(
f"{operation_name}: quality check failed "
f"(score={score.overall_score:.2f}, threshold={cfg.quality_threshold}, "
f"issues={score.issues})"
)
if not cfg.retry_on_quality_fail:
# Don't retry on quality failure
return RetryResult(
success=False,
result=result,
attempts=attempt,
quality_score=score,
)
except Exception as e:
last_error = e
logger.warning(f"{operation_name}: attempt {attempt} failed with error: {e}")
if not cfg.retry_on_error:
raise
# Calculate backoff delay
if attempt < cfg.max_retries:
delay_ms = min(
cfg.initial_delay_ms * (cfg.backoff_factor ** (attempt - 1)),
cfg.max_delay_ms
)
logger.debug(f"{operation_name}: waiting {delay_ms:.0f}ms before retry")
await asyncio.sleep(delay_ms / 1000)
# All retries exhausted
logger.warning(f"{operation_name}: all {cfg.max_retries} attempts failed")
# Try fallback if available
if cfg.enable_fallback and fallback_operation:
try:
logger.info(f"{operation_name}: trying fallback strategy")
result = await fallback_operation()
# Evaluate fallback result if evaluator available
if quality_evaluator:
score = await quality_evaluator(result)
return RetryResult(
success=score.passed if score else True,
result=result,
attempts=cfg.max_retries + 1,
quality_score=score,
)
return RetryResult(
success=True,
result=result,
attempts=cfg.max_retries + 1,
)
except Exception as e:
logger.error(f"{operation_name}: fallback also failed: {e}")
last_error = e
# Nothing worked
if last_error:
raise QualityError(
f"{operation_name} failed after {cfg.max_retries} attempts: {last_error}",
attempts=cfg.max_retries,
last_score=last_score,
)
raise QualityError(
f"{operation_name} failed to meet quality threshold after {cfg.max_retries} attempts",
attempts=cfg.max_retries,
last_score=last_score,
)
async def execute_simple(
self,
operation: Callable[[], Any],
config: Optional[RetryConfig] = None,
operation_name: str = "operation",
) -> Any:
"""
Execute operation with simple retry (no quality evaluation)
Args:
operation: Async callable to execute
config: Retry configuration
operation_name: Name for logging
Returns:
Operation result
Raises:
Exception: Last exception if all retries fail
"""
cfg = config or self.default_config
last_error: Optional[Exception] = None
for attempt in range(1, cfg.max_retries + 1):
try:
logger.debug(f"{operation_name}: attempt {attempt}/{cfg.max_retries}")
result = await operation()
logger.debug(f"{operation_name}: completed on attempt {attempt}")
return result
except Exception as e:
last_error = e
logger.warning(f"{operation_name}: attempt {attempt} failed: {e}")
if attempt < cfg.max_retries:
delay_ms = min(
cfg.initial_delay_ms * (cfg.backoff_factor ** (attempt - 1)),
cfg.max_delay_ms
)
await asyncio.sleep(delay_ms / 1000)
raise last_error
@staticmethod
def create_prompt_simplifier(original_prompt: str) -> Callable[[], str]:
"""
Create a fallback that simplifies the prompt
Args:
original_prompt: Original complex prompt
Returns:
Callable that returns a simplified prompt
"""
def simplify() -> str:
# Simple prompt simplification strategies
simplified = original_prompt
# Remove complex modifiers
remove_phrases = [
"highly detailed",
"ultra realistic",
"8k resolution",
"masterpiece",
"best quality",
]
for phrase in remove_phrases:
simplified = simplified.replace(phrase, "").replace(phrase.lower(), "")
# Truncate if too long
if len(simplified) > 150:
simplified = simplified[:150].rsplit(" ", 1)[0]
# Clean up
simplified = " ".join(simplified.split())
return simplified
return simplify

View File

@@ -0,0 +1,276 @@
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
StyleGuard - Visual style consistency engine
Ensures consistent visual style across all frames in a video by:
1. Extracting style anchor from the first generated frame
2. Injecting style constraints into subsequent frame prompts
3. (Optional) Using style reference techniques like IP-Adapter
"""
from dataclasses import dataclass, field
from typing import List, Optional
from loguru import logger
@dataclass
class StyleAnchor:
"""Style anchor extracted from reference frame"""
# Core style elements
color_palette: str = "" # e.g., "warm earth tones", "cool blues"
art_style: str = "" # e.g., "minimalist", "realistic", "anime"
composition_style: str = "" # e.g., "centered", "rule of thirds"
texture: str = "" # e.g., "smooth", "grainy", "watercolor"
lighting: str = "" # e.g., "soft ambient", "dramatic shadows"
# Combined style prefix for prompts
style_prefix: str = ""
# Reference image path (for IP-Adapter style techniques)
reference_image: Optional[str] = None
def to_prompt_prefix(self) -> str:
"""Generate a style prefix for image prompts"""
if self.style_prefix:
return self.style_prefix
elements = []
if self.art_style:
elements.append(f"{self.art_style} style")
if self.color_palette:
elements.append(f"{self.color_palette}")
if self.lighting:
elements.append(f"{self.lighting} lighting")
if self.texture:
elements.append(f"{self.texture} texture")
return ", ".join(elements) if elements else ""
def to_dict(self) -> dict:
return {
"color_palette": self.color_palette,
"art_style": self.art_style,
"composition_style": self.composition_style,
"texture": self.texture,
"lighting": self.lighting,
"style_prefix": self.style_prefix,
"reference_image": self.reference_image,
}
@dataclass
class StyleGuardConfig:
"""Configuration for StyleGuard"""
# Extraction settings
extract_from_first_frame: bool = True
use_vlm_extraction: bool = True
# Application settings
apply_to_all_frames: bool = True
prefix_position: str = "start" # "start" or "end"
# Optional external style reference
external_style_image: Optional[str] = None
custom_style_prefix: Optional[str] = None
class StyleGuard:
"""
Style consistency guardian for video generation
Ensures all frames in a video maintain visual consistency by:
1. Analyzing the first frame (or reference image) to extract style
2. Applying style constraints to all subsequent frame prompts
Example:
>>> style_guard = StyleGuard(llm_service)
>>>
>>> # Extract style from first frame
>>> anchor = await style_guard.extract_style_anchor(
... image_path="output/frame_001.png"
... )
>>>
>>> # Apply to subsequent prompts
>>> styled_prompt = style_guard.apply_style(
... prompt="A cat sitting on a windowsill",
... style_anchor=anchor
... )
"""
def __init__(
self,
llm_service=None,
config: Optional[StyleGuardConfig] = None
):
"""
Initialize StyleGuard
Args:
llm_service: LLM service for VLM-based style extraction
config: StyleGuard configuration
"""
self.llm_service = llm_service
self.config = config or StyleGuardConfig()
self._current_anchor: Optional[StyleAnchor] = None
async def extract_style_anchor(
self,
image_path: str,
) -> StyleAnchor:
"""
Extract style anchor from reference image
Args:
image_path: Path to reference image
Returns:
StyleAnchor with extracted style elements
"""
logger.info(f"Extracting style anchor from: {image_path}")
if self.config.custom_style_prefix:
# Use custom style prefix if provided
anchor = StyleAnchor(
style_prefix=self.config.custom_style_prefix,
reference_image=image_path
)
self._current_anchor = anchor
return anchor
if self.config.use_vlm_extraction and self.llm_service:
anchor = await self._extract_with_vlm(image_path)
else:
anchor = self._extract_basic(image_path)
self._current_anchor = anchor
logger.info(f"Style anchor extracted: {anchor.to_prompt_prefix()}")
return anchor
async def _extract_with_vlm(self, image_path: str) -> StyleAnchor:
"""Extract style using Vision Language Model"""
try:
# TODO: Implement VLM call when vision-capable LLM is integrated
# For now, return a placeholder
logger.debug("VLM style extraction: using placeholder (VLM not yet integrated)")
# Placeholder extraction based on common styles
return StyleAnchor(
art_style="consistent artistic",
color_palette="harmonious colors",
lighting="balanced",
style_prefix="maintaining visual consistency, same artistic style as previous frames",
reference_image=image_path,
)
except Exception as e:
logger.warning(f"VLM style extraction failed: {e}")
return self._extract_basic(image_path)
def _extract_basic(self, image_path: str) -> StyleAnchor:
"""Basic style extraction without VLM"""
# Return generic style anchor
return StyleAnchor(
style_prefix="consistent visual style",
reference_image=image_path,
)
def apply_style(
self,
prompt: str,
style_anchor: Optional[StyleAnchor] = None,
) -> str:
"""
Apply style constraints to an image prompt
Args:
prompt: Original image prompt
style_anchor: Style anchor to apply (uses current if not provided)
Returns:
Modified prompt with style constraints
"""
anchor = style_anchor or self._current_anchor
if not anchor:
return prompt
style_prefix = anchor.to_prompt_prefix()
if not style_prefix:
return prompt
if self.config.prefix_position == "start":
return f"{style_prefix}, {prompt}"
else:
return f"{prompt}, {style_prefix}"
def apply_style_to_batch(
self,
prompts: List[str],
style_anchor: Optional[StyleAnchor] = None,
skip_first: bool = True,
) -> List[str]:
"""
Apply style constraints to a batch of prompts
Args:
prompts: List of image prompts
style_anchor: Style anchor to apply
skip_first: Skip first prompt (used as reference)
Returns:
List of styled prompts
"""
if not prompts:
return prompts
anchor = style_anchor or self._current_anchor
if not anchor:
return prompts
result = []
for i, prompt in enumerate(prompts):
if skip_first and i == 0:
result.append(prompt)
else:
result.append(self.apply_style(prompt, anchor))
return result
def get_consistency_prompt_suffix(self) -> str:
"""
Get a consistency prompt suffix for LLM prompt generation
This can be added to the LLM prompt when generating image prompts
to encourage consistent style descriptions.
"""
return (
"Ensure all image prompts maintain consistent visual style, "
"including similar color palette, art style, lighting, and composition. "
"Each image should feel like it belongs to the same visual narrative."
)
@property
def current_anchor(self) -> Optional[StyleAnchor]:
"""Get the current style anchor"""
return self._current_anchor
def reset(self):
"""Reset the current style anchor"""
self._current_anchor = None