Files
AI-Video/pixelle_video/utils/content_generators.py
empty 3d3aba3670
Some checks failed
Deploy Documentation / deploy (push) Has been cancelled
feat: Add smart paragraph merging mode with AI grouping
- Add "smart" split mode that uses LLM to intelligently merge related paragraphs
- Implement two-step approach: analyze text structure, then group by semantic relevance
- Add paragraph_merging.py with analysis and grouping prompts
- Update UI to support smart mode selection with auto-detect hint
- Add i18n translations for smart mode (en_US, zh_CN)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 00:19:46 +08:00

688 lines
25 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Content generation utility functions
Pure/stateless functions for generating content using LLM.
These functions are reusable across different pipelines.
"""
import json
import re
from typing import List, Optional, Literal
from loguru import logger
async def generate_title(
llm_service,
content: str,
strategy: Literal["auto", "direct", "llm"] = "auto",
max_length: int = 15
) -> str:
"""
Generate title from content
Args:
llm_service: LLM service instance
content: Source content (topic or script)
strategy: Generation strategy
- "auto": Auto-decide based on content length (default)
- "direct": Use content directly (truncated if needed)
- "llm": Always use LLM to generate title
max_length: Maximum title length (default: 15)
Returns:
Generated title
"""
if strategy == "direct":
content = content.strip()
return content[:max_length] if len(content) > max_length else content
if strategy == "auto":
if len(content.strip()) <= 15:
return content.strip()
# Fall through to LLM
# Use LLM to generate title
from pixelle_video.prompts import build_title_generation_prompt
# Pass max_length to prompt so LLM knows the character limit
prompt = build_title_generation_prompt(content, max_length=max_length)
response = await llm_service(prompt, temperature=0.7, max_tokens=50)
# Clean up response
title = response.strip()
# Remove quotes if present
if title.startswith('"') and title.endswith('"'):
title = title[1:-1]
if title.startswith("'") and title.endswith("'"):
title = title[1:-1]
# Remove trailing punctuation
title = title.rstrip('.,!?;:\'"')
# Safety: if still over limit, truncate smartly
if len(title) > max_length:
# Try to truncate at word boundary
truncated = title[:max_length]
last_space = truncated.rfind(' ')
# Only use word boundary if it's not too far back (at least 60% of max_length)
if last_space > max_length * 0.6:
title = truncated[:last_space]
else:
title = truncated
# Remove any trailing punctuation after truncation
title = title.rstrip('.,!?;:\'"')
logger.debug(f"Generated title: '{title}' (length: {len(title)})")
return title
async def generate_narrations_from_topic(
llm_service,
topic: str,
n_scenes: int = 5,
min_words: int = 5,
max_words: int = 20
) -> List[str]:
"""
Generate narrations from topic using LLM
Args:
llm_service: LLM service instance
topic: Topic/theme to generate narrations from
n_scenes: Number of narrations to generate
min_words: Minimum narration length
max_words: Maximum narration length
Returns:
List of narration texts
"""
from pixelle_video.prompts import build_topic_narration_prompt
logger.info(f"Generating {n_scenes} narrations from topic: {topic}")
prompt = build_topic_narration_prompt(
topic=topic,
n_storyboard=n_scenes,
min_words=min_words,
max_words=max_words
)
response = await llm_service(
prompt=prompt,
temperature=0.8,
max_tokens=2000
)
logger.debug(f"LLM response: {response[:200]}...")
# Parse JSON
result = _parse_json(response)
if "narrations" not in result:
raise ValueError("Invalid response format: missing 'narrations' key")
narrations = result["narrations"]
# Validate count
if len(narrations) > n_scenes:
logger.warning(f"Got {len(narrations)} narrations, taking first {n_scenes}")
narrations = narrations[:n_scenes]
elif len(narrations) < n_scenes:
raise ValueError(f"Expected {n_scenes} narrations, got only {len(narrations)}")
logger.info(f"Generated {len(narrations)} narrations successfully")
return narrations
async def generate_narrations_from_content(
llm_service,
content: str,
n_scenes: int = 5,
min_words: int = 5,
max_words: int = 20
) -> List[str]:
"""
Generate narrations from user-provided content using LLM
Args:
llm_service: LLM service instance
content: User-provided content
n_scenes: Number of narrations to generate
min_words: Minimum narration length
max_words: Maximum narration length
Returns:
List of narration texts
"""
from pixelle_video.prompts import build_content_narration_prompt
logger.info(f"Generating {n_scenes} narrations from content ({len(content)} chars)")
prompt = build_content_narration_prompt(
content=content,
n_storyboard=n_scenes,
min_words=min_words,
max_words=max_words
)
response = await llm_service(
prompt=prompt,
temperature=0.8,
max_tokens=2000
)
# Parse JSON
result = _parse_json(response)
if "narrations" not in result:
raise ValueError("Invalid response format: missing 'narrations' key")
narrations = result["narrations"]
# Validate count
if len(narrations) > n_scenes:
logger.warning(f"Got {len(narrations)} narrations, taking first {n_scenes}")
narrations = narrations[:n_scenes]
elif len(narrations) < n_scenes:
raise ValueError(f"Expected {n_scenes} narrations, got only {len(narrations)}")
logger.info(f"Generated {len(narrations)} narrations successfully")
return narrations
async def split_narration_script(
script: str,
split_mode: Literal["paragraph", "line", "sentence", "smart"] = "paragraph",
llm_service = None,
target_segments: int = 8,
) -> List[str]:
"""
Split user-provided narration script into segments
Args:
script: Fixed narration script
split_mode: Splitting strategy
- "paragraph": Split by double newline (\\n\\n), preserve single newlines within paragraphs
- "line": Split by single newline (\\n), each line is a segment
- "sentence": Split by sentence-ending punctuation (。.!?)
- "smart": First split by paragraph, then use LLM to intelligently merge related paragraphs
llm_service: LLM service instance (required for "smart" mode)
target_segments: Target number of segments for "smart" mode (default: 8)
Returns:
List of narration segments
"""
logger.info(f"Splitting script (mode={split_mode}, length={len(script)} chars)")
narrations = []
if split_mode == "smart":
# Smart mode: first split by paragraph, then merge intelligently
if llm_service is None:
raise ValueError("llm_service is required for 'smart' split mode")
# Step 1: Split by paragraph first
paragraphs = re.split(r'\n\s*\n', script)
paragraphs = [p.strip() for p in paragraphs if p.strip()]
logger.info(f" Initial split: {len(paragraphs)} paragraphs")
# Step 2: Merge intelligently using LLM
# If target_segments is None, merge_paragraphs_smart will auto-analyze
if target_segments is not None and len(paragraphs) <= target_segments:
# No need to merge if already within target
logger.info(f" Paragraphs count ({len(paragraphs)}) <= target ({target_segments}), no merge needed")
narrations = paragraphs
else:
narrations = await merge_paragraphs_smart(
llm_service=llm_service,
paragraphs=paragraphs,
target_segments=target_segments # Can be None for auto-analysis
)
logger.info(f"✅ Smart split: {len(paragraphs)} paragraphs -> {len(narrations)} segments")
elif split_mode == "paragraph":
# Split by double newline (paragraph mode)
# Preserve single newlines within paragraphs
paragraphs = re.split(r'\n\s*\n', script)
for para in paragraphs:
# Only strip leading/trailing whitespace, preserve internal newlines
cleaned = para.strip()
if cleaned:
narrations.append(para)
logger.info(f"✅ Split script into {len(narrations)} segments (by paragraph)")
elif split_mode == "line":
# Split by single newline (original behavior)
narrations = [line.strip() for line in script.split('\n') if line.strip()]
logger.info(f"✅ Split script into {len(narrations)} segments (by line)")
elif split_mode == "sentence":
# Split by sentence-ending punctuation
# Supports Chinese (。!?) and English (.!?)
# Use regex to split while keeping sentences intact
cleaned = re.sub(r'\s+', ' ', script.strip())
# Split on sentence-ending punctuation, keeping the punctuation with the sentence
sentences = re.split(r'(?<=[。.!?])\s*', cleaned)
narrations = [s.strip() for s in sentences if s.strip()]
logger.info(f"✅ Split script into {len(narrations)} segments (by sentence)")
else:
# Fallback to line mode
logger.warning(f"Unknown split_mode '{split_mode}', falling back to 'line'")
narrations = [line.strip() for line in script.split('\n') if line.strip()]
# Log statistics
if narrations:
lengths = [len(s) for s in narrations]
logger.info(f" Min: {min(lengths)} chars, Max: {max(lengths)} chars, Avg: {sum(lengths)//len(lengths)} chars")
return narrations
async def merge_paragraphs_smart(
llm_service,
paragraphs: List[str],
target_segments: int = None, # Now optional - auto-analyze if not provided
max_retries: int = 3,
) -> List[str]:
"""
Use LLM to intelligently merge paragraphs based on semantic relevance.
Two-step approach:
1. If target_segments is not provided, first analyze text to recommend optimal count
2. Then group paragraphs based on the target count
Args:
llm_service: LLM service instance
paragraphs: List of original paragraphs
target_segments: Target number of merged segments (auto-analyzed if None)
max_retries: Maximum retry attempts for each step
Returns:
List of merged paragraphs
"""
from pixelle_video.prompts import (
build_paragraph_analysis_prompt,
build_paragraph_grouping_prompt
)
# ========================================
# Step 1: Analyze and recommend segment count (if not provided)
# ========================================
if target_segments is None:
logger.info(f"Analyzing {len(paragraphs)} paragraphs to recommend segment count...")
analysis_prompt = build_paragraph_analysis_prompt(paragraphs)
analysis_result = None
for attempt in range(1, max_retries + 1):
try:
response = await llm_service(
prompt=analysis_prompt,
temperature=0.3,
max_tokens=1500
)
logger.debug(f"Analysis response length: {len(response)} chars")
result = _parse_json(response)
if "recommended_segments" not in result:
raise KeyError("Missing 'recommended_segments' in analysis")
target_segments = result["recommended_segments"]
analysis_result = result
# Validate range
if target_segments < 3:
target_segments = 3
elif target_segments > 15:
target_segments = 15
reasoning = result.get("reasoning", "N/A")
logger.info(f"✅ Analysis complete: recommended {target_segments} segments")
logger.info(f" Reasoning: {reasoning[:100]}...")
break
except Exception as e:
logger.error(f"Analysis attempt {attempt} failed: {e}")
if attempt >= max_retries:
# Fallback: use simple heuristic
target_segments = max(3, min(12, len(paragraphs) // 3))
logger.warning(f"Using fallback: {target_segments} segments (paragraphs/3)")
analysis_result = None
break
logger.info("Retrying analysis...")
else:
analysis_result = None
logger.info(f"Using provided target: {target_segments} segments")
# ========================================
# Step 2: Group paragraphs
# ========================================
logger.info(f"Grouping {len(paragraphs)} paragraphs into {target_segments} segments...")
grouping_prompt = build_paragraph_grouping_prompt(
paragraphs=paragraphs,
target_segments=target_segments,
analysis_result=analysis_result
)
for attempt in range(1, max_retries + 1):
try:
response = await llm_service(
prompt=grouping_prompt,
temperature=0.3,
max_tokens=2000
)
logger.debug(f"Grouping response length: {len(response)} chars")
result = _parse_json(response)
if "groups" not in result:
raise KeyError("Invalid response format: missing 'groups'")
groups = result["groups"]
# Validate count
if len(groups) != target_segments:
logger.warning(
f"Grouping attempt {attempt}: expected {target_segments} groups, got {len(groups)}"
)
if attempt < max_retries:
continue
logger.warning(f"Accepting {len(groups)} groups after {max_retries} attempts")
# Validate group boundaries
for i, group in enumerate(groups):
if "start" not in group or "end" not in group:
raise ValueError(f"Group {i} missing 'start' or 'end'")
if group["start"] > group["end"]:
raise ValueError(f"Group {i} has invalid range: start > end")
if group["start"] < 0 or group["end"] >= len(paragraphs):
raise ValueError(f"Group {i} has out-of-bounds indices")
# Merge paragraphs based on groups
merged = []
for group in groups:
start, end = group["start"], group["end"]
merged_text = "\n\n".join(paragraphs[start:end + 1])
merged.append(merged_text)
logger.info(f"✅ Successfully merged into {len(merged)} segments")
return merged
except Exception as e:
logger.error(f"Grouping attempt {attempt} failed: {e}")
if attempt >= max_retries:
raise
logger.info("Retrying grouping...")
# Fallback: should not reach here
return paragraphs
async def generate_image_prompts(
llm_service,
narrations: List[str],
min_words: int = 30,
max_words: int = 60,
batch_size: int = 10,
max_retries: int = 3,
progress_callback: Optional[callable] = None
) -> List[str]:
"""
Generate image prompts from narrations (with batching and retry)
Args:
llm_service: LLM service instance
narrations: List of narrations
min_words: Min image prompt length
max_words: Max image prompt length
batch_size: Max narrations per batch (default: 10)
max_retries: Max retry attempts per batch (default: 3)
progress_callback: Optional callback(completed, total, message) for progress updates
Returns:
List of image prompts (base prompts, without prefix applied)
"""
from pixelle_video.prompts import build_image_prompt_prompt
logger.info(f"Generating image prompts for {len(narrations)} narrations (batch_size={batch_size})")
# Split narrations into batches
batches = [narrations[i:i + batch_size] for i in range(0, len(narrations), batch_size)]
logger.info(f"Split into {len(batches)} batches")
all_prompts = []
# Process each batch
for batch_idx, batch_narrations in enumerate(batches, 1):
logger.info(f"Processing batch {batch_idx}/{len(batches)} ({len(batch_narrations)} narrations)")
# Retry logic for this batch
for attempt in range(1, max_retries + 1):
try:
# Generate prompts for this batch
prompt = build_image_prompt_prompt(
narrations=batch_narrations,
min_words=min_words,
max_words=max_words
)
response = await llm_service(
prompt=prompt,
temperature=0.7,
max_tokens=8192
)
logger.debug(f"Batch {batch_idx} attempt {attempt}: LLM response length: {len(response)} chars")
# Parse JSON
result = _parse_json(response)
if "image_prompts" not in result:
raise KeyError("Invalid response format: missing 'image_prompts'")
batch_prompts = result["image_prompts"]
# Validate count
if len(batch_prompts) != len(batch_narrations):
error_msg = (
f"Batch {batch_idx} prompt count mismatch (attempt {attempt}/{max_retries}):\n"
f" Expected: {len(batch_narrations)} prompts\n"
f" Got: {len(batch_prompts)} prompts"
)
logger.warning(error_msg)
if attempt < max_retries:
logger.info(f"Retrying batch {batch_idx}...")
continue
else:
raise ValueError(error_msg)
# Success!
logger.info(f"✅ Batch {batch_idx} completed successfully ({len(batch_prompts)} prompts)")
all_prompts.extend(batch_prompts)
# Report progress
if progress_callback:
progress_callback(
len(all_prompts),
len(narrations),
f"Batch {batch_idx}/{len(batches)} completed"
)
break
except json.JSONDecodeError as e:
logger.error(f"Batch {batch_idx} JSON parse error (attempt {attempt}/{max_retries}): {e}")
if attempt >= max_retries:
raise
logger.info(f"Retrying batch {batch_idx}...")
logger.info(f"✅ Generated {len(all_prompts)} image prompts")
return all_prompts
async def generate_video_prompts(
llm_service,
narrations: List[str],
min_words: int = 30,
max_words: int = 60,
batch_size: int = 10,
max_retries: int = 3,
progress_callback: Optional[callable] = None
) -> List[str]:
"""
Generate video prompts from narrations (with batching and retry)
Args:
llm_service: LLM service instance
narrations: List of narrations
min_words: Min video prompt length
max_words: Max video prompt length
batch_size: Max narrations per batch (default: 10)
max_retries: Max retry attempts per batch (default: 3)
progress_callback: Optional callback(completed, total, message) for progress updates
Returns:
List of video prompts (base prompts, without prefix applied)
"""
from pixelle_video.prompts.video_generation import build_video_prompt_prompt
logger.info(f"Generating video prompts for {len(narrations)} narrations (batch_size={batch_size})")
# Split narrations into batches
batches = [narrations[i:i + batch_size] for i in range(0, len(narrations), batch_size)]
logger.info(f"Split into {len(batches)} batches")
all_prompts = []
# Process each batch
for batch_idx, batch_narrations in enumerate(batches, 1):
logger.info(f"Processing batch {batch_idx}/{len(batches)} ({len(batch_narrations)} narrations)")
# Retry logic for this batch
for attempt in range(1, max_retries + 1):
try:
# Generate prompts for this batch
prompt = build_video_prompt_prompt(
narrations=batch_narrations,
min_words=min_words,
max_words=max_words
)
response = await llm_service(
prompt=prompt,
temperature=0.7,
max_tokens=8192
)
logger.debug(f"Batch {batch_idx} attempt {attempt}: LLM response length: {len(response)} chars")
# Parse JSON
result = _parse_json(response)
if "video_prompts" not in result:
raise KeyError("Invalid response format: missing 'video_prompts'")
batch_prompts = result["video_prompts"]
# Validate batch result
if len(batch_prompts) != len(batch_narrations):
raise ValueError(
f"Prompt count mismatch: expected {len(batch_narrations)}, got {len(batch_prompts)}"
)
# Success - add to all_prompts
all_prompts.extend(batch_prompts)
logger.info(f"✓ Batch {batch_idx} completed: {len(batch_prompts)} video prompts")
# Report progress
if progress_callback:
completed = len(all_prompts)
total = len(narrations)
progress_callback(completed, total, f"Batch {batch_idx}/{len(batches)} completed")
break # Success, move to next batch
except Exception as e:
logger.warning(f"✗ Batch {batch_idx} attempt {attempt} failed: {e}")
if attempt >= max_retries:
raise
logger.info(f"Retrying batch {batch_idx}...")
logger.info(f"✅ Generated {len(all_prompts)} video prompts")
return all_prompts
def _parse_json(text: str) -> dict:
"""
Parse JSON from text, with fallback to extract JSON from markdown code blocks
Args:
text: Text containing JSON
Returns:
Parsed JSON dict
Raises:
json.JSONDecodeError: If no valid JSON found
"""
# Try direct parsing first
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Try to extract JSON from markdown code block
json_pattern = r'```(?:json)?\s*([\s\S]+?)\s*```'
match = re.search(json_pattern, text, re.DOTALL)
if match:
try:
return json.loads(match.group(1))
except json.JSONDecodeError:
pass
# Try to find any JSON object with known keys (including analysis keys)
json_pattern = r'\{[^{}]*(?:"narrations"|"image_prompts"|"video_prompts"|"merged_paragraphs"|"groups"|"recommended_segments"|"scene_boundaries")\s*:\s*[^{}]*\}'
match = re.search(json_pattern, text, re.DOTALL)
if match:
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
pass
# Try to find any JSON object that looks like it contains an array
# This is a more aggressive fallback for complex nested arrays
json_start = text.find('{')
json_end = text.rfind('}')
if json_start != -1 and json_end != -1 and json_end > json_start:
potential_json = text[json_start:json_end + 1]
try:
return json.loads(potential_json)
except json.JSONDecodeError:
pass
# If all fails, raise error
raise json.JSONDecodeError("No valid JSON found", text, 0)