- Add FeatureExtractor for CLIP-based image/text feature extraction - Add ObjectiveMetricsCalculator for technical quality metrics - Add VLMEvaluator for vision language model evaluation - Add HybridQualityGate combining objective + VLM evaluation - Enhance CharacterMemory with visual feature support - Add quality optional dependency (torch, ftfy, regex) - Add unit tests for new modules 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
146 lines
5.2 KiB
Python
146 lines
5.2 KiB
Python
# Copyright (C) 2025 AIDC-AI
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Data models for quality assurance
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import List, Optional, Any
|
|
from enum import Enum
|
|
|
|
|
|
class QualityLevel(Enum):
|
|
"""Quality level enumeration"""
|
|
EXCELLENT = "excellent" # >= 0.8
|
|
GOOD = "good" # >= 0.6
|
|
ACCEPTABLE = "acceptable" # >= 0.4
|
|
POOR = "poor" # < 0.4
|
|
|
|
|
|
@dataclass
|
|
class QualityScore:
|
|
"""Quality evaluation result for generated content"""
|
|
|
|
# Individual scores (0.0 - 1.0)
|
|
aesthetic_score: float = 0.0 # Visual appeal / beauty
|
|
text_match_score: float = 0.0 # How well image matches the prompt
|
|
technical_score: float = 0.0 # Technical quality (clarity, no artifacts)
|
|
|
|
# Overall
|
|
overall_score: float = 0.0 # Weighted average
|
|
passed: bool = False # Whether it meets threshold
|
|
|
|
# Diagnostics
|
|
issues: List[str] = field(default_factory=list) # Detected problems
|
|
evaluation_time_ms: float = 0.0 # Time taken for evaluation
|
|
|
|
# Extended metrics (NEW - for HybridQualityGate)
|
|
clip_score: Optional[float] = None # CLIP image-text similarity
|
|
technical_metrics: Optional[Any] = None # TechnicalMetrics object
|
|
vlm_used: bool = True # Whether VLM was used for evaluation
|
|
|
|
@property
|
|
def level(self) -> QualityLevel:
|
|
"""Get quality level based on overall score"""
|
|
if self.overall_score >= 0.8:
|
|
return QualityLevel.EXCELLENT
|
|
elif self.overall_score >= 0.6:
|
|
return QualityLevel.GOOD
|
|
elif self.overall_score >= 0.4:
|
|
return QualityLevel.ACCEPTABLE
|
|
else:
|
|
return QualityLevel.POOR
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert to dictionary for serialization"""
|
|
return {
|
|
"aesthetic_score": self.aesthetic_score,
|
|
"text_match_score": self.text_match_score,
|
|
"technical_score": self.technical_score,
|
|
"overall_score": self.overall_score,
|
|
"passed": self.passed,
|
|
"level": self.level.value,
|
|
"issues": self.issues,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class QualityConfig:
|
|
"""Configuration for quality evaluation"""
|
|
|
|
# Thresholds (0.0 - 1.0)
|
|
overall_threshold: float = 0.6 # Minimum overall score to pass
|
|
aesthetic_threshold: float = 0.5 # Minimum aesthetic score
|
|
text_match_threshold: float = 0.6 # Minimum text-match score
|
|
technical_threshold: float = 0.7 # Minimum technical score
|
|
|
|
# Weights for overall score calculation
|
|
aesthetic_weight: float = 0.3
|
|
text_match_weight: float = 0.4
|
|
technical_weight: float = 0.3
|
|
|
|
# Evaluation settings
|
|
use_vlm_evaluation: bool = True # Use VLM for evaluation (vs local models)
|
|
vlm_model: Optional[str] = None # VLM model to use (None = use default LLM)
|
|
skip_on_static_template: bool = True # Skip image quality check for static templates
|
|
|
|
def __post_init__(self):
|
|
"""Validate weights sum to 1.0"""
|
|
total = self.aesthetic_weight + self.text_match_weight + self.technical_weight
|
|
if abs(total - 1.0) > 0.01:
|
|
# Normalize weights
|
|
self.aesthetic_weight /= total
|
|
self.text_match_weight /= total
|
|
self.technical_weight /= total
|
|
|
|
|
|
@dataclass
|
|
class RetryConfig:
|
|
"""Configuration for retry behavior"""
|
|
|
|
max_retries: int = 3 # Maximum retry attempts
|
|
backoff_factor: float = 1.5 # Exponential backoff multiplier
|
|
initial_delay_ms: int = 500 # Initial delay before first retry
|
|
max_delay_ms: int = 10000 # Maximum delay between retries
|
|
|
|
# Quality-based retry
|
|
quality_threshold: float = 0.6 # Quality score threshold for pass
|
|
|
|
# Fallback behavior
|
|
enable_fallback: bool = True # Enable fallback strategy on failure
|
|
fallback_prompt_simplify: bool = True # Simplify prompt on retry
|
|
|
|
# Retry conditions
|
|
retry_on_quality_fail: bool = True # Retry when quality below threshold
|
|
retry_on_error: bool = True # Retry on generation errors
|
|
|
|
|
|
class QualityError(Exception):
|
|
"""Exception raised when quality standards are not met after all retries"""
|
|
|
|
def __init__(
|
|
self,
|
|
message: str,
|
|
attempts: int = 0,
|
|
last_score: Optional[QualityScore] = None
|
|
):
|
|
super().__init__(message)
|
|
self.attempts = attempts
|
|
self.last_score = last_score
|
|
|
|
def __str__(self) -> str:
|
|
base = super().__str__()
|
|
if self.last_score:
|
|
return f"{base} (attempts={self.attempts}, last_score={self.last_score.overall_score:.2f})"
|
|
return f"{base} (attempts={self.attempts})"
|