diff --git a/config.example.yaml b/config.example.yaml index f5613ee..4a55a0a 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -65,3 +65,22 @@ template: # - 1920x1080 (horizontal/landscape): image_film.html, image_full.html, etc. # See templates/ directory for all available templates default_template: "1080x1920/image_default.html" + +# ==================== Quality Control Configuration ==================== +# Configure quality evaluation for generated content +quality: + # Enable quality checking (set to false to skip all quality checks) + enable_quality_check: true + + # Hybrid evaluation settings + hybrid: + enable_clip_score: true # Use CLIP for image-text matching + clip_model: "ViT-B/32" # CLIP model variant + enable_technical_metrics: true # Use technical quality metrics + enable_smart_skip: true # Skip VLM when objective scores are good + smart_skip_threshold: 0.75 # Threshold for smart skip + + # Character consistency settings + character: + enable_visual_features: true # Enable CLIP visual features for characters + visual_similarity_threshold: 0.75 # Min similarity for character consistency diff --git a/pixelle_video/services/quality/__init__.py b/pixelle_video/services/quality/__init__.py index 4ab31ec..4c910f8 100644 --- a/pixelle_video/services/quality/__init__.py +++ b/pixelle_video/services/quality/__init__.py @@ -28,7 +28,11 @@ from pixelle_video.services.quality.models import ( RetryConfig, QualityError, ) -from pixelle_video.services.quality.quality_gate import QualityGate +from pixelle_video.services.quality.quality_gate import ( + QualityGate, + HybridQualityGate, + HybridQualityConfig, +) from pixelle_video.services.quality.retry_manager import RetryManager from pixelle_video.services.quality.output_validator import ( OutputValidator, @@ -52,14 +56,29 @@ from pixelle_video.services.quality.character_memory import ( Character, CharacterType, ) +from pixelle_video.services.quality.feature_extractor import ( + FeatureExtractor, + FeatureExtractorConfig, +) +from pixelle_video.services.quality.objective_metrics import ( + ObjectiveMetricsCalculator, + TechnicalMetrics, +) +from pixelle_video.services.quality.vlm_evaluator import ( + VLMEvaluator, + VLMEvaluatorConfig, + VLMEvaluationResult, +) __all__ = [ # Quality evaluation "QualityScore", - "QualityConfig", + "QualityConfig", "RetryConfig", "QualityError", "QualityGate", + "HybridQualityGate", + "HybridQualityConfig", "RetryManager", # Output validation "OutputValidator", @@ -79,6 +98,16 @@ __all__ = [ "CharacterMemoryConfig", "Character", "CharacterType", + # Feature extraction (NEW) + "FeatureExtractor", + "FeatureExtractorConfig", + # Objective metrics (NEW) + "ObjectiveMetricsCalculator", + "TechnicalMetrics", + # VLM evaluation (NEW) + "VLMEvaluator", + "VLMEvaluatorConfig", + "VLMEvaluationResult", ] diff --git a/pixelle_video/services/quality/character_memory.py b/pixelle_video/services/quality/character_memory.py index 40ba19a..02cd01a 100644 --- a/pixelle_video/services/quality/character_memory.py +++ b/pixelle_video/services/quality/character_memory.py @@ -21,10 +21,11 @@ Maintains consistent character appearance across video frames by: """ from dataclasses import dataclass, field -from typing import List, Dict, Optional, Set +from typing import List, Dict, Optional, Set, Tuple, Any from datetime import datetime from enum import Enum +import numpy as np from loguru import logger @@ -70,6 +71,12 @@ class Character: first_appearance_frame: int = 0 # Frame index of first appearance appearance_frames: List[int] = field(default_factory=list) # All frames with this character created_at: Optional[datetime] = None + + # Visual features (NEW - for cross-frame consistency) + visual_features: Optional[Any] = None # CLIP feature vector (np.ndarray) + feature_extraction_frame: Optional[int] = None # Frame where features were extracted + similarity_history: List[float] = field(default_factory=list) # Similarity scores history + min_similarity_threshold: float = 0.75 # Minimum similarity for consistency def __post_init__(self): if self.created_at is None: @@ -111,6 +118,23 @@ class Character: if self.name.lower() == name_lower: return True return any(alias.lower() == name_lower for alias in self.aliases) + + def set_visual_features(self, features: np.ndarray, frame_index: int): + """Set visual features from reference frame""" + self.visual_features = features + self.feature_extraction_frame = frame_index + + def check_visual_similarity(self, other_features: np.ndarray) -> Tuple[bool, float]: + """Check if other features match this character""" + if self.visual_features is None: + return True, 1.0 + + similarity = float(np.dot(self.visual_features, other_features)) + similarity = (similarity + 1) / 2 # Normalize to 0-1 + + self.similarity_history.append(similarity) + is_match = similarity >= self.min_similarity_threshold + return is_match, similarity def to_dict(self) -> dict: return { @@ -149,6 +173,11 @@ class CharacterMemoryConfig: include_clothing: bool = True # Include clothing in prompts include_features: bool = True # Include distinctive features + # Visual feature settings (NEW) + enable_visual_features: bool = True # Enable CLIP visual features + visual_similarity_threshold: float = 0.75 # Min similarity for consistency + extract_features_on_first: bool = True # Extract features on first appearance + class CharacterMemory: """ @@ -191,6 +220,7 @@ class CharacterMemory: self.config = config or CharacterMemoryConfig() self._characters: Dict[str, Character] = {} self._name_index: Dict[str, str] = {} # name -> character_id mapping + self._feature_extractor = None # Lazy-loaded def register_character( self, @@ -528,3 +558,60 @@ class CharacterMemory: self._characters.clear() self._name_index.clear() logger.info("Character memory cleared") + + @property + def feature_extractor(self): + """Lazy-load feature extractor""" + if self._feature_extractor is None and self.config.enable_visual_features: + from pixelle_video.services.quality.feature_extractor import FeatureExtractor + self._feature_extractor = FeatureExtractor() + return self._feature_extractor + + async def extract_character_features( + self, + character_name: str, + image_path: str, + frame_index: int = 0 + ) -> bool: + """Extract and store visual features for a character""" + if not self.config.enable_visual_features: + return False + + char = self.get_character(character_name) + if not char: + logger.warning(f"Character not found: {character_name}") + return False + + extractor = self.feature_extractor + if extractor is None or not extractor.is_available: + logger.debug("Feature extractor not available") + return False + + features = extractor.extract_image_features(image_path) + if features is None: + return False + + char.set_visual_features(features, frame_index) + char.add_reference_image(image_path, set_as_primary=True) + logger.info(f"Extracted visual features for {character_name}") + return True + + async def check_character_consistency( + self, + character_name: str, + image_path: str + ) -> Tuple[bool, float]: + """Check if image maintains character consistency""" + char = self.get_character(character_name) + if not char or char.visual_features is None: + return True, 1.0 + + extractor = self.feature_extractor + if extractor is None or not extractor.is_available: + return True, 1.0 + + new_features = extractor.extract_image_features(image_path) + if new_features is None: + return True, 1.0 + + return char.check_visual_similarity(new_features) diff --git a/pixelle_video/services/quality/feature_extractor.py b/pixelle_video/services/quality/feature_extractor.py new file mode 100644 index 0000000..a0726f3 --- /dev/null +++ b/pixelle_video/services/quality/feature_extractor.py @@ -0,0 +1,261 @@ +# Copyright (C) 2025 AIDC-AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FeatureExtractor - CLIP-based visual feature extraction + +Provides: +1. Image-to-vector encoding for similarity comparison +2. Text-to-vector encoding for CLIP score calculation +3. Lazy loading of CLIP model (optional dependency) + +Note: CLIP is an optional dependency. Install with: + pip install pixelle-video[quality] +""" + +from dataclasses import dataclass +from typing import Optional, Union +from pathlib import Path + +import numpy as np +from loguru import logger + + +@dataclass +class FeatureExtractorConfig: + """Configuration for feature extraction""" + + # Model settings + model_name: str = "ViT-B/32" + device: str = "auto" # "auto", "cpu", "cuda", "mps" + + # Performance settings + batch_size: int = 8 + cache_features: bool = True + + # Similarity thresholds + character_similarity_threshold: float = 0.75 + style_similarity_threshold: float = 0.70 + + +class FeatureExtractor: + """ + CLIP-based feature extraction for quality evaluation + + Features: + - Lazy loading: CLIP model only loaded when first needed + - Graceful degradation: Returns None if CLIP unavailable + - Caching: Optional feature caching for performance + + Example: + >>> extractor = FeatureExtractor() + >>> if extractor.is_available: + ... score = extractor.calculate_clip_score( + ... image_path="frame_001.png", + ... text="A sunset over mountains" + ... ) + """ + + def __init__(self, config: Optional[FeatureExtractorConfig] = None): + self.config = config or FeatureExtractorConfig() + self._model = None + self._preprocess = None + self._device = None + self._available: Optional[bool] = None + self._feature_cache: dict = {} + + @property + def is_available(self) -> bool: + """Check if CLIP is available (lazy check)""" + if self._available is None: + self._available = self._check_availability() + return self._available + + def _check_availability(self) -> bool: + """Check if CLIP dependencies are installed""" + try: + import torch + import clip + return True + except ImportError: + logger.warning( + "CLIP not available. Install with: " + "pip install torch clip-by-openai" + ) + return False + + def _load_model(self): + """Lazy load CLIP model""" + if self._model is not None: + return + + if not self.is_available: + return + + import torch + import clip + + # Determine device + if self.config.device == "auto": + if torch.cuda.is_available(): + self._device = "cuda" + elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + self._device = "mps" + else: + self._device = "cpu" + else: + self._device = self.config.device + + logger.info(f"Loading CLIP model {self.config.model_name} on {self._device}") + self._model, self._preprocess = clip.load( + self.config.model_name, + device=self._device + ) + logger.info("CLIP model loaded successfully") + + def extract_image_features( + self, + image_path: Union[str, Path] + ) -> Optional[np.ndarray]: + """ + Extract CLIP features from an image + + Args: + image_path: Path to image file + + Returns: + Normalized feature vector (512-dim for ViT-B/32) or None + """ + if not self.is_available: + return None + + self._load_model() + + # Check cache + cache_key = str(image_path) + if self.config.cache_features and cache_key in self._feature_cache: + return self._feature_cache[cache_key] + + try: + import torch + from PIL import Image + + image = Image.open(image_path).convert("RGB") + image_input = self._preprocess(image).unsqueeze(0).to(self._device) + + with torch.no_grad(): + features = self._model.encode_image(image_input) + features = features / features.norm(dim=-1, keepdim=True) + features = features.cpu().numpy().flatten() + + # Cache result + if self.config.cache_features: + self._feature_cache[cache_key] = features + + return features + + except Exception as e: + logger.warning(f"Failed to extract image features: {e}") + return None + + def extract_text_features(self, text: str) -> Optional[np.ndarray]: + """ + Extract CLIP features from text + + Args: + text: Text to encode + + Returns: + Normalized feature vector or None + """ + if not self.is_available: + return None + + self._load_model() + + try: + import torch + import clip + + # Truncate text if too long (CLIP max is 77 tokens) + text = text[:300] + + text_input = clip.tokenize([text]).to(self._device) + + with torch.no_grad(): + features = self._model.encode_text(text_input) + features = features / features.norm(dim=-1, keepdim=True) + features = features.cpu().numpy().flatten() + + return features + + except Exception as e: + logger.warning(f"Failed to extract text features: {e}") + return None + + def calculate_clip_score( + self, + image_path: Union[str, Path], + text: str + ) -> Optional[float]: + """ + Calculate CLIP score (image-text similarity) + + Args: + image_path: Path to image + text: Text prompt to compare + + Returns: + Similarity score (0.0-1.0) or None if unavailable + """ + image_features = self.extract_image_features(image_path) + text_features = self.extract_text_features(text) + + if image_features is None or text_features is None: + return None + + # Cosine similarity (features are already normalized) + similarity = float(np.dot(image_features, text_features)) + + # Convert from [-1, 1] to [0, 1] range + score = (similarity + 1) / 2 + + return score + + def calculate_image_similarity( + self, + image_path_1: Union[str, Path], + image_path_2: Union[str, Path] + ) -> Optional[float]: + """ + Calculate similarity between two images + + Args: + image_path_1: Path to first image + image_path_2: Path to second image + + Returns: + Similarity score (0.0-1.0) or None + """ + features_1 = self.extract_image_features(image_path_1) + features_2 = self.extract_image_features(image_path_2) + + if features_1 is None or features_2 is None: + return None + + similarity = float(np.dot(features_1, features_2)) + return (similarity + 1) / 2 + + def clear_cache(self): + """Clear feature cache""" + self._feature_cache.clear() + logger.debug("Feature cache cleared") diff --git a/pixelle_video/services/quality/models.py b/pixelle_video/services/quality/models.py index f3bce61..12607fc 100644 --- a/pixelle_video/services/quality/models.py +++ b/pixelle_video/services/quality/models.py @@ -15,7 +15,7 @@ Data models for quality assurance """ from dataclasses import dataclass, field -from typing import List, Optional +from typing import List, Optional, Any from enum import Enum @@ -43,6 +43,11 @@ class QualityScore: # Diagnostics issues: List[str] = field(default_factory=list) # Detected problems evaluation_time_ms: float = 0.0 # Time taken for evaluation + + # Extended metrics (NEW - for HybridQualityGate) + clip_score: Optional[float] = None # CLIP image-text similarity + technical_metrics: Optional[Any] = None # TechnicalMetrics object + vlm_used: bool = True # Whether VLM was used for evaluation @property def level(self) -> QualityLevel: diff --git a/pixelle_video/services/quality/objective_metrics.py b/pixelle_video/services/quality/objective_metrics.py new file mode 100644 index 0000000..6f3c135 --- /dev/null +++ b/pixelle_video/services/quality/objective_metrics.py @@ -0,0 +1,232 @@ +# Copyright (C) 2025 AIDC-AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +ObjectiveMetrics - Technical image quality metrics + +Provides fast, local computation of: +1. Sharpness/clarity detection +2. Color distribution analysis +3. Exposure detection (over/under) + +Only depends on PIL and numpy (no heavy ML dependencies). +""" + +from dataclasses import dataclass, field +from typing import List, Tuple + +import numpy as np +from loguru import logger + + +@dataclass +class TechnicalMetrics: + """Technical quality metrics for an image""" + + # Sharpness (0.0-1.0, higher = sharper) + sharpness_score: float = 0.0 + + # Color metrics + color_variance: float = 0.0 + saturation_score: float = 0.0 + + # Exposure metrics + brightness_score: float = 0.0 + is_overexposed: bool = False + is_underexposed: bool = False + + # Overall technical score + overall_technical: float = 0.0 + + # Issues detected + issues: List[str] = field(default_factory=list) + + def to_dict(self) -> dict: + """Convert to dictionary""" + return { + "sharpness_score": self.sharpness_score, + "color_variance": self.color_variance, + "saturation_score": self.saturation_score, + "brightness_score": self.brightness_score, + "is_overexposed": self.is_overexposed, + "is_underexposed": self.is_underexposed, + "overall_technical": self.overall_technical, + "issues": self.issues, + } + + +class ObjectiveMetricsCalculator: + """ + Calculate objective technical quality metrics + + All computations are local and fast (no external API calls). + Uses PIL and numpy only. + + Example: + >>> calculator = ObjectiveMetricsCalculator() + >>> metrics = calculator.analyze_image("frame_001.png") + >>> print(f"Sharpness: {metrics.sharpness_score:.2f}") + """ + + def __init__( + self, + sharpness_threshold: float = 0.3, + overexposure_threshold: float = 0.85, + underexposure_threshold: float = 0.15, + ): + self.sharpness_threshold = sharpness_threshold + self.overexposure_threshold = overexposure_threshold + self.underexposure_threshold = underexposure_threshold + + def analyze_image(self, image_path: str) -> TechnicalMetrics: + """ + Analyze image and return technical metrics + + Args: + image_path: Path to image file + + Returns: + TechnicalMetrics with all computed values + """ + try: + from PIL import Image + + with Image.open(image_path) as img: + img_rgb = img.convert("RGB") + img_array = np.array(img_rgb) + + # Calculate individual metrics + sharpness = self._calculate_sharpness(img_rgb) + color_var, saturation = self._calculate_color_metrics(img_array) + brightness, overexp, underexp = self._calculate_exposure(img_array) + + # Detect issues + issues = [] + if sharpness < self.sharpness_threshold: + issues.append("Image appears blurry") + if overexp: + issues.append("Image is overexposed") + if underexp: + issues.append("Image is underexposed") + if color_var < 0.1: + issues.append("Low color diversity") + + # Calculate overall technical score + overall = self._calculate_overall( + sharpness, color_var, saturation, brightness + ) + + return TechnicalMetrics( + sharpness_score=sharpness, + color_variance=color_var, + saturation_score=saturation, + brightness_score=brightness, + is_overexposed=overexp, + is_underexposed=underexp, + overall_technical=overall, + issues=issues + ) + + except Exception as e: + logger.warning(f"Failed to analyze image: {e}") + return TechnicalMetrics(issues=[f"Analysis failed: {str(e)}"]) + + def _calculate_sharpness(self, img) -> float: + """Calculate sharpness using edge detection""" + from PIL import ImageFilter + + # Convert to grayscale + gray = img.convert("L") + + # Apply edge detection + edges = gray.filter(ImageFilter.FIND_EDGES) + edge_array = np.array(edges).astype(np.float32) + + # Variance of edge response + variance = np.var(edge_array) + + # Normalize to 0-1 (empirical scaling) + sharpness = min(1.0, variance / 2000) + + return float(sharpness) + + def _calculate_color_metrics( + self, + img_array: np.ndarray + ) -> Tuple[float, float]: + """Calculate color variance and saturation""" + r, g, b = img_array[:, :, 0], img_array[:, :, 1], img_array[:, :, 2] + + max_rgb = np.maximum(np.maximum(r, g), b) + min_rgb = np.minimum(np.minimum(r, g), b) + + # Saturation + delta = max_rgb - min_rgb + saturation = np.where(max_rgb > 0, delta / (max_rgb + 1e-7), 0) + avg_saturation = np.mean(saturation) + + # Color variance (diversity) + color_std = np.std(img_array.reshape(-1, 3), axis=0) + color_variance = np.mean(color_std) / 128 + + return float(color_variance), float(avg_saturation) + + def _calculate_exposure( + self, + img_array: np.ndarray + ) -> Tuple[float, bool, bool]: + """Calculate brightness and detect exposure issues""" + # Calculate luminance + luminance = ( + 0.299 * img_array[:, :, 0] + + 0.587 * img_array[:, :, 1] + + 0.114 * img_array[:, :, 2] + ) / 255.0 + + avg_brightness = float(np.mean(luminance)) + + # Check for over/under exposure + overexposed_pixels = np.mean(luminance > 0.95) + underexposed_pixels = np.mean(luminance < 0.05) + + is_overexposed = ( + avg_brightness > self.overexposure_threshold or + overexposed_pixels > 0.3 + ) + is_underexposed = ( + avg_brightness < self.underexposure_threshold or + underexposed_pixels > 0.3 + ) + + return avg_brightness, is_overexposed, is_underexposed + + def _calculate_overall( + self, + sharpness: float, + color_var: float, + saturation: float, + brightness: float + ) -> float: + """Calculate overall technical score""" + # Brightness penalty (ideal is 0.5) + brightness_score = 1.0 - abs(brightness - 0.5) * 2 + brightness_score = max(0, brightness_score) + + # Weighted combination + overall = ( + sharpness * 0.4 + + min(1.0, color_var * 2) * 0.2 + + saturation * 0.2 + + brightness_score * 0.2 + ) + + return float(overall) diff --git a/pixelle_video/services/quality/quality_gate.py b/pixelle_video/services/quality/quality_gate.py index eb7b59e..aaba9d8 100644 --- a/pixelle_video/services/quality/quality_gate.py +++ b/pixelle_video/services/quality/quality_gate.py @@ -17,9 +17,12 @@ Evaluates images and videos based on: - Aesthetic quality (visual appeal) - Text-to-image matching (semantic alignment) - Technical quality (clarity, no artifacts) + +Includes HybridQualityGate for combined objective + VLM evaluation. """ import time +from dataclasses import dataclass from typing import Optional from pathlib import Path @@ -28,6 +31,27 @@ from loguru import logger from pixelle_video.services.quality.models import QualityScore, QualityConfig +@dataclass +class HybridQualityConfig(QualityConfig): + """Extended configuration for hybrid quality evaluation""" + + # CLIP settings + enable_clip_score: bool = True + clip_model: str = "ViT-B/32" + clip_weight: float = 0.5 + + # Technical metrics settings + enable_technical_metrics: bool = True + sharpness_threshold: float = 0.3 + + # Smart VLM skip + enable_smart_skip: bool = True + smart_skip_threshold: float = 0.75 + + # Feature caching + cache_features: bool = True + + class QualityGate: """ Quality evaluation gate for AI-generated content @@ -361,3 +385,160 @@ Respond in JSON format: "issues": ["list of any problems found"] }} """ + + +class HybridQualityGate(QualityGate): + """ + Hybrid quality gate combining objective metrics with VLM evaluation + + Evaluation flow: + 1. Calculate technical metrics (fast, local) + 2. Calculate CLIP score if enabled (local, requires CLIP) + 3. If smart_skip enabled and objective score >= threshold, skip VLM + 4. Otherwise, call VLM for subjective evaluation + 5. Combine scores with configurable weights + + Example: + >>> gate = HybridQualityGate(llm_service, config) + >>> score = await gate.evaluate_image( + ... image_path="frame_001.png", + ... prompt="A sunset over mountains" + ... ) + """ + + def __init__( + self, + llm_service=None, + config: Optional[HybridQualityConfig] = None + ): + parent_config = config or HybridQualityConfig() + super().__init__(llm_service, parent_config) + + self.hybrid_config = parent_config + self._feature_extractor = None + self._metrics_calculator = None + self._vlm_evaluator = None + + @property + def feature_extractor(self): + """Lazy-load feature extractor""" + if self._feature_extractor is None: + from pixelle_video.services.quality.feature_extractor import ( + FeatureExtractor, FeatureExtractorConfig + ) + self._feature_extractor = FeatureExtractor( + FeatureExtractorConfig( + model_name=self.hybrid_config.clip_model, + cache_features=self.hybrid_config.cache_features + ) + ) + return self._feature_extractor + + @property + def metrics_calculator(self): + """Lazy-load metrics calculator""" + if self._metrics_calculator is None: + from pixelle_video.services.quality.objective_metrics import ( + ObjectiveMetricsCalculator + ) + self._metrics_calculator = ObjectiveMetricsCalculator( + sharpness_threshold=self.hybrid_config.sharpness_threshold + ) + return self._metrics_calculator + + @property + def vlm_evaluator(self): + """Lazy-load VLM evaluator""" + if self._vlm_evaluator is None: + from pixelle_video.services.quality.vlm_evaluator import VLMEvaluator + self._vlm_evaluator = VLMEvaluator(self.llm_service) + return self._vlm_evaluator + + async def evaluate_image( + self, + image_path: str, + prompt: str, + narration: Optional[str] = None, + ) -> QualityScore: + """Evaluate image quality using hybrid approach""" + start_time = time.time() + issues = [] + + if not Path(image_path).exists(): + return QualityScore( + passed=False, + issues=["Image file not found"], + evaluation_time_ms=(time.time() - start_time) * 1000 + ) + + # Step 1: Technical metrics (fast, local) + technical_score = 0.7 + technical_metrics = None + + if self.hybrid_config.enable_technical_metrics: + technical_metrics = self.metrics_calculator.analyze_image(image_path) + technical_score = technical_metrics.overall_technical + issues.extend(technical_metrics.issues) + + # Step 2: CLIP score (if available) + clip_score = None + text_match_score = 0.7 + + if self.hybrid_config.enable_clip_score: + clip_score = self.feature_extractor.calculate_clip_score( + image_path, prompt + ) + if clip_score is not None: + text_match_score = clip_score + + # Step 3: Determine if VLM needed + objective_score = (technical_score + text_match_score) / 2 + use_vlm = True + aesthetic_score = 0.7 + + if self.hybrid_config.enable_smart_skip: + if objective_score >= self.hybrid_config.smart_skip_threshold: + use_vlm = False + logger.debug(f"Smart skip: {objective_score:.2f} >= threshold") + + # Step 4: VLM evaluation (if needed) + if use_vlm and self.config.use_vlm_evaluation and self.llm_service: + vlm_result = await self.vlm_evaluator.evaluate_image( + image_path, prompt, narration + ) + aesthetic_score = vlm_result.aesthetic_score or 0.7 + + if clip_score is not None: + text_match_score = ( + clip_score * self.hybrid_config.clip_weight + + vlm_result.text_match_score * (1 - self.hybrid_config.clip_weight) + ) + else: + text_match_score = vlm_result.text_match_score or 0.7 + + issues.extend(vlm_result.issues) + + # Step 5: Calculate overall + overall = ( + aesthetic_score * self.config.aesthetic_weight + + text_match_score * self.config.text_match_weight + + technical_score * self.config.technical_weight + ) + + score = QualityScore( + aesthetic_score=aesthetic_score, + text_match_score=text_match_score, + technical_score=technical_score, + overall_score=overall, + issues=issues, + evaluation_time_ms=(time.time() - start_time) * 1000 + ) + + score.passed = overall >= self.config.overall_threshold + + logger.debug( + f"Hybrid eval: overall={overall:.2f}, clip={clip_score}, " + f"vlm_used={use_vlm}, time={score.evaluation_time_ms:.0f}ms" + ) + + return score diff --git a/pixelle_video/services/quality/vlm_evaluator.py b/pixelle_video/services/quality/vlm_evaluator.py new file mode 100644 index 0000000..3070d3d --- /dev/null +++ b/pixelle_video/services/quality/vlm_evaluator.py @@ -0,0 +1,243 @@ +# Copyright (C) 2025 AIDC-AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +VLMEvaluator - Vision Language Model based image quality evaluation + +Supports multiple VLM providers: +- OpenAI: gpt-4-vision-preview, gpt-4o +- Qwen-VL: qwen-vl-max, qwen-vl-plus +- GLM-4V: via OpenAI compatible API +""" + +import base64 +import json +import re +from dataclasses import dataclass, field +from typing import Optional, List +from pathlib import Path + +from loguru import logger + + +@dataclass +class VLMEvaluationResult: + """Result from VLM evaluation""" + aesthetic_score: float = 0.0 + text_match_score: float = 0.0 + technical_score: float = 0.0 + issues: List[str] = field(default_factory=list) + raw_response: Optional[str] = None + + def to_dict(self) -> dict: + return { + "aesthetic_score": self.aesthetic_score, + "text_match_score": self.text_match_score, + "technical_score": self.technical_score, + "issues": self.issues, + } + + +@dataclass +class VLMEvaluatorConfig: + """Configuration for VLM evaluator""" + provider: str = "auto" # "openai", "qwen", "auto" + model: Optional[str] = None # Auto-select if None + max_image_size: int = 1024 # Max image dimension + timeout: int = 30 + temperature: float = 0.1 # Low for consistent evaluation + + +class VLMEvaluator: + """ + VLM-based image quality evaluator + + Example: + >>> evaluator = VLMEvaluator(llm_service) + >>> result = await evaluator.evaluate_image( + ... image_path="frame_001.png", + ... prompt="A sunset over mountains" + ... ) + """ + + EVALUATION_PROMPT = """请评估这张AI生成的图片质量。 + +生成提示词: {prompt} +{narration_section} + +请从以下三个维度评分(0.0-1.0): + +1. **美学质量** (aesthetic_score): 构图、色彩搭配、视觉吸引力 +2. **图文匹配** (text_match_score): 图片与提示词的语义对齐程度 +3. **技术质量** (technical_score): 清晰度、无伪影、无变形 + +同时列出发现的问题(如有)。 + +请以JSON格式返回: +```json +{{ + "aesthetic_score": 0.0-1.0, + "text_match_score": 0.0-1.0, + "technical_score": 0.0-1.0, + "issues": ["问题1", "问题2"] +}} +```""" + + def __init__( + self, + llm_service=None, + config: Optional[VLMEvaluatorConfig] = None + ): + self.llm_service = llm_service + self.config = config or VLMEvaluatorConfig() + + def _encode_image_base64(self, image_path: str) -> str: + """Encode image to base64, with optional resizing""" + from PIL import Image + import io + + with Image.open(image_path) as img: + # Resize if too large + max_size = self.config.max_image_size + if max(img.size) > max_size: + ratio = max_size / max(img.size) + new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio)) + img = img.resize(new_size, Image.Resampling.LANCZOS) + + # Convert to RGB if needed + if img.mode in ('RGBA', 'P'): + img = img.convert('RGB') + + # Encode to base64 + buffer = io.BytesIO() + img.save(buffer, format='JPEG', quality=85) + return base64.b64encode(buffer.getvalue()).decode('utf-8') + + def _parse_response(self, response: str) -> VLMEvaluationResult: + """Parse VLM response to extract scores""" + result = VLMEvaluationResult(raw_response=response) + + try: + # Try to extract JSON from response + json_match = re.search(r'```json\s*([\s\S]*?)\s*```', response) + if json_match: + json_str = json_match.group(1) + else: + # Try to find raw JSON + brace_start = response.find('{') + brace_end = response.rfind('}') + if brace_start != -1 and brace_end > brace_start: + json_str = response[brace_start:brace_end + 1] + else: + logger.warning("No JSON found in VLM response") + return result + + data = json.loads(json_str) + + result.aesthetic_score = float(data.get('aesthetic_score', 0.0)) + result.text_match_score = float(data.get('text_match_score', 0.0)) + result.technical_score = float(data.get('technical_score', 0.0)) + result.issues = data.get('issues', []) + + # Clamp scores to valid range + result.aesthetic_score = max(0.0, min(1.0, result.aesthetic_score)) + result.text_match_score = max(0.0, min(1.0, result.text_match_score)) + result.technical_score = max(0.0, min(1.0, result.technical_score)) + + except (json.JSONDecodeError, ValueError) as e: + logger.warning(f"Failed to parse VLM response: {e}") + + return result + + async def evaluate_image( + self, + image_path: str, + prompt: str, + narration: Optional[str] = None + ) -> VLMEvaluationResult: + """ + Evaluate image quality using VLM + + Args: + image_path: Path to image file + prompt: Generation prompt + narration: Optional narration text + + Returns: + VLMEvaluationResult with scores + """ + if not Path(image_path).exists(): + return VLMEvaluationResult(issues=["Image file not found"]) + + if not self.llm_service: + logger.warning("No LLM service provided for VLM evaluation") + return VLMEvaluationResult(issues=["No LLM service"]) + + try: + # Encode image + image_b64 = self._encode_image_base64(image_path) + + # Build prompt + narration_section = f"旁白文案: {narration}" if narration else "" + eval_prompt = self.EVALUATION_PROMPT.format( + prompt=prompt, + narration_section=narration_section + ) + + # Call VLM via LLM service with vision + response = await self._call_vlm(image_b64, eval_prompt) + + return self._parse_response(response) + + except Exception as e: + logger.error(f"VLM evaluation failed: {e}") + return VLMEvaluationResult(issues=[f"Evaluation error: {str(e)}"]) + + async def _call_vlm(self, image_b64: str, prompt: str) -> str: + """Call VLM with image and prompt""" + from openai import AsyncOpenAI + + # Get config from LLM service + base_url = self.llm_service._get_config_value("base_url") + api_key = self.llm_service._get_config_value("api_key") + model = self.config.model or self.llm_service._get_config_value("model") + + client = AsyncOpenAI(api_key=api_key, base_url=base_url) + + # Build message with image + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_b64}" + } + }, + { + "type": "text", + "text": prompt + } + ] + } + ] + + response = await client.chat.completions.create( + model=model, + messages=messages, + temperature=self.config.temperature, + max_tokens=500, + timeout=self.config.timeout + ) + + return response.choices[0].message.content diff --git a/pyproject.toml b/pyproject.toml index 0826c7b..7b018a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,11 @@ dev = [ "pytest-asyncio>=0.23.0", "ruff>=0.6.0", ] +quality = [ + "torch>=2.0.0", + "ftfy>=6.1.0", + "regex>=2023.0.0", +] [project.scripts] pixelle-video = "pixelle_video.cli:main" diff --git a/tests/test_feature_extractor.py b/tests/test_feature_extractor.py new file mode 100644 index 0000000..d760913 --- /dev/null +++ b/tests/test_feature_extractor.py @@ -0,0 +1,47 @@ +# Copyright (C) 2025 AIDC-AI +# Tests for FeatureExtractor + +import pytest + +from pixelle_video.services.quality.feature_extractor import ( + FeatureExtractor, + FeatureExtractorConfig, +) + + +class TestFeatureExtractorConfig: + """Tests for FeatureExtractorConfig""" + + def test_default_values(self): + config = FeatureExtractorConfig() + assert config.model_name == "ViT-B/32" + assert config.device == "auto" + assert config.cache_features is True + + +class TestFeatureExtractor: + """Tests for FeatureExtractor""" + + def test_init_default(self): + extractor = FeatureExtractor() + assert extractor.config.model_name == "ViT-B/32" + + def test_is_available_check(self): + """Test availability check (may be True or False)""" + extractor = FeatureExtractor() + # Just check it returns a boolean + assert isinstance(extractor.is_available, bool) + + def test_extract_without_clip(self): + """Test graceful degradation when CLIP unavailable""" + extractor = FeatureExtractor() + if not extractor.is_available: + result = extractor.extract_image_features("/fake/path.png") + assert result is None + + def test_clear_cache(self): + """Test cache clearing""" + extractor = FeatureExtractor() + extractor._feature_cache["test"] = "value" + extractor.clear_cache() + assert len(extractor._feature_cache) == 0 diff --git a/tests/test_hybrid_quality_gate.py b/tests/test_hybrid_quality_gate.py new file mode 100644 index 0000000..c23c0e5 --- /dev/null +++ b/tests/test_hybrid_quality_gate.py @@ -0,0 +1,53 @@ +# Copyright (C) 2025 AIDC-AI +# Tests for HybridQualityGate + +import pytest + +from pixelle_video.services.quality.quality_gate import ( + QualityGate, + HybridQualityGate, + HybridQualityConfig, +) +from pixelle_video.services.quality.models import QualityScore + + +class TestHybridQualityConfig: + """Tests for HybridQualityConfig""" + + def test_default_values(self): + config = HybridQualityConfig() + assert config.enable_clip_score is True + assert config.enable_smart_skip is True + assert config.smart_skip_threshold == 0.75 + + def test_inherits_quality_config(self): + config = HybridQualityConfig() + assert hasattr(config, "overall_threshold") + assert config.overall_threshold == 0.6 + + +class TestHybridQualityGate: + """Tests for HybridQualityGate""" + + def test_init_default(self): + gate = HybridQualityGate() + assert gate.hybrid_config is not None + + def test_inherits_quality_gate(self): + gate = HybridQualityGate() + assert isinstance(gate, QualityGate) + + def test_lazy_load_metrics_calculator(self): + gate = HybridQualityGate() + calc = gate.metrics_calculator + assert calc is not None + + @pytest.mark.asyncio + async def test_evaluate_nonexistent_image(self): + gate = HybridQualityGate() + score = await gate.evaluate_image( + "/nonexistent/path.png", + "test prompt" + ) + assert score.passed is False + assert "not found" in score.issues[0].lower() diff --git a/tests/test_objective_metrics.py b/tests/test_objective_metrics.py new file mode 100644 index 0000000..8843032 --- /dev/null +++ b/tests/test_objective_metrics.py @@ -0,0 +1,64 @@ +# Copyright (C) 2025 AIDC-AI +# Tests for ObjectiveMetricsCalculator + +import pytest +from pathlib import Path + +from pixelle_video.services.quality.objective_metrics import ( + ObjectiveMetricsCalculator, + TechnicalMetrics, +) + + +class TestTechnicalMetrics: + """Tests for TechnicalMetrics dataclass""" + + def test_default_values(self): + metrics = TechnicalMetrics() + assert metrics.sharpness_score == 0.0 + assert metrics.overall_technical == 0.0 + assert metrics.issues == [] + + def test_to_dict(self): + metrics = TechnicalMetrics( + sharpness_score=0.8, + brightness_score=0.5, + issues=["test issue"] + ) + d = metrics.to_dict() + assert d["sharpness_score"] == 0.8 + assert "test issue" in d["issues"] + + +class TestObjectiveMetricsCalculator: + """Tests for ObjectiveMetricsCalculator""" + + def test_init_default(self): + calc = ObjectiveMetricsCalculator() + assert calc.sharpness_threshold == 0.3 + + def test_init_custom(self): + calc = ObjectiveMetricsCalculator(sharpness_threshold=0.5) + assert calc.sharpness_threshold == 0.5 + + def test_analyze_nonexistent_image(self): + calc = ObjectiveMetricsCalculator() + metrics = calc.analyze_image("/nonexistent/path.png") + assert len(metrics.issues) > 0 + assert "failed" in metrics.issues[0].lower() + + def test_analyze_real_image(self, tmp_path): + """Test with a real image file""" + from PIL import Image + + # Create test image + img = Image.new("RGB", (256, 256), color=(128, 128, 128)) + img_path = tmp_path / "test.png" + img.save(img_path) + + calc = ObjectiveMetricsCalculator() + metrics = calc.analyze_image(str(img_path)) + + assert 0.0 <= metrics.sharpness_score <= 1.0 + assert 0.0 <= metrics.brightness_score <= 1.0 + assert 0.0 <= metrics.overall_technical <= 1.0