337 lines
11 KiB
Python
337 lines
11 KiB
Python
# Copyright (C) 2025 AIDC-AI
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
"""
|
||
OutputValidator - LLM output validation and quality control
|
||
|
||
Validates LLM-generated content for:
|
||
- Narrations: length, relevance, coherence
|
||
- Image prompts: format, language, prompt-narration alignment
|
||
"""
|
||
|
||
import re
|
||
from dataclasses import dataclass, field
|
||
from typing import List, Optional
|
||
|
||
from loguru import logger
|
||
|
||
|
||
@dataclass
|
||
class ValidationConfig:
|
||
"""Configuration for output validation"""
|
||
|
||
# Narration validation
|
||
min_narration_words: int = 5
|
||
max_narration_words: int = 50
|
||
relevance_threshold: float = 0.6
|
||
coherence_threshold: float = 0.6
|
||
|
||
# Image prompt validation
|
||
min_prompt_words: int = 10
|
||
max_prompt_words: int = 100
|
||
require_english_prompts: bool = True
|
||
prompt_match_threshold: float = 0.5
|
||
|
||
|
||
@dataclass
|
||
class ValidationResult:
|
||
"""Result of validation"""
|
||
passed: bool
|
||
issues: List[str] = field(default_factory=list)
|
||
score: float = 1.0 # 1.0 = perfect, 0.0 = failed
|
||
suggestions: List[str] = field(default_factory=list)
|
||
|
||
def to_dict(self) -> dict:
|
||
return {
|
||
"passed": self.passed,
|
||
"score": self.score,
|
||
"issues": self.issues,
|
||
"suggestions": self.suggestions,
|
||
}
|
||
|
||
|
||
class OutputValidator:
|
||
"""
|
||
Validator for LLM-generated outputs
|
||
|
||
Validates narrations and image prompts to ensure they meet quality standards
|
||
before proceeding with media generation.
|
||
|
||
Example:
|
||
>>> validator = OutputValidator(llm_service)
|
||
>>> result = await validator.validate_narrations(
|
||
... narrations=["旁白1", "旁白2"],
|
||
... topic="人生哲理",
|
||
... config=ValidationConfig()
|
||
... )
|
||
>>> if not result.passed:
|
||
... print(f"Validation failed: {result.issues}")
|
||
"""
|
||
|
||
def __init__(self, llm_service=None):
|
||
"""
|
||
Initialize OutputValidator
|
||
|
||
Args:
|
||
llm_service: Optional LLM service for semantic validation
|
||
"""
|
||
self.llm_service = llm_service
|
||
|
||
async def validate_narrations(
|
||
self,
|
||
narrations: List[str],
|
||
topic: str,
|
||
config: Optional[ValidationConfig] = None,
|
||
) -> ValidationResult:
|
||
"""
|
||
Validate generated narrations
|
||
|
||
Checks:
|
||
1. Length constraints
|
||
2. Non-empty content
|
||
3. Topic relevance (if LLM available)
|
||
4. Coherence between narrations (if LLM available)
|
||
|
||
Args:
|
||
narrations: List of narration texts
|
||
topic: Original topic/theme
|
||
config: Validation configuration
|
||
|
||
Returns:
|
||
ValidationResult with pass/fail and issues
|
||
"""
|
||
cfg = config or ValidationConfig()
|
||
issues = []
|
||
suggestions = []
|
||
scores = []
|
||
|
||
if not narrations:
|
||
return ValidationResult(
|
||
passed=False,
|
||
issues=["No narrations provided"],
|
||
score=0.0
|
||
)
|
||
|
||
# 1. Length validation
|
||
for i, narration in enumerate(narrations, 1):
|
||
word_count = len(narration)
|
||
|
||
if not narration.strip():
|
||
issues.append(f"分镜{i}: 内容为空")
|
||
scores.append(0.0)
|
||
continue
|
||
|
||
if word_count < cfg.min_narration_words:
|
||
issues.append(f"分镜{i}: 内容过短 ({word_count}字,最少{cfg.min_narration_words}字)")
|
||
scores.append(0.5)
|
||
elif word_count > cfg.max_narration_words:
|
||
issues.append(f"分镜{i}: 内容过长 ({word_count}字,最多{cfg.max_narration_words}字)")
|
||
suggestions.append(f"考虑将分镜{i}拆分为多个短句")
|
||
scores.append(0.7)
|
||
else:
|
||
scores.append(1.0)
|
||
|
||
# 2. Semantic validation (if LLM available)
|
||
if self.llm_service:
|
||
try:
|
||
relevance = await self._check_relevance(narrations, topic)
|
||
if relevance < cfg.relevance_threshold:
|
||
issues.append(f'内容与主题"{topic}"相关性不足 ({relevance:.0%})')
|
||
suggestions.append("建议重新生成,确保内容紧扣主题")
|
||
scores.append(relevance)
|
||
|
||
coherence = await self._check_coherence(narrations)
|
||
if coherence < cfg.coherence_threshold:
|
||
issues.append(f"内容连贯性不足 ({coherence:.0%})")
|
||
suggestions.append("建议检查段落之间的逻辑衔接")
|
||
scores.append(coherence)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Semantic validation failed: {e}")
|
||
# Don't fail on semantic check errors
|
||
|
||
# Calculate overall score
|
||
overall_score = sum(scores) / len(scores) if scores else 0.0
|
||
passed = len(issues) == 0 or overall_score >= 0.7
|
||
|
||
logger.debug(f"Narration validation: score={overall_score:.2f}, issues={len(issues)}")
|
||
|
||
return ValidationResult(
|
||
passed=passed,
|
||
issues=issues,
|
||
score=overall_score,
|
||
suggestions=suggestions,
|
||
)
|
||
|
||
async def validate_image_prompts(
|
||
self,
|
||
prompts: List[str],
|
||
narrations: List[str],
|
||
config: Optional[ValidationConfig] = None,
|
||
) -> ValidationResult:
|
||
"""
|
||
Validate generated image prompts
|
||
|
||
Checks:
|
||
1. Length constraints
|
||
2. Language (should be English)
|
||
3. Prompt-narration alignment
|
||
|
||
Args:
|
||
prompts: List of image prompts
|
||
narrations: Corresponding narrations
|
||
config: Validation configuration
|
||
|
||
Returns:
|
||
ValidationResult with pass/fail and issues
|
||
"""
|
||
cfg = config or ValidationConfig()
|
||
issues = []
|
||
suggestions = []
|
||
scores = []
|
||
|
||
if not prompts:
|
||
return ValidationResult(
|
||
passed=False,
|
||
issues=["No image prompts provided"],
|
||
score=0.0
|
||
)
|
||
|
||
if len(prompts) != len(narrations):
|
||
issues.append(f"提示词数量({len(prompts)})与旁白数量({len(narrations)})不匹配")
|
||
|
||
for i, prompt in enumerate(prompts, 1):
|
||
if not prompt.strip():
|
||
issues.append(f"提示词{i}: 内容为空")
|
||
scores.append(0.0)
|
||
continue
|
||
|
||
word_count = len(prompt.split())
|
||
|
||
# Length check
|
||
if word_count < cfg.min_prompt_words:
|
||
issues.append(f"提示词{i}: 过短 ({word_count}词,最少{cfg.min_prompt_words}词)")
|
||
scores.append(0.5)
|
||
elif word_count > cfg.max_prompt_words:
|
||
issues.append(f"提示词{i}: 过长 ({word_count}词,最多{cfg.max_prompt_words}词)")
|
||
scores.append(0.8)
|
||
else:
|
||
scores.append(1.0)
|
||
|
||
# English check
|
||
if cfg.require_english_prompts:
|
||
chinese_ratio = self._get_chinese_ratio(prompt)
|
||
if chinese_ratio > 0.3: # More than 30% Chinese characters
|
||
issues.append(f"提示词{i}: 应使用英文 (当前含{chinese_ratio:.0%}中文)")
|
||
suggestions.append(f"将提示词{i}翻译为英文以获得更好的生成效果")
|
||
scores[-1] *= 0.5
|
||
|
||
overall_score = sum(scores) / len(scores) if scores else 0.0
|
||
passed = len(issues) == 0 or overall_score >= 0.7
|
||
|
||
logger.debug(f"Image prompt validation: score={overall_score:.2f}, issues={len(issues)}")
|
||
|
||
return ValidationResult(
|
||
passed=passed,
|
||
issues=issues,
|
||
score=overall_score,
|
||
suggestions=suggestions,
|
||
)
|
||
|
||
async def _check_relevance(
|
||
self,
|
||
narrations: List[str],
|
||
topic: str,
|
||
) -> float:
|
||
"""
|
||
Check relevance of narrations to topic using LLM
|
||
|
||
Returns:
|
||
Relevance score 0.0-1.0
|
||
"""
|
||
if not self.llm_service:
|
||
return 0.8 # Default score when LLM not available
|
||
|
||
try:
|
||
combined_text = "\n".join(narrations[:3]) # Check first 3 for efficiency
|
||
|
||
prompt = f"""评估以下内容与主题"{topic}"的相关性。
|
||
|
||
内容:
|
||
{combined_text}
|
||
|
||
请用0-100的分数评估相关性,只输出数字。
|
||
相关性越高,分数越高。"""
|
||
|
||
response = await self.llm_service(prompt, temperature=0.1, max_tokens=10)
|
||
|
||
# Parse score from response
|
||
score_match = re.search(r'\d+', response)
|
||
if score_match:
|
||
score = int(score_match.group()) / 100
|
||
return min(1.0, max(0.0, score))
|
||
|
||
return 0.7 # Default if parsing fails
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Relevance check failed: {e}")
|
||
return 0.7
|
||
|
||
async def _check_coherence(
|
||
self,
|
||
narrations: List[str],
|
||
) -> float:
|
||
"""
|
||
Check coherence between narrations using LLM
|
||
|
||
Returns:
|
||
Coherence score 0.0-1.0
|
||
"""
|
||
if not self.llm_service or len(narrations) < 2:
|
||
return 0.8 # Default score
|
||
|
||
try:
|
||
numbered = "\n".join(f"{i+1}. {n}" for i, n in enumerate(narrations[:5]))
|
||
|
||
prompt = f"""评估以下段落之间的逻辑连贯性。
|
||
|
||
{numbered}
|
||
|
||
请用0-100的分数评估连贯性,只输出数字。
|
||
段落之间逻辑顺畅、衔接自然则分数高。"""
|
||
|
||
response = await self.llm_service(prompt, temperature=0.1, max_tokens=10)
|
||
|
||
score_match = re.search(r'\d+', response)
|
||
if score_match:
|
||
score = int(score_match.group()) / 100
|
||
return min(1.0, max(0.0, score))
|
||
|
||
return 0.7
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Coherence check failed: {e}")
|
||
return 0.7
|
||
|
||
def _get_chinese_ratio(self, text: str) -> float:
|
||
"""Calculate ratio of Chinese characters in text"""
|
||
if not text:
|
||
return 0.0
|
||
|
||
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||
total_chars = len(text.replace(' ', ''))
|
||
|
||
if total_chars == 0:
|
||
return 0.0
|
||
|
||
return chinese_chars / total_chars
|