feat: Enhance StyleGuard with VLM-based style extraction for specific style_prefix
This commit is contained in:
@@ -254,8 +254,13 @@ async def extract_style(
|
||||
return style_schema
|
||||
|
||||
from pixelle_video.services.quality.style_guard import StyleGuard
|
||||
from api.dependencies import get_pixelle_video
|
||||
|
||||
style_guard = StyleGuard()
|
||||
# Get LLM service for VLM-based style extraction
|
||||
pixelle_video = await get_pixelle_video()
|
||||
llm_service = pixelle_video.llm if pixelle_video else None
|
||||
|
||||
style_guard = StyleGuard(llm_service=llm_service)
|
||||
anchor = await style_guard.extract_style_anchor(actual_path)
|
||||
|
||||
style_schema = StyleAnchorSchema(
|
||||
|
||||
@@ -164,28 +164,112 @@ class StyleGuard:
|
||||
async def _extract_with_vlm(self, image_path: str) -> StyleAnchor:
|
||||
"""Extract style using Vision Language Model"""
|
||||
try:
|
||||
# TODO: Implement VLM call when vision-capable LLM is integrated
|
||||
# For now, return a placeholder
|
||||
logger.debug("VLM style extraction: using placeholder (VLM not yet integrated)")
|
||||
if not self.llm_service:
|
||||
logger.warning("No LLM service available, using basic extraction")
|
||||
return self._extract_basic(image_path)
|
||||
|
||||
# Placeholder extraction based on common styles
|
||||
return StyleAnchor(
|
||||
art_style="consistent artistic",
|
||||
color_palette="harmonious colors",
|
||||
lighting="balanced",
|
||||
style_prefix="maintaining visual consistency, same artistic style as previous frames",
|
||||
reference_image=image_path,
|
||||
)
|
||||
import base64
|
||||
import os
|
||||
|
||||
# Read and encode image
|
||||
if not os.path.exists(image_path):
|
||||
logger.warning(f"Image not found: {image_path}")
|
||||
return self._extract_basic(image_path)
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
# Determine image type
|
||||
ext = os.path.splitext(image_path)[1].lower()
|
||||
media_type = "image/png" if ext == ".png" else "image/jpeg"
|
||||
|
||||
# Call VLM to analyze style
|
||||
style_prompt = """Analyze this image and extract its visual style characteristics.
|
||||
|
||||
Provide a concise style description that could be used as a prefix for image generation prompts to maintain visual consistency.
|
||||
|
||||
Output format (JSON):
|
||||
{
|
||||
"art_style": "specific art style (e.g., oil painting, digital illustration, anime, photorealistic, watercolor, line art)",
|
||||
"color_palette": "dominant colors and mood (e.g., warm earth tones, vibrant neon, muted pastels)",
|
||||
"lighting": "lighting style (e.g., soft natural light, dramatic shadows, studio lighting)",
|
||||
"texture": "visual texture (e.g., smooth, grainy, brushstroke visible)",
|
||||
"style_prefix": "A complete prompt prefix combining all elements (30-50 words)"
|
||||
}
|
||||
|
||||
Focus on creating a specific, reproducible style_prefix that will generate visually consistent images."""
|
||||
|
||||
# Try to call LLM with vision capability
|
||||
try:
|
||||
response = await self.llm_service(
|
||||
prompt=style_prompt,
|
||||
images=[f"data:{media_type};base64,{image_data}"],
|
||||
temperature=0.3,
|
||||
max_tokens=500
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback: try without image (text-only LLM)
|
||||
logger.warning(f"VLM call failed, using basic extraction: {e}")
|
||||
return self._extract_basic(image_path)
|
||||
|
||||
# Parse response
|
||||
import json
|
||||
import re
|
||||
|
||||
try:
|
||||
# Try to extract JSON from response
|
||||
match = re.search(r'\{[\s\S]*\}', response)
|
||||
if match:
|
||||
data = json.loads(match.group())
|
||||
else:
|
||||
data = json.loads(response)
|
||||
|
||||
anchor = StyleAnchor(
|
||||
art_style=data.get("art_style", ""),
|
||||
color_palette=data.get("color_palette", ""),
|
||||
lighting=data.get("lighting", ""),
|
||||
texture=data.get("texture", ""),
|
||||
style_prefix=data.get("style_prefix", ""),
|
||||
reference_image=image_path,
|
||||
)
|
||||
|
||||
logger.info(f"VLM extracted style: {anchor.style_prefix[:80]}...")
|
||||
return anchor
|
||||
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"Failed to parse VLM response: {e}")
|
||||
# Use the raw response as style_prefix if it looks reasonable
|
||||
if len(response) < 200 and len(response) > 20:
|
||||
return StyleAnchor(
|
||||
style_prefix=response.strip(),
|
||||
reference_image=image_path,
|
||||
)
|
||||
return self._extract_basic(image_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"VLM style extraction failed: {e}")
|
||||
return self._extract_basic(image_path)
|
||||
|
||||
def _extract_basic(self, image_path: str) -> StyleAnchor:
|
||||
"""Basic style extraction without VLM"""
|
||||
# Return generic style anchor
|
||||
"""Basic style extraction without VLM - analyze filename for hints"""
|
||||
import os
|
||||
|
||||
filename = os.path.basename(image_path).lower()
|
||||
|
||||
# Try to infer style from filename or path
|
||||
style_hints = []
|
||||
|
||||
if "anime" in filename or "cartoon" in filename:
|
||||
style_hints.append("anime style illustration")
|
||||
elif "realistic" in filename or "photo" in filename:
|
||||
style_hints.append("photorealistic style")
|
||||
elif "sketch" in filename or "line" in filename:
|
||||
style_hints.append("sketch style, clean lines")
|
||||
else:
|
||||
style_hints.append("consistent visual style, high quality")
|
||||
|
||||
return StyleAnchor(
|
||||
style_prefix="consistent visual style",
|
||||
style_prefix=", ".join(style_hints),
|
||||
reference_image=image_path,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user