抽象pipeline逻辑
This commit is contained in:
351
pixelle_video/utils/content_generators.py
Normal file
351
pixelle_video/utils/content_generators.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
Content generation utility functions
|
||||
|
||||
Pure/stateless functions for generating content using LLM.
|
||||
These functions are reusable across different pipelines.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional, Literal
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
async def generate_title(
|
||||
llm_service,
|
||||
content: str,
|
||||
strategy: Literal["auto", "direct", "llm"] = "auto",
|
||||
max_length: int = 15
|
||||
) -> str:
|
||||
"""
|
||||
Generate title from content
|
||||
|
||||
Args:
|
||||
llm_service: LLM service instance
|
||||
content: Source content (topic or script)
|
||||
strategy: Generation strategy
|
||||
- "auto": Auto-decide based on content length (default)
|
||||
- "direct": Use content directly (truncated if needed)
|
||||
- "llm": Always use LLM to generate title
|
||||
max_length: Maximum title length (default: 15)
|
||||
|
||||
Returns:
|
||||
Generated title
|
||||
"""
|
||||
if strategy == "direct":
|
||||
content = content.strip()
|
||||
return content[:max_length] if len(content) > max_length else content
|
||||
|
||||
if strategy == "auto":
|
||||
if len(content.strip()) <= 15:
|
||||
return content.strip()
|
||||
# Fall through to LLM
|
||||
|
||||
# Use LLM to generate title
|
||||
from pixelle_video.prompts import build_title_generation_prompt
|
||||
|
||||
prompt = build_title_generation_prompt(content, max_length=500)
|
||||
response = await llm_service(prompt, temperature=0.7, max_tokens=50)
|
||||
|
||||
# Clean up response
|
||||
title = response.strip()
|
||||
|
||||
# Remove quotes if present
|
||||
if title.startswith('"') and title.endswith('"'):
|
||||
title = title[1:-1]
|
||||
if title.startswith("'") and title.endswith("'"):
|
||||
title = title[1:-1]
|
||||
|
||||
# Limit to max_length (safety)
|
||||
if len(title) > max_length:
|
||||
title = title[:max_length]
|
||||
|
||||
logger.debug(f"Generated title: '{title}' (length: {len(title)})")
|
||||
return title
|
||||
|
||||
|
||||
async def generate_narrations_from_topic(
|
||||
llm_service,
|
||||
topic: str,
|
||||
n_scenes: int = 5,
|
||||
min_words: int = 5,
|
||||
max_words: int = 20
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate narrations from topic using LLM
|
||||
|
||||
Args:
|
||||
llm_service: LLM service instance
|
||||
topic: Topic/theme to generate narrations from
|
||||
n_scenes: Number of narrations to generate
|
||||
min_words: Minimum narration length
|
||||
max_words: Maximum narration length
|
||||
|
||||
Returns:
|
||||
List of narration texts
|
||||
"""
|
||||
from pixelle_video.prompts import build_topic_narration_prompt
|
||||
|
||||
logger.info(f"Generating {n_scenes} narrations from topic: {topic}")
|
||||
|
||||
prompt = build_topic_narration_prompt(
|
||||
topic=topic,
|
||||
n_storyboard=n_scenes,
|
||||
min_words=min_words,
|
||||
max_words=max_words
|
||||
)
|
||||
|
||||
response = await llm_service(
|
||||
prompt=prompt,
|
||||
temperature=0.8,
|
||||
max_tokens=2000
|
||||
)
|
||||
|
||||
logger.debug(f"LLM response: {response[:200]}...")
|
||||
|
||||
# Parse JSON
|
||||
result = _parse_json(response)
|
||||
|
||||
if "narrations" not in result:
|
||||
raise ValueError("Invalid response format: missing 'narrations' key")
|
||||
|
||||
narrations = result["narrations"]
|
||||
|
||||
# Validate count
|
||||
if len(narrations) > n_scenes:
|
||||
logger.warning(f"Got {len(narrations)} narrations, taking first {n_scenes}")
|
||||
narrations = narrations[:n_scenes]
|
||||
elif len(narrations) < n_scenes:
|
||||
raise ValueError(f"Expected {n_scenes} narrations, got only {len(narrations)}")
|
||||
|
||||
logger.info(f"Generated {len(narrations)} narrations successfully")
|
||||
return narrations
|
||||
|
||||
|
||||
async def generate_narrations_from_content(
|
||||
llm_service,
|
||||
content: str,
|
||||
n_scenes: int = 5,
|
||||
min_words: int = 5,
|
||||
max_words: int = 20
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate narrations from user-provided content using LLM
|
||||
|
||||
Args:
|
||||
llm_service: LLM service instance
|
||||
content: User-provided content
|
||||
n_scenes: Number of narrations to generate
|
||||
min_words: Minimum narration length
|
||||
max_words: Maximum narration length
|
||||
|
||||
Returns:
|
||||
List of narration texts
|
||||
"""
|
||||
from pixelle_video.prompts import build_content_narration_prompt
|
||||
|
||||
logger.info(f"Generating {n_scenes} narrations from content ({len(content)} chars)")
|
||||
|
||||
prompt = build_content_narration_prompt(
|
||||
content=content,
|
||||
n_storyboard=n_scenes,
|
||||
min_words=min_words,
|
||||
max_words=max_words
|
||||
)
|
||||
|
||||
response = await llm_service(
|
||||
prompt=prompt,
|
||||
temperature=0.8,
|
||||
max_tokens=2000
|
||||
)
|
||||
|
||||
# Parse JSON
|
||||
result = _parse_json(response)
|
||||
|
||||
if "narrations" not in result:
|
||||
raise ValueError("Invalid response format: missing 'narrations' key")
|
||||
|
||||
narrations = result["narrations"]
|
||||
|
||||
# Validate count
|
||||
if len(narrations) > n_scenes:
|
||||
logger.warning(f"Got {len(narrations)} narrations, taking first {n_scenes}")
|
||||
narrations = narrations[:n_scenes]
|
||||
elif len(narrations) < n_scenes:
|
||||
raise ValueError(f"Expected {n_scenes} narrations, got only {len(narrations)}")
|
||||
|
||||
logger.info(f"Generated {len(narrations)} narrations successfully")
|
||||
return narrations
|
||||
|
||||
|
||||
async def split_narration_script(
|
||||
script: str,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Split user-provided narration script into segments by lines
|
||||
|
||||
Args:
|
||||
script: Fixed narration script (each line is a narration)
|
||||
|
||||
Returns:
|
||||
List of narration segments
|
||||
"""
|
||||
logger.info(f"Splitting script by lines (length: {len(script)} chars)")
|
||||
|
||||
# Split by newline, filter empty lines
|
||||
narrations = [line.strip() for line in script.split('\n') if line.strip()]
|
||||
|
||||
logger.info(f"✅ Split script into {len(narrations)} segments (by lines)")
|
||||
|
||||
# Log statistics
|
||||
if narrations:
|
||||
lengths = [len(s) for s in narrations]
|
||||
logger.info(f" Min: {min(lengths)} chars, Max: {max(lengths)} chars, Avg: {sum(lengths)//len(lengths)} chars")
|
||||
|
||||
return narrations
|
||||
|
||||
|
||||
async def generate_image_prompts(
|
||||
llm_service,
|
||||
narrations: List[str],
|
||||
min_words: int = 30,
|
||||
max_words: int = 60,
|
||||
batch_size: int = 10,
|
||||
max_retries: int = 3,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate image prompts from narrations (with batching and retry)
|
||||
|
||||
Args:
|
||||
llm_service: LLM service instance
|
||||
narrations: List of narrations
|
||||
min_words: Min image prompt length
|
||||
max_words: Max image prompt length
|
||||
batch_size: Max narrations per batch (default: 10)
|
||||
max_retries: Max retry attempts per batch (default: 3)
|
||||
progress_callback: Optional callback(completed, total, message) for progress updates
|
||||
|
||||
Returns:
|
||||
List of image prompts (base prompts, without prefix applied)
|
||||
"""
|
||||
from pixelle_video.prompts import build_image_prompt_prompt
|
||||
|
||||
logger.info(f"Generating image prompts for {len(narrations)} narrations (batch_size={batch_size})")
|
||||
|
||||
# Split narrations into batches
|
||||
batches = [narrations[i:i + batch_size] for i in range(0, len(narrations), batch_size)]
|
||||
logger.info(f"Split into {len(batches)} batches")
|
||||
|
||||
all_prompts = []
|
||||
|
||||
# Process each batch
|
||||
for batch_idx, batch_narrations in enumerate(batches, 1):
|
||||
logger.info(f"Processing batch {batch_idx}/{len(batches)} ({len(batch_narrations)} narrations)")
|
||||
|
||||
# Retry logic for this batch
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
# Generate prompts for this batch
|
||||
prompt = build_image_prompt_prompt(
|
||||
narrations=batch_narrations,
|
||||
min_words=min_words,
|
||||
max_words=max_words
|
||||
)
|
||||
|
||||
response = await llm_service(
|
||||
prompt=prompt,
|
||||
temperature=0.7,
|
||||
max_tokens=8192
|
||||
)
|
||||
|
||||
logger.debug(f"Batch {batch_idx} attempt {attempt}: LLM response length: {len(response)} chars")
|
||||
|
||||
# Parse JSON
|
||||
result = _parse_json(response)
|
||||
|
||||
if "image_prompts" not in result:
|
||||
raise KeyError("Invalid response format: missing 'image_prompts'")
|
||||
|
||||
batch_prompts = result["image_prompts"]
|
||||
|
||||
# Validate count
|
||||
if len(batch_prompts) != len(batch_narrations):
|
||||
error_msg = (
|
||||
f"Batch {batch_idx} prompt count mismatch (attempt {attempt}/{max_retries}):\n"
|
||||
f" Expected: {len(batch_narrations)} prompts\n"
|
||||
f" Got: {len(batch_prompts)} prompts"
|
||||
)
|
||||
logger.warning(error_msg)
|
||||
|
||||
if attempt < max_retries:
|
||||
logger.info(f"Retrying batch {batch_idx}...")
|
||||
continue
|
||||
else:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Success!
|
||||
logger.info(f"✅ Batch {batch_idx} completed successfully ({len(batch_prompts)} prompts)")
|
||||
all_prompts.extend(batch_prompts)
|
||||
|
||||
# Report progress
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
len(all_prompts),
|
||||
len(narrations),
|
||||
f"Batch {batch_idx}/{len(batches)} completed"
|
||||
)
|
||||
|
||||
break
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Batch {batch_idx} JSON parse error (attempt {attempt}/{max_retries}): {e}")
|
||||
if attempt >= max_retries:
|
||||
raise
|
||||
logger.info(f"Retrying batch {batch_idx}...")
|
||||
|
||||
logger.info(f"✅ Generated {len(all_prompts)} image prompts")
|
||||
return all_prompts
|
||||
|
||||
|
||||
def _parse_json(text: str) -> dict:
|
||||
"""
|
||||
Parse JSON from text, with fallback to extract JSON from markdown code blocks
|
||||
|
||||
Args:
|
||||
text: Text containing JSON
|
||||
|
||||
Returns:
|
||||
Parsed JSON dict
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: If no valid JSON found
|
||||
"""
|
||||
# Try direct parsing first
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try to extract JSON from markdown code block
|
||||
json_pattern = r'```(?:json)?\s*([\s\S]+?)\s*```'
|
||||
match = re.search(json_pattern, text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try to find any JSON object in the text
|
||||
json_pattern = r'\{[^{}]*(?:"narrations"|"image_prompts")\s*:\s*\[[^\]]*\][^{}]*\}'
|
||||
match = re.search(json_pattern, text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group(0))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# If all fails, raise error
|
||||
raise json.JSONDecodeError("No valid JSON found", text, 0)
|
||||
|
||||
Reference in New Issue
Block a user