Files
AI-Video/pixelle_video/services/quality/retry_manager.py

297 lines
10 KiB
Python

# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
RetryManager - Smart retry logic with quality-based decisions
Provides unified retry management for:
- Media generation (images, videos)
- LLM calls
- Any async operations with quality evaluation
"""
import asyncio
from typing import Callable, TypeVar, Optional, Tuple, Any
from dataclasses import dataclass
from loguru import logger
from pixelle_video.services.quality.models import (
QualityScore,
QualityConfig,
RetryConfig,
QualityError,
)
T = TypeVar("T")
@dataclass
class RetryResult:
"""Result of a retry operation"""
success: bool
result: Any
attempts: int
quality_score: Optional[QualityScore] = None
error: Optional[Exception] = None
class RetryManager:
"""
Smart retry manager with quality-based decisions
Features:
- Exponential backoff with configurable delays
- Quality-aware retry (retries when quality below threshold)
- Fallback strategies
- Detailed logging and metrics
Example:
>>> retry_manager = RetryManager()
>>>
>>> async def generate():
... return await media_service.generate_image(prompt)
>>>
>>> async def evaluate(image_path):
... return await quality_gate.evaluate_image(image_path, prompt)
>>>
>>> result = await retry_manager.execute_with_retry(
... operation=generate,
... quality_evaluator=evaluate,
... config=RetryConfig(max_retries=3)
... )
"""
def __init__(self, config: Optional[RetryConfig] = None):
"""
Initialize RetryManager
Args:
config: Default retry configuration
"""
self.default_config = config or RetryConfig()
async def execute_with_retry(
self,
operation: Callable[[], Any],
quality_evaluator: Optional[Callable[[Any], QualityScore]] = None,
config: Optional[RetryConfig] = None,
fallback_operation: Optional[Callable[[], Any]] = None,
operation_name: str = "operation",
) -> RetryResult:
"""
Execute operation with automatic retry and quality evaluation
Args:
operation: Async callable to execute
quality_evaluator: Optional async callable to evaluate result quality
config: Retry configuration (uses default if not provided)
fallback_operation: Optional fallback to try when all retries fail
operation_name: Name for logging purposes
Returns:
RetryResult with success status, result, and quality score
Raises:
QualityError: When all retries fail and no fallback available
"""
cfg = config or self.default_config
last_error: Optional[Exception] = None
last_score: Optional[QualityScore] = None
for attempt in range(1, cfg.max_retries + 1):
try:
# Execute the operation
logger.debug(f"{operation_name}: attempt {attempt}/{cfg.max_retries}")
result = await operation()
# If no quality evaluator, accept the result
if quality_evaluator is None:
logger.debug(f"{operation_name}: completed (no quality check)")
return RetryResult(
success=True,
result=result,
attempts=attempt,
)
# Evaluate quality
score = await quality_evaluator(result)
last_score = score
if score.passed:
logger.info(
f"{operation_name}: passed quality check "
f"(score={score.overall_score:.2f}, attempt={attempt})"
)
return RetryResult(
success=True,
result=result,
attempts=attempt,
quality_score=score,
)
# Quality check failed
logger.warning(
f"{operation_name}: quality check failed "
f"(score={score.overall_score:.2f}, threshold={cfg.quality_threshold}, "
f"issues={score.issues})"
)
if not cfg.retry_on_quality_fail:
# Don't retry on quality failure
return RetryResult(
success=False,
result=result,
attempts=attempt,
quality_score=score,
)
except Exception as e:
last_error = e
logger.warning(f"{operation_name}: attempt {attempt} failed with error: {e}")
if not cfg.retry_on_error:
raise
# Calculate backoff delay
if attempt < cfg.max_retries:
delay_ms = min(
cfg.initial_delay_ms * (cfg.backoff_factor ** (attempt - 1)),
cfg.max_delay_ms
)
logger.debug(f"{operation_name}: waiting {delay_ms:.0f}ms before retry")
await asyncio.sleep(delay_ms / 1000)
# All retries exhausted
logger.warning(f"{operation_name}: all {cfg.max_retries} attempts failed")
# Try fallback if available
if cfg.enable_fallback and fallback_operation:
try:
logger.info(f"{operation_name}: trying fallback strategy")
result = await fallback_operation()
# Evaluate fallback result if evaluator available
if quality_evaluator:
score = await quality_evaluator(result)
return RetryResult(
success=score.passed if score else True,
result=result,
attempts=cfg.max_retries + 1,
quality_score=score,
)
return RetryResult(
success=True,
result=result,
attempts=cfg.max_retries + 1,
)
except Exception as e:
logger.error(f"{operation_name}: fallback also failed: {e}")
last_error = e
# Nothing worked
if last_error:
raise QualityError(
f"{operation_name} failed after {cfg.max_retries} attempts: {last_error}",
attempts=cfg.max_retries,
last_score=last_score,
)
raise QualityError(
f"{operation_name} failed to meet quality threshold after {cfg.max_retries} attempts",
attempts=cfg.max_retries,
last_score=last_score,
)
async def execute_simple(
self,
operation: Callable[[], Any],
config: Optional[RetryConfig] = None,
operation_name: str = "operation",
) -> Any:
"""
Execute operation with simple retry (no quality evaluation)
Args:
operation: Async callable to execute
config: Retry configuration
operation_name: Name for logging
Returns:
Operation result
Raises:
Exception: Last exception if all retries fail
"""
cfg = config or self.default_config
last_error: Optional[Exception] = None
for attempt in range(1, cfg.max_retries + 1):
try:
logger.debug(f"{operation_name}: attempt {attempt}/{cfg.max_retries}")
result = await operation()
logger.debug(f"{operation_name}: completed on attempt {attempt}")
return result
except Exception as e:
last_error = e
logger.warning(f"{operation_name}: attempt {attempt} failed: {e}")
if attempt < cfg.max_retries:
delay_ms = min(
cfg.initial_delay_ms * (cfg.backoff_factor ** (attempt - 1)),
cfg.max_delay_ms
)
await asyncio.sleep(delay_ms / 1000)
raise last_error
@staticmethod
def create_prompt_simplifier(original_prompt: str) -> Callable[[], str]:
"""
Create a fallback that simplifies the prompt
Args:
original_prompt: Original complex prompt
Returns:
Callable that returns a simplified prompt
"""
def simplify() -> str:
# Simple prompt simplification strategies
simplified = original_prompt
# Remove complex modifiers
remove_phrases = [
"highly detailed",
"ultra realistic",
"8k resolution",
"masterpiece",
"best quality",
]
for phrase in remove_phrases:
simplified = simplified.replace(phrase, "").replace(phrase.lower(), "")
# Truncate if too long
if len(simplified) > 150:
simplified = simplified[:150].rsplit(" ", 1)[0]
# Clean up
simplified = " ".join(simplified.split())
return simplified
return simplify