diff --git a/api/routers/quality.py b/api/routers/quality.py index 6012179..b8fe2e0 100644 --- a/api/routers/quality.py +++ b/api/routers/quality.py @@ -254,8 +254,13 @@ async def extract_style( return style_schema from pixelle_video.services.quality.style_guard import StyleGuard + from api.dependencies import get_pixelle_video - style_guard = StyleGuard() + # Get LLM service for VLM-based style extraction + pixelle_video = await get_pixelle_video() + llm_service = pixelle_video.llm if pixelle_video else None + + style_guard = StyleGuard(llm_service=llm_service) anchor = await style_guard.extract_style_anchor(actual_path) style_schema = StyleAnchorSchema( diff --git a/pixelle_video/services/quality/style_guard.py b/pixelle_video/services/quality/style_guard.py index 1ccc698..2001a77 100644 --- a/pixelle_video/services/quality/style_guard.py +++ b/pixelle_video/services/quality/style_guard.py @@ -164,28 +164,112 @@ class StyleGuard: async def _extract_with_vlm(self, image_path: str) -> StyleAnchor: """Extract style using Vision Language Model""" try: - # TODO: Implement VLM call when vision-capable LLM is integrated - # For now, return a placeholder - logger.debug("VLM style extraction: using placeholder (VLM not yet integrated)") + if not self.llm_service: + logger.warning("No LLM service available, using basic extraction") + return self._extract_basic(image_path) - # Placeholder extraction based on common styles - return StyleAnchor( - art_style="consistent artistic", - color_palette="harmonious colors", - lighting="balanced", - style_prefix="maintaining visual consistency, same artistic style as previous frames", - reference_image=image_path, - ) + import base64 + import os + + # Read and encode image + if not os.path.exists(image_path): + logger.warning(f"Image not found: {image_path}") + return self._extract_basic(image_path) + + with open(image_path, "rb") as f: + image_data = base64.b64encode(f.read()).decode("utf-8") + + # Determine image type + ext = os.path.splitext(image_path)[1].lower() + media_type = "image/png" if ext == ".png" else "image/jpeg" + + # Call VLM to analyze style + style_prompt = """Analyze this image and extract its visual style characteristics. + +Provide a concise style description that could be used as a prefix for image generation prompts to maintain visual consistency. + +Output format (JSON): +{ + "art_style": "specific art style (e.g., oil painting, digital illustration, anime, photorealistic, watercolor, line art)", + "color_palette": "dominant colors and mood (e.g., warm earth tones, vibrant neon, muted pastels)", + "lighting": "lighting style (e.g., soft natural light, dramatic shadows, studio lighting)", + "texture": "visual texture (e.g., smooth, grainy, brushstroke visible)", + "style_prefix": "A complete prompt prefix combining all elements (30-50 words)" +} + +Focus on creating a specific, reproducible style_prefix that will generate visually consistent images.""" + + # Try to call LLM with vision capability + try: + response = await self.llm_service( + prompt=style_prompt, + images=[f"data:{media_type};base64,{image_data}"], + temperature=0.3, + max_tokens=500 + ) + except Exception as e: + # Fallback: try without image (text-only LLM) + logger.warning(f"VLM call failed, using basic extraction: {e}") + return self._extract_basic(image_path) + + # Parse response + import json + import re + + try: + # Try to extract JSON from response + match = re.search(r'\{[\s\S]*\}', response) + if match: + data = json.loads(match.group()) + else: + data = json.loads(response) + + anchor = StyleAnchor( + art_style=data.get("art_style", ""), + color_palette=data.get("color_palette", ""), + lighting=data.get("lighting", ""), + texture=data.get("texture", ""), + style_prefix=data.get("style_prefix", ""), + reference_image=image_path, + ) + + logger.info(f"VLM extracted style: {anchor.style_prefix[:80]}...") + return anchor + + except (json.JSONDecodeError, KeyError) as e: + logger.warning(f"Failed to parse VLM response: {e}") + # Use the raw response as style_prefix if it looks reasonable + if len(response) < 200 and len(response) > 20: + return StyleAnchor( + style_prefix=response.strip(), + reference_image=image_path, + ) + return self._extract_basic(image_path) except Exception as e: logger.warning(f"VLM style extraction failed: {e}") return self._extract_basic(image_path) def _extract_basic(self, image_path: str) -> StyleAnchor: - """Basic style extraction without VLM""" - # Return generic style anchor + """Basic style extraction without VLM - analyze filename for hints""" + import os + + filename = os.path.basename(image_path).lower() + + # Try to infer style from filename or path + style_hints = [] + + if "anime" in filename or "cartoon" in filename: + style_hints.append("anime style illustration") + elif "realistic" in filename or "photo" in filename: + style_hints.append("photorealistic style") + elif "sketch" in filename or "line" in filename: + style_hints.append("sketch style, clean lines") + else: + style_hints.append("consistent visual style, high quality") + return StyleAnchor( - style_prefix="consistent visual style", + style_prefix=", ".join(style_hints), reference_image=image_path, )