feat: Enhance StyleGuard with VLM-based style extraction for specific style_prefix

This commit is contained in:
empty
2026-01-07 00:16:57 +08:00
parent a3ab12e87c
commit 297f3ccda4
2 changed files with 104 additions and 15 deletions

View File

@@ -254,8 +254,13 @@ async def extract_style(
return style_schema
from pixelle_video.services.quality.style_guard import StyleGuard
from api.dependencies import get_pixelle_video
style_guard = StyleGuard()
# Get LLM service for VLM-based style extraction
pixelle_video = await get_pixelle_video()
llm_service = pixelle_video.llm if pixelle_video else None
style_guard = StyleGuard(llm_service=llm_service)
anchor = await style_guard.extract_style_anchor(actual_path)
style_schema = StyleAnchorSchema(

View File

@@ -164,28 +164,112 @@ class StyleGuard:
async def _extract_with_vlm(self, image_path: str) -> StyleAnchor:
"""Extract style using Vision Language Model"""
try:
# TODO: Implement VLM call when vision-capable LLM is integrated
# For now, return a placeholder
logger.debug("VLM style extraction: using placeholder (VLM not yet integrated)")
if not self.llm_service:
logger.warning("No LLM service available, using basic extraction")
return self._extract_basic(image_path)
# Placeholder extraction based on common styles
return StyleAnchor(
art_style="consistent artistic",
color_palette="harmonious colors",
lighting="balanced",
style_prefix="maintaining visual consistency, same artistic style as previous frames",
reference_image=image_path,
)
import base64
import os
# Read and encode image
if not os.path.exists(image_path):
logger.warning(f"Image not found: {image_path}")
return self._extract_basic(image_path)
with open(image_path, "rb") as f:
image_data = base64.b64encode(f.read()).decode("utf-8")
# Determine image type
ext = os.path.splitext(image_path)[1].lower()
media_type = "image/png" if ext == ".png" else "image/jpeg"
# Call VLM to analyze style
style_prompt = """Analyze this image and extract its visual style characteristics.
Provide a concise style description that could be used as a prefix for image generation prompts to maintain visual consistency.
Output format (JSON):
{
"art_style": "specific art style (e.g., oil painting, digital illustration, anime, photorealistic, watercolor, line art)",
"color_palette": "dominant colors and mood (e.g., warm earth tones, vibrant neon, muted pastels)",
"lighting": "lighting style (e.g., soft natural light, dramatic shadows, studio lighting)",
"texture": "visual texture (e.g., smooth, grainy, brushstroke visible)",
"style_prefix": "A complete prompt prefix combining all elements (30-50 words)"
}
Focus on creating a specific, reproducible style_prefix that will generate visually consistent images."""
# Try to call LLM with vision capability
try:
response = await self.llm_service(
prompt=style_prompt,
images=[f"data:{media_type};base64,{image_data}"],
temperature=0.3,
max_tokens=500
)
except Exception as e:
# Fallback: try without image (text-only LLM)
logger.warning(f"VLM call failed, using basic extraction: {e}")
return self._extract_basic(image_path)
# Parse response
import json
import re
try:
# Try to extract JSON from response
match = re.search(r'\{[\s\S]*\}', response)
if match:
data = json.loads(match.group())
else:
data = json.loads(response)
anchor = StyleAnchor(
art_style=data.get("art_style", ""),
color_palette=data.get("color_palette", ""),
lighting=data.get("lighting", ""),
texture=data.get("texture", ""),
style_prefix=data.get("style_prefix", ""),
reference_image=image_path,
)
logger.info(f"VLM extracted style: {anchor.style_prefix[:80]}...")
return anchor
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Failed to parse VLM response: {e}")
# Use the raw response as style_prefix if it looks reasonable
if len(response) < 200 and len(response) > 20:
return StyleAnchor(
style_prefix=response.strip(),
reference_image=image_path,
)
return self._extract_basic(image_path)
except Exception as e:
logger.warning(f"VLM style extraction failed: {e}")
return self._extract_basic(image_path)
def _extract_basic(self, image_path: str) -> StyleAnchor:
"""Basic style extraction without VLM"""
# Return generic style anchor
"""Basic style extraction without VLM - analyze filename for hints"""
import os
filename = os.path.basename(image_path).lower()
# Try to infer style from filename or path
style_hints = []
if "anime" in filename or "cartoon" in filename:
style_hints.append("anime style illustration")
elif "realistic" in filename or "photo" in filename:
style_hints.append("photorealistic style")
elif "sketch" in filename or "line" in filename:
style_hints.append("sketch style, clean lines")
else:
style_hints.append("consistent visual style, high quality")
return StyleAnchor(
style_prefix="consistent visual style",
style_prefix=", ".join(style_hints),
reference_image=image_path,
)