# Copyright (C) 2025 AIDC-AI # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ RetryManager - Smart retry logic with quality-based decisions Provides unified retry management for: - Media generation (images, videos) - LLM calls - Any async operations with quality evaluation """ import asyncio from typing import Callable, TypeVar, Optional, Tuple, Any from dataclasses import dataclass from loguru import logger from pixelle_video.services.quality.models import ( QualityScore, QualityConfig, RetryConfig, QualityError, ) T = TypeVar("T") @dataclass class RetryResult: """Result of a retry operation""" success: bool result: Any attempts: int quality_score: Optional[QualityScore] = None error: Optional[Exception] = None class RetryManager: """ Smart retry manager with quality-based decisions Features: - Exponential backoff with configurable delays - Quality-aware retry (retries when quality below threshold) - Fallback strategies - Detailed logging and metrics Example: >>> retry_manager = RetryManager() >>> >>> async def generate(): ... return await media_service.generate_image(prompt) >>> >>> async def evaluate(image_path): ... return await quality_gate.evaluate_image(image_path, prompt) >>> >>> result = await retry_manager.execute_with_retry( ... operation=generate, ... quality_evaluator=evaluate, ... config=RetryConfig(max_retries=3) ... ) """ def __init__(self, config: Optional[RetryConfig] = None): """ Initialize RetryManager Args: config: Default retry configuration """ self.default_config = config or RetryConfig() async def execute_with_retry( self, operation: Callable[[], Any], quality_evaluator: Optional[Callable[[Any], QualityScore]] = None, config: Optional[RetryConfig] = None, fallback_operation: Optional[Callable[[], Any]] = None, operation_name: str = "operation", ) -> RetryResult: """ Execute operation with automatic retry and quality evaluation Args: operation: Async callable to execute quality_evaluator: Optional async callable to evaluate result quality config: Retry configuration (uses default if not provided) fallback_operation: Optional fallback to try when all retries fail operation_name: Name for logging purposes Returns: RetryResult with success status, result, and quality score Raises: QualityError: When all retries fail and no fallback available """ cfg = config or self.default_config last_error: Optional[Exception] = None last_score: Optional[QualityScore] = None for attempt in range(1, cfg.max_retries + 1): try: # Execute the operation logger.debug(f"{operation_name}: attempt {attempt}/{cfg.max_retries}") result = await operation() # If no quality evaluator, accept the result if quality_evaluator is None: logger.debug(f"{operation_name}: completed (no quality check)") return RetryResult( success=True, result=result, attempts=attempt, ) # Evaluate quality score = await quality_evaluator(result) last_score = score if score.passed: logger.info( f"{operation_name}: passed quality check " f"(score={score.overall_score:.2f}, attempt={attempt})" ) return RetryResult( success=True, result=result, attempts=attempt, quality_score=score, ) # Quality check failed logger.warning( f"{operation_name}: quality check failed " f"(score={score.overall_score:.2f}, threshold={cfg.quality_threshold}, " f"issues={score.issues})" ) if not cfg.retry_on_quality_fail: # Don't retry on quality failure return RetryResult( success=False, result=result, attempts=attempt, quality_score=score, ) except Exception as e: last_error = e logger.warning(f"{operation_name}: attempt {attempt} failed with error: {e}") if not cfg.retry_on_error: raise # Calculate backoff delay if attempt < cfg.max_retries: delay_ms = min( cfg.initial_delay_ms * (cfg.backoff_factor ** (attempt - 1)), cfg.max_delay_ms ) logger.debug(f"{operation_name}: waiting {delay_ms:.0f}ms before retry") await asyncio.sleep(delay_ms / 1000) # All retries exhausted logger.warning(f"{operation_name}: all {cfg.max_retries} attempts failed") # Try fallback if available if cfg.enable_fallback and fallback_operation: try: logger.info(f"{operation_name}: trying fallback strategy") result = await fallback_operation() # Evaluate fallback result if evaluator available if quality_evaluator: score = await quality_evaluator(result) return RetryResult( success=score.passed if score else True, result=result, attempts=cfg.max_retries + 1, quality_score=score, ) return RetryResult( success=True, result=result, attempts=cfg.max_retries + 1, ) except Exception as e: logger.error(f"{operation_name}: fallback also failed: {e}") last_error = e # Nothing worked if last_error: raise QualityError( f"{operation_name} failed after {cfg.max_retries} attempts: {last_error}", attempts=cfg.max_retries, last_score=last_score, ) raise QualityError( f"{operation_name} failed to meet quality threshold after {cfg.max_retries} attempts", attempts=cfg.max_retries, last_score=last_score, ) async def execute_simple( self, operation: Callable[[], Any], config: Optional[RetryConfig] = None, operation_name: str = "operation", ) -> Any: """ Execute operation with simple retry (no quality evaluation) Args: operation: Async callable to execute config: Retry configuration operation_name: Name for logging Returns: Operation result Raises: Exception: Last exception if all retries fail """ cfg = config or self.default_config last_error: Optional[Exception] = None for attempt in range(1, cfg.max_retries + 1): try: logger.debug(f"{operation_name}: attempt {attempt}/{cfg.max_retries}") result = await operation() logger.debug(f"{operation_name}: completed on attempt {attempt}") return result except Exception as e: last_error = e logger.warning(f"{operation_name}: attempt {attempt} failed: {e}") if attempt < cfg.max_retries: delay_ms = min( cfg.initial_delay_ms * (cfg.backoff_factor ** (attempt - 1)), cfg.max_delay_ms ) await asyncio.sleep(delay_ms / 1000) raise last_error @staticmethod def create_prompt_simplifier(original_prompt: str) -> Callable[[], str]: """ Create a fallback that simplifies the prompt Args: original_prompt: Original complex prompt Returns: Callable that returns a simplified prompt """ def simplify() -> str: # Simple prompt simplification strategies simplified = original_prompt # Remove complex modifiers remove_phrases = [ "highly detailed", "ultra realistic", "8k resolution", "masterpiece", "best quality", ] for phrase in remove_phrases: simplified = simplified.replace(phrase, "").replace(phrase.lower(), "") # Truncate if too long if len(simplified) > 150: simplified = simplified[:150].rsplit(" ", 1)[0] # Clean up simplified = " ".join(simplified.split()) return simplified return simplify