feat: Add hybrid quality evaluation system with CLIP and VLM support
- Add FeatureExtractor for CLIP-based image/text feature extraction - Add ObjectiveMetricsCalculator for technical quality metrics - Add VLMEvaluator for vision language model evaluation - Add HybridQualityGate combining objective + VLM evaluation - Enhance CharacterMemory with visual feature support - Add quality optional dependency (torch, ftfy, regex) - Add unit tests for new modules 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -65,3 +65,22 @@ template:
|
||||
# - 1920x1080 (horizontal/landscape): image_film.html, image_full.html, etc.
|
||||
# See templates/ directory for all available templates
|
||||
default_template: "1080x1920/image_default.html"
|
||||
|
||||
# ==================== Quality Control Configuration ====================
|
||||
# Configure quality evaluation for generated content
|
||||
quality:
|
||||
# Enable quality checking (set to false to skip all quality checks)
|
||||
enable_quality_check: true
|
||||
|
||||
# Hybrid evaluation settings
|
||||
hybrid:
|
||||
enable_clip_score: true # Use CLIP for image-text matching
|
||||
clip_model: "ViT-B/32" # CLIP model variant
|
||||
enable_technical_metrics: true # Use technical quality metrics
|
||||
enable_smart_skip: true # Skip VLM when objective scores are good
|
||||
smart_skip_threshold: 0.75 # Threshold for smart skip
|
||||
|
||||
# Character consistency settings
|
||||
character:
|
||||
enable_visual_features: true # Enable CLIP visual features for characters
|
||||
visual_similarity_threshold: 0.75 # Min similarity for character consistency
|
||||
|
||||
@@ -28,7 +28,11 @@ from pixelle_video.services.quality.models import (
|
||||
RetryConfig,
|
||||
QualityError,
|
||||
)
|
||||
from pixelle_video.services.quality.quality_gate import QualityGate
|
||||
from pixelle_video.services.quality.quality_gate import (
|
||||
QualityGate,
|
||||
HybridQualityGate,
|
||||
HybridQualityConfig,
|
||||
)
|
||||
from pixelle_video.services.quality.retry_manager import RetryManager
|
||||
from pixelle_video.services.quality.output_validator import (
|
||||
OutputValidator,
|
||||
@@ -52,6 +56,19 @@ from pixelle_video.services.quality.character_memory import (
|
||||
Character,
|
||||
CharacterType,
|
||||
)
|
||||
from pixelle_video.services.quality.feature_extractor import (
|
||||
FeatureExtractor,
|
||||
FeatureExtractorConfig,
|
||||
)
|
||||
from pixelle_video.services.quality.objective_metrics import (
|
||||
ObjectiveMetricsCalculator,
|
||||
TechnicalMetrics,
|
||||
)
|
||||
from pixelle_video.services.quality.vlm_evaluator import (
|
||||
VLMEvaluator,
|
||||
VLMEvaluatorConfig,
|
||||
VLMEvaluationResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Quality evaluation
|
||||
@@ -60,6 +77,8 @@ __all__ = [
|
||||
"RetryConfig",
|
||||
"QualityError",
|
||||
"QualityGate",
|
||||
"HybridQualityGate",
|
||||
"HybridQualityConfig",
|
||||
"RetryManager",
|
||||
# Output validation
|
||||
"OutputValidator",
|
||||
@@ -79,6 +98,16 @@ __all__ = [
|
||||
"CharacterMemoryConfig",
|
||||
"Character",
|
||||
"CharacterType",
|
||||
# Feature extraction (NEW)
|
||||
"FeatureExtractor",
|
||||
"FeatureExtractorConfig",
|
||||
# Objective metrics (NEW)
|
||||
"ObjectiveMetricsCalculator",
|
||||
"TechnicalMetrics",
|
||||
# VLM evaluation (NEW)
|
||||
"VLMEvaluator",
|
||||
"VLMEvaluatorConfig",
|
||||
"VLMEvaluationResult",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -21,10 +21,11 @@ Maintains consistent character appearance across video frames by:
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Optional, Set
|
||||
from typing import List, Dict, Optional, Set, Tuple, Any
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@@ -71,6 +72,12 @@ class Character:
|
||||
appearance_frames: List[int] = field(default_factory=list) # All frames with this character
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
# Visual features (NEW - for cross-frame consistency)
|
||||
visual_features: Optional[Any] = None # CLIP feature vector (np.ndarray)
|
||||
feature_extraction_frame: Optional[int] = None # Frame where features were extracted
|
||||
similarity_history: List[float] = field(default_factory=list) # Similarity scores history
|
||||
min_similarity_threshold: float = 0.75 # Minimum similarity for consistency
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.now()
|
||||
@@ -112,6 +119,23 @@ class Character:
|
||||
return True
|
||||
return any(alias.lower() == name_lower for alias in self.aliases)
|
||||
|
||||
def set_visual_features(self, features: np.ndarray, frame_index: int):
|
||||
"""Set visual features from reference frame"""
|
||||
self.visual_features = features
|
||||
self.feature_extraction_frame = frame_index
|
||||
|
||||
def check_visual_similarity(self, other_features: np.ndarray) -> Tuple[bool, float]:
|
||||
"""Check if other features match this character"""
|
||||
if self.visual_features is None:
|
||||
return True, 1.0
|
||||
|
||||
similarity = float(np.dot(self.visual_features, other_features))
|
||||
similarity = (similarity + 1) / 2 # Normalize to 0-1
|
||||
|
||||
self.similarity_history.append(similarity)
|
||||
is_match = similarity >= self.min_similarity_threshold
|
||||
return is_match, similarity
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
@@ -149,6 +173,11 @@ class CharacterMemoryConfig:
|
||||
include_clothing: bool = True # Include clothing in prompts
|
||||
include_features: bool = True # Include distinctive features
|
||||
|
||||
# Visual feature settings (NEW)
|
||||
enable_visual_features: bool = True # Enable CLIP visual features
|
||||
visual_similarity_threshold: float = 0.75 # Min similarity for consistency
|
||||
extract_features_on_first: bool = True # Extract features on first appearance
|
||||
|
||||
|
||||
class CharacterMemory:
|
||||
"""
|
||||
@@ -191,6 +220,7 @@ class CharacterMemory:
|
||||
self.config = config or CharacterMemoryConfig()
|
||||
self._characters: Dict[str, Character] = {}
|
||||
self._name_index: Dict[str, str] = {} # name -> character_id mapping
|
||||
self._feature_extractor = None # Lazy-loaded
|
||||
|
||||
def register_character(
|
||||
self,
|
||||
@@ -528,3 +558,60 @@ class CharacterMemory:
|
||||
self._characters.clear()
|
||||
self._name_index.clear()
|
||||
logger.info("Character memory cleared")
|
||||
|
||||
@property
|
||||
def feature_extractor(self):
|
||||
"""Lazy-load feature extractor"""
|
||||
if self._feature_extractor is None and self.config.enable_visual_features:
|
||||
from pixelle_video.services.quality.feature_extractor import FeatureExtractor
|
||||
self._feature_extractor = FeatureExtractor()
|
||||
return self._feature_extractor
|
||||
|
||||
async def extract_character_features(
|
||||
self,
|
||||
character_name: str,
|
||||
image_path: str,
|
||||
frame_index: int = 0
|
||||
) -> bool:
|
||||
"""Extract and store visual features for a character"""
|
||||
if not self.config.enable_visual_features:
|
||||
return False
|
||||
|
||||
char = self.get_character(character_name)
|
||||
if not char:
|
||||
logger.warning(f"Character not found: {character_name}")
|
||||
return False
|
||||
|
||||
extractor = self.feature_extractor
|
||||
if extractor is None or not extractor.is_available:
|
||||
logger.debug("Feature extractor not available")
|
||||
return False
|
||||
|
||||
features = extractor.extract_image_features(image_path)
|
||||
if features is None:
|
||||
return False
|
||||
|
||||
char.set_visual_features(features, frame_index)
|
||||
char.add_reference_image(image_path, set_as_primary=True)
|
||||
logger.info(f"Extracted visual features for {character_name}")
|
||||
return True
|
||||
|
||||
async def check_character_consistency(
|
||||
self,
|
||||
character_name: str,
|
||||
image_path: str
|
||||
) -> Tuple[bool, float]:
|
||||
"""Check if image maintains character consistency"""
|
||||
char = self.get_character(character_name)
|
||||
if not char or char.visual_features is None:
|
||||
return True, 1.0
|
||||
|
||||
extractor = self.feature_extractor
|
||||
if extractor is None or not extractor.is_available:
|
||||
return True, 1.0
|
||||
|
||||
new_features = extractor.extract_image_features(image_path)
|
||||
if new_features is None:
|
||||
return True, 1.0
|
||||
|
||||
return char.check_visual_similarity(new_features)
|
||||
|
||||
261
pixelle_video/services/quality/feature_extractor.py
Normal file
261
pixelle_video/services/quality/feature_extractor.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
FeatureExtractor - CLIP-based visual feature extraction
|
||||
|
||||
Provides:
|
||||
1. Image-to-vector encoding for similarity comparison
|
||||
2. Text-to-vector encoding for CLIP score calculation
|
||||
3. Lazy loading of CLIP model (optional dependency)
|
||||
|
||||
Note: CLIP is an optional dependency. Install with:
|
||||
pip install pixelle-video[quality]
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureExtractorConfig:
|
||||
"""Configuration for feature extraction"""
|
||||
|
||||
# Model settings
|
||||
model_name: str = "ViT-B/32"
|
||||
device: str = "auto" # "auto", "cpu", "cuda", "mps"
|
||||
|
||||
# Performance settings
|
||||
batch_size: int = 8
|
||||
cache_features: bool = True
|
||||
|
||||
# Similarity thresholds
|
||||
character_similarity_threshold: float = 0.75
|
||||
style_similarity_threshold: float = 0.70
|
||||
|
||||
|
||||
class FeatureExtractor:
|
||||
"""
|
||||
CLIP-based feature extraction for quality evaluation
|
||||
|
||||
Features:
|
||||
- Lazy loading: CLIP model only loaded when first needed
|
||||
- Graceful degradation: Returns None if CLIP unavailable
|
||||
- Caching: Optional feature caching for performance
|
||||
|
||||
Example:
|
||||
>>> extractor = FeatureExtractor()
|
||||
>>> if extractor.is_available:
|
||||
... score = extractor.calculate_clip_score(
|
||||
... image_path="frame_001.png",
|
||||
... text="A sunset over mountains"
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[FeatureExtractorConfig] = None):
|
||||
self.config = config or FeatureExtractorConfig()
|
||||
self._model = None
|
||||
self._preprocess = None
|
||||
self._device = None
|
||||
self._available: Optional[bool] = None
|
||||
self._feature_cache: dict = {}
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Check if CLIP is available (lazy check)"""
|
||||
if self._available is None:
|
||||
self._available = self._check_availability()
|
||||
return self._available
|
||||
|
||||
def _check_availability(self) -> bool:
|
||||
"""Check if CLIP dependencies are installed"""
|
||||
try:
|
||||
import torch
|
||||
import clip
|
||||
return True
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"CLIP not available. Install with: "
|
||||
"pip install torch clip-by-openai"
|
||||
)
|
||||
return False
|
||||
|
||||
def _load_model(self):
|
||||
"""Lazy load CLIP model"""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
if not self.is_available:
|
||||
return
|
||||
|
||||
import torch
|
||||
import clip
|
||||
|
||||
# Determine device
|
||||
if self.config.device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
self._device = "cuda"
|
||||
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
self._device = "mps"
|
||||
else:
|
||||
self._device = "cpu"
|
||||
else:
|
||||
self._device = self.config.device
|
||||
|
||||
logger.info(f"Loading CLIP model {self.config.model_name} on {self._device}")
|
||||
self._model, self._preprocess = clip.load(
|
||||
self.config.model_name,
|
||||
device=self._device
|
||||
)
|
||||
logger.info("CLIP model loaded successfully")
|
||||
|
||||
def extract_image_features(
|
||||
self,
|
||||
image_path: Union[str, Path]
|
||||
) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Extract CLIP features from an image
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
|
||||
Returns:
|
||||
Normalized feature vector (512-dim for ViT-B/32) or None
|
||||
"""
|
||||
if not self.is_available:
|
||||
return None
|
||||
|
||||
self._load_model()
|
||||
|
||||
# Check cache
|
||||
cache_key = str(image_path)
|
||||
if self.config.cache_features and cache_key in self._feature_cache:
|
||||
return self._feature_cache[cache_key]
|
||||
|
||||
try:
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
image_input = self._preprocess(image).unsqueeze(0).to(self._device)
|
||||
|
||||
with torch.no_grad():
|
||||
features = self._model.encode_image(image_input)
|
||||
features = features / features.norm(dim=-1, keepdim=True)
|
||||
features = features.cpu().numpy().flatten()
|
||||
|
||||
# Cache result
|
||||
if self.config.cache_features:
|
||||
self._feature_cache[cache_key] = features
|
||||
|
||||
return features
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract image features: {e}")
|
||||
return None
|
||||
|
||||
def extract_text_features(self, text: str) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Extract CLIP features from text
|
||||
|
||||
Args:
|
||||
text: Text to encode
|
||||
|
||||
Returns:
|
||||
Normalized feature vector or None
|
||||
"""
|
||||
if not self.is_available:
|
||||
return None
|
||||
|
||||
self._load_model()
|
||||
|
||||
try:
|
||||
import torch
|
||||
import clip
|
||||
|
||||
# Truncate text if too long (CLIP max is 77 tokens)
|
||||
text = text[:300]
|
||||
|
||||
text_input = clip.tokenize([text]).to(self._device)
|
||||
|
||||
with torch.no_grad():
|
||||
features = self._model.encode_text(text_input)
|
||||
features = features / features.norm(dim=-1, keepdim=True)
|
||||
features = features.cpu().numpy().flatten()
|
||||
|
||||
return features
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract text features: {e}")
|
||||
return None
|
||||
|
||||
def calculate_clip_score(
|
||||
self,
|
||||
image_path: Union[str, Path],
|
||||
text: str
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Calculate CLIP score (image-text similarity)
|
||||
|
||||
Args:
|
||||
image_path: Path to image
|
||||
text: Text prompt to compare
|
||||
|
||||
Returns:
|
||||
Similarity score (0.0-1.0) or None if unavailable
|
||||
"""
|
||||
image_features = self.extract_image_features(image_path)
|
||||
text_features = self.extract_text_features(text)
|
||||
|
||||
if image_features is None or text_features is None:
|
||||
return None
|
||||
|
||||
# Cosine similarity (features are already normalized)
|
||||
similarity = float(np.dot(image_features, text_features))
|
||||
|
||||
# Convert from [-1, 1] to [0, 1] range
|
||||
score = (similarity + 1) / 2
|
||||
|
||||
return score
|
||||
|
||||
def calculate_image_similarity(
|
||||
self,
|
||||
image_path_1: Union[str, Path],
|
||||
image_path_2: Union[str, Path]
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Calculate similarity between two images
|
||||
|
||||
Args:
|
||||
image_path_1: Path to first image
|
||||
image_path_2: Path to second image
|
||||
|
||||
Returns:
|
||||
Similarity score (0.0-1.0) or None
|
||||
"""
|
||||
features_1 = self.extract_image_features(image_path_1)
|
||||
features_2 = self.extract_image_features(image_path_2)
|
||||
|
||||
if features_1 is None or features_2 is None:
|
||||
return None
|
||||
|
||||
similarity = float(np.dot(features_1, features_2))
|
||||
return (similarity + 1) / 2
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear feature cache"""
|
||||
self._feature_cache.clear()
|
||||
logger.debug("Feature cache cleared")
|
||||
@@ -15,7 +15,7 @@ Data models for quality assurance
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Any
|
||||
from enum import Enum
|
||||
|
||||
|
||||
@@ -44,6 +44,11 @@ class QualityScore:
|
||||
issues: List[str] = field(default_factory=list) # Detected problems
|
||||
evaluation_time_ms: float = 0.0 # Time taken for evaluation
|
||||
|
||||
# Extended metrics (NEW - for HybridQualityGate)
|
||||
clip_score: Optional[float] = None # CLIP image-text similarity
|
||||
technical_metrics: Optional[Any] = None # TechnicalMetrics object
|
||||
vlm_used: bool = True # Whether VLM was used for evaluation
|
||||
|
||||
@property
|
||||
def level(self) -> QualityLevel:
|
||||
"""Get quality level based on overall score"""
|
||||
|
||||
232
pixelle_video/services/quality/objective_metrics.py
Normal file
232
pixelle_video/services/quality/objective_metrics.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
ObjectiveMetrics - Technical image quality metrics
|
||||
|
||||
Provides fast, local computation of:
|
||||
1. Sharpness/clarity detection
|
||||
2. Color distribution analysis
|
||||
3. Exposure detection (over/under)
|
||||
|
||||
Only depends on PIL and numpy (no heavy ML dependencies).
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class TechnicalMetrics:
|
||||
"""Technical quality metrics for an image"""
|
||||
|
||||
# Sharpness (0.0-1.0, higher = sharper)
|
||||
sharpness_score: float = 0.0
|
||||
|
||||
# Color metrics
|
||||
color_variance: float = 0.0
|
||||
saturation_score: float = 0.0
|
||||
|
||||
# Exposure metrics
|
||||
brightness_score: float = 0.0
|
||||
is_overexposed: bool = False
|
||||
is_underexposed: bool = False
|
||||
|
||||
# Overall technical score
|
||||
overall_technical: float = 0.0
|
||||
|
||||
# Issues detected
|
||||
issues: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary"""
|
||||
return {
|
||||
"sharpness_score": self.sharpness_score,
|
||||
"color_variance": self.color_variance,
|
||||
"saturation_score": self.saturation_score,
|
||||
"brightness_score": self.brightness_score,
|
||||
"is_overexposed": self.is_overexposed,
|
||||
"is_underexposed": self.is_underexposed,
|
||||
"overall_technical": self.overall_technical,
|
||||
"issues": self.issues,
|
||||
}
|
||||
|
||||
|
||||
class ObjectiveMetricsCalculator:
|
||||
"""
|
||||
Calculate objective technical quality metrics
|
||||
|
||||
All computations are local and fast (no external API calls).
|
||||
Uses PIL and numpy only.
|
||||
|
||||
Example:
|
||||
>>> calculator = ObjectiveMetricsCalculator()
|
||||
>>> metrics = calculator.analyze_image("frame_001.png")
|
||||
>>> print(f"Sharpness: {metrics.sharpness_score:.2f}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sharpness_threshold: float = 0.3,
|
||||
overexposure_threshold: float = 0.85,
|
||||
underexposure_threshold: float = 0.15,
|
||||
):
|
||||
self.sharpness_threshold = sharpness_threshold
|
||||
self.overexposure_threshold = overexposure_threshold
|
||||
self.underexposure_threshold = underexposure_threshold
|
||||
|
||||
def analyze_image(self, image_path: str) -> TechnicalMetrics:
|
||||
"""
|
||||
Analyze image and return technical metrics
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
|
||||
Returns:
|
||||
TechnicalMetrics with all computed values
|
||||
"""
|
||||
try:
|
||||
from PIL import Image
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
img_rgb = img.convert("RGB")
|
||||
img_array = np.array(img_rgb)
|
||||
|
||||
# Calculate individual metrics
|
||||
sharpness = self._calculate_sharpness(img_rgb)
|
||||
color_var, saturation = self._calculate_color_metrics(img_array)
|
||||
brightness, overexp, underexp = self._calculate_exposure(img_array)
|
||||
|
||||
# Detect issues
|
||||
issues = []
|
||||
if sharpness < self.sharpness_threshold:
|
||||
issues.append("Image appears blurry")
|
||||
if overexp:
|
||||
issues.append("Image is overexposed")
|
||||
if underexp:
|
||||
issues.append("Image is underexposed")
|
||||
if color_var < 0.1:
|
||||
issues.append("Low color diversity")
|
||||
|
||||
# Calculate overall technical score
|
||||
overall = self._calculate_overall(
|
||||
sharpness, color_var, saturation, brightness
|
||||
)
|
||||
|
||||
return TechnicalMetrics(
|
||||
sharpness_score=sharpness,
|
||||
color_variance=color_var,
|
||||
saturation_score=saturation,
|
||||
brightness_score=brightness,
|
||||
is_overexposed=overexp,
|
||||
is_underexposed=underexp,
|
||||
overall_technical=overall,
|
||||
issues=issues
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze image: {e}")
|
||||
return TechnicalMetrics(issues=[f"Analysis failed: {str(e)}"])
|
||||
|
||||
def _calculate_sharpness(self, img) -> float:
|
||||
"""Calculate sharpness using edge detection"""
|
||||
from PIL import ImageFilter
|
||||
|
||||
# Convert to grayscale
|
||||
gray = img.convert("L")
|
||||
|
||||
# Apply edge detection
|
||||
edges = gray.filter(ImageFilter.FIND_EDGES)
|
||||
edge_array = np.array(edges).astype(np.float32)
|
||||
|
||||
# Variance of edge response
|
||||
variance = np.var(edge_array)
|
||||
|
||||
# Normalize to 0-1 (empirical scaling)
|
||||
sharpness = min(1.0, variance / 2000)
|
||||
|
||||
return float(sharpness)
|
||||
|
||||
def _calculate_color_metrics(
|
||||
self,
|
||||
img_array: np.ndarray
|
||||
) -> Tuple[float, float]:
|
||||
"""Calculate color variance and saturation"""
|
||||
r, g, b = img_array[:, :, 0], img_array[:, :, 1], img_array[:, :, 2]
|
||||
|
||||
max_rgb = np.maximum(np.maximum(r, g), b)
|
||||
min_rgb = np.minimum(np.minimum(r, g), b)
|
||||
|
||||
# Saturation
|
||||
delta = max_rgb - min_rgb
|
||||
saturation = np.where(max_rgb > 0, delta / (max_rgb + 1e-7), 0)
|
||||
avg_saturation = np.mean(saturation)
|
||||
|
||||
# Color variance (diversity)
|
||||
color_std = np.std(img_array.reshape(-1, 3), axis=0)
|
||||
color_variance = np.mean(color_std) / 128
|
||||
|
||||
return float(color_variance), float(avg_saturation)
|
||||
|
||||
def _calculate_exposure(
|
||||
self,
|
||||
img_array: np.ndarray
|
||||
) -> Tuple[float, bool, bool]:
|
||||
"""Calculate brightness and detect exposure issues"""
|
||||
# Calculate luminance
|
||||
luminance = (
|
||||
0.299 * img_array[:, :, 0] +
|
||||
0.587 * img_array[:, :, 1] +
|
||||
0.114 * img_array[:, :, 2]
|
||||
) / 255.0
|
||||
|
||||
avg_brightness = float(np.mean(luminance))
|
||||
|
||||
# Check for over/under exposure
|
||||
overexposed_pixels = np.mean(luminance > 0.95)
|
||||
underexposed_pixels = np.mean(luminance < 0.05)
|
||||
|
||||
is_overexposed = (
|
||||
avg_brightness > self.overexposure_threshold or
|
||||
overexposed_pixels > 0.3
|
||||
)
|
||||
is_underexposed = (
|
||||
avg_brightness < self.underexposure_threshold or
|
||||
underexposed_pixels > 0.3
|
||||
)
|
||||
|
||||
return avg_brightness, is_overexposed, is_underexposed
|
||||
|
||||
def _calculate_overall(
|
||||
self,
|
||||
sharpness: float,
|
||||
color_var: float,
|
||||
saturation: float,
|
||||
brightness: float
|
||||
) -> float:
|
||||
"""Calculate overall technical score"""
|
||||
# Brightness penalty (ideal is 0.5)
|
||||
brightness_score = 1.0 - abs(brightness - 0.5) * 2
|
||||
brightness_score = max(0, brightness_score)
|
||||
|
||||
# Weighted combination
|
||||
overall = (
|
||||
sharpness * 0.4 +
|
||||
min(1.0, color_var * 2) * 0.2 +
|
||||
saturation * 0.2 +
|
||||
brightness_score * 0.2
|
||||
)
|
||||
|
||||
return float(overall)
|
||||
@@ -17,9 +17,12 @@ Evaluates images and videos based on:
|
||||
- Aesthetic quality (visual appeal)
|
||||
- Text-to-image matching (semantic alignment)
|
||||
- Technical quality (clarity, no artifacts)
|
||||
|
||||
Includes HybridQualityGate for combined objective + VLM evaluation.
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
@@ -28,6 +31,27 @@ from loguru import logger
|
||||
from pixelle_video.services.quality.models import QualityScore, QualityConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridQualityConfig(QualityConfig):
|
||||
"""Extended configuration for hybrid quality evaluation"""
|
||||
|
||||
# CLIP settings
|
||||
enable_clip_score: bool = True
|
||||
clip_model: str = "ViT-B/32"
|
||||
clip_weight: float = 0.5
|
||||
|
||||
# Technical metrics settings
|
||||
enable_technical_metrics: bool = True
|
||||
sharpness_threshold: float = 0.3
|
||||
|
||||
# Smart VLM skip
|
||||
enable_smart_skip: bool = True
|
||||
smart_skip_threshold: float = 0.75
|
||||
|
||||
# Feature caching
|
||||
cache_features: bool = True
|
||||
|
||||
|
||||
class QualityGate:
|
||||
"""
|
||||
Quality evaluation gate for AI-generated content
|
||||
@@ -361,3 +385,160 @@ Respond in JSON format:
|
||||
"issues": ["list of any problems found"]
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
class HybridQualityGate(QualityGate):
|
||||
"""
|
||||
Hybrid quality gate combining objective metrics with VLM evaluation
|
||||
|
||||
Evaluation flow:
|
||||
1. Calculate technical metrics (fast, local)
|
||||
2. Calculate CLIP score if enabled (local, requires CLIP)
|
||||
3. If smart_skip enabled and objective score >= threshold, skip VLM
|
||||
4. Otherwise, call VLM for subjective evaluation
|
||||
5. Combine scores with configurable weights
|
||||
|
||||
Example:
|
||||
>>> gate = HybridQualityGate(llm_service, config)
|
||||
>>> score = await gate.evaluate_image(
|
||||
... image_path="frame_001.png",
|
||||
... prompt="A sunset over mountains"
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_service=None,
|
||||
config: Optional[HybridQualityConfig] = None
|
||||
):
|
||||
parent_config = config or HybridQualityConfig()
|
||||
super().__init__(llm_service, parent_config)
|
||||
|
||||
self.hybrid_config = parent_config
|
||||
self._feature_extractor = None
|
||||
self._metrics_calculator = None
|
||||
self._vlm_evaluator = None
|
||||
|
||||
@property
|
||||
def feature_extractor(self):
|
||||
"""Lazy-load feature extractor"""
|
||||
if self._feature_extractor is None:
|
||||
from pixelle_video.services.quality.feature_extractor import (
|
||||
FeatureExtractor, FeatureExtractorConfig
|
||||
)
|
||||
self._feature_extractor = FeatureExtractor(
|
||||
FeatureExtractorConfig(
|
||||
model_name=self.hybrid_config.clip_model,
|
||||
cache_features=self.hybrid_config.cache_features
|
||||
)
|
||||
)
|
||||
return self._feature_extractor
|
||||
|
||||
@property
|
||||
def metrics_calculator(self):
|
||||
"""Lazy-load metrics calculator"""
|
||||
if self._metrics_calculator is None:
|
||||
from pixelle_video.services.quality.objective_metrics import (
|
||||
ObjectiveMetricsCalculator
|
||||
)
|
||||
self._metrics_calculator = ObjectiveMetricsCalculator(
|
||||
sharpness_threshold=self.hybrid_config.sharpness_threshold
|
||||
)
|
||||
return self._metrics_calculator
|
||||
|
||||
@property
|
||||
def vlm_evaluator(self):
|
||||
"""Lazy-load VLM evaluator"""
|
||||
if self._vlm_evaluator is None:
|
||||
from pixelle_video.services.quality.vlm_evaluator import VLMEvaluator
|
||||
self._vlm_evaluator = VLMEvaluator(self.llm_service)
|
||||
return self._vlm_evaluator
|
||||
|
||||
async def evaluate_image(
|
||||
self,
|
||||
image_path: str,
|
||||
prompt: str,
|
||||
narration: Optional[str] = None,
|
||||
) -> QualityScore:
|
||||
"""Evaluate image quality using hybrid approach"""
|
||||
start_time = time.time()
|
||||
issues = []
|
||||
|
||||
if not Path(image_path).exists():
|
||||
return QualityScore(
|
||||
passed=False,
|
||||
issues=["Image file not found"],
|
||||
evaluation_time_ms=(time.time() - start_time) * 1000
|
||||
)
|
||||
|
||||
# Step 1: Technical metrics (fast, local)
|
||||
technical_score = 0.7
|
||||
technical_metrics = None
|
||||
|
||||
if self.hybrid_config.enable_technical_metrics:
|
||||
technical_metrics = self.metrics_calculator.analyze_image(image_path)
|
||||
technical_score = technical_metrics.overall_technical
|
||||
issues.extend(technical_metrics.issues)
|
||||
|
||||
# Step 2: CLIP score (if available)
|
||||
clip_score = None
|
||||
text_match_score = 0.7
|
||||
|
||||
if self.hybrid_config.enable_clip_score:
|
||||
clip_score = self.feature_extractor.calculate_clip_score(
|
||||
image_path, prompt
|
||||
)
|
||||
if clip_score is not None:
|
||||
text_match_score = clip_score
|
||||
|
||||
# Step 3: Determine if VLM needed
|
||||
objective_score = (technical_score + text_match_score) / 2
|
||||
use_vlm = True
|
||||
aesthetic_score = 0.7
|
||||
|
||||
if self.hybrid_config.enable_smart_skip:
|
||||
if objective_score >= self.hybrid_config.smart_skip_threshold:
|
||||
use_vlm = False
|
||||
logger.debug(f"Smart skip: {objective_score:.2f} >= threshold")
|
||||
|
||||
# Step 4: VLM evaluation (if needed)
|
||||
if use_vlm and self.config.use_vlm_evaluation and self.llm_service:
|
||||
vlm_result = await self.vlm_evaluator.evaluate_image(
|
||||
image_path, prompt, narration
|
||||
)
|
||||
aesthetic_score = vlm_result.aesthetic_score or 0.7
|
||||
|
||||
if clip_score is not None:
|
||||
text_match_score = (
|
||||
clip_score * self.hybrid_config.clip_weight +
|
||||
vlm_result.text_match_score * (1 - self.hybrid_config.clip_weight)
|
||||
)
|
||||
else:
|
||||
text_match_score = vlm_result.text_match_score or 0.7
|
||||
|
||||
issues.extend(vlm_result.issues)
|
||||
|
||||
# Step 5: Calculate overall
|
||||
overall = (
|
||||
aesthetic_score * self.config.aesthetic_weight +
|
||||
text_match_score * self.config.text_match_weight +
|
||||
technical_score * self.config.technical_weight
|
||||
)
|
||||
|
||||
score = QualityScore(
|
||||
aesthetic_score=aesthetic_score,
|
||||
text_match_score=text_match_score,
|
||||
technical_score=technical_score,
|
||||
overall_score=overall,
|
||||
issues=issues,
|
||||
evaluation_time_ms=(time.time() - start_time) * 1000
|
||||
)
|
||||
|
||||
score.passed = overall >= self.config.overall_threshold
|
||||
|
||||
logger.debug(
|
||||
f"Hybrid eval: overall={overall:.2f}, clip={clip_score}, "
|
||||
f"vlm_used={use_vlm}, time={score.evaluation_time_ms:.0f}ms"
|
||||
)
|
||||
|
||||
return score
|
||||
|
||||
243
pixelle_video/services/quality/vlm_evaluator.py
Normal file
243
pixelle_video/services/quality/vlm_evaluator.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
VLMEvaluator - Vision Language Model based image quality evaluation
|
||||
|
||||
Supports multiple VLM providers:
|
||||
- OpenAI: gpt-4-vision-preview, gpt-4o
|
||||
- Qwen-VL: qwen-vl-max, qwen-vl-plus
|
||||
- GLM-4V: via OpenAI compatible API
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLMEvaluationResult:
|
||||
"""Result from VLM evaluation"""
|
||||
aesthetic_score: float = 0.0
|
||||
text_match_score: float = 0.0
|
||||
technical_score: float = 0.0
|
||||
issues: List[str] = field(default_factory=list)
|
||||
raw_response: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"aesthetic_score": self.aesthetic_score,
|
||||
"text_match_score": self.text_match_score,
|
||||
"technical_score": self.technical_score,
|
||||
"issues": self.issues,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLMEvaluatorConfig:
|
||||
"""Configuration for VLM evaluator"""
|
||||
provider: str = "auto" # "openai", "qwen", "auto"
|
||||
model: Optional[str] = None # Auto-select if None
|
||||
max_image_size: int = 1024 # Max image dimension
|
||||
timeout: int = 30
|
||||
temperature: float = 0.1 # Low for consistent evaluation
|
||||
|
||||
|
||||
class VLMEvaluator:
|
||||
"""
|
||||
VLM-based image quality evaluator
|
||||
|
||||
Example:
|
||||
>>> evaluator = VLMEvaluator(llm_service)
|
||||
>>> result = await evaluator.evaluate_image(
|
||||
... image_path="frame_001.png",
|
||||
... prompt="A sunset over mountains"
|
||||
... )
|
||||
"""
|
||||
|
||||
EVALUATION_PROMPT = """请评估这张AI生成的图片质量。
|
||||
|
||||
生成提示词: {prompt}
|
||||
{narration_section}
|
||||
|
||||
请从以下三个维度评分(0.0-1.0):
|
||||
|
||||
1. **美学质量** (aesthetic_score): 构图、色彩搭配、视觉吸引力
|
||||
2. **图文匹配** (text_match_score): 图片与提示词的语义对齐程度
|
||||
3. **技术质量** (technical_score): 清晰度、无伪影、无变形
|
||||
|
||||
同时列出发现的问题(如有)。
|
||||
|
||||
请以JSON格式返回:
|
||||
```json
|
||||
{{
|
||||
"aesthetic_score": 0.0-1.0,
|
||||
"text_match_score": 0.0-1.0,
|
||||
"technical_score": 0.0-1.0,
|
||||
"issues": ["问题1", "问题2"]
|
||||
}}
|
||||
```"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_service=None,
|
||||
config: Optional[VLMEvaluatorConfig] = None
|
||||
):
|
||||
self.llm_service = llm_service
|
||||
self.config = config or VLMEvaluatorConfig()
|
||||
|
||||
def _encode_image_base64(self, image_path: str) -> str:
|
||||
"""Encode image to base64, with optional resizing"""
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
# Resize if too large
|
||||
max_size = self.config.max_image_size
|
||||
if max(img.size) > max_size:
|
||||
ratio = max_size / max(img.size)
|
||||
new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
|
||||
img = img.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert to RGB if needed
|
||||
if img.mode in ('RGBA', 'P'):
|
||||
img = img.convert('RGB')
|
||||
|
||||
# Encode to base64
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='JPEG', quality=85)
|
||||
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
|
||||
def _parse_response(self, response: str) -> VLMEvaluationResult:
|
||||
"""Parse VLM response to extract scores"""
|
||||
result = VLMEvaluationResult(raw_response=response)
|
||||
|
||||
try:
|
||||
# Try to extract JSON from response
|
||||
json_match = re.search(r'```json\s*([\s\S]*?)\s*```', response)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
# Try to find raw JSON
|
||||
brace_start = response.find('{')
|
||||
brace_end = response.rfind('}')
|
||||
if brace_start != -1 and brace_end > brace_start:
|
||||
json_str = response[brace_start:brace_end + 1]
|
||||
else:
|
||||
logger.warning("No JSON found in VLM response")
|
||||
return result
|
||||
|
||||
data = json.loads(json_str)
|
||||
|
||||
result.aesthetic_score = float(data.get('aesthetic_score', 0.0))
|
||||
result.text_match_score = float(data.get('text_match_score', 0.0))
|
||||
result.technical_score = float(data.get('technical_score', 0.0))
|
||||
result.issues = data.get('issues', [])
|
||||
|
||||
# Clamp scores to valid range
|
||||
result.aesthetic_score = max(0.0, min(1.0, result.aesthetic_score))
|
||||
result.text_match_score = max(0.0, min(1.0, result.text_match_score))
|
||||
result.technical_score = max(0.0, min(1.0, result.technical_score))
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.warning(f"Failed to parse VLM response: {e}")
|
||||
|
||||
return result
|
||||
|
||||
async def evaluate_image(
|
||||
self,
|
||||
image_path: str,
|
||||
prompt: str,
|
||||
narration: Optional[str] = None
|
||||
) -> VLMEvaluationResult:
|
||||
"""
|
||||
Evaluate image quality using VLM
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
prompt: Generation prompt
|
||||
narration: Optional narration text
|
||||
|
||||
Returns:
|
||||
VLMEvaluationResult with scores
|
||||
"""
|
||||
if not Path(image_path).exists():
|
||||
return VLMEvaluationResult(issues=["Image file not found"])
|
||||
|
||||
if not self.llm_service:
|
||||
logger.warning("No LLM service provided for VLM evaluation")
|
||||
return VLMEvaluationResult(issues=["No LLM service"])
|
||||
|
||||
try:
|
||||
# Encode image
|
||||
image_b64 = self._encode_image_base64(image_path)
|
||||
|
||||
# Build prompt
|
||||
narration_section = f"旁白文案: {narration}" if narration else ""
|
||||
eval_prompt = self.EVALUATION_PROMPT.format(
|
||||
prompt=prompt,
|
||||
narration_section=narration_section
|
||||
)
|
||||
|
||||
# Call VLM via LLM service with vision
|
||||
response = await self._call_vlm(image_b64, eval_prompt)
|
||||
|
||||
return self._parse_response(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"VLM evaluation failed: {e}")
|
||||
return VLMEvaluationResult(issues=[f"Evaluation error: {str(e)}"])
|
||||
|
||||
async def _call_vlm(self, image_b64: str, prompt: str) -> str:
|
||||
"""Call VLM with image and prompt"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Get config from LLM service
|
||||
base_url = self.llm_service._get_config_value("base_url")
|
||||
api_key = self.llm_service._get_config_value("api_key")
|
||||
model = self.config.model or self.llm_service._get_config_value("model")
|
||||
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
# Build message with image
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_b64}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=self.config.temperature,
|
||||
max_tokens=500,
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
@@ -34,6 +34,11 @@ dev = [
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"ruff>=0.6.0",
|
||||
]
|
||||
quality = [
|
||||
"torch>=2.0.0",
|
||||
"ftfy>=6.1.0",
|
||||
"regex>=2023.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
pixelle-video = "pixelle_video.cli:main"
|
||||
|
||||
47
tests/test_feature_extractor.py
Normal file
47
tests/test_feature_extractor.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
# Tests for FeatureExtractor
|
||||
|
||||
import pytest
|
||||
|
||||
from pixelle_video.services.quality.feature_extractor import (
|
||||
FeatureExtractor,
|
||||
FeatureExtractorConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestFeatureExtractorConfig:
|
||||
"""Tests for FeatureExtractorConfig"""
|
||||
|
||||
def test_default_values(self):
|
||||
config = FeatureExtractorConfig()
|
||||
assert config.model_name == "ViT-B/32"
|
||||
assert config.device == "auto"
|
||||
assert config.cache_features is True
|
||||
|
||||
|
||||
class TestFeatureExtractor:
|
||||
"""Tests for FeatureExtractor"""
|
||||
|
||||
def test_init_default(self):
|
||||
extractor = FeatureExtractor()
|
||||
assert extractor.config.model_name == "ViT-B/32"
|
||||
|
||||
def test_is_available_check(self):
|
||||
"""Test availability check (may be True or False)"""
|
||||
extractor = FeatureExtractor()
|
||||
# Just check it returns a boolean
|
||||
assert isinstance(extractor.is_available, bool)
|
||||
|
||||
def test_extract_without_clip(self):
|
||||
"""Test graceful degradation when CLIP unavailable"""
|
||||
extractor = FeatureExtractor()
|
||||
if not extractor.is_available:
|
||||
result = extractor.extract_image_features("/fake/path.png")
|
||||
assert result is None
|
||||
|
||||
def test_clear_cache(self):
|
||||
"""Test cache clearing"""
|
||||
extractor = FeatureExtractor()
|
||||
extractor._feature_cache["test"] = "value"
|
||||
extractor.clear_cache()
|
||||
assert len(extractor._feature_cache) == 0
|
||||
53
tests/test_hybrid_quality_gate.py
Normal file
53
tests/test_hybrid_quality_gate.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
# Tests for HybridQualityGate
|
||||
|
||||
import pytest
|
||||
|
||||
from pixelle_video.services.quality.quality_gate import (
|
||||
QualityGate,
|
||||
HybridQualityGate,
|
||||
HybridQualityConfig,
|
||||
)
|
||||
from pixelle_video.services.quality.models import QualityScore
|
||||
|
||||
|
||||
class TestHybridQualityConfig:
|
||||
"""Tests for HybridQualityConfig"""
|
||||
|
||||
def test_default_values(self):
|
||||
config = HybridQualityConfig()
|
||||
assert config.enable_clip_score is True
|
||||
assert config.enable_smart_skip is True
|
||||
assert config.smart_skip_threshold == 0.75
|
||||
|
||||
def test_inherits_quality_config(self):
|
||||
config = HybridQualityConfig()
|
||||
assert hasattr(config, "overall_threshold")
|
||||
assert config.overall_threshold == 0.6
|
||||
|
||||
|
||||
class TestHybridQualityGate:
|
||||
"""Tests for HybridQualityGate"""
|
||||
|
||||
def test_init_default(self):
|
||||
gate = HybridQualityGate()
|
||||
assert gate.hybrid_config is not None
|
||||
|
||||
def test_inherits_quality_gate(self):
|
||||
gate = HybridQualityGate()
|
||||
assert isinstance(gate, QualityGate)
|
||||
|
||||
def test_lazy_load_metrics_calculator(self):
|
||||
gate = HybridQualityGate()
|
||||
calc = gate.metrics_calculator
|
||||
assert calc is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluate_nonexistent_image(self):
|
||||
gate = HybridQualityGate()
|
||||
score = await gate.evaluate_image(
|
||||
"/nonexistent/path.png",
|
||||
"test prompt"
|
||||
)
|
||||
assert score.passed is False
|
||||
assert "not found" in score.issues[0].lower()
|
||||
64
tests/test_objective_metrics.py
Normal file
64
tests/test_objective_metrics.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
# Tests for ObjectiveMetricsCalculator
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from pixelle_video.services.quality.objective_metrics import (
|
||||
ObjectiveMetricsCalculator,
|
||||
TechnicalMetrics,
|
||||
)
|
||||
|
||||
|
||||
class TestTechnicalMetrics:
|
||||
"""Tests for TechnicalMetrics dataclass"""
|
||||
|
||||
def test_default_values(self):
|
||||
metrics = TechnicalMetrics()
|
||||
assert metrics.sharpness_score == 0.0
|
||||
assert metrics.overall_technical == 0.0
|
||||
assert metrics.issues == []
|
||||
|
||||
def test_to_dict(self):
|
||||
metrics = TechnicalMetrics(
|
||||
sharpness_score=0.8,
|
||||
brightness_score=0.5,
|
||||
issues=["test issue"]
|
||||
)
|
||||
d = metrics.to_dict()
|
||||
assert d["sharpness_score"] == 0.8
|
||||
assert "test issue" in d["issues"]
|
||||
|
||||
|
||||
class TestObjectiveMetricsCalculator:
|
||||
"""Tests for ObjectiveMetricsCalculator"""
|
||||
|
||||
def test_init_default(self):
|
||||
calc = ObjectiveMetricsCalculator()
|
||||
assert calc.sharpness_threshold == 0.3
|
||||
|
||||
def test_init_custom(self):
|
||||
calc = ObjectiveMetricsCalculator(sharpness_threshold=0.5)
|
||||
assert calc.sharpness_threshold == 0.5
|
||||
|
||||
def test_analyze_nonexistent_image(self):
|
||||
calc = ObjectiveMetricsCalculator()
|
||||
metrics = calc.analyze_image("/nonexistent/path.png")
|
||||
assert len(metrics.issues) > 0
|
||||
assert "failed" in metrics.issues[0].lower()
|
||||
|
||||
def test_analyze_real_image(self, tmp_path):
|
||||
"""Test with a real image file"""
|
||||
from PIL import Image
|
||||
|
||||
# Create test image
|
||||
img = Image.new("RGB", (256, 256), color=(128, 128, 128))
|
||||
img_path = tmp_path / "test.png"
|
||||
img.save(img_path)
|
||||
|
||||
calc = ObjectiveMetricsCalculator()
|
||||
metrics = calc.analyze_image(str(img_path))
|
||||
|
||||
assert 0.0 <= metrics.sharpness_score <= 1.0
|
||||
assert 0.0 <= metrics.brightness_score <= 1.0
|
||||
assert 0.0 <= metrics.overall_technical <= 1.0
|
||||
Reference in New Issue
Block a user