- Add FeatureExtractor for CLIP-based image/text feature extraction - Add ObjectiveMetricsCalculator for technical quality metrics - Add VLMEvaluator for vision language model evaluation - Add HybridQualityGate combining objective + VLM evaluation - Enhance CharacterMemory with visual feature support - Add quality optional dependency (torch, ftfy, regex) - Add unit tests for new modules 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
618 lines
22 KiB
Python
618 lines
22 KiB
Python
# Copyright (C) 2025 AIDC-AI
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
"""
|
||
CharacterMemory - Character consistency and memory system
|
||
|
||
Maintains consistent character appearance across video frames by:
|
||
1. Detecting and registering characters from narrations
|
||
2. Extracting visual descriptions from first appearances
|
||
3. Injecting character consistency prompts into subsequent frames
|
||
4. Supporting reference images for ComfyUI IP-Adapter/ControlNet
|
||
"""
|
||
|
||
from dataclasses import dataclass, field
|
||
from typing import List, Dict, Optional, Set, Tuple, Any
|
||
from datetime import datetime
|
||
from enum import Enum
|
||
|
||
import numpy as np
|
||
from loguru import logger
|
||
|
||
|
||
class CharacterType(Enum):
|
||
"""Type of character"""
|
||
PERSON = "person" # Human character
|
||
ANIMAL = "animal" # Animal character
|
||
CREATURE = "creature" # Fantasy/fictional creature
|
||
OBJECT = "object" # Personified object
|
||
ABSTRACT = "abstract" # Abstract entity
|
||
|
||
|
||
@dataclass
|
||
class Character:
|
||
"""
|
||
Represents a character in the video narrative
|
||
|
||
Stores visual description, reference images, and appearance history
|
||
to maintain consistency across frames.
|
||
"""
|
||
|
||
# Identity
|
||
id: str # Unique identifier
|
||
name: str # Character name (e.g., "小明", "the hero")
|
||
aliases: List[str] = field(default_factory=list) # Alternative names
|
||
character_type: CharacterType = CharacterType.PERSON
|
||
|
||
# Visual description (for prompt injection)
|
||
appearance_description: str = "" # Detailed visual description
|
||
clothing_description: str = "" # Clothing/outfit description
|
||
distinctive_features: List[str] = field(default_factory=list) # Unique features
|
||
|
||
# Reference images (for IP-Adapter/ControlNet)
|
||
reference_images: List[str] = field(default_factory=list) # Paths to reference images
|
||
primary_reference: Optional[str] = None # Primary reference image
|
||
|
||
# Prompt elements
|
||
prompt_prefix: str = "" # Pre-built prompt prefix
|
||
negative_prompt: str = "" # Negative prompt additions
|
||
|
||
# Metadata
|
||
is_active: bool = True # Whether this character is active for logic
|
||
first_appearance_frame: int = 0 # Frame index of first appearance
|
||
appearance_frames: List[int] = field(default_factory=list) # All frames with this character
|
||
created_at: Optional[datetime] = None
|
||
|
||
# Visual features (NEW - for cross-frame consistency)
|
||
visual_features: Optional[Any] = None # CLIP feature vector (np.ndarray)
|
||
feature_extraction_frame: Optional[int] = None # Frame where features were extracted
|
||
similarity_history: List[float] = field(default_factory=list) # Similarity scores history
|
||
min_similarity_threshold: float = 0.75 # Minimum similarity for consistency
|
||
|
||
def __post_init__(self):
|
||
if self.created_at is None:
|
||
self.created_at = datetime.now()
|
||
if not hasattr(self, 'is_active'):
|
||
self.is_active = True
|
||
if not self.prompt_prefix:
|
||
self._build_prompt_prefix()
|
||
|
||
def _build_prompt_prefix(self):
|
||
"""Build prompt prefix from visual descriptions"""
|
||
elements = []
|
||
|
||
if self.appearance_description:
|
||
elements.append(self.appearance_description)
|
||
if self.clothing_description:
|
||
elements.append(f"wearing {self.clothing_description}")
|
||
if self.distinctive_features:
|
||
elements.append(", ".join(self.distinctive_features))
|
||
|
||
self.prompt_prefix = ", ".join(elements) if elements else ""
|
||
|
||
def get_prompt_injection(self) -> str:
|
||
"""Get the prompt text to inject for this character"""
|
||
if self.prompt_prefix:
|
||
return f"({self.name}: {self.prompt_prefix})"
|
||
return f"({self.name})"
|
||
|
||
def add_reference_image(self, image_path: str, set_as_primary: bool = False):
|
||
"""Add a reference image for this character"""
|
||
if image_path not in self.reference_images:
|
||
self.reference_images.append(image_path)
|
||
if set_as_primary or self.primary_reference is None:
|
||
self.primary_reference = image_path
|
||
|
||
def matches_name(self, name: str) -> bool:
|
||
"""Check if a name matches this character"""
|
||
name_lower = name.lower().strip()
|
||
if self.name.lower() == name_lower:
|
||
return True
|
||
return any(alias.lower() == name_lower for alias in self.aliases)
|
||
|
||
def set_visual_features(self, features: np.ndarray, frame_index: int):
|
||
"""Set visual features from reference frame"""
|
||
self.visual_features = features
|
||
self.feature_extraction_frame = frame_index
|
||
|
||
def check_visual_similarity(self, other_features: np.ndarray) -> Tuple[bool, float]:
|
||
"""Check if other features match this character"""
|
||
if self.visual_features is None:
|
||
return True, 1.0
|
||
|
||
similarity = float(np.dot(self.visual_features, other_features))
|
||
similarity = (similarity + 1) / 2 # Normalize to 0-1
|
||
|
||
self.similarity_history.append(similarity)
|
||
is_match = similarity >= self.min_similarity_threshold
|
||
return is_match, similarity
|
||
|
||
def to_dict(self) -> dict:
|
||
return {
|
||
"id": self.id,
|
||
"name": self.name,
|
||
"aliases": self.aliases,
|
||
"type": self.character_type.value,
|
||
"appearance_description": self.appearance_description,
|
||
"clothing_description": self.clothing_description,
|
||
"distinctive_features": self.distinctive_features,
|
||
"reference_images": self.reference_images,
|
||
"primary_reference": self.primary_reference,
|
||
"prompt_prefix": self.prompt_prefix,
|
||
"first_appearance_frame": self.first_appearance_frame,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class CharacterMemoryConfig:
|
||
"""Configuration for character memory system"""
|
||
|
||
# Detection settings
|
||
auto_detect_characters: bool = True # Automatically detect characters from narrations
|
||
use_llm_detection: bool = True # Use LLM to extract character info
|
||
|
||
# Consistency settings
|
||
inject_character_prompts: bool = True # Inject character descriptions into prompts
|
||
use_reference_images: bool = True # Use reference images for generation
|
||
|
||
# Reference image settings
|
||
extract_reference_from_first: bool = True # Extract reference from first appearance
|
||
max_reference_images: int = 3 # Max reference images per character
|
||
|
||
# Prompt injection settings
|
||
prompt_injection_position: str = "start" # "start" or "end"
|
||
include_clothing: bool = True # Include clothing in prompts
|
||
include_features: bool = True # Include distinctive features
|
||
|
||
# Visual feature settings (NEW)
|
||
enable_visual_features: bool = True # Enable CLIP visual features
|
||
visual_similarity_threshold: float = 0.75 # Min similarity for consistency
|
||
extract_features_on_first: bool = True # Extract features on first appearance
|
||
|
||
|
||
class CharacterMemory:
|
||
"""
|
||
Character memory and consistency manager
|
||
|
||
Tracks characters across video frames and ensures visual consistency
|
||
by injecting character descriptions and reference images into the
|
||
generation pipeline.
|
||
|
||
Example:
|
||
>>> memory = CharacterMemory(llm_service)
|
||
>>>
|
||
>>> # Register a character
|
||
>>> char = memory.register_character(
|
||
... name="小明",
|
||
... appearance_description="young man with short black hair",
|
||
... clothing_description="blue t-shirt"
|
||
... )
|
||
>>>
|
||
>>> # Apply to prompt
|
||
>>> enhanced_prompt = memory.apply_to_prompt(
|
||
... prompt="A person walking in the park",
|
||
... characters=["小明"]
|
||
... )
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
llm_service=None,
|
||
config: Optional[CharacterMemoryConfig] = None
|
||
):
|
||
"""
|
||
Initialize CharacterMemory
|
||
|
||
Args:
|
||
llm_service: Optional LLM service for character detection
|
||
config: Character memory configuration
|
||
"""
|
||
self.llm_service = llm_service
|
||
self.config = config or CharacterMemoryConfig()
|
||
self._characters: Dict[str, Character] = {}
|
||
self._name_index: Dict[str, str] = {} # name -> character_id mapping
|
||
self._feature_extractor = None # Lazy-loaded
|
||
|
||
def register_character(
|
||
self,
|
||
name: str,
|
||
appearance_description: str = "",
|
||
clothing_description: str = "",
|
||
distinctive_features: Optional[List[str]] = None,
|
||
character_type: CharacterType = CharacterType.PERSON,
|
||
first_frame: int = 0,
|
||
) -> Character:
|
||
"""
|
||
Register a new character
|
||
|
||
Args:
|
||
name: Character name
|
||
appearance_description: Visual appearance description
|
||
clothing_description: Clothing/outfit description
|
||
distinctive_features: List of distinctive features
|
||
character_type: Type of character
|
||
first_frame: Frame index of first appearance
|
||
|
||
Returns:
|
||
Created Character object
|
||
"""
|
||
# Generate unique ID
|
||
char_id = f"char_{len(self._characters)}_{name.replace(' ', '_').lower()}"
|
||
|
||
character = Character(
|
||
id=char_id,
|
||
name=name,
|
||
appearance_description=appearance_description,
|
||
clothing_description=clothing_description,
|
||
distinctive_features=distinctive_features or [],
|
||
character_type=character_type,
|
||
first_appearance_frame=first_frame,
|
||
appearance_frames=[first_frame],
|
||
)
|
||
|
||
self._characters[char_id] = character
|
||
self._name_index[name.lower()] = char_id
|
||
|
||
logger.info(f"Registered character: {name} (id={char_id})")
|
||
|
||
return character
|
||
|
||
def get_character(self, name: str) -> Optional[Character]:
|
||
"""Get a character by name"""
|
||
name_lower = name.lower().strip()
|
||
char_id = self._name_index.get(name_lower)
|
||
if char_id:
|
||
return self._characters.get(char_id)
|
||
|
||
# Search aliases
|
||
for char in self._characters.values():
|
||
if char.matches_name(name):
|
||
return char
|
||
|
||
return None
|
||
|
||
def get_character_by_id(self, char_id: str) -> Optional[Character]:
|
||
"""Get a character by ID"""
|
||
return self._characters.get(char_id)
|
||
|
||
@property
|
||
def characters(self) -> List[Character]:
|
||
"""Get all registered characters"""
|
||
return list(self._characters.values())
|
||
|
||
async def detect_characters_from_narration(
|
||
self,
|
||
narration: str,
|
||
frame_index: int = 0,
|
||
) -> List[Character]:
|
||
"""
|
||
Detect and register characters mentioned in narration
|
||
|
||
Args:
|
||
narration: Narration text to analyze
|
||
frame_index: Current frame index
|
||
|
||
Returns:
|
||
List of detected/registered characters
|
||
"""
|
||
if not self.config.auto_detect_characters:
|
||
return []
|
||
|
||
detected = []
|
||
|
||
if self.config.use_llm_detection and self.llm_service:
|
||
detected = await self._detect_with_llm(narration, frame_index)
|
||
else:
|
||
detected = self._detect_basic(narration, frame_index)
|
||
|
||
return detected
|
||
|
||
async def _detect_with_llm(
|
||
self,
|
||
narration: str,
|
||
frame_index: int,
|
||
) -> List[Character]:
|
||
"""Detect characters using LLM"""
|
||
if not self.llm_service:
|
||
return []
|
||
|
||
try:
|
||
prompt = f"""分析以下文案,提取其中提到的角色/人物。
|
||
|
||
文案: {narration}
|
||
|
||
请用 JSON 格式返回角色列表,每个角色包含:
|
||
- name: 角色名称或代称
|
||
- type: person/animal/creature/object
|
||
- appearance: 外貌描述(如有)
|
||
- clothing: 服装描述(如有)
|
||
|
||
如果没有明确角色,返回空列表 []。
|
||
|
||
只返回 JSON,不要其他解释。"""
|
||
|
||
response = await self.llm_service(prompt, temperature=0.1)
|
||
|
||
# Parse response
|
||
import json
|
||
import re
|
||
|
||
# Extract JSON from response
|
||
json_match = re.search(r'\[.*\]', response, re.DOTALL)
|
||
if json_match:
|
||
characters_data = json.loads(json_match.group())
|
||
|
||
result = []
|
||
for char_data in characters_data:
|
||
name = char_data.get("name", "").strip()
|
||
if not name:
|
||
continue
|
||
|
||
# Check if already registered
|
||
existing = self.get_character(name)
|
||
if existing:
|
||
existing.appearance_frames.append(frame_index)
|
||
result.append(existing)
|
||
else:
|
||
# Register new character
|
||
char_type = CharacterType.PERSON
|
||
type_str = char_data.get("type", "person").lower()
|
||
if type_str == "animal":
|
||
char_type = CharacterType.ANIMAL
|
||
elif type_str == "creature":
|
||
char_type = CharacterType.CREATURE
|
||
|
||
char = self.register_character(
|
||
name=name,
|
||
appearance_description=char_data.get("appearance", ""),
|
||
clothing_description=char_data.get("clothing", ""),
|
||
character_type=char_type,
|
||
first_frame=frame_index,
|
||
)
|
||
result.append(char)
|
||
|
||
return result
|
||
|
||
return []
|
||
|
||
except Exception as e:
|
||
logger.warning(f"LLM character detection failed: {e}")
|
||
return self._detect_basic(narration, frame_index)
|
||
|
||
def _detect_basic(
|
||
self,
|
||
narration: str,
|
||
frame_index: int,
|
||
) -> List[Character]:
|
||
"""Basic character detection without LLM"""
|
||
# Simple pattern matching for common character references
|
||
import re
|
||
|
||
patterns = [
|
||
r'(?:他|她|它)们?', # Chinese pronouns
|
||
r'(?:小\w{1,2})', # Names like 小明, 小红
|
||
r'(?:老\w{1,2})', # Names like 老王, 老李
|
||
]
|
||
|
||
detected = []
|
||
for pattern in patterns:
|
||
matches = re.findall(pattern, narration)
|
||
for match in matches:
|
||
existing = self.get_character(match)
|
||
if existing:
|
||
existing.appearance_frames.append(frame_index)
|
||
if existing not in detected:
|
||
detected.append(existing)
|
||
|
||
return detected
|
||
|
||
def apply_to_prompt(
|
||
self,
|
||
prompt: str,
|
||
character_names: Optional[List[str]] = None,
|
||
frame_index: Optional[int] = None,
|
||
) -> str:
|
||
"""
|
||
Apply character consistency to an image prompt
|
||
|
||
Args:
|
||
prompt: Original image prompt
|
||
character_names: Specific characters to include (None = auto-detect)
|
||
frame_index: Current frame index for tracking
|
||
|
||
Returns:
|
||
Enhanced prompt with character consistency
|
||
"""
|
||
if not self.config.inject_character_prompts:
|
||
return prompt
|
||
|
||
characters_to_include = []
|
||
|
||
if character_names:
|
||
for name in character_names:
|
||
char = self.get_character(name)
|
||
if char:
|
||
characters_to_include.append(char)
|
||
else:
|
||
# Include all characters that have appeared
|
||
characters_to_include = self.characters
|
||
|
||
if not characters_to_include:
|
||
return prompt
|
||
|
||
# Build character injection
|
||
injections = []
|
||
for char in characters_to_include:
|
||
injection = char.get_prompt_injection()
|
||
if injection:
|
||
injections.append(injection)
|
||
|
||
# Track appearance
|
||
if frame_index is not None and frame_index not in char.appearance_frames:
|
||
char.appearance_frames.append(frame_index)
|
||
|
||
if not injections:
|
||
return prompt
|
||
|
||
character_prompt = ", ".join(injections)
|
||
|
||
if self.config.prompt_injection_position == "start":
|
||
return f"{character_prompt}, {prompt}"
|
||
else:
|
||
return f"{prompt}, {character_prompt}"
|
||
|
||
def get_reference_images(
|
||
self,
|
||
character_names: Optional[List[str]] = None,
|
||
) -> List[str]:
|
||
"""
|
||
Get reference images for specified characters
|
||
|
||
Args:
|
||
character_names: Character names (None = all characters)
|
||
|
||
Returns:
|
||
List of reference image paths
|
||
"""
|
||
if not self.config.use_reference_images:
|
||
return []
|
||
|
||
images = []
|
||
|
||
if character_names:
|
||
for name in character_names:
|
||
char = self.get_character(name)
|
||
if char and char.primary_reference:
|
||
images.append(char.primary_reference)
|
||
else:
|
||
for char in self.characters:
|
||
if char.primary_reference:
|
||
images.append(char.primary_reference)
|
||
|
||
return images[:self.config.max_reference_images]
|
||
|
||
def set_reference_image(
|
||
self,
|
||
character_name: str,
|
||
image_path: str,
|
||
set_as_primary: bool = True,
|
||
):
|
||
"""
|
||
Set a reference image for a character
|
||
|
||
Args:
|
||
character_name: Character name
|
||
image_path: Path to reference image
|
||
set_as_primary: Whether to set as primary reference
|
||
"""
|
||
char = self.get_character(character_name)
|
||
if char:
|
||
char.add_reference_image(image_path, set_as_primary)
|
||
logger.debug(f"Set reference image for {character_name}: {image_path}")
|
||
else:
|
||
logger.warning(f"Character not found: {character_name}")
|
||
|
||
def update_character_appearance(
|
||
self,
|
||
character_name: str,
|
||
appearance_description: Optional[str] = None,
|
||
clothing_description: Optional[str] = None,
|
||
distinctive_features: Optional[List[str]] = None,
|
||
):
|
||
"""Update a character's visual description"""
|
||
char = self.get_character(character_name)
|
||
if char:
|
||
if appearance_description:
|
||
char.appearance_description = appearance_description
|
||
if clothing_description:
|
||
char.clothing_description = clothing_description
|
||
if distinctive_features:
|
||
char.distinctive_features = distinctive_features
|
||
char._build_prompt_prefix()
|
||
logger.debug(f"Updated appearance for {character_name}")
|
||
|
||
def get_consistency_summary(self) -> str:
|
||
"""Get a summary of character consistency for logging"""
|
||
if not self._characters:
|
||
return "No characters registered"
|
||
|
||
lines = [f"Characters ({len(self._characters)}):"]
|
||
for char in self.characters:
|
||
lines.append(
|
||
f" - {char.name}: {len(char.appearance_frames)} appearances, "
|
||
f"ref_images={len(char.reference_images)}"
|
||
)
|
||
return "\n".join(lines)
|
||
|
||
def reset(self):
|
||
"""Clear all character memory"""
|
||
self._characters.clear()
|
||
self._name_index.clear()
|
||
logger.info("Character memory cleared")
|
||
|
||
@property
|
||
def feature_extractor(self):
|
||
"""Lazy-load feature extractor"""
|
||
if self._feature_extractor is None and self.config.enable_visual_features:
|
||
from pixelle_video.services.quality.feature_extractor import FeatureExtractor
|
||
self._feature_extractor = FeatureExtractor()
|
||
return self._feature_extractor
|
||
|
||
async def extract_character_features(
|
||
self,
|
||
character_name: str,
|
||
image_path: str,
|
||
frame_index: int = 0
|
||
) -> bool:
|
||
"""Extract and store visual features for a character"""
|
||
if not self.config.enable_visual_features:
|
||
return False
|
||
|
||
char = self.get_character(character_name)
|
||
if not char:
|
||
logger.warning(f"Character not found: {character_name}")
|
||
return False
|
||
|
||
extractor = self.feature_extractor
|
||
if extractor is None or not extractor.is_available:
|
||
logger.debug("Feature extractor not available")
|
||
return False
|
||
|
||
features = extractor.extract_image_features(image_path)
|
||
if features is None:
|
||
return False
|
||
|
||
char.set_visual_features(features, frame_index)
|
||
char.add_reference_image(image_path, set_as_primary=True)
|
||
logger.info(f"Extracted visual features for {character_name}")
|
||
return True
|
||
|
||
async def check_character_consistency(
|
||
self,
|
||
character_name: str,
|
||
image_path: str
|
||
) -> Tuple[bool, float]:
|
||
"""Check if image maintains character consistency"""
|
||
char = self.get_character(character_name)
|
||
if not char or char.visual_features is None:
|
||
return True, 1.0
|
||
|
||
extractor = self.feature_extractor
|
||
if extractor is None or not extractor.is_available:
|
||
return True, 1.0
|
||
|
||
new_features = extractor.extract_image_features(image_path)
|
||
if new_features is None:
|
||
return True, 1.0
|
||
|
||
return char.check_visual_similarity(new_features)
|