- Add FeatureExtractor for CLIP-based image/text feature extraction - Add ObjectiveMetricsCalculator for technical quality metrics - Add VLMEvaluator for vision language model evaluation - Add HybridQualityGate combining objective + VLM evaluation - Enhance CharacterMemory with visual feature support - Add quality optional dependency (torch, ftfy, regex) - Add unit tests for new modules 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
262 lines
7.6 KiB
Python
262 lines
7.6 KiB
Python
# Copyright (C) 2025 AIDC-AI
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
FeatureExtractor - CLIP-based visual feature extraction
|
|
|
|
Provides:
|
|
1. Image-to-vector encoding for similarity comparison
|
|
2. Text-to-vector encoding for CLIP score calculation
|
|
3. Lazy loading of CLIP model (optional dependency)
|
|
|
|
Note: CLIP is an optional dependency. Install with:
|
|
pip install pixelle-video[quality]
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Union
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
from loguru import logger
|
|
|
|
|
|
@dataclass
|
|
class FeatureExtractorConfig:
|
|
"""Configuration for feature extraction"""
|
|
|
|
# Model settings
|
|
model_name: str = "ViT-B/32"
|
|
device: str = "auto" # "auto", "cpu", "cuda", "mps"
|
|
|
|
# Performance settings
|
|
batch_size: int = 8
|
|
cache_features: bool = True
|
|
|
|
# Similarity thresholds
|
|
character_similarity_threshold: float = 0.75
|
|
style_similarity_threshold: float = 0.70
|
|
|
|
|
|
class FeatureExtractor:
|
|
"""
|
|
CLIP-based feature extraction for quality evaluation
|
|
|
|
Features:
|
|
- Lazy loading: CLIP model only loaded when first needed
|
|
- Graceful degradation: Returns None if CLIP unavailable
|
|
- Caching: Optional feature caching for performance
|
|
|
|
Example:
|
|
>>> extractor = FeatureExtractor()
|
|
>>> if extractor.is_available:
|
|
... score = extractor.calculate_clip_score(
|
|
... image_path="frame_001.png",
|
|
... text="A sunset over mountains"
|
|
... )
|
|
"""
|
|
|
|
def __init__(self, config: Optional[FeatureExtractorConfig] = None):
|
|
self.config = config or FeatureExtractorConfig()
|
|
self._model = None
|
|
self._preprocess = None
|
|
self._device = None
|
|
self._available: Optional[bool] = None
|
|
self._feature_cache: dict = {}
|
|
|
|
@property
|
|
def is_available(self) -> bool:
|
|
"""Check if CLIP is available (lazy check)"""
|
|
if self._available is None:
|
|
self._available = self._check_availability()
|
|
return self._available
|
|
|
|
def _check_availability(self) -> bool:
|
|
"""Check if CLIP dependencies are installed"""
|
|
try:
|
|
import torch
|
|
import clip
|
|
return True
|
|
except ImportError:
|
|
logger.warning(
|
|
"CLIP not available. Install with: "
|
|
"pip install torch clip-by-openai"
|
|
)
|
|
return False
|
|
|
|
def _load_model(self):
|
|
"""Lazy load CLIP model"""
|
|
if self._model is not None:
|
|
return
|
|
|
|
if not self.is_available:
|
|
return
|
|
|
|
import torch
|
|
import clip
|
|
|
|
# Determine device
|
|
if self.config.device == "auto":
|
|
if torch.cuda.is_available():
|
|
self._device = "cuda"
|
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
self._device = "mps"
|
|
else:
|
|
self._device = "cpu"
|
|
else:
|
|
self._device = self.config.device
|
|
|
|
logger.info(f"Loading CLIP model {self.config.model_name} on {self._device}")
|
|
self._model, self._preprocess = clip.load(
|
|
self.config.model_name,
|
|
device=self._device
|
|
)
|
|
logger.info("CLIP model loaded successfully")
|
|
|
|
def extract_image_features(
|
|
self,
|
|
image_path: Union[str, Path]
|
|
) -> Optional[np.ndarray]:
|
|
"""
|
|
Extract CLIP features from an image
|
|
|
|
Args:
|
|
image_path: Path to image file
|
|
|
|
Returns:
|
|
Normalized feature vector (512-dim for ViT-B/32) or None
|
|
"""
|
|
if not self.is_available:
|
|
return None
|
|
|
|
self._load_model()
|
|
|
|
# Check cache
|
|
cache_key = str(image_path)
|
|
if self.config.cache_features and cache_key in self._feature_cache:
|
|
return self._feature_cache[cache_key]
|
|
|
|
try:
|
|
import torch
|
|
from PIL import Image
|
|
|
|
image = Image.open(image_path).convert("RGB")
|
|
image_input = self._preprocess(image).unsqueeze(0).to(self._device)
|
|
|
|
with torch.no_grad():
|
|
features = self._model.encode_image(image_input)
|
|
features = features / features.norm(dim=-1, keepdim=True)
|
|
features = features.cpu().numpy().flatten()
|
|
|
|
# Cache result
|
|
if self.config.cache_features:
|
|
self._feature_cache[cache_key] = features
|
|
|
|
return features
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to extract image features: {e}")
|
|
return None
|
|
|
|
def extract_text_features(self, text: str) -> Optional[np.ndarray]:
|
|
"""
|
|
Extract CLIP features from text
|
|
|
|
Args:
|
|
text: Text to encode
|
|
|
|
Returns:
|
|
Normalized feature vector or None
|
|
"""
|
|
if not self.is_available:
|
|
return None
|
|
|
|
self._load_model()
|
|
|
|
try:
|
|
import torch
|
|
import clip
|
|
|
|
# Truncate text if too long (CLIP max is 77 tokens)
|
|
text = text[:300]
|
|
|
|
text_input = clip.tokenize([text]).to(self._device)
|
|
|
|
with torch.no_grad():
|
|
features = self._model.encode_text(text_input)
|
|
features = features / features.norm(dim=-1, keepdim=True)
|
|
features = features.cpu().numpy().flatten()
|
|
|
|
return features
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to extract text features: {e}")
|
|
return None
|
|
|
|
def calculate_clip_score(
|
|
self,
|
|
image_path: Union[str, Path],
|
|
text: str
|
|
) -> Optional[float]:
|
|
"""
|
|
Calculate CLIP score (image-text similarity)
|
|
|
|
Args:
|
|
image_path: Path to image
|
|
text: Text prompt to compare
|
|
|
|
Returns:
|
|
Similarity score (0.0-1.0) or None if unavailable
|
|
"""
|
|
image_features = self.extract_image_features(image_path)
|
|
text_features = self.extract_text_features(text)
|
|
|
|
if image_features is None or text_features is None:
|
|
return None
|
|
|
|
# Cosine similarity (features are already normalized)
|
|
similarity = float(np.dot(image_features, text_features))
|
|
|
|
# Convert from [-1, 1] to [0, 1] range
|
|
score = (similarity + 1) / 2
|
|
|
|
return score
|
|
|
|
def calculate_image_similarity(
|
|
self,
|
|
image_path_1: Union[str, Path],
|
|
image_path_2: Union[str, Path]
|
|
) -> Optional[float]:
|
|
"""
|
|
Calculate similarity between two images
|
|
|
|
Args:
|
|
image_path_1: Path to first image
|
|
image_path_2: Path to second image
|
|
|
|
Returns:
|
|
Similarity score (0.0-1.0) or None
|
|
"""
|
|
features_1 = self.extract_image_features(image_path_1)
|
|
features_2 = self.extract_image_features(image_path_2)
|
|
|
|
if features_1 is None or features_2 is None:
|
|
return None
|
|
|
|
similarity = float(np.dot(features_1, features_2))
|
|
return (similarity + 1) / 2
|
|
|
|
def clear_cache(self):
|
|
"""Clear feature cache"""
|
|
self._feature_cache.clear()
|
|
logger.debug("Feature cache cleared")
|