297 lines
10 KiB
Python
297 lines
10 KiB
Python
# Copyright (C) 2025 AIDC-AI
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
RetryManager - Smart retry logic with quality-based decisions
|
|
|
|
Provides unified retry management for:
|
|
- Media generation (images, videos)
|
|
- LLM calls
|
|
- Any async operations with quality evaluation
|
|
"""
|
|
|
|
import asyncio
|
|
from typing import Callable, TypeVar, Optional, Tuple, Any
|
|
from dataclasses import dataclass
|
|
|
|
from loguru import logger
|
|
|
|
from pixelle_video.services.quality.models import (
|
|
QualityScore,
|
|
QualityConfig,
|
|
RetryConfig,
|
|
QualityError,
|
|
)
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
@dataclass
|
|
class RetryResult:
|
|
"""Result of a retry operation"""
|
|
success: bool
|
|
result: Any
|
|
attempts: int
|
|
quality_score: Optional[QualityScore] = None
|
|
error: Optional[Exception] = None
|
|
|
|
|
|
class RetryManager:
|
|
"""
|
|
Smart retry manager with quality-based decisions
|
|
|
|
Features:
|
|
- Exponential backoff with configurable delays
|
|
- Quality-aware retry (retries when quality below threshold)
|
|
- Fallback strategies
|
|
- Detailed logging and metrics
|
|
|
|
Example:
|
|
>>> retry_manager = RetryManager()
|
|
>>>
|
|
>>> async def generate():
|
|
... return await media_service.generate_image(prompt)
|
|
>>>
|
|
>>> async def evaluate(image_path):
|
|
... return await quality_gate.evaluate_image(image_path, prompt)
|
|
>>>
|
|
>>> result = await retry_manager.execute_with_retry(
|
|
... operation=generate,
|
|
... quality_evaluator=evaluate,
|
|
... config=RetryConfig(max_retries=3)
|
|
... )
|
|
"""
|
|
|
|
def __init__(self, config: Optional[RetryConfig] = None):
|
|
"""
|
|
Initialize RetryManager
|
|
|
|
Args:
|
|
config: Default retry configuration
|
|
"""
|
|
self.default_config = config or RetryConfig()
|
|
|
|
async def execute_with_retry(
|
|
self,
|
|
operation: Callable[[], Any],
|
|
quality_evaluator: Optional[Callable[[Any], QualityScore]] = None,
|
|
config: Optional[RetryConfig] = None,
|
|
fallback_operation: Optional[Callable[[], Any]] = None,
|
|
operation_name: str = "operation",
|
|
) -> RetryResult:
|
|
"""
|
|
Execute operation with automatic retry and quality evaluation
|
|
|
|
Args:
|
|
operation: Async callable to execute
|
|
quality_evaluator: Optional async callable to evaluate result quality
|
|
config: Retry configuration (uses default if not provided)
|
|
fallback_operation: Optional fallback to try when all retries fail
|
|
operation_name: Name for logging purposes
|
|
|
|
Returns:
|
|
RetryResult with success status, result, and quality score
|
|
|
|
Raises:
|
|
QualityError: When all retries fail and no fallback available
|
|
"""
|
|
cfg = config or self.default_config
|
|
last_error: Optional[Exception] = None
|
|
last_score: Optional[QualityScore] = None
|
|
|
|
for attempt in range(1, cfg.max_retries + 1):
|
|
try:
|
|
# Execute the operation
|
|
logger.debug(f"{operation_name}: attempt {attempt}/{cfg.max_retries}")
|
|
result = await operation()
|
|
|
|
# If no quality evaluator, accept the result
|
|
if quality_evaluator is None:
|
|
logger.debug(f"{operation_name}: completed (no quality check)")
|
|
return RetryResult(
|
|
success=True,
|
|
result=result,
|
|
attempts=attempt,
|
|
)
|
|
|
|
# Evaluate quality
|
|
score = await quality_evaluator(result)
|
|
last_score = score
|
|
|
|
if score.passed:
|
|
logger.info(
|
|
f"{operation_name}: passed quality check "
|
|
f"(score={score.overall_score:.2f}, attempt={attempt})"
|
|
)
|
|
return RetryResult(
|
|
success=True,
|
|
result=result,
|
|
attempts=attempt,
|
|
quality_score=score,
|
|
)
|
|
|
|
# Quality check failed
|
|
logger.warning(
|
|
f"{operation_name}: quality check failed "
|
|
f"(score={score.overall_score:.2f}, threshold={cfg.quality_threshold}, "
|
|
f"issues={score.issues})"
|
|
)
|
|
|
|
if not cfg.retry_on_quality_fail:
|
|
# Don't retry on quality failure
|
|
return RetryResult(
|
|
success=False,
|
|
result=result,
|
|
attempts=attempt,
|
|
quality_score=score,
|
|
)
|
|
|
|
except Exception as e:
|
|
last_error = e
|
|
logger.warning(f"{operation_name}: attempt {attempt} failed with error: {e}")
|
|
|
|
if not cfg.retry_on_error:
|
|
raise
|
|
|
|
# Calculate backoff delay
|
|
if attempt < cfg.max_retries:
|
|
delay_ms = min(
|
|
cfg.initial_delay_ms * (cfg.backoff_factor ** (attempt - 1)),
|
|
cfg.max_delay_ms
|
|
)
|
|
logger.debug(f"{operation_name}: waiting {delay_ms:.0f}ms before retry")
|
|
await asyncio.sleep(delay_ms / 1000)
|
|
|
|
# All retries exhausted
|
|
logger.warning(f"{operation_name}: all {cfg.max_retries} attempts failed")
|
|
|
|
# Try fallback if available
|
|
if cfg.enable_fallback and fallback_operation:
|
|
try:
|
|
logger.info(f"{operation_name}: trying fallback strategy")
|
|
result = await fallback_operation()
|
|
|
|
# Evaluate fallback result if evaluator available
|
|
if quality_evaluator:
|
|
score = await quality_evaluator(result)
|
|
return RetryResult(
|
|
success=score.passed if score else True,
|
|
result=result,
|
|
attempts=cfg.max_retries + 1,
|
|
quality_score=score,
|
|
)
|
|
|
|
return RetryResult(
|
|
success=True,
|
|
result=result,
|
|
attempts=cfg.max_retries + 1,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"{operation_name}: fallback also failed: {e}")
|
|
last_error = e
|
|
|
|
# Nothing worked
|
|
if last_error:
|
|
raise QualityError(
|
|
f"{operation_name} failed after {cfg.max_retries} attempts: {last_error}",
|
|
attempts=cfg.max_retries,
|
|
last_score=last_score,
|
|
)
|
|
|
|
raise QualityError(
|
|
f"{operation_name} failed to meet quality threshold after {cfg.max_retries} attempts",
|
|
attempts=cfg.max_retries,
|
|
last_score=last_score,
|
|
)
|
|
|
|
async def execute_simple(
|
|
self,
|
|
operation: Callable[[], Any],
|
|
config: Optional[RetryConfig] = None,
|
|
operation_name: str = "operation",
|
|
) -> Any:
|
|
"""
|
|
Execute operation with simple retry (no quality evaluation)
|
|
|
|
Args:
|
|
operation: Async callable to execute
|
|
config: Retry configuration
|
|
operation_name: Name for logging
|
|
|
|
Returns:
|
|
Operation result
|
|
|
|
Raises:
|
|
Exception: Last exception if all retries fail
|
|
"""
|
|
cfg = config or self.default_config
|
|
last_error: Optional[Exception] = None
|
|
|
|
for attempt in range(1, cfg.max_retries + 1):
|
|
try:
|
|
logger.debug(f"{operation_name}: attempt {attempt}/{cfg.max_retries}")
|
|
result = await operation()
|
|
logger.debug(f"{operation_name}: completed on attempt {attempt}")
|
|
return result
|
|
|
|
except Exception as e:
|
|
last_error = e
|
|
logger.warning(f"{operation_name}: attempt {attempt} failed: {e}")
|
|
|
|
if attempt < cfg.max_retries:
|
|
delay_ms = min(
|
|
cfg.initial_delay_ms * (cfg.backoff_factor ** (attempt - 1)),
|
|
cfg.max_delay_ms
|
|
)
|
|
await asyncio.sleep(delay_ms / 1000)
|
|
|
|
raise last_error
|
|
|
|
@staticmethod
|
|
def create_prompt_simplifier(original_prompt: str) -> Callable[[], str]:
|
|
"""
|
|
Create a fallback that simplifies the prompt
|
|
|
|
Args:
|
|
original_prompt: Original complex prompt
|
|
|
|
Returns:
|
|
Callable that returns a simplified prompt
|
|
"""
|
|
def simplify() -> str:
|
|
# Simple prompt simplification strategies
|
|
simplified = original_prompt
|
|
|
|
# Remove complex modifiers
|
|
remove_phrases = [
|
|
"highly detailed",
|
|
"ultra realistic",
|
|
"8k resolution",
|
|
"masterpiece",
|
|
"best quality",
|
|
]
|
|
for phrase in remove_phrases:
|
|
simplified = simplified.replace(phrase, "").replace(phrase.lower(), "")
|
|
|
|
# Truncate if too long
|
|
if len(simplified) > 150:
|
|
simplified = simplified[:150].rsplit(" ", 1)[0]
|
|
|
|
# Clean up
|
|
simplified = " ".join(simplified.split())
|
|
|
|
return simplified
|
|
|
|
return simplify
|