Files
AI-Video/pixelle_video/services/quality/output_validator.py

337 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
OutputValidator - LLM output validation and quality control
Validates LLM-generated content for:
- Narrations: length, relevance, coherence
- Image prompts: format, language, prompt-narration alignment
"""
import re
from dataclasses import dataclass, field
from typing import List, Optional
from loguru import logger
@dataclass
class ValidationConfig:
"""Configuration for output validation"""
# Narration validation
min_narration_words: int = 5
max_narration_words: int = 50
relevance_threshold: float = 0.6
coherence_threshold: float = 0.6
# Image prompt validation
min_prompt_words: int = 10
max_prompt_words: int = 100
require_english_prompts: bool = True
prompt_match_threshold: float = 0.5
@dataclass
class ValidationResult:
"""Result of validation"""
passed: bool
issues: List[str] = field(default_factory=list)
score: float = 1.0 # 1.0 = perfect, 0.0 = failed
suggestions: List[str] = field(default_factory=list)
def to_dict(self) -> dict:
return {
"passed": self.passed,
"score": self.score,
"issues": self.issues,
"suggestions": self.suggestions,
}
class OutputValidator:
"""
Validator for LLM-generated outputs
Validates narrations and image prompts to ensure they meet quality standards
before proceeding with media generation.
Example:
>>> validator = OutputValidator(llm_service)
>>> result = await validator.validate_narrations(
... narrations=["旁白1", "旁白2"],
... topic="人生哲理",
... config=ValidationConfig()
... )
>>> if not result.passed:
... print(f"Validation failed: {result.issues}")
"""
def __init__(self, llm_service=None):
"""
Initialize OutputValidator
Args:
llm_service: Optional LLM service for semantic validation
"""
self.llm_service = llm_service
async def validate_narrations(
self,
narrations: List[str],
topic: str,
config: Optional[ValidationConfig] = None,
) -> ValidationResult:
"""
Validate generated narrations
Checks:
1. Length constraints
2. Non-empty content
3. Topic relevance (if LLM available)
4. Coherence between narrations (if LLM available)
Args:
narrations: List of narration texts
topic: Original topic/theme
config: Validation configuration
Returns:
ValidationResult with pass/fail and issues
"""
cfg = config or ValidationConfig()
issues = []
suggestions = []
scores = []
if not narrations:
return ValidationResult(
passed=False,
issues=["No narrations provided"],
score=0.0
)
# 1. Length validation
for i, narration in enumerate(narrations, 1):
word_count = len(narration)
if not narration.strip():
issues.append(f"分镜{i}: 内容为空")
scores.append(0.0)
continue
if word_count < cfg.min_narration_words:
issues.append(f"分镜{i}: 内容过短 ({word_count}字,最少{cfg.min_narration_words}字)")
scores.append(0.5)
elif word_count > cfg.max_narration_words:
issues.append(f"分镜{i}: 内容过长 ({word_count}字,最多{cfg.max_narration_words}字)")
suggestions.append(f"考虑将分镜{i}拆分为多个短句")
scores.append(0.7)
else:
scores.append(1.0)
# 2. Semantic validation (if LLM available)
if self.llm_service:
try:
relevance = await self._check_relevance(narrations, topic)
if relevance < cfg.relevance_threshold:
issues.append(f'内容与主题"{topic}"相关性不足 ({relevance:.0%})')
suggestions.append("建议重新生成,确保内容紧扣主题")
scores.append(relevance)
coherence = await self._check_coherence(narrations)
if coherence < cfg.coherence_threshold:
issues.append(f"内容连贯性不足 ({coherence:.0%})")
suggestions.append("建议检查段落之间的逻辑衔接")
scores.append(coherence)
except Exception as e:
logger.warning(f"Semantic validation failed: {e}")
# Don't fail on semantic check errors
# Calculate overall score
overall_score = sum(scores) / len(scores) if scores else 0.0
passed = len(issues) == 0 or overall_score >= 0.7
logger.debug(f"Narration validation: score={overall_score:.2f}, issues={len(issues)}")
return ValidationResult(
passed=passed,
issues=issues,
score=overall_score,
suggestions=suggestions,
)
async def validate_image_prompts(
self,
prompts: List[str],
narrations: List[str],
config: Optional[ValidationConfig] = None,
) -> ValidationResult:
"""
Validate generated image prompts
Checks:
1. Length constraints
2. Language (should be English)
3. Prompt-narration alignment
Args:
prompts: List of image prompts
narrations: Corresponding narrations
config: Validation configuration
Returns:
ValidationResult with pass/fail and issues
"""
cfg = config or ValidationConfig()
issues = []
suggestions = []
scores = []
if not prompts:
return ValidationResult(
passed=False,
issues=["No image prompts provided"],
score=0.0
)
if len(prompts) != len(narrations):
issues.append(f"提示词数量({len(prompts)})与旁白数量({len(narrations)})不匹配")
for i, prompt in enumerate(prompts, 1):
if not prompt.strip():
issues.append(f"提示词{i}: 内容为空")
scores.append(0.0)
continue
word_count = len(prompt.split())
# Length check
if word_count < cfg.min_prompt_words:
issues.append(f"提示词{i}: 过短 ({word_count}词,最少{cfg.min_prompt_words}词)")
scores.append(0.5)
elif word_count > cfg.max_prompt_words:
issues.append(f"提示词{i}: 过长 ({word_count}词,最多{cfg.max_prompt_words}词)")
scores.append(0.8)
else:
scores.append(1.0)
# English check
if cfg.require_english_prompts:
chinese_ratio = self._get_chinese_ratio(prompt)
if chinese_ratio > 0.3: # More than 30% Chinese characters
issues.append(f"提示词{i}: 应使用英文 (当前含{chinese_ratio:.0%}中文)")
suggestions.append(f"将提示词{i}翻译为英文以获得更好的生成效果")
scores[-1] *= 0.5
overall_score = sum(scores) / len(scores) if scores else 0.0
passed = len(issues) == 0 or overall_score >= 0.7
logger.debug(f"Image prompt validation: score={overall_score:.2f}, issues={len(issues)}")
return ValidationResult(
passed=passed,
issues=issues,
score=overall_score,
suggestions=suggestions,
)
async def _check_relevance(
self,
narrations: List[str],
topic: str,
) -> float:
"""
Check relevance of narrations to topic using LLM
Returns:
Relevance score 0.0-1.0
"""
if not self.llm_service:
return 0.8 # Default score when LLM not available
try:
combined_text = "\n".join(narrations[:3]) # Check first 3 for efficiency
prompt = f"""评估以下内容与主题"{topic}"的相关性。
内容:
{combined_text}
请用0-100的分数评估相关性只输出数字。
相关性越高,分数越高。"""
response = await self.llm_service(prompt, temperature=0.1, max_tokens=10)
# Parse score from response
score_match = re.search(r'\d+', response)
if score_match:
score = int(score_match.group()) / 100
return min(1.0, max(0.0, score))
return 0.7 # Default if parsing fails
except Exception as e:
logger.warning(f"Relevance check failed: {e}")
return 0.7
async def _check_coherence(
self,
narrations: List[str],
) -> float:
"""
Check coherence between narrations using LLM
Returns:
Coherence score 0.0-1.0
"""
if not self.llm_service or len(narrations) < 2:
return 0.8 # Default score
try:
numbered = "\n".join(f"{i+1}. {n}" for i, n in enumerate(narrations[:5]))
prompt = f"""评估以下段落之间的逻辑连贯性。
{numbered}
请用0-100的分数评估连贯性只输出数字。
段落之间逻辑顺畅、衔接自然则分数高。"""
response = await self.llm_service(prompt, temperature=0.1, max_tokens=10)
score_match = re.search(r'\d+', response)
if score_match:
score = int(score_match.group()) / 100
return min(1.0, max(0.0, score))
return 0.7
except Exception as e:
logger.warning(f"Coherence check failed: {e}")
return 0.7
def _get_chinese_ratio(self, text: str) -> float:
"""Calculate ratio of Chinese characters in text"""
if not text:
return 0.0
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
total_chars = len(text.replace(' ', ''))
if total_chars == 0:
return 0.0
return chinese_chars / total_chars