Files
AI-Video/pixelle_video/services/quality/style_guard.py

389 lines
14 KiB
Python

# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
StyleGuard - Visual style consistency engine
Ensures consistent visual style across all frames in a video by:
1. Extracting style anchor from the first generated frame
2. Injecting style constraints into subsequent frame prompts
3. (Optional) Using style reference techniques like IP-Adapter
"""
from dataclasses import dataclass, field
from typing import List, Optional
from loguru import logger
@dataclass
class StyleAnchor:
"""Style anchor extracted from reference frame"""
# Core style elements
color_palette: str = "" # e.g., "warm earth tones", "cool blues"
art_style: str = "" # e.g., "minimalist", "realistic", "anime"
composition_style: str = "" # e.g., "centered", "rule of thirds"
texture: str = "" # e.g., "smooth", "grainy", "watercolor"
lighting: str = "" # e.g., "soft ambient", "dramatic shadows"
# Combined style prefix for prompts
style_prefix: str = ""
# Reference image path (for IP-Adapter style techniques)
reference_image: Optional[str] = None
def to_prompt_prefix(self) -> str:
"""Generate a style prefix for image prompts"""
if self.style_prefix:
return self.style_prefix
elements = []
if self.art_style:
elements.append(f"{self.art_style} style")
if self.color_palette:
elements.append(f"{self.color_palette}")
if self.lighting:
elements.append(f"{self.lighting} lighting")
if self.texture:
elements.append(f"{self.texture} texture")
return ", ".join(elements) if elements else ""
def to_dict(self) -> dict:
return {
"color_palette": self.color_palette,
"art_style": self.art_style,
"composition_style": self.composition_style,
"texture": self.texture,
"lighting": self.lighting,
"style_prefix": self.style_prefix,
"reference_image": self.reference_image,
}
@dataclass
class StyleGuardConfig:
"""Configuration for StyleGuard"""
# Extraction settings
extract_from_first_frame: bool = True
use_vlm_extraction: bool = True
# Application settings
apply_to_all_frames: bool = True
prefix_position: str = "start" # "start" or "end"
# Optional external style reference
external_style_image: Optional[str] = None
custom_style_prefix: Optional[str] = None
class StyleGuard:
"""
Style consistency guardian for video generation
Ensures all frames in a video maintain visual consistency by:
1. Analyzing the first frame (or reference image) to extract style
2. Applying style constraints to all subsequent frame prompts
Example:
>>> style_guard = StyleGuard(llm_service)
>>>
>>> # Extract style from first frame
>>> anchor = await style_guard.extract_style_anchor(
... image_path="output/frame_001.png"
... )
>>>
>>> # Apply to subsequent prompts
>>> styled_prompt = style_guard.apply_style(
... prompt="A cat sitting on a windowsill",
... style_anchor=anchor
... )
"""
def __init__(
self,
llm_service=None,
config: Optional[StyleGuardConfig] = None
):
"""
Initialize StyleGuard
Args:
llm_service: LLM service for VLM-based style extraction
config: StyleGuard configuration
"""
self.llm_service = llm_service
self.config = config or StyleGuardConfig()
self._current_anchor: Optional[StyleAnchor] = None
async def extract_style_anchor(
self,
image_path: str,
) -> StyleAnchor:
"""
Extract style anchor from reference image
Args:
image_path: Path to reference image
Returns:
StyleAnchor with extracted style elements
"""
logger.info(f"Extracting style anchor from: {image_path}")
if self.config.custom_style_prefix:
# Use custom style prefix if provided
anchor = StyleAnchor(
style_prefix=self.config.custom_style_prefix,
reference_image=image_path
)
self._current_anchor = anchor
return anchor
if self.config.use_vlm_extraction and self.llm_service:
anchor = await self._extract_with_vlm(image_path)
else:
anchor = self._extract_basic(image_path)
self._current_anchor = anchor
logger.info(f"Style anchor extracted: {anchor.to_prompt_prefix()}")
return anchor
async def _extract_with_vlm(self, image_path: str) -> StyleAnchor:
"""Extract style using Vision Language Model"""
try:
if not self.llm_service:
logger.warning("No LLM service available, using basic extraction")
return self._extract_basic(image_path)
import base64
import os
from openai import AsyncOpenAI
# Read and encode image
if not os.path.exists(image_path):
logger.warning(f"Image not found: {image_path}")
return self._extract_basic(image_path)
with open(image_path, "rb") as f:
image_data = base64.b64encode(f.read()).decode("utf-8")
# Determine image type
ext = os.path.splitext(image_path)[1].lower()
media_type = "image/png" if ext == ".png" else "image/jpeg"
# Style extraction prompt
style_prompt = """Analyze this image and extract its visual style characteristics.
Provide a concise style description that could be used as a prefix for image generation prompts to maintain visual consistency.
Output format (JSON):
{
"art_style": "specific art style (e.g., oil painting, digital illustration, anime, photorealistic, watercolor, line art)",
"color_palette": "dominant colors and mood (e.g., warm earth tones, vibrant neon, muted pastels)",
"lighting": "lighting style (e.g., soft natural light, dramatic shadows, studio lighting)",
"texture": "visual texture (e.g., smooth, grainy, brushstroke visible)",
"style_prefix": "A complete prompt prefix combining all elements (30-50 words)"
}
Focus on creating a specific, reproducible style_prefix that will generate visually consistent images."""
# Build multimodal message with image
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": style_prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:{media_type};base64,{image_data}"
}
}
]
}
]
# Get LLM config for VLM call
from pixelle_video.config import config_manager
llm_config = config_manager.config.llm
# Create OpenAI client directly for VLM call
client = AsyncOpenAI(
api_key=llm_config.api_key,
base_url=llm_config.base_url
)
# Call VLM with multimodal message
try:
response = await client.chat.completions.create(
model=llm_config.model,
messages=messages,
temperature=0.3,
max_tokens=500
)
vlm_response = response.choices[0].message.content
logger.debug(f"VLM style extraction response: {vlm_response[:100]}...")
except Exception as e:
logger.warning(f"VLM call failed, using basic extraction: {e}")
return self._extract_basic(image_path)
# Parse response
import json
import re
try:
# Try to extract JSON from response
match = re.search(r'\{[\s\S]*\}', vlm_response)
if match:
data = json.loads(match.group())
else:
data = json.loads(vlm_response)
anchor = StyleAnchor(
art_style=data.get("art_style", ""),
color_palette=data.get("color_palette", ""),
lighting=data.get("lighting", ""),
texture=data.get("texture", ""),
style_prefix=data.get("style_prefix", ""),
reference_image=image_path,
)
logger.info(f"VLM extracted style: {anchor.style_prefix[:80]}...")
return anchor
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Failed to parse VLM response: {e}")
# Use the raw response as style_prefix if it looks reasonable
if len(vlm_response) < 200 and len(vlm_response) > 20:
return StyleAnchor(
style_prefix=vlm_response.strip(),
reference_image=image_path,
)
return self._extract_basic(image_path)
except Exception as e:
logger.warning(f"VLM style extraction failed: {e}")
return self._extract_basic(image_path)
def _extract_basic(self, image_path: str) -> StyleAnchor:
"""Basic style extraction without VLM - analyze filename for hints"""
import os
filename = os.path.basename(image_path).lower()
# Try to infer style from filename or path
style_hints = []
if "anime" in filename or "cartoon" in filename:
style_hints.append("anime style illustration")
elif "realistic" in filename or "photo" in filename:
style_hints.append("photorealistic style")
elif "sketch" in filename or "line" in filename:
style_hints.append("sketch style, clean lines")
else:
style_hints.append("consistent visual style, high quality")
return StyleAnchor(
style_prefix=", ".join(style_hints),
reference_image=image_path,
)
def apply_style(
self,
prompt: str,
style_anchor: Optional[StyleAnchor] = None,
) -> str:
"""
Apply style constraints to an image prompt
Args:
prompt: Original image prompt
style_anchor: Style anchor to apply (uses current if not provided)
Returns:
Modified prompt with style constraints
"""
anchor = style_anchor or self._current_anchor
if not anchor:
return prompt
style_prefix = anchor.to_prompt_prefix()
if not style_prefix:
return prompt
if self.config.prefix_position == "start":
return f"{style_prefix}, {prompt}"
else:
return f"{prompt}, {style_prefix}"
def apply_style_to_batch(
self,
prompts: List[str],
style_anchor: Optional[StyleAnchor] = None,
skip_first: bool = True,
) -> List[str]:
"""
Apply style constraints to a batch of prompts
Args:
prompts: List of image prompts
style_anchor: Style anchor to apply
skip_first: Skip first prompt (used as reference)
Returns:
List of styled prompts
"""
if not prompts:
return prompts
anchor = style_anchor or self._current_anchor
if not anchor:
return prompts
result = []
for i, prompt in enumerate(prompts):
if skip_first and i == 0:
result.append(prompt)
else:
result.append(self.apply_style(prompt, anchor))
return result
def get_consistency_prompt_suffix(self) -> str:
"""
Get a consistency prompt suffix for LLM prompt generation
This can be added to the LLM prompt when generating image prompts
to encourage consistent style descriptions.
"""
return (
"Ensure all image prompts maintain consistent visual style, "
"including similar color palette, art style, lighting, and composition. "
"Each image should feel like it belongs to the same visual narrative."
)
@property
def current_anchor(self) -> Optional[StyleAnchor]:
"""Get the current style anchor"""
return self._current_anchor
def reset(self):
"""Reset the current style anchor"""
self._current_anchor = None