feat: Add comprehensive timeline editor with frame editing and regeneration capabilities
This commit is contained in:
84
pixelle_video/services/quality/__init__.py
Normal file
84
pixelle_video/services/quality/__init__.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Quality Assurance Services for Pixelle-Video
|
||||
|
||||
This module provides quality control mechanisms for video generation:
|
||||
- QualityGate: Evaluates generated content quality
|
||||
- RetryManager: Smart retry with quality-based decisions
|
||||
- OutputValidator: LLM output validation
|
||||
- StyleGuard: Visual style consistency
|
||||
- ContentFilter: Content moderation
|
||||
- CharacterMemory: Character consistency across frames
|
||||
"""
|
||||
|
||||
from pixelle_video.services.quality.models import (
|
||||
QualityScore,
|
||||
QualityConfig,
|
||||
RetryConfig,
|
||||
QualityError,
|
||||
)
|
||||
from pixelle_video.services.quality.quality_gate import QualityGate
|
||||
from pixelle_video.services.quality.retry_manager import RetryManager
|
||||
from pixelle_video.services.quality.output_validator import (
|
||||
OutputValidator,
|
||||
ValidationConfig,
|
||||
ValidationResult,
|
||||
)
|
||||
from pixelle_video.services.quality.style_guard import (
|
||||
StyleGuard,
|
||||
StyleGuardConfig,
|
||||
StyleAnchor,
|
||||
)
|
||||
from pixelle_video.services.quality.content_filter import (
|
||||
ContentFilter,
|
||||
ContentFilterConfig,
|
||||
FilterResult,
|
||||
FilterCategory,
|
||||
)
|
||||
from pixelle_video.services.quality.character_memory import (
|
||||
CharacterMemory,
|
||||
CharacterMemoryConfig,
|
||||
Character,
|
||||
CharacterType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Quality evaluation
|
||||
"QualityScore",
|
||||
"QualityConfig",
|
||||
"RetryConfig",
|
||||
"QualityError",
|
||||
"QualityGate",
|
||||
"RetryManager",
|
||||
# Output validation
|
||||
"OutputValidator",
|
||||
"ValidationConfig",
|
||||
"ValidationResult",
|
||||
# Style consistency
|
||||
"StyleGuard",
|
||||
"StyleGuardConfig",
|
||||
"StyleAnchor",
|
||||
# Content moderation
|
||||
"ContentFilter",
|
||||
"ContentFilterConfig",
|
||||
"FilterResult",
|
||||
"FilterCategory",
|
||||
# Character memory
|
||||
"CharacterMemory",
|
||||
"CharacterMemoryConfig",
|
||||
"Character",
|
||||
"CharacterType",
|
||||
]
|
||||
|
||||
|
||||
530
pixelle_video/services/quality/character_memory.py
Normal file
530
pixelle_video/services/quality/character_memory.py
Normal file
@@ -0,0 +1,530 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
CharacterMemory - Character consistency and memory system
|
||||
|
||||
Maintains consistent character appearance across video frames by:
|
||||
1. Detecting and registering characters from narrations
|
||||
2. Extracting visual descriptions from first appearances
|
||||
3. Injecting character consistency prompts into subsequent frames
|
||||
4. Supporting reference images for ComfyUI IP-Adapter/ControlNet
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Optional, Set
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class CharacterType(Enum):
|
||||
"""Type of character"""
|
||||
PERSON = "person" # Human character
|
||||
ANIMAL = "animal" # Animal character
|
||||
CREATURE = "creature" # Fantasy/fictional creature
|
||||
OBJECT = "object" # Personified object
|
||||
ABSTRACT = "abstract" # Abstract entity
|
||||
|
||||
|
||||
@dataclass
|
||||
class Character:
|
||||
"""
|
||||
Represents a character in the video narrative
|
||||
|
||||
Stores visual description, reference images, and appearance history
|
||||
to maintain consistency across frames.
|
||||
"""
|
||||
|
||||
# Identity
|
||||
id: str # Unique identifier
|
||||
name: str # Character name (e.g., "小明", "the hero")
|
||||
aliases: List[str] = field(default_factory=list) # Alternative names
|
||||
character_type: CharacterType = CharacterType.PERSON
|
||||
|
||||
# Visual description (for prompt injection)
|
||||
appearance_description: str = "" # Detailed visual description
|
||||
clothing_description: str = "" # Clothing/outfit description
|
||||
distinctive_features: List[str] = field(default_factory=list) # Unique features
|
||||
|
||||
# Reference images (for IP-Adapter/ControlNet)
|
||||
reference_images: List[str] = field(default_factory=list) # Paths to reference images
|
||||
primary_reference: Optional[str] = None # Primary reference image
|
||||
|
||||
# Prompt elements
|
||||
prompt_prefix: str = "" # Pre-built prompt prefix
|
||||
negative_prompt: str = "" # Negative prompt additions
|
||||
|
||||
# Metadata
|
||||
is_active: bool = True # Whether this character is active for logic
|
||||
first_appearance_frame: int = 0 # Frame index of first appearance
|
||||
appearance_frames: List[int] = field(default_factory=list) # All frames with this character
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.now()
|
||||
if not hasattr(self, 'is_active'):
|
||||
self.is_active = True
|
||||
if not self.prompt_prefix:
|
||||
self._build_prompt_prefix()
|
||||
|
||||
def _build_prompt_prefix(self):
|
||||
"""Build prompt prefix from visual descriptions"""
|
||||
elements = []
|
||||
|
||||
if self.appearance_description:
|
||||
elements.append(self.appearance_description)
|
||||
if self.clothing_description:
|
||||
elements.append(f"wearing {self.clothing_description}")
|
||||
if self.distinctive_features:
|
||||
elements.append(", ".join(self.distinctive_features))
|
||||
|
||||
self.prompt_prefix = ", ".join(elements) if elements else ""
|
||||
|
||||
def get_prompt_injection(self) -> str:
|
||||
"""Get the prompt text to inject for this character"""
|
||||
if self.prompt_prefix:
|
||||
return f"({self.name}: {self.prompt_prefix})"
|
||||
return f"({self.name})"
|
||||
|
||||
def add_reference_image(self, image_path: str, set_as_primary: bool = False):
|
||||
"""Add a reference image for this character"""
|
||||
if image_path not in self.reference_images:
|
||||
self.reference_images.append(image_path)
|
||||
if set_as_primary or self.primary_reference is None:
|
||||
self.primary_reference = image_path
|
||||
|
||||
def matches_name(self, name: str) -> bool:
|
||||
"""Check if a name matches this character"""
|
||||
name_lower = name.lower().strip()
|
||||
if self.name.lower() == name_lower:
|
||||
return True
|
||||
return any(alias.lower() == name_lower for alias in self.aliases)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"aliases": self.aliases,
|
||||
"type": self.character_type.value,
|
||||
"appearance_description": self.appearance_description,
|
||||
"clothing_description": self.clothing_description,
|
||||
"distinctive_features": self.distinctive_features,
|
||||
"reference_images": self.reference_images,
|
||||
"primary_reference": self.primary_reference,
|
||||
"prompt_prefix": self.prompt_prefix,
|
||||
"first_appearance_frame": self.first_appearance_frame,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CharacterMemoryConfig:
|
||||
"""Configuration for character memory system"""
|
||||
|
||||
# Detection settings
|
||||
auto_detect_characters: bool = True # Automatically detect characters from narrations
|
||||
use_llm_detection: bool = True # Use LLM to extract character info
|
||||
|
||||
# Consistency settings
|
||||
inject_character_prompts: bool = True # Inject character descriptions into prompts
|
||||
use_reference_images: bool = True # Use reference images for generation
|
||||
|
||||
# Reference image settings
|
||||
extract_reference_from_first: bool = True # Extract reference from first appearance
|
||||
max_reference_images: int = 3 # Max reference images per character
|
||||
|
||||
# Prompt injection settings
|
||||
prompt_injection_position: str = "start" # "start" or "end"
|
||||
include_clothing: bool = True # Include clothing in prompts
|
||||
include_features: bool = True # Include distinctive features
|
||||
|
||||
|
||||
class CharacterMemory:
|
||||
"""
|
||||
Character memory and consistency manager
|
||||
|
||||
Tracks characters across video frames and ensures visual consistency
|
||||
by injecting character descriptions and reference images into the
|
||||
generation pipeline.
|
||||
|
||||
Example:
|
||||
>>> memory = CharacterMemory(llm_service)
|
||||
>>>
|
||||
>>> # Register a character
|
||||
>>> char = memory.register_character(
|
||||
... name="小明",
|
||||
... appearance_description="young man with short black hair",
|
||||
... clothing_description="blue t-shirt"
|
||||
... )
|
||||
>>>
|
||||
>>> # Apply to prompt
|
||||
>>> enhanced_prompt = memory.apply_to_prompt(
|
||||
... prompt="A person walking in the park",
|
||||
... characters=["小明"]
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_service=None,
|
||||
config: Optional[CharacterMemoryConfig] = None
|
||||
):
|
||||
"""
|
||||
Initialize CharacterMemory
|
||||
|
||||
Args:
|
||||
llm_service: Optional LLM service for character detection
|
||||
config: Character memory configuration
|
||||
"""
|
||||
self.llm_service = llm_service
|
||||
self.config = config or CharacterMemoryConfig()
|
||||
self._characters: Dict[str, Character] = {}
|
||||
self._name_index: Dict[str, str] = {} # name -> character_id mapping
|
||||
|
||||
def register_character(
|
||||
self,
|
||||
name: str,
|
||||
appearance_description: str = "",
|
||||
clothing_description: str = "",
|
||||
distinctive_features: Optional[List[str]] = None,
|
||||
character_type: CharacterType = CharacterType.PERSON,
|
||||
first_frame: int = 0,
|
||||
) -> Character:
|
||||
"""
|
||||
Register a new character
|
||||
|
||||
Args:
|
||||
name: Character name
|
||||
appearance_description: Visual appearance description
|
||||
clothing_description: Clothing/outfit description
|
||||
distinctive_features: List of distinctive features
|
||||
character_type: Type of character
|
||||
first_frame: Frame index of first appearance
|
||||
|
||||
Returns:
|
||||
Created Character object
|
||||
"""
|
||||
# Generate unique ID
|
||||
char_id = f"char_{len(self._characters)}_{name.replace(' ', '_').lower()}"
|
||||
|
||||
character = Character(
|
||||
id=char_id,
|
||||
name=name,
|
||||
appearance_description=appearance_description,
|
||||
clothing_description=clothing_description,
|
||||
distinctive_features=distinctive_features or [],
|
||||
character_type=character_type,
|
||||
first_appearance_frame=first_frame,
|
||||
appearance_frames=[first_frame],
|
||||
)
|
||||
|
||||
self._characters[char_id] = character
|
||||
self._name_index[name.lower()] = char_id
|
||||
|
||||
logger.info(f"Registered character: {name} (id={char_id})")
|
||||
|
||||
return character
|
||||
|
||||
def get_character(self, name: str) -> Optional[Character]:
|
||||
"""Get a character by name"""
|
||||
name_lower = name.lower().strip()
|
||||
char_id = self._name_index.get(name_lower)
|
||||
if char_id:
|
||||
return self._characters.get(char_id)
|
||||
|
||||
# Search aliases
|
||||
for char in self._characters.values():
|
||||
if char.matches_name(name):
|
||||
return char
|
||||
|
||||
return None
|
||||
|
||||
def get_character_by_id(self, char_id: str) -> Optional[Character]:
|
||||
"""Get a character by ID"""
|
||||
return self._characters.get(char_id)
|
||||
|
||||
@property
|
||||
def characters(self) -> List[Character]:
|
||||
"""Get all registered characters"""
|
||||
return list(self._characters.values())
|
||||
|
||||
async def detect_characters_from_narration(
|
||||
self,
|
||||
narration: str,
|
||||
frame_index: int = 0,
|
||||
) -> List[Character]:
|
||||
"""
|
||||
Detect and register characters mentioned in narration
|
||||
|
||||
Args:
|
||||
narration: Narration text to analyze
|
||||
frame_index: Current frame index
|
||||
|
||||
Returns:
|
||||
List of detected/registered characters
|
||||
"""
|
||||
if not self.config.auto_detect_characters:
|
||||
return []
|
||||
|
||||
detected = []
|
||||
|
||||
if self.config.use_llm_detection and self.llm_service:
|
||||
detected = await self._detect_with_llm(narration, frame_index)
|
||||
else:
|
||||
detected = self._detect_basic(narration, frame_index)
|
||||
|
||||
return detected
|
||||
|
||||
async def _detect_with_llm(
|
||||
self,
|
||||
narration: str,
|
||||
frame_index: int,
|
||||
) -> List[Character]:
|
||||
"""Detect characters using LLM"""
|
||||
if not self.llm_service:
|
||||
return []
|
||||
|
||||
try:
|
||||
prompt = f"""分析以下文案,提取其中提到的角色/人物。
|
||||
|
||||
文案: {narration}
|
||||
|
||||
请用 JSON 格式返回角色列表,每个角色包含:
|
||||
- name: 角色名称或代称
|
||||
- type: person/animal/creature/object
|
||||
- appearance: 外貌描述(如有)
|
||||
- clothing: 服装描述(如有)
|
||||
|
||||
如果没有明确角色,返回空列表 []。
|
||||
|
||||
只返回 JSON,不要其他解释。"""
|
||||
|
||||
response = await self.llm_service(prompt, temperature=0.1)
|
||||
|
||||
# Parse response
|
||||
import json
|
||||
import re
|
||||
|
||||
# Extract JSON from response
|
||||
json_match = re.search(r'\[.*\]', response, re.DOTALL)
|
||||
if json_match:
|
||||
characters_data = json.loads(json_match.group())
|
||||
|
||||
result = []
|
||||
for char_data in characters_data:
|
||||
name = char_data.get("name", "").strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
# Check if already registered
|
||||
existing = self.get_character(name)
|
||||
if existing:
|
||||
existing.appearance_frames.append(frame_index)
|
||||
result.append(existing)
|
||||
else:
|
||||
# Register new character
|
||||
char_type = CharacterType.PERSON
|
||||
type_str = char_data.get("type", "person").lower()
|
||||
if type_str == "animal":
|
||||
char_type = CharacterType.ANIMAL
|
||||
elif type_str == "creature":
|
||||
char_type = CharacterType.CREATURE
|
||||
|
||||
char = self.register_character(
|
||||
name=name,
|
||||
appearance_description=char_data.get("appearance", ""),
|
||||
clothing_description=char_data.get("clothing", ""),
|
||||
character_type=char_type,
|
||||
first_frame=frame_index,
|
||||
)
|
||||
result.append(char)
|
||||
|
||||
return result
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM character detection failed: {e}")
|
||||
return self._detect_basic(narration, frame_index)
|
||||
|
||||
def _detect_basic(
|
||||
self,
|
||||
narration: str,
|
||||
frame_index: int,
|
||||
) -> List[Character]:
|
||||
"""Basic character detection without LLM"""
|
||||
# Simple pattern matching for common character references
|
||||
import re
|
||||
|
||||
patterns = [
|
||||
r'(?:他|她|它)们?', # Chinese pronouns
|
||||
r'(?:小\w{1,2})', # Names like 小明, 小红
|
||||
r'(?:老\w{1,2})', # Names like 老王, 老李
|
||||
]
|
||||
|
||||
detected = []
|
||||
for pattern in patterns:
|
||||
matches = re.findall(pattern, narration)
|
||||
for match in matches:
|
||||
existing = self.get_character(match)
|
||||
if existing:
|
||||
existing.appearance_frames.append(frame_index)
|
||||
if existing not in detected:
|
||||
detected.append(existing)
|
||||
|
||||
return detected
|
||||
|
||||
def apply_to_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
character_names: Optional[List[str]] = None,
|
||||
frame_index: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Apply character consistency to an image prompt
|
||||
|
||||
Args:
|
||||
prompt: Original image prompt
|
||||
character_names: Specific characters to include (None = auto-detect)
|
||||
frame_index: Current frame index for tracking
|
||||
|
||||
Returns:
|
||||
Enhanced prompt with character consistency
|
||||
"""
|
||||
if not self.config.inject_character_prompts:
|
||||
return prompt
|
||||
|
||||
characters_to_include = []
|
||||
|
||||
if character_names:
|
||||
for name in character_names:
|
||||
char = self.get_character(name)
|
||||
if char:
|
||||
characters_to_include.append(char)
|
||||
else:
|
||||
# Include all characters that have appeared
|
||||
characters_to_include = self.characters
|
||||
|
||||
if not characters_to_include:
|
||||
return prompt
|
||||
|
||||
# Build character injection
|
||||
injections = []
|
||||
for char in characters_to_include:
|
||||
injection = char.get_prompt_injection()
|
||||
if injection:
|
||||
injections.append(injection)
|
||||
|
||||
# Track appearance
|
||||
if frame_index is not None and frame_index not in char.appearance_frames:
|
||||
char.appearance_frames.append(frame_index)
|
||||
|
||||
if not injections:
|
||||
return prompt
|
||||
|
||||
character_prompt = ", ".join(injections)
|
||||
|
||||
if self.config.prompt_injection_position == "start":
|
||||
return f"{character_prompt}, {prompt}"
|
||||
else:
|
||||
return f"{prompt}, {character_prompt}"
|
||||
|
||||
def get_reference_images(
|
||||
self,
|
||||
character_names: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get reference images for specified characters
|
||||
|
||||
Args:
|
||||
character_names: Character names (None = all characters)
|
||||
|
||||
Returns:
|
||||
List of reference image paths
|
||||
"""
|
||||
if not self.config.use_reference_images:
|
||||
return []
|
||||
|
||||
images = []
|
||||
|
||||
if character_names:
|
||||
for name in character_names:
|
||||
char = self.get_character(name)
|
||||
if char and char.primary_reference:
|
||||
images.append(char.primary_reference)
|
||||
else:
|
||||
for char in self.characters:
|
||||
if char.primary_reference:
|
||||
images.append(char.primary_reference)
|
||||
|
||||
return images[:self.config.max_reference_images]
|
||||
|
||||
def set_reference_image(
|
||||
self,
|
||||
character_name: str,
|
||||
image_path: str,
|
||||
set_as_primary: bool = True,
|
||||
):
|
||||
"""
|
||||
Set a reference image for a character
|
||||
|
||||
Args:
|
||||
character_name: Character name
|
||||
image_path: Path to reference image
|
||||
set_as_primary: Whether to set as primary reference
|
||||
"""
|
||||
char = self.get_character(character_name)
|
||||
if char:
|
||||
char.add_reference_image(image_path, set_as_primary)
|
||||
logger.debug(f"Set reference image for {character_name}: {image_path}")
|
||||
else:
|
||||
logger.warning(f"Character not found: {character_name}")
|
||||
|
||||
def update_character_appearance(
|
||||
self,
|
||||
character_name: str,
|
||||
appearance_description: Optional[str] = None,
|
||||
clothing_description: Optional[str] = None,
|
||||
distinctive_features: Optional[List[str]] = None,
|
||||
):
|
||||
"""Update a character's visual description"""
|
||||
char = self.get_character(character_name)
|
||||
if char:
|
||||
if appearance_description:
|
||||
char.appearance_description = appearance_description
|
||||
if clothing_description:
|
||||
char.clothing_description = clothing_description
|
||||
if distinctive_features:
|
||||
char.distinctive_features = distinctive_features
|
||||
char._build_prompt_prefix()
|
||||
logger.debug(f"Updated appearance for {character_name}")
|
||||
|
||||
def get_consistency_summary(self) -> str:
|
||||
"""Get a summary of character consistency for logging"""
|
||||
if not self._characters:
|
||||
return "No characters registered"
|
||||
|
||||
lines = [f"Characters ({len(self._characters)}):"]
|
||||
for char in self.characters:
|
||||
lines.append(
|
||||
f" - {char.name}: {len(char.appearance_frames)} appearances, "
|
||||
f"ref_images={len(char.reference_images)}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
def reset(self):
|
||||
"""Clear all character memory"""
|
||||
self._characters.clear()
|
||||
self._name_index.clear()
|
||||
logger.info("Character memory cleared")
|
||||
316
pixelle_video/services/quality/content_filter.py
Normal file
316
pixelle_video/services/quality/content_filter.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
ContentFilter - Content moderation and safety filtering
|
||||
|
||||
Provides content safety checks for:
|
||||
- Text content (narrations, prompts)
|
||||
- Generated images
|
||||
- Generated videos
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Set
|
||||
from enum import Enum
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class FilterCategory(Enum):
|
||||
"""Content filter categories"""
|
||||
SAFE = "safe"
|
||||
SENSITIVE = "sensitive" # May require review
|
||||
BLOCKED = "blocked" # Should not proceed
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterResult:
|
||||
"""Result of content filtering"""
|
||||
category: FilterCategory
|
||||
passed: bool
|
||||
flagged_items: List[str] = field(default_factory=list)
|
||||
reason: Optional[str] = None
|
||||
confidence: float = 1.0
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"category": self.category.value,
|
||||
"passed": self.passed,
|
||||
"flagged_items": self.flagged_items,
|
||||
"reason": self.reason,
|
||||
"confidence": self.confidence,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentFilterConfig:
|
||||
"""Configuration for content filtering"""
|
||||
|
||||
# Text filtering
|
||||
enable_keyword_filter: bool = True
|
||||
enable_llm_filter: bool = False # Use LLM for semantic filtering
|
||||
|
||||
# Custom keywords to block (added to default list)
|
||||
custom_blocked_keywords: List[str] = field(default_factory=list)
|
||||
|
||||
# Sensitivity level: "strict", "moderate", "permissive"
|
||||
sensitivity_level: str = "moderate"
|
||||
|
||||
# Image filtering
|
||||
enable_image_filter: bool = False # Requires external service
|
||||
|
||||
# Action on detection
|
||||
block_on_sensitive: bool = False # Block content marked as sensitive
|
||||
log_filtered_content: bool = True
|
||||
|
||||
|
||||
class ContentFilter:
|
||||
"""
|
||||
Content moderation filter for generated content
|
||||
|
||||
Provides safety filtering for text and media content to prevent
|
||||
inappropriate or harmful content from being generated.
|
||||
|
||||
Example:
|
||||
>>> filter = ContentFilter()
|
||||
>>> result = await filter.check_text("Hello, world!")
|
||||
>>> if result.passed:
|
||||
... print("Content is safe")
|
||||
"""
|
||||
|
||||
# Default blocked keywords (minimal list for demonstration)
|
||||
DEFAULT_BLOCKED_PATTERNS = [
|
||||
r"\b(violence|gore|blood)\b",
|
||||
r"\b(nsfw|explicit|pornographic)\b",
|
||||
r"\b(illegal|drugs|weapons)\b",
|
||||
]
|
||||
|
||||
# Sensitive keywords (may require review)
|
||||
DEFAULT_SENSITIVE_PATTERNS = [
|
||||
r"\b(death|dying|kill)\b",
|
||||
r"\b(hate|racist|sexist)\b",
|
||||
r"\b(controversial|political)\b",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_service=None,
|
||||
config: Optional[ContentFilterConfig] = None
|
||||
):
|
||||
"""
|
||||
Initialize ContentFilter
|
||||
|
||||
Args:
|
||||
llm_service: Optional LLM service for semantic filtering
|
||||
config: Filter configuration
|
||||
"""
|
||||
self.llm_service = llm_service
|
||||
self.config = config or ContentFilterConfig()
|
||||
|
||||
# Compile patterns
|
||||
self._blocked_patterns = [
|
||||
re.compile(p, re.IGNORECASE)
|
||||
for p in self.DEFAULT_BLOCKED_PATTERNS
|
||||
]
|
||||
self._sensitive_patterns = [
|
||||
re.compile(p, re.IGNORECASE)
|
||||
for p in self.DEFAULT_SENSITIVE_PATTERNS
|
||||
]
|
||||
|
||||
# Add custom keywords
|
||||
if self.config.custom_blocked_keywords:
|
||||
for keyword in self.config.custom_blocked_keywords:
|
||||
pattern = re.compile(rf"\b{re.escape(keyword)}\b", re.IGNORECASE)
|
||||
self._blocked_patterns.append(pattern)
|
||||
|
||||
async def check_text(self, text: str) -> FilterResult:
|
||||
"""
|
||||
Check text content for safety
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
FilterResult with safety assessment
|
||||
"""
|
||||
if not text:
|
||||
return FilterResult(
|
||||
category=FilterCategory.SAFE,
|
||||
passed=True
|
||||
)
|
||||
|
||||
flagged_items = []
|
||||
category = FilterCategory.SAFE
|
||||
|
||||
# Keyword filtering
|
||||
if self.config.enable_keyword_filter:
|
||||
# Check blocked patterns
|
||||
for pattern in self._blocked_patterns:
|
||||
matches = pattern.findall(text)
|
||||
if matches:
|
||||
flagged_items.extend(matches)
|
||||
category = FilterCategory.BLOCKED
|
||||
|
||||
# Check sensitive patterns (if not already blocked)
|
||||
if category != FilterCategory.BLOCKED:
|
||||
for pattern in self._sensitive_patterns:
|
||||
matches = pattern.findall(text)
|
||||
if matches:
|
||||
flagged_items.extend(matches)
|
||||
category = FilterCategory.SENSITIVE
|
||||
|
||||
# LLM-based semantic filtering
|
||||
if self.config.enable_llm_filter and self.llm_service:
|
||||
semantic_result = await self._check_with_llm(text)
|
||||
if semantic_result.category.value > category.value:
|
||||
category = semantic_result.category
|
||||
flagged_items.extend(semantic_result.flagged_items)
|
||||
|
||||
# Determine if passed
|
||||
if category == FilterCategory.BLOCKED:
|
||||
passed = False
|
||||
reason = "Content contains blocked keywords or themes"
|
||||
elif category == FilterCategory.SENSITIVE and self.config.block_on_sensitive:
|
||||
passed = False
|
||||
reason = "Content contains sensitive themes (strict mode)"
|
||||
else:
|
||||
passed = True
|
||||
reason = None
|
||||
|
||||
# Log if configured
|
||||
if self.config.log_filtered_content and flagged_items:
|
||||
logger.warning(f"Content filter flagged: {flagged_items}")
|
||||
|
||||
return FilterResult(
|
||||
category=category,
|
||||
passed=passed,
|
||||
flagged_items=flagged_items,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
async def check_texts(self, texts: List[str]) -> FilterResult:
|
||||
"""
|
||||
Check multiple texts and return aggregate result
|
||||
|
||||
Args:
|
||||
texts: List of texts to check
|
||||
|
||||
Returns:
|
||||
Aggregate FilterResult
|
||||
"""
|
||||
all_flagged = []
|
||||
worst_category = FilterCategory.SAFE
|
||||
|
||||
for text in texts:
|
||||
result = await self.check_text(text)
|
||||
all_flagged.extend(result.flagged_items)
|
||||
|
||||
if result.category == FilterCategory.BLOCKED:
|
||||
worst_category = FilterCategory.BLOCKED
|
||||
elif result.category == FilterCategory.SENSITIVE and worst_category != FilterCategory.BLOCKED:
|
||||
worst_category = FilterCategory.SENSITIVE
|
||||
|
||||
passed = worst_category == FilterCategory.SAFE or (
|
||||
worst_category == FilterCategory.SENSITIVE and not self.config.block_on_sensitive
|
||||
)
|
||||
|
||||
return FilterResult(
|
||||
category=worst_category,
|
||||
passed=passed,
|
||||
flagged_items=list(set(all_flagged)), # Deduplicate
|
||||
reason=f"Found {len(all_flagged)} flagged items" if all_flagged else None,
|
||||
)
|
||||
|
||||
async def check_image(self, image_path: str) -> FilterResult:
|
||||
"""
|
||||
Check image content for safety
|
||||
|
||||
Note: This requires external NSFW detection service integration.
|
||||
Currently returns safe by default.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
|
||||
Returns:
|
||||
FilterResult with safety assessment
|
||||
"""
|
||||
if not self.config.enable_image_filter:
|
||||
return FilterResult(
|
||||
category=FilterCategory.SAFE,
|
||||
passed=True,
|
||||
reason="Image filtering disabled"
|
||||
)
|
||||
|
||||
# TODO: Integrate with external NSFW detection
|
||||
# Options: TensorFlow NSFW model, Azure Content Moderator, etc.
|
||||
|
||||
logger.debug(f"Image safety check: {image_path} (not implemented, assuming safe)")
|
||||
|
||||
return FilterResult(
|
||||
category=FilterCategory.SAFE,
|
||||
passed=True,
|
||||
reason="Image filtering not implemented"
|
||||
)
|
||||
|
||||
async def _check_with_llm(self, text: str) -> FilterResult:
|
||||
"""Check text using LLM for semantic understanding"""
|
||||
if not self.llm_service:
|
||||
return FilterResult(category=FilterCategory.SAFE, passed=True)
|
||||
|
||||
try:
|
||||
prompt = f"""Analyze the following text for content safety.
|
||||
Rate it as: SAFE, SENSITIVE, or BLOCKED.
|
||||
|
||||
SAFE: No concerning content
|
||||
SENSITIVE: Contains themes that may need review (violence, death, controversial topics)
|
||||
BLOCKED: Contains explicit, illegal, or harmful content
|
||||
|
||||
Text: {text[:500]}
|
||||
|
||||
Respond with only one word: SAFE, SENSITIVE, or BLOCKED."""
|
||||
|
||||
response = await self.llm_service(prompt, temperature=0.0, max_tokens=10)
|
||||
response = response.strip().upper()
|
||||
|
||||
if "BLOCKED" in response:
|
||||
return FilterResult(
|
||||
category=FilterCategory.BLOCKED,
|
||||
passed=False,
|
||||
reason="LLM detected blocked content"
|
||||
)
|
||||
elif "SENSITIVE" in response:
|
||||
return FilterResult(
|
||||
category=FilterCategory.SENSITIVE,
|
||||
passed=not self.config.block_on_sensitive,
|
||||
reason="LLM detected sensitive content"
|
||||
)
|
||||
else:
|
||||
return FilterResult(
|
||||
category=FilterCategory.SAFE,
|
||||
passed=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM content check failed: {e}")
|
||||
return FilterResult(category=FilterCategory.SAFE, passed=True)
|
||||
|
||||
def add_blocked_keyword(self, keyword: str):
|
||||
"""Add a keyword to the blocked list"""
|
||||
pattern = re.compile(rf"\b{re.escape(keyword)}\b", re.IGNORECASE)
|
||||
self._blocked_patterns.append(pattern)
|
||||
|
||||
def add_sensitive_keyword(self, keyword: str):
|
||||
"""Add a keyword to the sensitive list"""
|
||||
pattern = re.compile(rf"\b{re.escape(keyword)}\b", re.IGNORECASE)
|
||||
self._sensitive_patterns.append(pattern)
|
||||
140
pixelle_video/services/quality/models.py
Normal file
140
pixelle_video/services/quality/models.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Data models for quality assurance
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class QualityLevel(Enum):
|
||||
"""Quality level enumeration"""
|
||||
EXCELLENT = "excellent" # >= 0.8
|
||||
GOOD = "good" # >= 0.6
|
||||
ACCEPTABLE = "acceptable" # >= 0.4
|
||||
POOR = "poor" # < 0.4
|
||||
|
||||
|
||||
@dataclass
|
||||
class QualityScore:
|
||||
"""Quality evaluation result for generated content"""
|
||||
|
||||
# Individual scores (0.0 - 1.0)
|
||||
aesthetic_score: float = 0.0 # Visual appeal / beauty
|
||||
text_match_score: float = 0.0 # How well image matches the prompt
|
||||
technical_score: float = 0.0 # Technical quality (clarity, no artifacts)
|
||||
|
||||
# Overall
|
||||
overall_score: float = 0.0 # Weighted average
|
||||
passed: bool = False # Whether it meets threshold
|
||||
|
||||
# Diagnostics
|
||||
issues: List[str] = field(default_factory=list) # Detected problems
|
||||
evaluation_time_ms: float = 0.0 # Time taken for evaluation
|
||||
|
||||
@property
|
||||
def level(self) -> QualityLevel:
|
||||
"""Get quality level based on overall score"""
|
||||
if self.overall_score >= 0.8:
|
||||
return QualityLevel.EXCELLENT
|
||||
elif self.overall_score >= 0.6:
|
||||
return QualityLevel.GOOD
|
||||
elif self.overall_score >= 0.4:
|
||||
return QualityLevel.ACCEPTABLE
|
||||
else:
|
||||
return QualityLevel.POOR
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for serialization"""
|
||||
return {
|
||||
"aesthetic_score": self.aesthetic_score,
|
||||
"text_match_score": self.text_match_score,
|
||||
"technical_score": self.technical_score,
|
||||
"overall_score": self.overall_score,
|
||||
"passed": self.passed,
|
||||
"level": self.level.value,
|
||||
"issues": self.issues,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class QualityConfig:
|
||||
"""Configuration for quality evaluation"""
|
||||
|
||||
# Thresholds (0.0 - 1.0)
|
||||
overall_threshold: float = 0.6 # Minimum overall score to pass
|
||||
aesthetic_threshold: float = 0.5 # Minimum aesthetic score
|
||||
text_match_threshold: float = 0.6 # Minimum text-match score
|
||||
technical_threshold: float = 0.7 # Minimum technical score
|
||||
|
||||
# Weights for overall score calculation
|
||||
aesthetic_weight: float = 0.3
|
||||
text_match_weight: float = 0.4
|
||||
technical_weight: float = 0.3
|
||||
|
||||
# Evaluation settings
|
||||
use_vlm_evaluation: bool = True # Use VLM for evaluation (vs local models)
|
||||
vlm_model: Optional[str] = None # VLM model to use (None = use default LLM)
|
||||
skip_on_static_template: bool = True # Skip image quality check for static templates
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate weights sum to 1.0"""
|
||||
total = self.aesthetic_weight + self.text_match_weight + self.technical_weight
|
||||
if abs(total - 1.0) > 0.01:
|
||||
# Normalize weights
|
||||
self.aesthetic_weight /= total
|
||||
self.text_match_weight /= total
|
||||
self.technical_weight /= total
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""Configuration for retry behavior"""
|
||||
|
||||
max_retries: int = 3 # Maximum retry attempts
|
||||
backoff_factor: float = 1.5 # Exponential backoff multiplier
|
||||
initial_delay_ms: int = 500 # Initial delay before first retry
|
||||
max_delay_ms: int = 10000 # Maximum delay between retries
|
||||
|
||||
# Quality-based retry
|
||||
quality_threshold: float = 0.6 # Quality score threshold for pass
|
||||
|
||||
# Fallback behavior
|
||||
enable_fallback: bool = True # Enable fallback strategy on failure
|
||||
fallback_prompt_simplify: bool = True # Simplify prompt on retry
|
||||
|
||||
# Retry conditions
|
||||
retry_on_quality_fail: bool = True # Retry when quality below threshold
|
||||
retry_on_error: bool = True # Retry on generation errors
|
||||
|
||||
|
||||
class QualityError(Exception):
|
||||
"""Exception raised when quality standards are not met after all retries"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
attempts: int = 0,
|
||||
last_score: Optional[QualityScore] = None
|
||||
):
|
||||
super().__init__(message)
|
||||
self.attempts = attempts
|
||||
self.last_score = last_score
|
||||
|
||||
def __str__(self) -> str:
|
||||
base = super().__str__()
|
||||
if self.last_score:
|
||||
return f"{base} (attempts={self.attempts}, last_score={self.last_score.overall_score:.2f})"
|
||||
return f"{base} (attempts={self.attempts})"
|
||||
336
pixelle_video/services/quality/output_validator.py
Normal file
336
pixelle_video/services/quality/output_validator.py
Normal file
@@ -0,0 +1,336 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OutputValidator - LLM output validation and quality control
|
||||
|
||||
Validates LLM-generated content for:
|
||||
- Narrations: length, relevance, coherence
|
||||
- Image prompts: format, language, prompt-narration alignment
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationConfig:
|
||||
"""Configuration for output validation"""
|
||||
|
||||
# Narration validation
|
||||
min_narration_words: int = 5
|
||||
max_narration_words: int = 50
|
||||
relevance_threshold: float = 0.6
|
||||
coherence_threshold: float = 0.6
|
||||
|
||||
# Image prompt validation
|
||||
min_prompt_words: int = 10
|
||||
max_prompt_words: int = 100
|
||||
require_english_prompts: bool = True
|
||||
prompt_match_threshold: float = 0.5
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of validation"""
|
||||
passed: bool
|
||||
issues: List[str] = field(default_factory=list)
|
||||
score: float = 1.0 # 1.0 = perfect, 0.0 = failed
|
||||
suggestions: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"passed": self.passed,
|
||||
"score": self.score,
|
||||
"issues": self.issues,
|
||||
"suggestions": self.suggestions,
|
||||
}
|
||||
|
||||
|
||||
class OutputValidator:
|
||||
"""
|
||||
Validator for LLM-generated outputs
|
||||
|
||||
Validates narrations and image prompts to ensure they meet quality standards
|
||||
before proceeding with media generation.
|
||||
|
||||
Example:
|
||||
>>> validator = OutputValidator(llm_service)
|
||||
>>> result = await validator.validate_narrations(
|
||||
... narrations=["旁白1", "旁白2"],
|
||||
... topic="人生哲理",
|
||||
... config=ValidationConfig()
|
||||
... )
|
||||
>>> if not result.passed:
|
||||
... print(f"Validation failed: {result.issues}")
|
||||
"""
|
||||
|
||||
def __init__(self, llm_service=None):
|
||||
"""
|
||||
Initialize OutputValidator
|
||||
|
||||
Args:
|
||||
llm_service: Optional LLM service for semantic validation
|
||||
"""
|
||||
self.llm_service = llm_service
|
||||
|
||||
async def validate_narrations(
|
||||
self,
|
||||
narrations: List[str],
|
||||
topic: str,
|
||||
config: Optional[ValidationConfig] = None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate generated narrations
|
||||
|
||||
Checks:
|
||||
1. Length constraints
|
||||
2. Non-empty content
|
||||
3. Topic relevance (if LLM available)
|
||||
4. Coherence between narrations (if LLM available)
|
||||
|
||||
Args:
|
||||
narrations: List of narration texts
|
||||
topic: Original topic/theme
|
||||
config: Validation configuration
|
||||
|
||||
Returns:
|
||||
ValidationResult with pass/fail and issues
|
||||
"""
|
||||
cfg = config or ValidationConfig()
|
||||
issues = []
|
||||
suggestions = []
|
||||
scores = []
|
||||
|
||||
if not narrations:
|
||||
return ValidationResult(
|
||||
passed=False,
|
||||
issues=["No narrations provided"],
|
||||
score=0.0
|
||||
)
|
||||
|
||||
# 1. Length validation
|
||||
for i, narration in enumerate(narrations, 1):
|
||||
word_count = len(narration)
|
||||
|
||||
if not narration.strip():
|
||||
issues.append(f"分镜{i}: 内容为空")
|
||||
scores.append(0.0)
|
||||
continue
|
||||
|
||||
if word_count < cfg.min_narration_words:
|
||||
issues.append(f"分镜{i}: 内容过短 ({word_count}字,最少{cfg.min_narration_words}字)")
|
||||
scores.append(0.5)
|
||||
elif word_count > cfg.max_narration_words:
|
||||
issues.append(f"分镜{i}: 内容过长 ({word_count}字,最多{cfg.max_narration_words}字)")
|
||||
suggestions.append(f"考虑将分镜{i}拆分为多个短句")
|
||||
scores.append(0.7)
|
||||
else:
|
||||
scores.append(1.0)
|
||||
|
||||
# 2. Semantic validation (if LLM available)
|
||||
if self.llm_service:
|
||||
try:
|
||||
relevance = await self._check_relevance(narrations, topic)
|
||||
if relevance < cfg.relevance_threshold:
|
||||
issues.append(f'内容与主题"{topic}"相关性不足 ({relevance:.0%})')
|
||||
suggestions.append("建议重新生成,确保内容紧扣主题")
|
||||
scores.append(relevance)
|
||||
|
||||
coherence = await self._check_coherence(narrations)
|
||||
if coherence < cfg.coherence_threshold:
|
||||
issues.append(f"内容连贯性不足 ({coherence:.0%})")
|
||||
suggestions.append("建议检查段落之间的逻辑衔接")
|
||||
scores.append(coherence)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Semantic validation failed: {e}")
|
||||
# Don't fail on semantic check errors
|
||||
|
||||
# Calculate overall score
|
||||
overall_score = sum(scores) / len(scores) if scores else 0.0
|
||||
passed = len(issues) == 0 or overall_score >= 0.7
|
||||
|
||||
logger.debug(f"Narration validation: score={overall_score:.2f}, issues={len(issues)}")
|
||||
|
||||
return ValidationResult(
|
||||
passed=passed,
|
||||
issues=issues,
|
||||
score=overall_score,
|
||||
suggestions=suggestions,
|
||||
)
|
||||
|
||||
async def validate_image_prompts(
|
||||
self,
|
||||
prompts: List[str],
|
||||
narrations: List[str],
|
||||
config: Optional[ValidationConfig] = None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate generated image prompts
|
||||
|
||||
Checks:
|
||||
1. Length constraints
|
||||
2. Language (should be English)
|
||||
3. Prompt-narration alignment
|
||||
|
||||
Args:
|
||||
prompts: List of image prompts
|
||||
narrations: Corresponding narrations
|
||||
config: Validation configuration
|
||||
|
||||
Returns:
|
||||
ValidationResult with pass/fail and issues
|
||||
"""
|
||||
cfg = config or ValidationConfig()
|
||||
issues = []
|
||||
suggestions = []
|
||||
scores = []
|
||||
|
||||
if not prompts:
|
||||
return ValidationResult(
|
||||
passed=False,
|
||||
issues=["No image prompts provided"],
|
||||
score=0.0
|
||||
)
|
||||
|
||||
if len(prompts) != len(narrations):
|
||||
issues.append(f"提示词数量({len(prompts)})与旁白数量({len(narrations)})不匹配")
|
||||
|
||||
for i, prompt in enumerate(prompts, 1):
|
||||
if not prompt.strip():
|
||||
issues.append(f"提示词{i}: 内容为空")
|
||||
scores.append(0.0)
|
||||
continue
|
||||
|
||||
word_count = len(prompt.split())
|
||||
|
||||
# Length check
|
||||
if word_count < cfg.min_prompt_words:
|
||||
issues.append(f"提示词{i}: 过短 ({word_count}词,最少{cfg.min_prompt_words}词)")
|
||||
scores.append(0.5)
|
||||
elif word_count > cfg.max_prompt_words:
|
||||
issues.append(f"提示词{i}: 过长 ({word_count}词,最多{cfg.max_prompt_words}词)")
|
||||
scores.append(0.8)
|
||||
else:
|
||||
scores.append(1.0)
|
||||
|
||||
# English check
|
||||
if cfg.require_english_prompts:
|
||||
chinese_ratio = self._get_chinese_ratio(prompt)
|
||||
if chinese_ratio > 0.3: # More than 30% Chinese characters
|
||||
issues.append(f"提示词{i}: 应使用英文 (当前含{chinese_ratio:.0%}中文)")
|
||||
suggestions.append(f"将提示词{i}翻译为英文以获得更好的生成效果")
|
||||
scores[-1] *= 0.5
|
||||
|
||||
overall_score = sum(scores) / len(scores) if scores else 0.0
|
||||
passed = len(issues) == 0 or overall_score >= 0.7
|
||||
|
||||
logger.debug(f"Image prompt validation: score={overall_score:.2f}, issues={len(issues)}")
|
||||
|
||||
return ValidationResult(
|
||||
passed=passed,
|
||||
issues=issues,
|
||||
score=overall_score,
|
||||
suggestions=suggestions,
|
||||
)
|
||||
|
||||
async def _check_relevance(
|
||||
self,
|
||||
narrations: List[str],
|
||||
topic: str,
|
||||
) -> float:
|
||||
"""
|
||||
Check relevance of narrations to topic using LLM
|
||||
|
||||
Returns:
|
||||
Relevance score 0.0-1.0
|
||||
"""
|
||||
if not self.llm_service:
|
||||
return 0.8 # Default score when LLM not available
|
||||
|
||||
try:
|
||||
combined_text = "\n".join(narrations[:3]) # Check first 3 for efficiency
|
||||
|
||||
prompt = f"""评估以下内容与主题"{topic}"的相关性。
|
||||
|
||||
内容:
|
||||
{combined_text}
|
||||
|
||||
请用0-100的分数评估相关性,只输出数字。
|
||||
相关性越高,分数越高。"""
|
||||
|
||||
response = await self.llm_service(prompt, temperature=0.1, max_tokens=10)
|
||||
|
||||
# Parse score from response
|
||||
score_match = re.search(r'\d+', response)
|
||||
if score_match:
|
||||
score = int(score_match.group()) / 100
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
return 0.7 # Default if parsing fails
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Relevance check failed: {e}")
|
||||
return 0.7
|
||||
|
||||
async def _check_coherence(
|
||||
self,
|
||||
narrations: List[str],
|
||||
) -> float:
|
||||
"""
|
||||
Check coherence between narrations using LLM
|
||||
|
||||
Returns:
|
||||
Coherence score 0.0-1.0
|
||||
"""
|
||||
if not self.llm_service or len(narrations) < 2:
|
||||
return 0.8 # Default score
|
||||
|
||||
try:
|
||||
numbered = "\n".join(f"{i+1}. {n}" for i, n in enumerate(narrations[:5]))
|
||||
|
||||
prompt = f"""评估以下段落之间的逻辑连贯性。
|
||||
|
||||
{numbered}
|
||||
|
||||
请用0-100的分数评估连贯性,只输出数字。
|
||||
段落之间逻辑顺畅、衔接自然则分数高。"""
|
||||
|
||||
response = await self.llm_service(prompt, temperature=0.1, max_tokens=10)
|
||||
|
||||
score_match = re.search(r'\d+', response)
|
||||
if score_match:
|
||||
score = int(score_match.group()) / 100
|
||||
return min(1.0, max(0.0, score))
|
||||
|
||||
return 0.7
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Coherence check failed: {e}")
|
||||
return 0.7
|
||||
|
||||
def _get_chinese_ratio(self, text: str) -> float:
|
||||
"""Calculate ratio of Chinese characters in text"""
|
||||
if not text:
|
||||
return 0.0
|
||||
|
||||
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||||
total_chars = len(text.replace(' ', ''))
|
||||
|
||||
if total_chars == 0:
|
||||
return 0.0
|
||||
|
||||
return chinese_chars / total_chars
|
||||
363
pixelle_video/services/quality/quality_gate.py
Normal file
363
pixelle_video/services/quality/quality_gate.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
QualityGate - Quality evaluation system for generated content
|
||||
|
||||
Evaluates images and videos based on:
|
||||
- Aesthetic quality (visual appeal)
|
||||
- Text-to-image matching (semantic alignment)
|
||||
- Technical quality (clarity, no artifacts)
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pixelle_video.services.quality.models import QualityScore, QualityConfig
|
||||
|
||||
|
||||
class QualityGate:
|
||||
"""
|
||||
Quality evaluation gate for AI-generated content
|
||||
|
||||
Uses VLM (Vision Language Model) or local models to evaluate:
|
||||
1. Aesthetic quality - Is the image visually appealing?
|
||||
2. Text matching - Does the image match the prompt/narration?
|
||||
3. Technical quality - Is the image clear and free of artifacts?
|
||||
|
||||
Example:
|
||||
>>> gate = QualityGate(llm_service, config)
|
||||
>>> score = await gate.evaluate_image(
|
||||
... image_path="output/frame_001.png",
|
||||
... prompt="A sunset over mountains",
|
||||
... narration="夕阳西下,余晖洒满山间"
|
||||
... )
|
||||
>>> if score.passed:
|
||||
... print("Image quality approved!")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_service=None,
|
||||
config: Optional[QualityConfig] = None
|
||||
):
|
||||
"""
|
||||
Initialize QualityGate
|
||||
|
||||
Args:
|
||||
llm_service: LLM service for VLM-based evaluation
|
||||
config: Quality configuration
|
||||
"""
|
||||
self.llm_service = llm_service
|
||||
self.config = config or QualityConfig()
|
||||
|
||||
async def evaluate_image(
|
||||
self,
|
||||
image_path: str,
|
||||
prompt: str,
|
||||
narration: Optional[str] = None,
|
||||
) -> QualityScore:
|
||||
"""
|
||||
Evaluate the quality of a generated image
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file
|
||||
prompt: The prompt used to generate the image
|
||||
narration: Optional narration text for context
|
||||
|
||||
Returns:
|
||||
QualityScore with evaluation results
|
||||
"""
|
||||
start_time = time.time()
|
||||
issues = []
|
||||
|
||||
# Validate image exists
|
||||
if not Path(image_path).exists():
|
||||
return QualityScore(
|
||||
passed=False,
|
||||
issues=["Image file not found"],
|
||||
evaluation_time_ms=(time.time() - start_time) * 1000
|
||||
)
|
||||
|
||||
# Evaluate using VLM or fallback to basic checks
|
||||
if self.config.use_vlm_evaluation and self.llm_service:
|
||||
score = await self._evaluate_with_vlm(image_path, prompt, narration)
|
||||
else:
|
||||
score = await self._evaluate_basic(image_path, prompt)
|
||||
|
||||
# Set evaluation time
|
||||
score.evaluation_time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Determine if passed
|
||||
score.passed = score.overall_score >= self.config.overall_threshold
|
||||
|
||||
logger.debug(
|
||||
f"Quality evaluation: overall={score.overall_score:.2f}, "
|
||||
f"passed={score.passed}, time={score.evaluation_time_ms:.0f}ms"
|
||||
)
|
||||
|
||||
return score
|
||||
|
||||
async def evaluate_video(
|
||||
self,
|
||||
video_path: str,
|
||||
prompt: str,
|
||||
narration: Optional[str] = None,
|
||||
) -> QualityScore:
|
||||
"""
|
||||
Evaluate the quality of a generated video
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file
|
||||
prompt: The prompt used to generate the video
|
||||
narration: Optional narration text for context
|
||||
|
||||
Returns:
|
||||
QualityScore with evaluation results
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Validate video exists
|
||||
if not Path(video_path).exists():
|
||||
return QualityScore(
|
||||
passed=False,
|
||||
issues=["Video file not found"],
|
||||
evaluation_time_ms=(time.time() - start_time) * 1000
|
||||
)
|
||||
|
||||
# For video, we can extract key frames and evaluate
|
||||
# For now, use VLM with video input or sample frames
|
||||
if self.config.use_vlm_evaluation and self.llm_service:
|
||||
score = await self._evaluate_video_with_vlm(video_path, prompt, narration)
|
||||
else:
|
||||
score = await self._evaluate_video_basic(video_path)
|
||||
|
||||
score.evaluation_time_ms = (time.time() - start_time) * 1000
|
||||
score.passed = score.overall_score >= self.config.overall_threshold
|
||||
|
||||
return score
|
||||
|
||||
async def _evaluate_with_vlm(
|
||||
self,
|
||||
image_path: str,
|
||||
prompt: str,
|
||||
narration: Optional[str] = None,
|
||||
) -> QualityScore:
|
||||
"""
|
||||
Evaluate image quality using Vision Language Model
|
||||
|
||||
Uses the LLM with vision capability to assess:
|
||||
- Visual quality and aesthetics
|
||||
- Prompt-image alignment
|
||||
- Technical defects
|
||||
"""
|
||||
evaluation_prompt = self._build_evaluation_prompt(prompt, narration)
|
||||
|
||||
try:
|
||||
# Call LLM with image (requires VLM-capable model like GPT-4o, Qwen-VL)
|
||||
# Note: This requires the LLM service to support vision input
|
||||
# For now, we'll use a basic score if VLM is not available
|
||||
|
||||
# TODO: Implement actual VLM call when integrating with vision-capable LLM
|
||||
# response = await self.llm_service(
|
||||
# prompt=evaluation_prompt,
|
||||
# images=[image_path],
|
||||
# response_type=ImageQualityResponse
|
||||
# )
|
||||
|
||||
# Fallback to basic evaluation for now
|
||||
logger.debug("VLM evaluation: using basic fallback (VLM integration pending)")
|
||||
return await self._evaluate_basic(image_path, prompt)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"VLM evaluation failed: {e}, falling back to basic")
|
||||
return await self._evaluate_basic(image_path, prompt)
|
||||
|
||||
async def _evaluate_basic(
|
||||
self,
|
||||
image_path: str,
|
||||
prompt: str,
|
||||
) -> QualityScore:
|
||||
"""
|
||||
Basic image quality evaluation without VLM
|
||||
|
||||
Performs simple checks:
|
||||
- File size and dimensions
|
||||
- Image format validation
|
||||
"""
|
||||
issues = []
|
||||
|
||||
try:
|
||||
# Import PIL for basic checks
|
||||
from PIL import Image
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
width, height = img.size
|
||||
|
||||
# Check minimum dimensions
|
||||
if width < 256 or height < 256:
|
||||
issues.append(f"Image too small: {width}x{height}")
|
||||
|
||||
# Check aspect ratio (not too extreme)
|
||||
aspect = max(width, height) / min(width, height)
|
||||
if aspect > 4:
|
||||
issues.append(f"Extreme aspect ratio: {aspect:.1f}")
|
||||
|
||||
# Basic scores (generous defaults when VLM not available)
|
||||
aesthetic_score = 0.7 if not issues else 0.4
|
||||
text_match_score = 0.7 # Can't properly evaluate without VLM
|
||||
technical_score = 0.8 if not issues else 0.5
|
||||
|
||||
# Calculate overall
|
||||
overall = (
|
||||
aesthetic_score * self.config.aesthetic_weight +
|
||||
text_match_score * self.config.text_match_weight +
|
||||
technical_score * self.config.technical_weight
|
||||
)
|
||||
|
||||
return QualityScore(
|
||||
aesthetic_score=aesthetic_score,
|
||||
text_match_score=text_match_score,
|
||||
technical_score=technical_score,
|
||||
overall_score=overall,
|
||||
issues=issues,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Basic evaluation failed: {e}")
|
||||
return QualityScore(
|
||||
overall_score=0.0,
|
||||
passed=False,
|
||||
issues=[f"Evaluation error: {str(e)}"]
|
||||
)
|
||||
|
||||
async def _evaluate_video_with_vlm(
|
||||
self,
|
||||
video_path: str,
|
||||
prompt: str,
|
||||
narration: Optional[str] = None,
|
||||
) -> QualityScore:
|
||||
"""Evaluate video using VLM (placeholder for future implementation)"""
|
||||
# TODO: Implement video frame sampling and VLM evaluation
|
||||
return await self._evaluate_video_basic(video_path)
|
||||
|
||||
async def _evaluate_video_basic(
|
||||
self,
|
||||
video_path: str,
|
||||
) -> QualityScore:
|
||||
"""Basic video quality evaluation"""
|
||||
issues = []
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
# Use ffprobe to get video info
|
||||
cmd = [
|
||||
"ffprobe", "-v", "quiet", "-print_format", "json",
|
||||
"-show_format", "-show_streams", video_path
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
issues.append("Failed to read video metadata")
|
||||
return QualityScore(overall_score=0.5, issues=issues)
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
|
||||
# Check for video stream
|
||||
video_stream = None
|
||||
for stream in info.get("streams", []):
|
||||
if stream.get("codec_type") == "video":
|
||||
video_stream = stream
|
||||
break
|
||||
|
||||
if not video_stream:
|
||||
issues.append("No video stream found")
|
||||
return QualityScore(overall_score=0.0, passed=False, issues=issues)
|
||||
|
||||
# Check dimensions
|
||||
width = video_stream.get("width", 0)
|
||||
height = video_stream.get("height", 0)
|
||||
if width < 256 or height < 256:
|
||||
issues.append(f"Video too small: {width}x{height}")
|
||||
|
||||
# Check duration
|
||||
duration = float(info.get("format", {}).get("duration", 0))
|
||||
if duration < 0.5:
|
||||
issues.append(f"Video too short: {duration:.1f}s")
|
||||
|
||||
# Calculate scores
|
||||
aesthetic_score = 0.7
|
||||
text_match_score = 0.7
|
||||
technical_score = 0.8 if not issues else 0.5
|
||||
|
||||
overall = (
|
||||
aesthetic_score * self.config.aesthetic_weight +
|
||||
text_match_score * self.config.text_match_weight +
|
||||
technical_score * self.config.technical_weight
|
||||
)
|
||||
|
||||
return QualityScore(
|
||||
aesthetic_score=aesthetic_score,
|
||||
text_match_score=text_match_score,
|
||||
technical_score=technical_score,
|
||||
overall_score=overall,
|
||||
issues=issues,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Video evaluation failed: {e}")
|
||||
return QualityScore(
|
||||
overall_score=0.5,
|
||||
issues=[f"Evaluation error: {str(e)}"]
|
||||
)
|
||||
|
||||
def _build_evaluation_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
narration: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Build the evaluation prompt for VLM"""
|
||||
context = f"Narration: {narration}\n" if narration else ""
|
||||
|
||||
return f"""Evaluate this AI-generated image on the following criteria.
|
||||
Rate each from 0.0 to 1.0.
|
||||
|
||||
Image Generation Prompt: {prompt}
|
||||
{context}
|
||||
Evaluation Criteria:
|
||||
|
||||
1. Aesthetic Quality (0.0-1.0):
|
||||
- Is the image visually appealing?
|
||||
- Good composition, colors, and style?
|
||||
|
||||
2. Prompt Matching (0.0-1.0):
|
||||
- Does the image accurately represent the prompt?
|
||||
- Are key elements from the prompt visible?
|
||||
|
||||
3. Technical Quality (0.0-1.0):
|
||||
- Is the image clear and well-defined?
|
||||
- Free of artifacts, distortions, or blurriness?
|
||||
- Natural looking (no AI artifacts like extra fingers)?
|
||||
|
||||
Respond in JSON format:
|
||||
{{
|
||||
"aesthetic_score": 0.0,
|
||||
"text_match_score": 0.0,
|
||||
"technical_score": 0.0,
|
||||
"issues": ["list of any problems found"]
|
||||
}}
|
||||
"""
|
||||
296
pixelle_video/services/quality/retry_manager.py
Normal file
296
pixelle_video/services/quality/retry_manager.py
Normal file
@@ -0,0 +1,296 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
RetryManager - Smart retry logic with quality-based decisions
|
||||
|
||||
Provides unified retry management for:
|
||||
- Media generation (images, videos)
|
||||
- LLM calls
|
||||
- Any async operations with quality evaluation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, TypeVar, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pixelle_video.services.quality.models import (
|
||||
QualityScore,
|
||||
QualityConfig,
|
||||
RetryConfig,
|
||||
QualityError,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryResult:
|
||||
"""Result of a retry operation"""
|
||||
success: bool
|
||||
result: Any
|
||||
attempts: int
|
||||
quality_score: Optional[QualityScore] = None
|
||||
error: Optional[Exception] = None
|
||||
|
||||
|
||||
class RetryManager:
|
||||
"""
|
||||
Smart retry manager with quality-based decisions
|
||||
|
||||
Features:
|
||||
- Exponential backoff with configurable delays
|
||||
- Quality-aware retry (retries when quality below threshold)
|
||||
- Fallback strategies
|
||||
- Detailed logging and metrics
|
||||
|
||||
Example:
|
||||
>>> retry_manager = RetryManager()
|
||||
>>>
|
||||
>>> async def generate():
|
||||
... return await media_service.generate_image(prompt)
|
||||
>>>
|
||||
>>> async def evaluate(image_path):
|
||||
... return await quality_gate.evaluate_image(image_path, prompt)
|
||||
>>>
|
||||
>>> result = await retry_manager.execute_with_retry(
|
||||
... operation=generate,
|
||||
... quality_evaluator=evaluate,
|
||||
... config=RetryConfig(max_retries=3)
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[RetryConfig] = None):
|
||||
"""
|
||||
Initialize RetryManager
|
||||
|
||||
Args:
|
||||
config: Default retry configuration
|
||||
"""
|
||||
self.default_config = config or RetryConfig()
|
||||
|
||||
async def execute_with_retry(
|
||||
self,
|
||||
operation: Callable[[], Any],
|
||||
quality_evaluator: Optional[Callable[[Any], QualityScore]] = None,
|
||||
config: Optional[RetryConfig] = None,
|
||||
fallback_operation: Optional[Callable[[], Any]] = None,
|
||||
operation_name: str = "operation",
|
||||
) -> RetryResult:
|
||||
"""
|
||||
Execute operation with automatic retry and quality evaluation
|
||||
|
||||
Args:
|
||||
operation: Async callable to execute
|
||||
quality_evaluator: Optional async callable to evaluate result quality
|
||||
config: Retry configuration (uses default if not provided)
|
||||
fallback_operation: Optional fallback to try when all retries fail
|
||||
operation_name: Name for logging purposes
|
||||
|
||||
Returns:
|
||||
RetryResult with success status, result, and quality score
|
||||
|
||||
Raises:
|
||||
QualityError: When all retries fail and no fallback available
|
||||
"""
|
||||
cfg = config or self.default_config
|
||||
last_error: Optional[Exception] = None
|
||||
last_score: Optional[QualityScore] = None
|
||||
|
||||
for attempt in range(1, cfg.max_retries + 1):
|
||||
try:
|
||||
# Execute the operation
|
||||
logger.debug(f"{operation_name}: attempt {attempt}/{cfg.max_retries}")
|
||||
result = await operation()
|
||||
|
||||
# If no quality evaluator, accept the result
|
||||
if quality_evaluator is None:
|
||||
logger.debug(f"{operation_name}: completed (no quality check)")
|
||||
return RetryResult(
|
||||
success=True,
|
||||
result=result,
|
||||
attempts=attempt,
|
||||
)
|
||||
|
||||
# Evaluate quality
|
||||
score = await quality_evaluator(result)
|
||||
last_score = score
|
||||
|
||||
if score.passed:
|
||||
logger.info(
|
||||
f"{operation_name}: passed quality check "
|
||||
f"(score={score.overall_score:.2f}, attempt={attempt})"
|
||||
)
|
||||
return RetryResult(
|
||||
success=True,
|
||||
result=result,
|
||||
attempts=attempt,
|
||||
quality_score=score,
|
||||
)
|
||||
|
||||
# Quality check failed
|
||||
logger.warning(
|
||||
f"{operation_name}: quality check failed "
|
||||
f"(score={score.overall_score:.2f}, threshold={cfg.quality_threshold}, "
|
||||
f"issues={score.issues})"
|
||||
)
|
||||
|
||||
if not cfg.retry_on_quality_fail:
|
||||
# Don't retry on quality failure
|
||||
return RetryResult(
|
||||
success=False,
|
||||
result=result,
|
||||
attempts=attempt,
|
||||
quality_score=score,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(f"{operation_name}: attempt {attempt} failed with error: {e}")
|
||||
|
||||
if not cfg.retry_on_error:
|
||||
raise
|
||||
|
||||
# Calculate backoff delay
|
||||
if attempt < cfg.max_retries:
|
||||
delay_ms = min(
|
||||
cfg.initial_delay_ms * (cfg.backoff_factor ** (attempt - 1)),
|
||||
cfg.max_delay_ms
|
||||
)
|
||||
logger.debug(f"{operation_name}: waiting {delay_ms:.0f}ms before retry")
|
||||
await asyncio.sleep(delay_ms / 1000)
|
||||
|
||||
# All retries exhausted
|
||||
logger.warning(f"{operation_name}: all {cfg.max_retries} attempts failed")
|
||||
|
||||
# Try fallback if available
|
||||
if cfg.enable_fallback and fallback_operation:
|
||||
try:
|
||||
logger.info(f"{operation_name}: trying fallback strategy")
|
||||
result = await fallback_operation()
|
||||
|
||||
# Evaluate fallback result if evaluator available
|
||||
if quality_evaluator:
|
||||
score = await quality_evaluator(result)
|
||||
return RetryResult(
|
||||
success=score.passed if score else True,
|
||||
result=result,
|
||||
attempts=cfg.max_retries + 1,
|
||||
quality_score=score,
|
||||
)
|
||||
|
||||
return RetryResult(
|
||||
success=True,
|
||||
result=result,
|
||||
attempts=cfg.max_retries + 1,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{operation_name}: fallback also failed: {e}")
|
||||
last_error = e
|
||||
|
||||
# Nothing worked
|
||||
if last_error:
|
||||
raise QualityError(
|
||||
f"{operation_name} failed after {cfg.max_retries} attempts: {last_error}",
|
||||
attempts=cfg.max_retries,
|
||||
last_score=last_score,
|
||||
)
|
||||
|
||||
raise QualityError(
|
||||
f"{operation_name} failed to meet quality threshold after {cfg.max_retries} attempts",
|
||||
attempts=cfg.max_retries,
|
||||
last_score=last_score,
|
||||
)
|
||||
|
||||
async def execute_simple(
|
||||
self,
|
||||
operation: Callable[[], Any],
|
||||
config: Optional[RetryConfig] = None,
|
||||
operation_name: str = "operation",
|
||||
) -> Any:
|
||||
"""
|
||||
Execute operation with simple retry (no quality evaluation)
|
||||
|
||||
Args:
|
||||
operation: Async callable to execute
|
||||
config: Retry configuration
|
||||
operation_name: Name for logging
|
||||
|
||||
Returns:
|
||||
Operation result
|
||||
|
||||
Raises:
|
||||
Exception: Last exception if all retries fail
|
||||
"""
|
||||
cfg = config or self.default_config
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for attempt in range(1, cfg.max_retries + 1):
|
||||
try:
|
||||
logger.debug(f"{operation_name}: attempt {attempt}/{cfg.max_retries}")
|
||||
result = await operation()
|
||||
logger.debug(f"{operation_name}: completed on attempt {attempt}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(f"{operation_name}: attempt {attempt} failed: {e}")
|
||||
|
||||
if attempt < cfg.max_retries:
|
||||
delay_ms = min(
|
||||
cfg.initial_delay_ms * (cfg.backoff_factor ** (attempt - 1)),
|
||||
cfg.max_delay_ms
|
||||
)
|
||||
await asyncio.sleep(delay_ms / 1000)
|
||||
|
||||
raise last_error
|
||||
|
||||
@staticmethod
|
||||
def create_prompt_simplifier(original_prompt: str) -> Callable[[], str]:
|
||||
"""
|
||||
Create a fallback that simplifies the prompt
|
||||
|
||||
Args:
|
||||
original_prompt: Original complex prompt
|
||||
|
||||
Returns:
|
||||
Callable that returns a simplified prompt
|
||||
"""
|
||||
def simplify() -> str:
|
||||
# Simple prompt simplification strategies
|
||||
simplified = original_prompt
|
||||
|
||||
# Remove complex modifiers
|
||||
remove_phrases = [
|
||||
"highly detailed",
|
||||
"ultra realistic",
|
||||
"8k resolution",
|
||||
"masterpiece",
|
||||
"best quality",
|
||||
]
|
||||
for phrase in remove_phrases:
|
||||
simplified = simplified.replace(phrase, "").replace(phrase.lower(), "")
|
||||
|
||||
# Truncate if too long
|
||||
if len(simplified) > 150:
|
||||
simplified = simplified[:150].rsplit(" ", 1)[0]
|
||||
|
||||
# Clean up
|
||||
simplified = " ".join(simplified.split())
|
||||
|
||||
return simplified
|
||||
|
||||
return simplify
|
||||
276
pixelle_video/services/quality/style_guard.py
Normal file
276
pixelle_video/services/quality/style_guard.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
StyleGuard - Visual style consistency engine
|
||||
|
||||
Ensures consistent visual style across all frames in a video by:
|
||||
1. Extracting style anchor from the first generated frame
|
||||
2. Injecting style constraints into subsequent frame prompts
|
||||
3. (Optional) Using style reference techniques like IP-Adapter
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class StyleAnchor:
|
||||
"""Style anchor extracted from reference frame"""
|
||||
|
||||
# Core style elements
|
||||
color_palette: str = "" # e.g., "warm earth tones", "cool blues"
|
||||
art_style: str = "" # e.g., "minimalist", "realistic", "anime"
|
||||
composition_style: str = "" # e.g., "centered", "rule of thirds"
|
||||
texture: str = "" # e.g., "smooth", "grainy", "watercolor"
|
||||
lighting: str = "" # e.g., "soft ambient", "dramatic shadows"
|
||||
|
||||
# Combined style prefix for prompts
|
||||
style_prefix: str = ""
|
||||
|
||||
# Reference image path (for IP-Adapter style techniques)
|
||||
reference_image: Optional[str] = None
|
||||
|
||||
def to_prompt_prefix(self) -> str:
|
||||
"""Generate a style prefix for image prompts"""
|
||||
if self.style_prefix:
|
||||
return self.style_prefix
|
||||
|
||||
elements = []
|
||||
if self.art_style:
|
||||
elements.append(f"{self.art_style} style")
|
||||
if self.color_palette:
|
||||
elements.append(f"{self.color_palette}")
|
||||
if self.lighting:
|
||||
elements.append(f"{self.lighting} lighting")
|
||||
if self.texture:
|
||||
elements.append(f"{self.texture} texture")
|
||||
|
||||
return ", ".join(elements) if elements else ""
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"color_palette": self.color_palette,
|
||||
"art_style": self.art_style,
|
||||
"composition_style": self.composition_style,
|
||||
"texture": self.texture,
|
||||
"lighting": self.lighting,
|
||||
"style_prefix": self.style_prefix,
|
||||
"reference_image": self.reference_image,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class StyleGuardConfig:
|
||||
"""Configuration for StyleGuard"""
|
||||
|
||||
# Extraction settings
|
||||
extract_from_first_frame: bool = True
|
||||
use_vlm_extraction: bool = True
|
||||
|
||||
# Application settings
|
||||
apply_to_all_frames: bool = True
|
||||
prefix_position: str = "start" # "start" or "end"
|
||||
|
||||
# Optional external style reference
|
||||
external_style_image: Optional[str] = None
|
||||
custom_style_prefix: Optional[str] = None
|
||||
|
||||
|
||||
class StyleGuard:
|
||||
"""
|
||||
Style consistency guardian for video generation
|
||||
|
||||
Ensures all frames in a video maintain visual consistency by:
|
||||
1. Analyzing the first frame (or reference image) to extract style
|
||||
2. Applying style constraints to all subsequent frame prompts
|
||||
|
||||
Example:
|
||||
>>> style_guard = StyleGuard(llm_service)
|
||||
>>>
|
||||
>>> # Extract style from first frame
|
||||
>>> anchor = await style_guard.extract_style_anchor(
|
||||
... image_path="output/frame_001.png"
|
||||
... )
|
||||
>>>
|
||||
>>> # Apply to subsequent prompts
|
||||
>>> styled_prompt = style_guard.apply_style(
|
||||
... prompt="A cat sitting on a windowsill",
|
||||
... style_anchor=anchor
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_service=None,
|
||||
config: Optional[StyleGuardConfig] = None
|
||||
):
|
||||
"""
|
||||
Initialize StyleGuard
|
||||
|
||||
Args:
|
||||
llm_service: LLM service for VLM-based style extraction
|
||||
config: StyleGuard configuration
|
||||
"""
|
||||
self.llm_service = llm_service
|
||||
self.config = config or StyleGuardConfig()
|
||||
self._current_anchor: Optional[StyleAnchor] = None
|
||||
|
||||
async def extract_style_anchor(
|
||||
self,
|
||||
image_path: str,
|
||||
) -> StyleAnchor:
|
||||
"""
|
||||
Extract style anchor from reference image
|
||||
|
||||
Args:
|
||||
image_path: Path to reference image
|
||||
|
||||
Returns:
|
||||
StyleAnchor with extracted style elements
|
||||
"""
|
||||
logger.info(f"Extracting style anchor from: {image_path}")
|
||||
|
||||
if self.config.custom_style_prefix:
|
||||
# Use custom style prefix if provided
|
||||
anchor = StyleAnchor(
|
||||
style_prefix=self.config.custom_style_prefix,
|
||||
reference_image=image_path
|
||||
)
|
||||
self._current_anchor = anchor
|
||||
return anchor
|
||||
|
||||
if self.config.use_vlm_extraction and self.llm_service:
|
||||
anchor = await self._extract_with_vlm(image_path)
|
||||
else:
|
||||
anchor = self._extract_basic(image_path)
|
||||
|
||||
self._current_anchor = anchor
|
||||
logger.info(f"Style anchor extracted: {anchor.to_prompt_prefix()}")
|
||||
|
||||
return anchor
|
||||
|
||||
async def _extract_with_vlm(self, image_path: str) -> StyleAnchor:
|
||||
"""Extract style using Vision Language Model"""
|
||||
try:
|
||||
# TODO: Implement VLM call when vision-capable LLM is integrated
|
||||
# For now, return a placeholder
|
||||
logger.debug("VLM style extraction: using placeholder (VLM not yet integrated)")
|
||||
|
||||
# Placeholder extraction based on common styles
|
||||
return StyleAnchor(
|
||||
art_style="consistent artistic",
|
||||
color_palette="harmonious colors",
|
||||
lighting="balanced",
|
||||
style_prefix="maintaining visual consistency, same artistic style as previous frames",
|
||||
reference_image=image_path,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"VLM style extraction failed: {e}")
|
||||
return self._extract_basic(image_path)
|
||||
|
||||
def _extract_basic(self, image_path: str) -> StyleAnchor:
|
||||
"""Basic style extraction without VLM"""
|
||||
# Return generic style anchor
|
||||
return StyleAnchor(
|
||||
style_prefix="consistent visual style",
|
||||
reference_image=image_path,
|
||||
)
|
||||
|
||||
def apply_style(
|
||||
self,
|
||||
prompt: str,
|
||||
style_anchor: Optional[StyleAnchor] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Apply style constraints to an image prompt
|
||||
|
||||
Args:
|
||||
prompt: Original image prompt
|
||||
style_anchor: Style anchor to apply (uses current if not provided)
|
||||
|
||||
Returns:
|
||||
Modified prompt with style constraints
|
||||
"""
|
||||
anchor = style_anchor or self._current_anchor
|
||||
|
||||
if not anchor:
|
||||
return prompt
|
||||
|
||||
style_prefix = anchor.to_prompt_prefix()
|
||||
|
||||
if not style_prefix:
|
||||
return prompt
|
||||
|
||||
if self.config.prefix_position == "start":
|
||||
return f"{style_prefix}, {prompt}"
|
||||
else:
|
||||
return f"{prompt}, {style_prefix}"
|
||||
|
||||
def apply_style_to_batch(
|
||||
self,
|
||||
prompts: List[str],
|
||||
style_anchor: Optional[StyleAnchor] = None,
|
||||
skip_first: bool = True,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Apply style constraints to a batch of prompts
|
||||
|
||||
Args:
|
||||
prompts: List of image prompts
|
||||
style_anchor: Style anchor to apply
|
||||
skip_first: Skip first prompt (used as reference)
|
||||
|
||||
Returns:
|
||||
List of styled prompts
|
||||
"""
|
||||
if not prompts:
|
||||
return prompts
|
||||
|
||||
anchor = style_anchor or self._current_anchor
|
||||
|
||||
if not anchor:
|
||||
return prompts
|
||||
|
||||
result = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
if skip_first and i == 0:
|
||||
result.append(prompt)
|
||||
else:
|
||||
result.append(self.apply_style(prompt, anchor))
|
||||
|
||||
return result
|
||||
|
||||
def get_consistency_prompt_suffix(self) -> str:
|
||||
"""
|
||||
Get a consistency prompt suffix for LLM prompt generation
|
||||
|
||||
This can be added to the LLM prompt when generating image prompts
|
||||
to encourage consistent style descriptions.
|
||||
"""
|
||||
return (
|
||||
"Ensure all image prompts maintain consistent visual style, "
|
||||
"including similar color palette, art style, lighting, and composition. "
|
||||
"Each image should feel like it belongs to the same visual narrative."
|
||||
)
|
||||
|
||||
@property
|
||||
def current_anchor(self) -> Optional[StyleAnchor]:
|
||||
"""Get the current style anchor"""
|
||||
return self._current_anchor
|
||||
|
||||
def reset(self):
|
||||
"""Reset the current style anchor"""
|
||||
self._current_anchor = None
|
||||
Reference in New Issue
Block a user