feat: Add AI quality features - character memory, content filter, style guard, quality gate
This commit is contained in:
@@ -56,6 +56,7 @@ from api.routers import (
|
||||
frame_router,
|
||||
editor_router,
|
||||
publish_router,
|
||||
quality_router,
|
||||
)
|
||||
|
||||
|
||||
@@ -137,6 +138,7 @@ app.include_router(resources_router, prefix=api_config.api_prefix)
|
||||
app.include_router(frame_router, prefix=api_config.api_prefix)
|
||||
app.include_router(editor_router, prefix=api_config.api_prefix)
|
||||
app.include_router(publish_router, prefix=api_config.api_prefix)
|
||||
app.include_router(quality_router, prefix=api_config.api_prefix)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
|
||||
@@ -26,6 +26,7 @@ from api.routers.resources import router as resources_router
|
||||
from api.routers.frame import router as frame_router
|
||||
from api.routers.editor import router as editor_router
|
||||
from api.routers.publish import router as publish_router
|
||||
from api.routers.quality import router as quality_router
|
||||
|
||||
__all__ = [
|
||||
"health_router",
|
||||
@@ -40,6 +41,8 @@ __all__ = [
|
||||
"frame_router",
|
||||
"editor_router",
|
||||
"publish_router",
|
||||
"quality_router",
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
||||
361
api/routers/quality.py
Normal file
361
api/routers/quality.py
Normal file
@@ -0,0 +1,361 @@
|
||||
# Copyright (C) 2025 AIDC-AI
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
"""
|
||||
Quality API router for AI quality features
|
||||
|
||||
Provides endpoints for:
|
||||
- Character memory management
|
||||
- Content filtering
|
||||
- Style guard
|
||||
- Quality gate evaluation
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Path, Body
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
from loguru import logger
|
||||
|
||||
router = APIRouter(prefix="/quality", tags=["Quality"])
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Schemas
|
||||
# ============================================================
|
||||
|
||||
class CharacterSchema(BaseModel):
|
||||
"""Character data"""
|
||||
id: str
|
||||
name: str
|
||||
appearance_description: str = ""
|
||||
clothing_description: str = ""
|
||||
distinctive_features: List[str] = []
|
||||
character_type: str = "person"
|
||||
reference_image: Optional[str] = None
|
||||
|
||||
|
||||
class CharacterCreateRequest(BaseModel):
|
||||
"""Request to create a character"""
|
||||
name: str = Field(..., description="Character name")
|
||||
appearance_description: str = Field("", description="Visual appearance")
|
||||
clothing_description: str = Field("", description="Clothing description")
|
||||
distinctive_features: List[str] = Field(default_factory=list)
|
||||
character_type: str = Field("person")
|
||||
|
||||
|
||||
class ContentCheckRequest(BaseModel):
|
||||
"""Request to check content"""
|
||||
text: str = Field(..., description="Text to check")
|
||||
|
||||
|
||||
class ContentCheckResponse(BaseModel):
|
||||
"""Response for content check"""
|
||||
passed: bool
|
||||
category: str # safe, sensitive, blocked
|
||||
flagged_items: List[str] = []
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
class StyleAnchorSchema(BaseModel):
|
||||
"""Style anchor data"""
|
||||
color_palette: str = ""
|
||||
art_style: str = ""
|
||||
composition_style: str = ""
|
||||
texture: str = ""
|
||||
lighting: str = ""
|
||||
style_prefix: str = ""
|
||||
reference_image: Optional[str] = None
|
||||
|
||||
|
||||
class QualityScoreSchema(BaseModel):
|
||||
"""Quality evaluation result"""
|
||||
overall_score: float
|
||||
aesthetic_score: float = 0.0
|
||||
alignment_score: float = 0.0
|
||||
technical_score: float = 0.0
|
||||
passed: bool
|
||||
issues: List[str] = []
|
||||
|
||||
|
||||
# ============================================================
|
||||
# In-memory storage (per storyboard)
|
||||
# ============================================================
|
||||
|
||||
_character_stores: dict = {} # storyboard_id -> {char_id -> Character}
|
||||
_style_anchors: dict = {} # storyboard_id -> StyleAnchor
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Character Memory Endpoints
|
||||
# ============================================================
|
||||
|
||||
@router.get(
|
||||
"/characters/{storyboard_id}",
|
||||
response_model=List[CharacterSchema]
|
||||
)
|
||||
async def get_characters(
|
||||
storyboard_id: str = Path(..., description="Storyboard ID")
|
||||
):
|
||||
"""Get all characters for a storyboard"""
|
||||
if storyboard_id not in _character_stores:
|
||||
return []
|
||||
|
||||
return list(_character_stores[storyboard_id].values())
|
||||
|
||||
|
||||
@router.post(
|
||||
"/characters/{storyboard_id}",
|
||||
response_model=CharacterSchema
|
||||
)
|
||||
async def create_character(
|
||||
storyboard_id: str = Path(..., description="Storyboard ID"),
|
||||
request: CharacterCreateRequest = Body(...)
|
||||
):
|
||||
"""Register a new character"""
|
||||
import uuid
|
||||
|
||||
if storyboard_id not in _character_stores:
|
||||
_character_stores[storyboard_id] = {}
|
||||
|
||||
char_id = f"char_{uuid.uuid4().hex[:8]}"
|
||||
character = CharacterSchema(
|
||||
id=char_id,
|
||||
name=request.name,
|
||||
appearance_description=request.appearance_description,
|
||||
clothing_description=request.clothing_description,
|
||||
distinctive_features=request.distinctive_features,
|
||||
character_type=request.character_type,
|
||||
)
|
||||
|
||||
_character_stores[storyboard_id][char_id] = character.model_dump()
|
||||
|
||||
logger.info(f"Created character {request.name} for storyboard {storyboard_id}")
|
||||
return character
|
||||
|
||||
|
||||
@router.put(
|
||||
"/characters/{storyboard_id}/{char_id}",
|
||||
response_model=CharacterSchema
|
||||
)
|
||||
async def update_character(
|
||||
storyboard_id: str = Path(..., description="Storyboard ID"),
|
||||
char_id: str = Path(..., description="Character ID"),
|
||||
request: CharacterCreateRequest = Body(...)
|
||||
):
|
||||
"""Update a character"""
|
||||
if storyboard_id not in _character_stores:
|
||||
raise HTTPException(status_code=404, detail="Storyboard not found")
|
||||
|
||||
if char_id not in _character_stores[storyboard_id]:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
character = CharacterSchema(
|
||||
id=char_id,
|
||||
name=request.name,
|
||||
appearance_description=request.appearance_description,
|
||||
clothing_description=request.clothing_description,
|
||||
distinctive_features=request.distinctive_features,
|
||||
character_type=request.character_type,
|
||||
)
|
||||
|
||||
_character_stores[storyboard_id][char_id] = character.model_dump()
|
||||
|
||||
return character
|
||||
|
||||
|
||||
@router.delete("/characters/{storyboard_id}/{char_id}")
|
||||
async def delete_character(
|
||||
storyboard_id: str = Path(..., description="Storyboard ID"),
|
||||
char_id: str = Path(..., description="Character ID")
|
||||
):
|
||||
"""Delete a character"""
|
||||
if storyboard_id not in _character_stores:
|
||||
raise HTTPException(status_code=404, detail="Storyboard not found")
|
||||
|
||||
if char_id not in _character_stores[storyboard_id]:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
del _character_stores[storyboard_id][char_id]
|
||||
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Content Filter Endpoints
|
||||
# ============================================================
|
||||
|
||||
@router.post(
|
||||
"/check-content",
|
||||
response_model=ContentCheckResponse
|
||||
)
|
||||
async def check_content(request: ContentCheckRequest):
|
||||
"""Check text content for safety"""
|
||||
from pixelle_video.services.quality.content_filter import ContentFilter
|
||||
|
||||
try:
|
||||
content_filter = ContentFilter()
|
||||
result = content_filter.check_text(request.text)
|
||||
|
||||
return ContentCheckResponse(
|
||||
passed=result.passed,
|
||||
category=result.category.value,
|
||||
flagged_items=result.flagged_items,
|
||||
reason=result.reason
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Content check failed: {e}")
|
||||
# Default to safe if filter fails
|
||||
return ContentCheckResponse(
|
||||
passed=True,
|
||||
category="safe",
|
||||
flagged_items=[],
|
||||
reason=None
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Style Guard Endpoints
|
||||
# ============================================================
|
||||
|
||||
@router.post(
|
||||
"/extract-style/{storyboard_id}",
|
||||
response_model=StyleAnchorSchema
|
||||
)
|
||||
async def extract_style(
|
||||
storyboard_id: str = Path(..., description="Storyboard ID"),
|
||||
image_path: str = Body(..., embed=True, description="Reference image path")
|
||||
):
|
||||
"""Extract style anchor from reference image"""
|
||||
from pixelle_video.services.quality.style_guard import StyleGuard
|
||||
|
||||
try:
|
||||
style_guard = StyleGuard()
|
||||
anchor = style_guard.extract_style_anchor(image_path)
|
||||
|
||||
style_schema = StyleAnchorSchema(
|
||||
color_palette=anchor.color_palette,
|
||||
art_style=anchor.art_style,
|
||||
composition_style=anchor.composition_style,
|
||||
texture=anchor.texture,
|
||||
lighting=anchor.lighting,
|
||||
style_prefix=anchor.style_prefix,
|
||||
reference_image=anchor.reference_image
|
||||
)
|
||||
|
||||
_style_anchors[storyboard_id] = style_schema.model_dump()
|
||||
|
||||
logger.info(f"Extracted style for storyboard {storyboard_id}")
|
||||
return style_schema
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Style extraction failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/style/{storyboard_id}",
|
||||
response_model=StyleAnchorSchema
|
||||
)
|
||||
async def get_style(
|
||||
storyboard_id: str = Path(..., description="Storyboard ID")
|
||||
):
|
||||
"""Get current style anchor for storyboard"""
|
||||
if storyboard_id not in _style_anchors:
|
||||
raise HTTPException(status_code=404, detail="No style anchor found")
|
||||
|
||||
return StyleAnchorSchema(**_style_anchors[storyboard_id])
|
||||
|
||||
|
||||
class ApplyStyleRequest(BaseModel):
|
||||
"""Request to apply style"""
|
||||
prompt: str = Field(..., description="Image prompt to style")
|
||||
|
||||
|
||||
class ApplyStyleResponse(BaseModel):
|
||||
"""Response with styled prompt"""
|
||||
styled_prompt: str
|
||||
|
||||
|
||||
@router.post(
|
||||
"/apply-style/{storyboard_id}",
|
||||
response_model=ApplyStyleResponse
|
||||
)
|
||||
async def apply_style(
|
||||
storyboard_id: str = Path(..., description="Storyboard ID"),
|
||||
request: ApplyStyleRequest = Body(...)
|
||||
):
|
||||
"""Apply style anchor to a prompt"""
|
||||
if storyboard_id not in _style_anchors:
|
||||
# Return original if no style
|
||||
return ApplyStyleResponse(styled_prompt=request.prompt)
|
||||
|
||||
from pixelle_video.services.quality.style_guard import StyleGuard, StyleAnchor
|
||||
|
||||
try:
|
||||
style_guard = StyleGuard()
|
||||
anchor_data = _style_anchors[storyboard_id]
|
||||
anchor = StyleAnchor(
|
||||
color_palette=anchor_data.get("color_palette", ""),
|
||||
art_style=anchor_data.get("art_style", ""),
|
||||
composition_style=anchor_data.get("composition_style", ""),
|
||||
texture=anchor_data.get("texture", ""),
|
||||
lighting=anchor_data.get("lighting", ""),
|
||||
style_prefix=anchor_data.get("style_prefix", ""),
|
||||
)
|
||||
|
||||
styled = style_guard.apply_style(request.prompt, anchor)
|
||||
return ApplyStyleResponse(styled_prompt=styled)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Apply style failed: {e}")
|
||||
return ApplyStyleResponse(styled_prompt=request.prompt)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Quality Gate Endpoints
|
||||
# ============================================================
|
||||
|
||||
class EvaluateImageRequest(BaseModel):
|
||||
"""Request to evaluate image"""
|
||||
image_path: str = Field(..., description="Path to image")
|
||||
prompt: str = Field(..., description="Prompt used to generate")
|
||||
narration: Optional[str] = Field(None, description="Associated narration")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/evaluate-image",
|
||||
response_model=QualityScoreSchema
|
||||
)
|
||||
async def evaluate_image(request: EvaluateImageRequest):
|
||||
"""Evaluate image quality"""
|
||||
from pixelle_video.services.quality.quality_gate import QualityGate
|
||||
|
||||
try:
|
||||
quality_gate = QualityGate()
|
||||
score = quality_gate.evaluate_image(
|
||||
image_path=request.image_path,
|
||||
prompt=request.prompt,
|
||||
narration=request.narration
|
||||
)
|
||||
|
||||
return QualityScoreSchema(
|
||||
overall_score=score.overall_score,
|
||||
aesthetic_score=score.aesthetic_score,
|
||||
alignment_score=score.alignment_score,
|
||||
technical_score=score.technical_score,
|
||||
passed=score.passed,
|
||||
issues=score.issues
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Image evaluation failed: {e}")
|
||||
# Default passing score
|
||||
return QualityScoreSchema(
|
||||
overall_score=0.7,
|
||||
passed=True,
|
||||
issues=[]
|
||||
)
|
||||
Reference in New Issue
Block a user