Optimize the generation logic
This commit is contained in:
@@ -4,12 +4,12 @@ Image prompt generation service
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import List
|
||||
from typing import List, Optional, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from reelforge.models.storyboard import StoryboardConfig
|
||||
from reelforge.prompts.image_prompt_template import build_image_prompt_prompt
|
||||
from reelforge.prompts import build_image_prompt_prompt
|
||||
|
||||
|
||||
class ImagePromptGeneratorService:
|
||||
@@ -29,62 +29,93 @@ class ImagePromptGeneratorService:
|
||||
narrations: List[str],
|
||||
config: StoryboardConfig,
|
||||
image_style_preset: str = None,
|
||||
image_style_description: str = None
|
||||
image_style_description: str = None,
|
||||
batch_size: int = 10,
|
||||
max_retries: int = 3,
|
||||
progress_callback: Optional[Callable] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate image prompts based on narrations
|
||||
Generate image prompts based on narrations (with batching and retry)
|
||||
|
||||
Args:
|
||||
narrations: List of narrations
|
||||
config: Storyboard configuration
|
||||
image_style_preset: Preset style name (e.g., "minimal", "futuristic")
|
||||
image_style_description: Custom style description (overrides preset)
|
||||
batch_size: Max narrations per batch (default: 10)
|
||||
max_retries: Max retry attempts per batch (default: 3)
|
||||
progress_callback: Optional callback(completed, total, message) for progress updates
|
||||
|
||||
Returns:
|
||||
List of image prompts with style applied
|
||||
|
||||
Raises:
|
||||
ValueError: If generated prompt count doesn't match narrations
|
||||
ValueError: If batch fails after max_retries
|
||||
json.JSONDecodeError: If unable to parse JSON
|
||||
"""
|
||||
logger.info(f"Generating image prompts for {len(narrations)} narrations")
|
||||
logger.info(f"Generating image prompts for {len(narrations)} narrations (batch_size={batch_size}, max_retries={max_retries})")
|
||||
|
||||
# 1. Build prompt (no style info - generate base scene descriptions)
|
||||
prompt = build_image_prompt_prompt(
|
||||
narrations=narrations,
|
||||
min_words=config.min_image_prompt_words,
|
||||
max_words=config.max_image_prompt_words,
|
||||
image_style_preset=None, # Don't include style in LLM prompt
|
||||
image_style_description=None
|
||||
)
|
||||
# Split narrations into batches
|
||||
batches = [narrations[i:i + batch_size] for i in range(0, len(narrations), batch_size)]
|
||||
logger.info(f"Split into {len(batches)} batches")
|
||||
|
||||
# 2. Call LLM to generate base scene descriptions
|
||||
response = await self.core.llm(
|
||||
prompt=prompt,
|
||||
temperature=0.9, # Higher temperature for more visual creativity
|
||||
max_tokens=2000
|
||||
)
|
||||
all_base_prompts = []
|
||||
|
||||
logger.debug(f"LLM response: {response[:200]}...")
|
||||
# Process each batch
|
||||
for batch_idx, batch_narrations in enumerate(batches, 1):
|
||||
logger.info(f"Processing batch {batch_idx}/{len(batches)} ({len(batch_narrations)} narrations)")
|
||||
|
||||
# Retry logic for this batch
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
# Generate prompts for this batch
|
||||
batch_prompts = await self._generate_batch_prompts(
|
||||
batch_narrations,
|
||||
config,
|
||||
batch_idx,
|
||||
attempt
|
||||
)
|
||||
|
||||
# Validate count
|
||||
if len(batch_prompts) != len(batch_narrations):
|
||||
error_msg = (
|
||||
f"Batch {batch_idx} prompt count mismatch (attempt {attempt}/{max_retries}):\n"
|
||||
f" Expected: {len(batch_narrations)} prompts\n"
|
||||
f" Got: {len(batch_prompts)} prompts\n"
|
||||
f" Difference: {abs(len(batch_prompts) - len(batch_narrations))} "
|
||||
f"{'missing' if len(batch_prompts) < len(batch_narrations) else 'extra'}"
|
||||
)
|
||||
logger.warning(error_msg)
|
||||
|
||||
if attempt < max_retries:
|
||||
logger.info(f"Retrying batch {batch_idx}...")
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Batch {batch_idx} failed after {max_retries} attempts")
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Success!
|
||||
logger.info(f"✅ Batch {batch_idx} completed successfully ({len(batch_prompts)} prompts)")
|
||||
all_base_prompts.extend(batch_prompts)
|
||||
|
||||
# Report progress
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
len(all_base_prompts),
|
||||
len(narrations),
|
||||
f"Batch {batch_idx}/{len(batches)} completed"
|
||||
)
|
||||
|
||||
break
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Batch {batch_idx} JSON parse error (attempt {attempt}/{max_retries}): {e}")
|
||||
if attempt >= max_retries:
|
||||
raise
|
||||
logger.info(f"Retrying batch {batch_idx}...")
|
||||
|
||||
# 3. Parse JSON
|
||||
try:
|
||||
result = self._parse_json(response)
|
||||
base_prompts = result["image_prompts"]
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse LLM response: {e}")
|
||||
logger.error(f"Response: {response}")
|
||||
raise
|
||||
except KeyError:
|
||||
logger.error("Response missing 'image_prompts' key")
|
||||
raise ValueError("Invalid response format")
|
||||
|
||||
# 4. Validate count matches narrations
|
||||
if len(base_prompts) != len(narrations):
|
||||
raise ValueError(
|
||||
f"Expected {len(narrations)} image prompts, "
|
||||
f"got {len(base_prompts)}"
|
||||
)
|
||||
base_prompts = all_base_prompts
|
||||
logger.info(f"✅ All batches completed. Total prompts: {len(base_prompts)}")
|
||||
|
||||
# 5. Apply style to each prompt using FinalImagePromptService
|
||||
from reelforge.services.final_image_prompt import StylePreset
|
||||
@@ -110,6 +141,58 @@ class ImagePromptGeneratorService:
|
||||
logger.info(f"Generated {len(final_prompts)} final image prompts with style applied")
|
||||
return final_prompts
|
||||
|
||||
async def _generate_batch_prompts(
|
||||
self,
|
||||
batch_narrations: List[str],
|
||||
config: StoryboardConfig,
|
||||
batch_idx: int,
|
||||
attempt: int
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate image prompts for a single batch of narrations
|
||||
|
||||
Args:
|
||||
batch_narrations: Batch of narrations
|
||||
config: Storyboard configuration
|
||||
batch_idx: Batch index (for logging)
|
||||
attempt: Attempt number (for logging)
|
||||
|
||||
Returns:
|
||||
List of image prompts for this batch
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: If unable to parse JSON
|
||||
KeyError: If response format is invalid
|
||||
"""
|
||||
logger.debug(f"Batch {batch_idx} attempt {attempt}: Generating prompts for {len(batch_narrations)} narrations")
|
||||
|
||||
# 1. Build prompt
|
||||
prompt = build_image_prompt_prompt(
|
||||
narrations=batch_narrations,
|
||||
min_words=config.min_image_prompt_words,
|
||||
max_words=config.max_image_prompt_words,
|
||||
image_style_preset=None,
|
||||
image_style_description=None
|
||||
)
|
||||
|
||||
# 2. Call LLM
|
||||
response = await self.core.llm(
|
||||
prompt=prompt,
|
||||
temperature=0.7,
|
||||
max_tokens=8192
|
||||
)
|
||||
|
||||
logger.debug(f"Batch {batch_idx} attempt {attempt}: LLM response length: {len(response)} chars")
|
||||
|
||||
# 3. Parse JSON
|
||||
result = self._parse_json(response)
|
||||
|
||||
if "image_prompts" not in result:
|
||||
logger.error("Response missing 'image_prompts' key")
|
||||
raise KeyError("Invalid response format: missing 'image_prompts'")
|
||||
|
||||
return result["image_prompts"]
|
||||
|
||||
def _parse_json(self, text: str) -> dict:
|
||||
"""
|
||||
Parse JSON from text, with fallback to extract JSON from markdown code blocks
|
||||
@@ -127,7 +210,7 @@ class ImagePromptGeneratorService:
|
||||
pass
|
||||
|
||||
# Try to extract JSON from markdown code block
|
||||
json_pattern = r'```(?:json)?\s*(\{.*?\})\s*```'
|
||||
json_pattern = r'```(?:json)?\s*([\s\S]+?)\s*```'
|
||||
match = re.search(json_pattern, text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user