# Copyright (C) 2025 AIDC-AI # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ OutputValidator - LLM output validation and quality control Validates LLM-generated content for: - Narrations: length, relevance, coherence - Image prompts: format, language, prompt-narration alignment """ import re from dataclasses import dataclass, field from typing import List, Optional from loguru import logger @dataclass class ValidationConfig: """Configuration for output validation""" # Narration validation min_narration_words: int = 5 max_narration_words: int = 50 relevance_threshold: float = 0.6 coherence_threshold: float = 0.6 # Image prompt validation min_prompt_words: int = 10 max_prompt_words: int = 100 require_english_prompts: bool = True prompt_match_threshold: float = 0.5 @dataclass class ValidationResult: """Result of validation""" passed: bool issues: List[str] = field(default_factory=list) score: float = 1.0 # 1.0 = perfect, 0.0 = failed suggestions: List[str] = field(default_factory=list) def to_dict(self) -> dict: return { "passed": self.passed, "score": self.score, "issues": self.issues, "suggestions": self.suggestions, } class OutputValidator: """ Validator for LLM-generated outputs Validates narrations and image prompts to ensure they meet quality standards before proceeding with media generation. Example: >>> validator = OutputValidator(llm_service) >>> result = await validator.validate_narrations( ... narrations=["旁白1", "旁白2"], ... topic="人生哲理", ... config=ValidationConfig() ... ) >>> if not result.passed: ... print(f"Validation failed: {result.issues}") """ def __init__(self, llm_service=None): """ Initialize OutputValidator Args: llm_service: Optional LLM service for semantic validation """ self.llm_service = llm_service async def validate_narrations( self, narrations: List[str], topic: str, config: Optional[ValidationConfig] = None, ) -> ValidationResult: """ Validate generated narrations Checks: 1. Length constraints 2. Non-empty content 3. Topic relevance (if LLM available) 4. Coherence between narrations (if LLM available) Args: narrations: List of narration texts topic: Original topic/theme config: Validation configuration Returns: ValidationResult with pass/fail and issues """ cfg = config or ValidationConfig() issues = [] suggestions = [] scores = [] if not narrations: return ValidationResult( passed=False, issues=["No narrations provided"], score=0.0 ) # 1. Length validation for i, narration in enumerate(narrations, 1): word_count = len(narration) if not narration.strip(): issues.append(f"分镜{i}: 内容为空") scores.append(0.0) continue if word_count < cfg.min_narration_words: issues.append(f"分镜{i}: 内容过短 ({word_count}字,最少{cfg.min_narration_words}字)") scores.append(0.5) elif word_count > cfg.max_narration_words: issues.append(f"分镜{i}: 内容过长 ({word_count}字,最多{cfg.max_narration_words}字)") suggestions.append(f"考虑将分镜{i}拆分为多个短句") scores.append(0.7) else: scores.append(1.0) # 2. Semantic validation (if LLM available) if self.llm_service: try: relevance = await self._check_relevance(narrations, topic) if relevance < cfg.relevance_threshold: issues.append(f'内容与主题"{topic}"相关性不足 ({relevance:.0%})') suggestions.append("建议重新生成,确保内容紧扣主题") scores.append(relevance) coherence = await self._check_coherence(narrations) if coherence < cfg.coherence_threshold: issues.append(f"内容连贯性不足 ({coherence:.0%})") suggestions.append("建议检查段落之间的逻辑衔接") scores.append(coherence) except Exception as e: logger.warning(f"Semantic validation failed: {e}") # Don't fail on semantic check errors # Calculate overall score overall_score = sum(scores) / len(scores) if scores else 0.0 passed = len(issues) == 0 or overall_score >= 0.7 logger.debug(f"Narration validation: score={overall_score:.2f}, issues={len(issues)}") return ValidationResult( passed=passed, issues=issues, score=overall_score, suggestions=suggestions, ) async def validate_image_prompts( self, prompts: List[str], narrations: List[str], config: Optional[ValidationConfig] = None, ) -> ValidationResult: """ Validate generated image prompts Checks: 1. Length constraints 2. Language (should be English) 3. Prompt-narration alignment Args: prompts: List of image prompts narrations: Corresponding narrations config: Validation configuration Returns: ValidationResult with pass/fail and issues """ cfg = config or ValidationConfig() issues = [] suggestions = [] scores = [] if not prompts: return ValidationResult( passed=False, issues=["No image prompts provided"], score=0.0 ) if len(prompts) != len(narrations): issues.append(f"提示词数量({len(prompts)})与旁白数量({len(narrations)})不匹配") for i, prompt in enumerate(prompts, 1): if not prompt.strip(): issues.append(f"提示词{i}: 内容为空") scores.append(0.0) continue word_count = len(prompt.split()) # Length check if word_count < cfg.min_prompt_words: issues.append(f"提示词{i}: 过短 ({word_count}词,最少{cfg.min_prompt_words}词)") scores.append(0.5) elif word_count > cfg.max_prompt_words: issues.append(f"提示词{i}: 过长 ({word_count}词,最多{cfg.max_prompt_words}词)") scores.append(0.8) else: scores.append(1.0) # English check if cfg.require_english_prompts: chinese_ratio = self._get_chinese_ratio(prompt) if chinese_ratio > 0.3: # More than 30% Chinese characters issues.append(f"提示词{i}: 应使用英文 (当前含{chinese_ratio:.0%}中文)") suggestions.append(f"将提示词{i}翻译为英文以获得更好的生成效果") scores[-1] *= 0.5 overall_score = sum(scores) / len(scores) if scores else 0.0 passed = len(issues) == 0 or overall_score >= 0.7 logger.debug(f"Image prompt validation: score={overall_score:.2f}, issues={len(issues)}") return ValidationResult( passed=passed, issues=issues, score=overall_score, suggestions=suggestions, ) async def _check_relevance( self, narrations: List[str], topic: str, ) -> float: """ Check relevance of narrations to topic using LLM Returns: Relevance score 0.0-1.0 """ if not self.llm_service: return 0.8 # Default score when LLM not available try: combined_text = "\n".join(narrations[:3]) # Check first 3 for efficiency prompt = f"""评估以下内容与主题"{topic}"的相关性。 内容: {combined_text} 请用0-100的分数评估相关性,只输出数字。 相关性越高,分数越高。""" response = await self.llm_service(prompt, temperature=0.1, max_tokens=10) # Parse score from response score_match = re.search(r'\d+', response) if score_match: score = int(score_match.group()) / 100 return min(1.0, max(0.0, score)) return 0.7 # Default if parsing fails except Exception as e: logger.warning(f"Relevance check failed: {e}") return 0.7 async def _check_coherence( self, narrations: List[str], ) -> float: """ Check coherence between narrations using LLM Returns: Coherence score 0.0-1.0 """ if not self.llm_service or len(narrations) < 2: return 0.8 # Default score try: numbered = "\n".join(f"{i+1}. {n}" for i, n in enumerate(narrations[:5])) prompt = f"""评估以下段落之间的逻辑连贯性。 {numbered} 请用0-100的分数评估连贯性,只输出数字。 段落之间逻辑顺畅、衔接自然则分数高。""" response = await self.llm_service(prompt, temperature=0.1, max_tokens=10) score_match = re.search(r'\d+', response) if score_match: score = int(score_match.group()) / 100 return min(1.0, max(0.0, score)) return 0.7 except Exception as e: logger.warning(f"Coherence check failed: {e}") return 0.7 def _get_chinese_ratio(self, text: str) -> float: """Calculate ratio of Chinese characters in text""" if not text: return 0.0 chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') total_chars = len(text.replace(' ', '')) if total_chars == 0: return 0.0 return chinese_chars / total_chars