Files
AI-Video/reelforge/services/image_prompt_generator.py

219 lines
8.0 KiB
Python

"""
Image prompt generation service
"""
import json
import re
from typing import List, Optional, Callable
from loguru import logger
from reelforge.models.storyboard import StoryboardConfig
from reelforge.prompts import build_image_prompt_prompt
class ImagePromptGeneratorService:
"""Image prompt generation service"""
def __init__(self, reelforge_core):
"""
Initialize
Args:
reelforge_core: ReelForgeCore instance
"""
self.core = reelforge_core
async def generate_image_prompts(
self,
narrations: List[str],
config: StoryboardConfig,
batch_size: int = 10,
max_retries: int = 3,
progress_callback: Optional[Callable] = None
) -> List[str]:
"""
Generate image prompts based on narrations (with batching and retry)
Args:
narrations: List of narrations
config: Storyboard configuration
batch_size: Max narrations per batch (default: 10)
max_retries: Max retry attempts per batch (default: 3)
progress_callback: Optional callback(completed, total, message) for progress updates
Returns:
List of image prompts with prompt_prefix applied (from config)
Raises:
ValueError: If batch fails after max_retries
json.JSONDecodeError: If unable to parse JSON
"""
logger.info(f"Generating image prompts for {len(narrations)} narrations (batch_size={batch_size}, max_retries={max_retries})")
# Split narrations into batches
batches = [narrations[i:i + batch_size] for i in range(0, len(narrations), batch_size)]
logger.info(f"Split into {len(batches)} batches")
all_base_prompts = []
# Process each batch
for batch_idx, batch_narrations in enumerate(batches, 1):
logger.info(f"Processing batch {batch_idx}/{len(batches)} ({len(batch_narrations)} narrations)")
# Retry logic for this batch
for attempt in range(1, max_retries + 1):
try:
# Generate prompts for this batch
batch_prompts = await self._generate_batch_prompts(
batch_narrations,
config,
batch_idx,
attempt
)
# Validate count
if len(batch_prompts) != len(batch_narrations):
error_msg = (
f"Batch {batch_idx} prompt count mismatch (attempt {attempt}/{max_retries}):\n"
f" Expected: {len(batch_narrations)} prompts\n"
f" Got: {len(batch_prompts)} prompts\n"
f" Difference: {abs(len(batch_prompts) - len(batch_narrations))} "
f"{'missing' if len(batch_prompts) < len(batch_narrations) else 'extra'}"
)
logger.warning(error_msg)
if attempt < max_retries:
logger.info(f"Retrying batch {batch_idx}...")
continue
else:
logger.error(f"Batch {batch_idx} failed after {max_retries} attempts")
raise ValueError(error_msg)
# Success!
logger.info(f"✅ Batch {batch_idx} completed successfully ({len(batch_prompts)} prompts)")
all_base_prompts.extend(batch_prompts)
# Report progress
if progress_callback:
progress_callback(
len(all_base_prompts),
len(narrations),
f"Batch {batch_idx}/{len(batches)} completed"
)
break
except json.JSONDecodeError as e:
logger.error(f"Batch {batch_idx} JSON parse error (attempt {attempt}/{max_retries}): {e}")
if attempt >= max_retries:
raise
logger.info(f"Retrying batch {batch_idx}...")
base_prompts = all_base_prompts
logger.info(f"✅ All batches completed. Total prompts: {len(base_prompts)}")
# 5. Apply prompt prefix to each prompt
from reelforge.utils.prompt_helper import build_image_prompt
# Get prompt prefix from config (fix: correct path is comfyui.image.prompt_prefix)
image_config = self.core.config.get("comfyui", {}).get("image", {})
prompt_prefix = image_config.get("prompt_prefix", "")
# Apply prefix to each base prompt
final_prompts = []
for base_prompt in base_prompts:
final_prompt = build_image_prompt(base_prompt, prompt_prefix)
final_prompts.append(final_prompt)
logger.info(f"Generated {len(final_prompts)} final image prompts with prefix applied")
return final_prompts
async def _generate_batch_prompts(
self,
batch_narrations: List[str],
config: StoryboardConfig,
batch_idx: int,
attempt: int
) -> List[str]:
"""
Generate image prompts for a single batch of narrations
Args:
batch_narrations: Batch of narrations
config: Storyboard configuration
batch_idx: Batch index (for logging)
attempt: Attempt number (for logging)
Returns:
List of image prompts for this batch
Raises:
json.JSONDecodeError: If unable to parse JSON
KeyError: If response format is invalid
"""
logger.debug(f"Batch {batch_idx} attempt {attempt}: Generating prompts for {len(batch_narrations)} narrations")
# 1. Build prompt
prompt = build_image_prompt_prompt(
narrations=batch_narrations,
min_words=config.min_image_prompt_words,
max_words=config.max_image_prompt_words
)
# 2. Call LLM
response = await self.core.llm(
prompt=prompt,
temperature=0.7,
max_tokens=8192
)
logger.debug(f"Batch {batch_idx} attempt {attempt}: LLM response length: {len(response)} chars")
# 3. Parse JSON
result = self._parse_json(response)
if "image_prompts" not in result:
logger.error("Response missing 'image_prompts' key")
raise KeyError("Invalid response format: missing 'image_prompts'")
return result["image_prompts"]
def _parse_json(self, text: str) -> dict:
"""
Parse JSON from text, with fallback to extract JSON from markdown code blocks
Args:
text: Text containing JSON
Returns:
Parsed JSON dict
"""
# Try direct parsing first
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Try to extract JSON from markdown code block
json_pattern = r'```(?:json)?\s*([\s\S]+?)\s*```'
match = re.search(json_pattern, text, re.DOTALL)
if match:
try:
return json.loads(match.group(1))
except json.JSONDecodeError:
pass
# Try to find any JSON object in the text
json_pattern = r'\{[^{}]*"image_prompts"\s*:\s*\[[^\]]*\][^{}]*\}'
match = re.search(json_pattern, text, re.DOTALL)
if match:
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
pass
# If all fails, raise error
raise json.JSONDecodeError("No valid JSON found", text, 0)