# Copyright (C) 2025 AIDC-AI # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Content generation utility functions Pure/stateless functions for generating content using LLM. These functions are reusable across different pipelines. """ import json import re from typing import List, Optional, Literal from loguru import logger async def generate_title( llm_service, content: str, strategy: Literal["auto", "direct", "llm"] = "auto", max_length: int = 15 ) -> str: """ Generate title from content Args: llm_service: LLM service instance content: Source content (topic or script) strategy: Generation strategy - "auto": Auto-decide based on content length (default) - "direct": Use content directly (truncated if needed) - "llm": Always use LLM to generate title max_length: Maximum title length (default: 15) Returns: Generated title """ if strategy == "direct": content = content.strip() return content[:max_length] if len(content) > max_length else content if strategy == "auto": if len(content.strip()) <= 15: return content.strip() # Fall through to LLM # Use LLM to generate title from pixelle_video.prompts import build_title_generation_prompt # Pass max_length to prompt so LLM knows the character limit prompt = build_title_generation_prompt(content, max_length=max_length) response = await llm_service(prompt, temperature=0.7, max_tokens=50) # Clean up response title = response.strip() # Remove quotes if present if title.startswith('"') and title.endswith('"'): title = title[1:-1] if title.startswith("'") and title.endswith("'"): title = title[1:-1] # Remove trailing punctuation title = title.rstrip('.,!?;:\'"') # Safety: if still over limit, truncate smartly if len(title) > max_length: # Try to truncate at word boundary truncated = title[:max_length] last_space = truncated.rfind(' ') # Only use word boundary if it's not too far back (at least 60% of max_length) if last_space > max_length * 0.6: title = truncated[:last_space] else: title = truncated # Remove any trailing punctuation after truncation title = title.rstrip('.,!?;:\'"') logger.debug(f"Generated title: '{title}' (length: {len(title)})") return title async def generate_narrations_from_topic( llm_service, topic: str, n_scenes: int = 5, min_words: int = 5, max_words: int = 20 ) -> List[str]: """ Generate narrations from topic using LLM Args: llm_service: LLM service instance topic: Topic/theme to generate narrations from n_scenes: Number of narrations to generate min_words: Minimum narration length max_words: Maximum narration length Returns: List of narration texts """ from pixelle_video.prompts import build_topic_narration_prompt logger.info(f"Generating {n_scenes} narrations from topic: {topic}") prompt = build_topic_narration_prompt( topic=topic, n_storyboard=n_scenes, min_words=min_words, max_words=max_words ) response = await llm_service( prompt=prompt, temperature=0.8, max_tokens=2000 ) logger.debug(f"LLM response: {response[:200]}...") # Parse JSON result = _parse_json(response) if "narrations" not in result: raise ValueError("Invalid response format: missing 'narrations' key") narrations = result["narrations"] # Validate count if len(narrations) > n_scenes: logger.warning(f"Got {len(narrations)} narrations, taking first {n_scenes}") narrations = narrations[:n_scenes] elif len(narrations) < n_scenes: raise ValueError(f"Expected {n_scenes} narrations, got only {len(narrations)}") logger.info(f"Generated {len(narrations)} narrations successfully") return narrations async def generate_narrations_from_content( llm_service, content: str, n_scenes: int = 5, min_words: int = 5, max_words: int = 20 ) -> List[str]: """ Generate narrations from user-provided content using LLM Args: llm_service: LLM service instance content: User-provided content n_scenes: Number of narrations to generate min_words: Minimum narration length max_words: Maximum narration length Returns: List of narration texts """ from pixelle_video.prompts import build_content_narration_prompt logger.info(f"Generating {n_scenes} narrations from content ({len(content)} chars)") prompt = build_content_narration_prompt( content=content, n_storyboard=n_scenes, min_words=min_words, max_words=max_words ) response = await llm_service( prompt=prompt, temperature=0.8, max_tokens=2000 ) # Parse JSON result = _parse_json(response) if "narrations" not in result: raise ValueError("Invalid response format: missing 'narrations' key") narrations = result["narrations"] # Validate count if len(narrations) > n_scenes: logger.warning(f"Got {len(narrations)} narrations, taking first {n_scenes}") narrations = narrations[:n_scenes] elif len(narrations) < n_scenes: raise ValueError(f"Expected {n_scenes} narrations, got only {len(narrations)}") logger.info(f"Generated {len(narrations)} narrations successfully") return narrations async def split_narration_script( script: str, split_mode: Literal["paragraph", "line", "sentence", "smart"] = "paragraph", llm_service = None, target_segments: int = 8, ) -> List[str]: """ Split user-provided narration script into segments Args: script: Fixed narration script split_mode: Splitting strategy - "paragraph": Split by double newline (\\n\\n), preserve single newlines within paragraphs - "line": Split by single newline (\\n), each line is a segment - "sentence": Split by sentence-ending punctuation (。.!?!?) - "smart": First split by paragraph, then use LLM to intelligently merge related paragraphs llm_service: LLM service instance (required for "smart" mode) target_segments: Target number of segments for "smart" mode (default: 8) Returns: List of narration segments """ logger.info(f"Splitting script (mode={split_mode}, length={len(script)} chars)") narrations = [] if split_mode == "smart": # Smart mode: first split by paragraph, then merge intelligently if llm_service is None: raise ValueError("llm_service is required for 'smart' split mode") # Step 1: Split by paragraph first paragraphs = re.split(r'\n\s*\n', script) paragraphs = [p.strip() for p in paragraphs if p.strip()] logger.info(f" Initial split: {len(paragraphs)} paragraphs") # Step 2: Merge intelligently using LLM # If target_segments is None, merge_paragraphs_smart will auto-analyze if target_segments is not None and len(paragraphs) <= target_segments: # No need to merge if already within target logger.info(f" Paragraphs count ({len(paragraphs)}) <= target ({target_segments}), no merge needed") narrations = paragraphs else: narrations = await merge_paragraphs_smart( llm_service=llm_service, paragraphs=paragraphs, target_segments=target_segments # Can be None for auto-analysis ) logger.info(f"✅ Smart split: {len(paragraphs)} paragraphs -> {len(narrations)} segments") elif split_mode == "paragraph": # Split by double newline (paragraph mode) # Preserve single newlines within paragraphs paragraphs = re.split(r'\n\s*\n', script) for para in paragraphs: # Only strip leading/trailing whitespace, preserve internal newlines cleaned = para.strip() if cleaned: narrations.append(para) logger.info(f"✅ Split script into {len(narrations)} segments (by paragraph)") elif split_mode == "line": # Split by single newline (original behavior) narrations = [line.strip() for line in script.split('\n') if line.strip()] logger.info(f"✅ Split script into {len(narrations)} segments (by line)") elif split_mode == "sentence": # Split by sentence-ending punctuation # Supports Chinese (。!?) and English (.!?) # Use regex to split while keeping sentences intact cleaned = re.sub(r'\s+', ' ', script.strip()) # Split on sentence-ending punctuation, keeping the punctuation with the sentence sentences = re.split(r'(?<=[。.!?!?])\s*', cleaned) narrations = [s.strip() for s in sentences if s.strip()] logger.info(f"✅ Split script into {len(narrations)} segments (by sentence)") else: # Fallback to line mode logger.warning(f"Unknown split_mode '{split_mode}', falling back to 'line'") narrations = [line.strip() for line in script.split('\n') if line.strip()] # Log statistics if narrations: lengths = [len(s) for s in narrations] logger.info(f" Min: {min(lengths)} chars, Max: {max(lengths)} chars, Avg: {sum(lengths)//len(lengths)} chars") return narrations async def merge_paragraphs_smart( llm_service, paragraphs: List[str], target_segments: int = None, # Now optional - auto-analyze if not provided max_retries: int = 3, ) -> List[str]: """ Use LLM to intelligently merge paragraphs based on semantic relevance. Two-step approach: 1. If target_segments is not provided, first analyze text to recommend optimal count 2. Then group paragraphs based on the target count Args: llm_service: LLM service instance paragraphs: List of original paragraphs target_segments: Target number of merged segments (auto-analyzed if None) max_retries: Maximum retry attempts for each step Returns: List of merged paragraphs """ from pixelle_video.prompts import ( build_paragraph_analysis_prompt, build_paragraph_grouping_prompt ) # ======================================== # Step 1: Analyze and recommend segment count (if not provided) # ======================================== if target_segments is None: logger.info(f"Analyzing {len(paragraphs)} paragraphs to recommend segment count...") analysis_prompt = build_paragraph_analysis_prompt(paragraphs) analysis_result = None for attempt in range(1, max_retries + 1): try: response = await llm_service( prompt=analysis_prompt, temperature=0.3, max_tokens=1500 ) logger.debug(f"Analysis response length: {len(response)} chars") result = _parse_json(response) if "recommended_segments" not in result: raise KeyError("Missing 'recommended_segments' in analysis") target_segments = result["recommended_segments"] analysis_result = result # Validate range if target_segments < 3: target_segments = 3 elif target_segments > 15: target_segments = 15 reasoning = result.get("reasoning", "N/A") logger.info(f"✅ Analysis complete: recommended {target_segments} segments") logger.info(f" Reasoning: {reasoning[:100]}...") break except Exception as e: logger.error(f"Analysis attempt {attempt} failed: {e}") if attempt >= max_retries: # Fallback: use simple heuristic target_segments = max(3, min(12, len(paragraphs) // 3)) logger.warning(f"Using fallback: {target_segments} segments (paragraphs/3)") analysis_result = None break logger.info("Retrying analysis...") else: analysis_result = None logger.info(f"Using provided target: {target_segments} segments") # ======================================== # Step 2: Group paragraphs # ======================================== logger.info(f"Grouping {len(paragraphs)} paragraphs into {target_segments} segments...") grouping_prompt = build_paragraph_grouping_prompt( paragraphs=paragraphs, target_segments=target_segments, analysis_result=analysis_result ) for attempt in range(1, max_retries + 1): try: response = await llm_service( prompt=grouping_prompt, temperature=0.3, max_tokens=2000 ) logger.debug(f"Grouping response length: {len(response)} chars") result = _parse_json(response) if "groups" not in result: raise KeyError("Invalid response format: missing 'groups'") groups = result["groups"] # Validate count if len(groups) != target_segments: logger.warning( f"Grouping attempt {attempt}: expected {target_segments} groups, got {len(groups)}" ) if attempt < max_retries: continue logger.warning(f"Accepting {len(groups)} groups after {max_retries} attempts") # Validate group boundaries for i, group in enumerate(groups): if "start" not in group or "end" not in group: raise ValueError(f"Group {i} missing 'start' or 'end'") if group["start"] > group["end"]: raise ValueError(f"Group {i} has invalid range: start > end") if group["start"] < 0 or group["end"] >= len(paragraphs): raise ValueError(f"Group {i} has out-of-bounds indices") # Merge paragraphs based on groups merged = [] for group in groups: start, end = group["start"], group["end"] merged_text = "\n\n".join(paragraphs[start:end + 1]) merged.append(merged_text) logger.info(f"✅ Successfully merged into {len(merged)} segments") return merged except Exception as e: logger.error(f"Grouping attempt {attempt} failed: {e}") if attempt >= max_retries: raise logger.info("Retrying grouping...") # Fallback: should not reach here return paragraphs async def generate_image_prompts( llm_service, narrations: List[str], min_words: int = 30, max_words: int = 60, batch_size: int = 10, max_retries: int = 3, progress_callback: Optional[callable] = None ) -> List[str]: """ Generate image prompts from narrations (with batching and retry) Args: llm_service: LLM service instance narrations: List of narrations min_words: Min image prompt length max_words: Max image prompt length batch_size: Max narrations per batch (default: 10) max_retries: Max retry attempts per batch (default: 3) progress_callback: Optional callback(completed, total, message) for progress updates Returns: List of image prompts (base prompts, without prefix applied) """ from pixelle_video.prompts import build_image_prompt_prompt logger.info(f"Generating image prompts for {len(narrations)} narrations (batch_size={batch_size})") # Split narrations into batches batches = [narrations[i:i + batch_size] for i in range(0, len(narrations), batch_size)] logger.info(f"Split into {len(batches)} batches") all_prompts = [] # Process each batch for batch_idx, batch_narrations in enumerate(batches, 1): logger.info(f"Processing batch {batch_idx}/{len(batches)} ({len(batch_narrations)} narrations)") # Retry logic for this batch for attempt in range(1, max_retries + 1): try: # Generate prompts for this batch prompt = build_image_prompt_prompt( narrations=batch_narrations, min_words=min_words, max_words=max_words ) response = await llm_service( prompt=prompt, temperature=0.7, max_tokens=8192 ) logger.debug(f"Batch {batch_idx} attempt {attempt}: LLM response length: {len(response)} chars") # Parse JSON result = _parse_json(response) if "image_prompts" not in result: raise KeyError("Invalid response format: missing 'image_prompts'") batch_prompts = result["image_prompts"] # Validate count if len(batch_prompts) != len(batch_narrations): error_msg = ( f"Batch {batch_idx} prompt count mismatch (attempt {attempt}/{max_retries}):\n" f" Expected: {len(batch_narrations)} prompts\n" f" Got: {len(batch_prompts)} prompts" ) logger.warning(error_msg) if attempt < max_retries: logger.info(f"Retrying batch {batch_idx}...") continue else: raise ValueError(error_msg) # Success! logger.info(f"✅ Batch {batch_idx} completed successfully ({len(batch_prompts)} prompts)") all_prompts.extend(batch_prompts) # Report progress if progress_callback: progress_callback( len(all_prompts), len(narrations), f"Batch {batch_idx}/{len(batches)} completed" ) break except json.JSONDecodeError as e: logger.error(f"Batch {batch_idx} JSON parse error (attempt {attempt}/{max_retries}): {e}") if attempt >= max_retries: raise logger.info(f"Retrying batch {batch_idx}...") logger.info(f"✅ Generated {len(all_prompts)} image prompts") return all_prompts async def generate_video_prompts( llm_service, narrations: List[str], min_words: int = 30, max_words: int = 60, batch_size: int = 10, max_retries: int = 3, progress_callback: Optional[callable] = None ) -> List[str]: """ Generate video prompts from narrations (with batching and retry) Args: llm_service: LLM service instance narrations: List of narrations min_words: Min video prompt length max_words: Max video prompt length batch_size: Max narrations per batch (default: 10) max_retries: Max retry attempts per batch (default: 3) progress_callback: Optional callback(completed, total, message) for progress updates Returns: List of video prompts (base prompts, without prefix applied) """ from pixelle_video.prompts.video_generation import build_video_prompt_prompt logger.info(f"Generating video prompts for {len(narrations)} narrations (batch_size={batch_size})") # Split narrations into batches batches = [narrations[i:i + batch_size] for i in range(0, len(narrations), batch_size)] logger.info(f"Split into {len(batches)} batches") all_prompts = [] # Process each batch for batch_idx, batch_narrations in enumerate(batches, 1): logger.info(f"Processing batch {batch_idx}/{len(batches)} ({len(batch_narrations)} narrations)") # Retry logic for this batch for attempt in range(1, max_retries + 1): try: # Generate prompts for this batch prompt = build_video_prompt_prompt( narrations=batch_narrations, min_words=min_words, max_words=max_words ) response = await llm_service( prompt=prompt, temperature=0.7, max_tokens=8192 ) logger.debug(f"Batch {batch_idx} attempt {attempt}: LLM response length: {len(response)} chars") # Parse JSON result = _parse_json(response) if "video_prompts" not in result: raise KeyError("Invalid response format: missing 'video_prompts'") batch_prompts = result["video_prompts"] # Validate batch result if len(batch_prompts) != len(batch_narrations): raise ValueError( f"Prompt count mismatch: expected {len(batch_narrations)}, got {len(batch_prompts)}" ) # Success - add to all_prompts all_prompts.extend(batch_prompts) logger.info(f"✓ Batch {batch_idx} completed: {len(batch_prompts)} video prompts") # Report progress if progress_callback: completed = len(all_prompts) total = len(narrations) progress_callback(completed, total, f"Batch {batch_idx}/{len(batches)} completed") break # Success, move to next batch except Exception as e: logger.warning(f"✗ Batch {batch_idx} attempt {attempt} failed: {e}") if attempt >= max_retries: raise logger.info(f"Retrying batch {batch_idx}...") logger.info(f"✅ Generated {len(all_prompts)} video prompts") return all_prompts def _parse_json(text: str) -> dict: """ Parse JSON from text, with fallback to extract JSON from markdown code blocks Args: text: Text containing JSON Returns: Parsed JSON dict Raises: json.JSONDecodeError: If no valid JSON found """ # Try direct parsing first try: return json.loads(text) except json.JSONDecodeError: pass # Try to extract JSON from markdown code block json_pattern = r'```(?:json)?\s*([\s\S]+?)\s*```' match = re.search(json_pattern, text, re.DOTALL) if match: try: return json.loads(match.group(1)) except json.JSONDecodeError: pass # Try to find any JSON object with known keys (including analysis keys) json_pattern = r'\{[^{}]*(?:"narrations"|"image_prompts"|"video_prompts"|"merged_paragraphs"|"groups"|"recommended_segments"|"scene_boundaries")\s*:\s*[^{}]*\}' match = re.search(json_pattern, text, re.DOTALL) if match: try: return json.loads(match.group(0)) except json.JSONDecodeError: pass # Try to find any JSON object that looks like it contains an array # This is a more aggressive fallback for complex nested arrays json_start = text.find('{') json_end = text.rfind('}') if json_start != -1 and json_end != -1 and json_end > json_start: potential_json = text[json_start:json_end + 1] try: return json.loads(potential_json) except json.JSONDecodeError: pass # If all fails, raise error raise json.JSONDecodeError("No valid JSON found", text, 0)