Files
AI-Video/api/routers/quality.py
2026-01-06 23:54:35 +08:00

394 lines
12 KiB
Python

# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
"""
Quality API router for AI quality features
Provides endpoints for:
- Character memory management
- Content filtering
- Style guard
- Quality gate evaluation
"""
from fastapi import APIRouter, HTTPException, Path, Body
from pydantic import BaseModel, Field
from typing import List, Optional
from loguru import logger
router = APIRouter(prefix="/quality", tags=["Quality"])
# ============================================================
# Schemas
# ============================================================
class CharacterSchema(BaseModel):
"""Character data"""
id: str
name: str
appearance_description: str = ""
clothing_description: str = ""
distinctive_features: List[str] = []
character_type: str = "person"
reference_image: Optional[str] = None
class CharacterCreateRequest(BaseModel):
"""Request to create a character"""
name: str = Field(..., description="Character name")
appearance_description: str = Field("", description="Visual appearance")
clothing_description: str = Field("", description="Clothing description")
distinctive_features: List[str] = Field(default_factory=list)
character_type: str = Field("person")
class ContentCheckRequest(BaseModel):
"""Request to check content"""
text: str = Field(..., description="Text to check")
class ContentCheckResponse(BaseModel):
"""Response for content check"""
passed: bool
category: str # safe, sensitive, blocked
flagged_items: List[str] = []
reason: Optional[str] = None
class StyleAnchorSchema(BaseModel):
"""Style anchor data"""
color_palette: str = ""
art_style: str = ""
composition_style: str = ""
texture: str = ""
lighting: str = ""
style_prefix: str = ""
reference_image: Optional[str] = None
class QualityScoreSchema(BaseModel):
"""Quality evaluation result"""
overall_score: float
aesthetic_score: float = 0.0
alignment_score: float = 0.0
technical_score: float = 0.0
passed: bool
issues: List[str] = []
# ============================================================
# In-memory storage (per storyboard)
# ============================================================
_character_stores: dict = {} # storyboard_id -> {char_id -> Character}
_style_anchors: dict = {} # storyboard_id -> StyleAnchor
# ============================================================
# Character Memory Endpoints
# ============================================================
@router.get(
"/characters/{storyboard_id}",
response_model=List[CharacterSchema]
)
async def get_characters(
storyboard_id: str = Path(..., description="Storyboard ID")
):
"""Get all characters for a storyboard"""
if storyboard_id not in _character_stores:
return []
return list(_character_stores[storyboard_id].values())
@router.post(
"/characters/{storyboard_id}",
response_model=CharacterSchema
)
async def create_character(
storyboard_id: str = Path(..., description="Storyboard ID"),
request: CharacterCreateRequest = Body(...)
):
"""Register a new character"""
import uuid
if storyboard_id not in _character_stores:
_character_stores[storyboard_id] = {}
char_id = f"char_{uuid.uuid4().hex[:8]}"
character = CharacterSchema(
id=char_id,
name=request.name,
appearance_description=request.appearance_description,
clothing_description=request.clothing_description,
distinctive_features=request.distinctive_features,
character_type=request.character_type,
)
_character_stores[storyboard_id][char_id] = character.model_dump()
logger.info(f"Created character {request.name} for storyboard {storyboard_id}")
return character
@router.put(
"/characters/{storyboard_id}/{char_id}",
response_model=CharacterSchema
)
async def update_character(
storyboard_id: str = Path(..., description="Storyboard ID"),
char_id: str = Path(..., description="Character ID"),
request: CharacterCreateRequest = Body(...)
):
"""Update a character"""
if storyboard_id not in _character_stores:
raise HTTPException(status_code=404, detail="Storyboard not found")
if char_id not in _character_stores[storyboard_id]:
raise HTTPException(status_code=404, detail="Character not found")
character = CharacterSchema(
id=char_id,
name=request.name,
appearance_description=request.appearance_description,
clothing_description=request.clothing_description,
distinctive_features=request.distinctive_features,
character_type=request.character_type,
)
_character_stores[storyboard_id][char_id] = character.model_dump()
return character
@router.delete("/characters/{storyboard_id}/{char_id}")
async def delete_character(
storyboard_id: str = Path(..., description="Storyboard ID"),
char_id: str = Path(..., description="Character ID")
):
"""Delete a character"""
if storyboard_id not in _character_stores:
raise HTTPException(status_code=404, detail="Storyboard not found")
if char_id not in _character_stores[storyboard_id]:
raise HTTPException(status_code=404, detail="Character not found")
del _character_stores[storyboard_id][char_id]
return {"deleted": True}
# ============================================================
# Content Filter Endpoints
# ============================================================
@router.post(
"/check-content",
response_model=ContentCheckResponse
)
async def check_content(request: ContentCheckRequest):
"""Check text content for safety"""
from pixelle_video.services.quality.content_filter import ContentFilter
try:
content_filter = ContentFilter()
result = content_filter.check_text(request.text)
return ContentCheckResponse(
passed=result.passed,
category=result.category.value,
flagged_items=result.flagged_items,
reason=result.reason
)
except Exception as e:
logger.error(f"Content check failed: {e}")
# Default to safe if filter fails
return ContentCheckResponse(
passed=True,
category="safe",
flagged_items=[],
reason=None
)
# ============================================================
# Style Guard Endpoints
# ============================================================
@router.post(
"/extract-style/{storyboard_id}",
response_model=StyleAnchorSchema
)
async def extract_style(
storyboard_id: str = Path(..., description="Storyboard ID"),
image_path: str = Body(..., embed=True, description="Reference image path")
):
"""Extract style anchor from reference image"""
try:
# Convert URL to file path if needed
actual_path = image_path
if image_path.startswith("http"):
# Extract path from URL like http://localhost:8000/api/files/...
actual_path = image_path.replace("http://localhost:8000/api/files/", "output/")
# Check if file exists
import os
if not os.path.exists(actual_path):
logger.warning(f"Image file not found: {actual_path}, using default style")
# Return default style instead of failing
style_schema = StyleAnchorSchema(
color_palette="vibrant",
art_style="digital illustration",
composition_style="centered",
texture="smooth",
lighting="soft natural light",
style_prefix="high quality, detailed, vibrant colors",
)
_style_anchors[storyboard_id] = style_schema.model_dump()
return style_schema
from pixelle_video.services.quality.style_guard import StyleGuard
style_guard = StyleGuard()
anchor = await style_guard.extract_style_anchor(actual_path)
style_schema = StyleAnchorSchema(
color_palette=anchor.color_palette,
art_style=anchor.art_style,
composition_style=anchor.composition_style,
texture=anchor.texture,
lighting=anchor.lighting,
style_prefix=anchor.style_prefix,
reference_image=anchor.reference_image
)
_style_anchors[storyboard_id] = style_schema.model_dump()
logger.info(f"Extracted style for storyboard {storyboard_id}")
return style_schema
except Exception as e:
logger.error(f"Style extraction failed: {e}")
# Return default style instead of failing
style_schema = StyleAnchorSchema(
color_palette="vibrant",
art_style="digital illustration",
composition_style="centered",
texture="smooth",
lighting="soft natural light",
style_prefix="high quality, detailed",
)
_style_anchors[storyboard_id] = style_schema.model_dump()
return style_schema
@router.get(
"/style/{storyboard_id}",
response_model=StyleAnchorSchema
)
async def get_style(
storyboard_id: str = Path(..., description="Storyboard ID")
):
"""Get current style anchor for storyboard"""
if storyboard_id not in _style_anchors:
raise HTTPException(status_code=404, detail="No style anchor found")
return StyleAnchorSchema(**_style_anchors[storyboard_id])
class ApplyStyleRequest(BaseModel):
"""Request to apply style"""
prompt: str = Field(..., description="Image prompt to style")
class ApplyStyleResponse(BaseModel):
"""Response with styled prompt"""
styled_prompt: str
@router.post(
"/apply-style/{storyboard_id}",
response_model=ApplyStyleResponse
)
async def apply_style(
storyboard_id: str = Path(..., description="Storyboard ID"),
request: ApplyStyleRequest = Body(...)
):
"""Apply style anchor to a prompt"""
if storyboard_id not in _style_anchors:
# Return original if no style
return ApplyStyleResponse(styled_prompt=request.prompt)
from pixelle_video.services.quality.style_guard import StyleGuard, StyleAnchor
try:
style_guard = StyleGuard()
anchor_data = _style_anchors[storyboard_id]
anchor = StyleAnchor(
color_palette=anchor_data.get("color_palette", ""),
art_style=anchor_data.get("art_style", ""),
composition_style=anchor_data.get("composition_style", ""),
texture=anchor_data.get("texture", ""),
lighting=anchor_data.get("lighting", ""),
style_prefix=anchor_data.get("style_prefix", ""),
)
styled = style_guard.apply_style(request.prompt, anchor)
return ApplyStyleResponse(styled_prompt=styled)
except Exception as e:
logger.error(f"Apply style failed: {e}")
return ApplyStyleResponse(styled_prompt=request.prompt)
# ============================================================
# Quality Gate Endpoints
# ============================================================
class EvaluateImageRequest(BaseModel):
"""Request to evaluate image"""
image_path: str = Field(..., description="Path to image")
prompt: str = Field(..., description="Prompt used to generate")
narration: Optional[str] = Field(None, description="Associated narration")
@router.post(
"/evaluate-image",
response_model=QualityScoreSchema
)
async def evaluate_image(request: EvaluateImageRequest):
"""Evaluate image quality"""
from pixelle_video.services.quality.quality_gate import QualityGate
try:
quality_gate = QualityGate()
score = quality_gate.evaluate_image(
image_path=request.image_path,
prompt=request.prompt,
narration=request.narration
)
return QualityScoreSchema(
overall_score=score.overall_score,
aesthetic_score=score.aesthetic_score,
alignment_score=score.alignment_score,
technical_score=score.technical_score,
passed=score.passed,
issues=score.issues
)
except Exception as e:
logger.error(f"Image evaluation failed: {e}")
# Default passing score
return QualityScoreSchema(
overall_score=0.7,
passed=True,
issues=[]
)