# Copyright (C) 2025 AIDC-AI # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Content generation utility functions Pure/stateless functions for generating content using LLM. These functions are reusable across different pipelines. """ import json import re from typing import List, Optional, Literal from loguru import logger async def generate_title( llm_service, content: str, strategy: Literal["auto", "direct", "llm"] = "auto", max_length: int = 15 ) -> str: """ Generate title from content Args: llm_service: LLM service instance content: Source content (topic or script) strategy: Generation strategy - "auto": Auto-decide based on content length (default) - "direct": Use content directly (truncated if needed) - "llm": Always use LLM to generate title max_length: Maximum title length (default: 15) Returns: Generated title """ if strategy == "direct": content = content.strip() return content[:max_length] if len(content) > max_length else content if strategy == "auto": if len(content.strip()) <= 15: return content.strip() # Fall through to LLM # Use LLM to generate title from pixelle_video.prompts import build_title_generation_prompt # Pass max_length to prompt so LLM knows the character limit prompt = build_title_generation_prompt(content, max_length=max_length) response = await llm_service(prompt, temperature=0.7, max_tokens=50) # Clean up response title = response.strip() # Remove quotes if present if title.startswith('"') and title.endswith('"'): title = title[1:-1] if title.startswith("'") and title.endswith("'"): title = title[1:-1] # Remove trailing punctuation title = title.rstrip('.,!?;:\'"') # Safety: if still over limit, truncate smartly if len(title) > max_length: # Try to truncate at word boundary truncated = title[:max_length] last_space = truncated.rfind(' ') # Only use word boundary if it's not too far back (at least 60% of max_length) if last_space > max_length * 0.6: title = truncated[:last_space] else: title = truncated # Remove any trailing punctuation after truncation title = title.rstrip('.,!?;:\'"') logger.debug(f"Generated title: '{title}' (length: {len(title)})") return title async def generate_narrations_from_topic( llm_service, topic: str, n_scenes: int = 5, min_words: int = 5, max_words: int = 20 ) -> List[str]: """ Generate narrations from topic using LLM Args: llm_service: LLM service instance topic: Topic/theme to generate narrations from n_scenes: Number of narrations to generate min_words: Minimum narration length max_words: Maximum narration length Returns: List of narration texts """ from pixelle_video.prompts import build_topic_narration_prompt logger.info(f"Generating {n_scenes} narrations from topic: {topic}") prompt = build_topic_narration_prompt( topic=topic, n_storyboard=n_scenes, min_words=min_words, max_words=max_words ) response = await llm_service( prompt=prompt, temperature=0.8, max_tokens=2000 ) logger.debug(f"LLM response: {response[:200]}...") # Parse JSON result = _parse_json(response) if "narrations" not in result: raise ValueError("Invalid response format: missing 'narrations' key") narrations = result["narrations"] # Validate count if len(narrations) > n_scenes: logger.warning(f"Got {len(narrations)} narrations, taking first {n_scenes}") narrations = narrations[:n_scenes] elif len(narrations) < n_scenes: raise ValueError(f"Expected {n_scenes} narrations, got only {len(narrations)}") logger.info(f"Generated {len(narrations)} narrations successfully") return narrations async def generate_narrations_from_content( llm_service, content: str, n_scenes: int = 5, min_words: int = 5, max_words: int = 20 ) -> List[str]: """ Generate narrations from user-provided content using LLM Args: llm_service: LLM service instance content: User-provided content n_scenes: Number of narrations to generate min_words: Minimum narration length max_words: Maximum narration length Returns: List of narration texts """ from pixelle_video.prompts import build_content_narration_prompt logger.info(f"Generating {n_scenes} narrations from content ({len(content)} chars)") prompt = build_content_narration_prompt( content=content, n_storyboard=n_scenes, min_words=min_words, max_words=max_words ) response = await llm_service( prompt=prompt, temperature=0.8, max_tokens=2000 ) # Parse JSON result = _parse_json(response) if "narrations" not in result: raise ValueError("Invalid response format: missing 'narrations' key") narrations = result["narrations"] # Validate count if len(narrations) > n_scenes: logger.warning(f"Got {len(narrations)} narrations, taking first {n_scenes}") narrations = narrations[:n_scenes] elif len(narrations) < n_scenes: raise ValueError(f"Expected {n_scenes} narrations, got only {len(narrations)}") logger.info(f"Generated {len(narrations)} narrations successfully") return narrations async def split_narration_script( script: str, split_mode: Literal["paragraph", "line", "sentence"] = "paragraph", ) -> List[str]: """ Split user-provided narration script into segments Args: script: Fixed narration script split_mode: Splitting strategy - "paragraph": Split by double newline (\\n\\n), preserve single newlines within paragraphs - "line": Split by single newline (\\n), each line is a segment - "sentence": Split by sentence-ending punctuation (。.!?!?) Returns: List of narration segments """ logger.info(f"Splitting script (mode={split_mode}, length={len(script)} chars)") narrations = [] if split_mode == "paragraph": # Split by double newline (paragraph mode) # Preserve single newlines within paragraphs paragraphs = re.split(r'\n\s*\n', script) for para in paragraphs: # Only strip leading/trailing whitespace, preserve internal newlines cleaned = para.strip() if cleaned: narrations.append(para) logger.info(f"✅ Split script into {len(narrations)} segments (by paragraph)") elif split_mode == "line": # Split by single newline (original behavior) narrations = [line.strip() for line in script.split('\n') if line.strip()] logger.info(f"✅ Split script into {len(narrations)} segments (by line)") elif split_mode == "sentence": # Split by sentence-ending punctuation # Supports Chinese (。!?) and English (.!?) # Use regex to split while keeping sentences intact cleaned = re.sub(r'\s+', ' ', script.strip()) # Split on sentence-ending punctuation, keeping the punctuation with the sentence sentences = re.split(r'(?<=[。.!?!?])\s*', cleaned) narrations = [s.strip() for s in sentences if s.strip()] logger.info(f"✅ Split script into {len(narrations)} segments (by sentence)") else: # Fallback to line mode logger.warning(f"Unknown split_mode '{split_mode}', falling back to 'line'") narrations = [line.strip() for line in script.split('\n') if line.strip()] # Log statistics if narrations: lengths = [len(s) for s in narrations] logger.info(f" Min: {min(lengths)} chars, Max: {max(lengths)} chars, Avg: {sum(lengths)//len(lengths)} chars") return narrations async def generate_image_prompts( llm_service, narrations: List[str], min_words: int = 30, max_words: int = 60, batch_size: int = 10, max_retries: int = 3, progress_callback: Optional[callable] = None ) -> List[str]: """ Generate image prompts from narrations (with batching and retry) Args: llm_service: LLM service instance narrations: List of narrations min_words: Min image prompt length max_words: Max image prompt length batch_size: Max narrations per batch (default: 10) max_retries: Max retry attempts per batch (default: 3) progress_callback: Optional callback(completed, total, message) for progress updates Returns: List of image prompts (base prompts, without prefix applied) """ from pixelle_video.prompts import build_image_prompt_prompt logger.info(f"Generating image prompts for {len(narrations)} narrations (batch_size={batch_size})") # Split narrations into batches batches = [narrations[i:i + batch_size] for i in range(0, len(narrations), batch_size)] logger.info(f"Split into {len(batches)} batches") all_prompts = [] # Process each batch for batch_idx, batch_narrations in enumerate(batches, 1): logger.info(f"Processing batch {batch_idx}/{len(batches)} ({len(batch_narrations)} narrations)") # Retry logic for this batch for attempt in range(1, max_retries + 1): try: # Generate prompts for this batch prompt = build_image_prompt_prompt( narrations=batch_narrations, min_words=min_words, max_words=max_words ) response = await llm_service( prompt=prompt, temperature=0.7, max_tokens=8192 ) logger.debug(f"Batch {batch_idx} attempt {attempt}: LLM response length: {len(response)} chars") # Parse JSON result = _parse_json(response) if "image_prompts" not in result: raise KeyError("Invalid response format: missing 'image_prompts'") batch_prompts = result["image_prompts"] # Validate count if len(batch_prompts) != len(batch_narrations): error_msg = ( f"Batch {batch_idx} prompt count mismatch (attempt {attempt}/{max_retries}):\n" f" Expected: {len(batch_narrations)} prompts\n" f" Got: {len(batch_prompts)} prompts" ) logger.warning(error_msg) if attempt < max_retries: logger.info(f"Retrying batch {batch_idx}...") continue else: raise ValueError(error_msg) # Success! logger.info(f"✅ Batch {batch_idx} completed successfully ({len(batch_prompts)} prompts)") all_prompts.extend(batch_prompts) # Report progress if progress_callback: progress_callback( len(all_prompts), len(narrations), f"Batch {batch_idx}/{len(batches)} completed" ) break except json.JSONDecodeError as e: logger.error(f"Batch {batch_idx} JSON parse error (attempt {attempt}/{max_retries}): {e}") if attempt >= max_retries: raise logger.info(f"Retrying batch {batch_idx}...") logger.info(f"✅ Generated {len(all_prompts)} image prompts") return all_prompts async def generate_video_prompts( llm_service, narrations: List[str], min_words: int = 30, max_words: int = 60, batch_size: int = 10, max_retries: int = 3, progress_callback: Optional[callable] = None ) -> List[str]: """ Generate video prompts from narrations (with batching and retry) Args: llm_service: LLM service instance narrations: List of narrations min_words: Min video prompt length max_words: Max video prompt length batch_size: Max narrations per batch (default: 10) max_retries: Max retry attempts per batch (default: 3) progress_callback: Optional callback(completed, total, message) for progress updates Returns: List of video prompts (base prompts, without prefix applied) """ from pixelle_video.prompts.video_generation import build_video_prompt_prompt logger.info(f"Generating video prompts for {len(narrations)} narrations (batch_size={batch_size})") # Split narrations into batches batches = [narrations[i:i + batch_size] for i in range(0, len(narrations), batch_size)] logger.info(f"Split into {len(batches)} batches") all_prompts = [] # Process each batch for batch_idx, batch_narrations in enumerate(batches, 1): logger.info(f"Processing batch {batch_idx}/{len(batches)} ({len(batch_narrations)} narrations)") # Retry logic for this batch for attempt in range(1, max_retries + 1): try: # Generate prompts for this batch prompt = build_video_prompt_prompt( narrations=batch_narrations, min_words=min_words, max_words=max_words ) response = await llm_service( prompt=prompt, temperature=0.7, max_tokens=8192 ) logger.debug(f"Batch {batch_idx} attempt {attempt}: LLM response length: {len(response)} chars") # Parse JSON result = _parse_json(response) if "video_prompts" not in result: raise KeyError("Invalid response format: missing 'video_prompts'") batch_prompts = result["video_prompts"] # Validate batch result if len(batch_prompts) != len(batch_narrations): raise ValueError( f"Prompt count mismatch: expected {len(batch_narrations)}, got {len(batch_prompts)}" ) # Success - add to all_prompts all_prompts.extend(batch_prompts) logger.info(f"✓ Batch {batch_idx} completed: {len(batch_prompts)} video prompts") # Report progress if progress_callback: completed = len(all_prompts) total = len(narrations) progress_callback(completed, total, f"Batch {batch_idx}/{len(batches)} completed") break # Success, move to next batch except Exception as e: logger.warning(f"✗ Batch {batch_idx} attempt {attempt} failed: {e}") if attempt >= max_retries: raise logger.info(f"Retrying batch {batch_idx}...") logger.info(f"✅ Generated {len(all_prompts)} video prompts") return all_prompts def _parse_json(text: str) -> dict: """ Parse JSON from text, with fallback to extract JSON from markdown code blocks Args: text: Text containing JSON Returns: Parsed JSON dict Raises: json.JSONDecodeError: If no valid JSON found """ # Try direct parsing first try: return json.loads(text) except json.JSONDecodeError: pass # Try to extract JSON from markdown code block json_pattern = r'```(?:json)?\s*([\s\S]+?)\s*```' match = re.search(json_pattern, text, re.DOTALL) if match: try: return json.loads(match.group(1)) except json.JSONDecodeError: pass # Try to find any JSON object in the text json_pattern = r'\{[^{}]*(?:"narrations"|"image_prompts")\s*:\s*\[[^\]]*\][^{}]*\}' match = re.search(json_pattern, text, re.DOTALL) if match: try: return json.loads(match.group(0)) except json.JSONDecodeError: pass # If all fails, raise error raise json.JSONDecodeError("No valid JSON found", text, 0)