389 lines
14 KiB
Python
389 lines
14 KiB
Python
# Copyright (C) 2025 AIDC-AI
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
StyleGuard - Visual style consistency engine
|
|
|
|
Ensures consistent visual style across all frames in a video by:
|
|
1. Extracting style anchor from the first generated frame
|
|
2. Injecting style constraints into subsequent frame prompts
|
|
3. (Optional) Using style reference techniques like IP-Adapter
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import List, Optional
|
|
|
|
from loguru import logger
|
|
|
|
|
|
@dataclass
|
|
class StyleAnchor:
|
|
"""Style anchor extracted from reference frame"""
|
|
|
|
# Core style elements
|
|
color_palette: str = "" # e.g., "warm earth tones", "cool blues"
|
|
art_style: str = "" # e.g., "minimalist", "realistic", "anime"
|
|
composition_style: str = "" # e.g., "centered", "rule of thirds"
|
|
texture: str = "" # e.g., "smooth", "grainy", "watercolor"
|
|
lighting: str = "" # e.g., "soft ambient", "dramatic shadows"
|
|
|
|
# Combined style prefix for prompts
|
|
style_prefix: str = ""
|
|
|
|
# Reference image path (for IP-Adapter style techniques)
|
|
reference_image: Optional[str] = None
|
|
|
|
def to_prompt_prefix(self) -> str:
|
|
"""Generate a style prefix for image prompts"""
|
|
if self.style_prefix:
|
|
return self.style_prefix
|
|
|
|
elements = []
|
|
if self.art_style:
|
|
elements.append(f"{self.art_style} style")
|
|
if self.color_palette:
|
|
elements.append(f"{self.color_palette}")
|
|
if self.lighting:
|
|
elements.append(f"{self.lighting} lighting")
|
|
if self.texture:
|
|
elements.append(f"{self.texture} texture")
|
|
|
|
return ", ".join(elements) if elements else ""
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"color_palette": self.color_palette,
|
|
"art_style": self.art_style,
|
|
"composition_style": self.composition_style,
|
|
"texture": self.texture,
|
|
"lighting": self.lighting,
|
|
"style_prefix": self.style_prefix,
|
|
"reference_image": self.reference_image,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class StyleGuardConfig:
|
|
"""Configuration for StyleGuard"""
|
|
|
|
# Extraction settings
|
|
extract_from_first_frame: bool = True
|
|
use_vlm_extraction: bool = True
|
|
|
|
# Application settings
|
|
apply_to_all_frames: bool = True
|
|
prefix_position: str = "start" # "start" or "end"
|
|
|
|
# Optional external style reference
|
|
external_style_image: Optional[str] = None
|
|
custom_style_prefix: Optional[str] = None
|
|
|
|
|
|
class StyleGuard:
|
|
"""
|
|
Style consistency guardian for video generation
|
|
|
|
Ensures all frames in a video maintain visual consistency by:
|
|
1. Analyzing the first frame (or reference image) to extract style
|
|
2. Applying style constraints to all subsequent frame prompts
|
|
|
|
Example:
|
|
>>> style_guard = StyleGuard(llm_service)
|
|
>>>
|
|
>>> # Extract style from first frame
|
|
>>> anchor = await style_guard.extract_style_anchor(
|
|
... image_path="output/frame_001.png"
|
|
... )
|
|
>>>
|
|
>>> # Apply to subsequent prompts
|
|
>>> styled_prompt = style_guard.apply_style(
|
|
... prompt="A cat sitting on a windowsill",
|
|
... style_anchor=anchor
|
|
... )
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
llm_service=None,
|
|
config: Optional[StyleGuardConfig] = None
|
|
):
|
|
"""
|
|
Initialize StyleGuard
|
|
|
|
Args:
|
|
llm_service: LLM service for VLM-based style extraction
|
|
config: StyleGuard configuration
|
|
"""
|
|
self.llm_service = llm_service
|
|
self.config = config or StyleGuardConfig()
|
|
self._current_anchor: Optional[StyleAnchor] = None
|
|
|
|
async def extract_style_anchor(
|
|
self,
|
|
image_path: str,
|
|
) -> StyleAnchor:
|
|
"""
|
|
Extract style anchor from reference image
|
|
|
|
Args:
|
|
image_path: Path to reference image
|
|
|
|
Returns:
|
|
StyleAnchor with extracted style elements
|
|
"""
|
|
logger.info(f"Extracting style anchor from: {image_path}")
|
|
|
|
if self.config.custom_style_prefix:
|
|
# Use custom style prefix if provided
|
|
anchor = StyleAnchor(
|
|
style_prefix=self.config.custom_style_prefix,
|
|
reference_image=image_path
|
|
)
|
|
self._current_anchor = anchor
|
|
return anchor
|
|
|
|
if self.config.use_vlm_extraction and self.llm_service:
|
|
anchor = await self._extract_with_vlm(image_path)
|
|
else:
|
|
anchor = self._extract_basic(image_path)
|
|
|
|
self._current_anchor = anchor
|
|
logger.info(f"Style anchor extracted: {anchor.to_prompt_prefix()}")
|
|
|
|
return anchor
|
|
|
|
async def _extract_with_vlm(self, image_path: str) -> StyleAnchor:
|
|
"""Extract style using Vision Language Model"""
|
|
try:
|
|
if not self.llm_service:
|
|
logger.warning("No LLM service available, using basic extraction")
|
|
return self._extract_basic(image_path)
|
|
|
|
import base64
|
|
import os
|
|
from openai import AsyncOpenAI
|
|
|
|
# Read and encode image
|
|
if not os.path.exists(image_path):
|
|
logger.warning(f"Image not found: {image_path}")
|
|
return self._extract_basic(image_path)
|
|
|
|
with open(image_path, "rb") as f:
|
|
image_data = base64.b64encode(f.read()).decode("utf-8")
|
|
|
|
# Determine image type
|
|
ext = os.path.splitext(image_path)[1].lower()
|
|
media_type = "image/png" if ext == ".png" else "image/jpeg"
|
|
|
|
# Style extraction prompt
|
|
style_prompt = """Analyze this image and extract its visual style characteristics.
|
|
|
|
Provide a concise style description that could be used as a prefix for image generation prompts to maintain visual consistency.
|
|
|
|
Output format (JSON):
|
|
{
|
|
"art_style": "specific art style (e.g., oil painting, digital illustration, anime, photorealistic, watercolor, line art)",
|
|
"color_palette": "dominant colors and mood (e.g., warm earth tones, vibrant neon, muted pastels)",
|
|
"lighting": "lighting style (e.g., soft natural light, dramatic shadows, studio lighting)",
|
|
"texture": "visual texture (e.g., smooth, grainy, brushstroke visible)",
|
|
"style_prefix": "A complete prompt prefix combining all elements (30-50 words)"
|
|
}
|
|
|
|
Focus on creating a specific, reproducible style_prefix that will generate visually consistent images."""
|
|
|
|
# Build multimodal message with image
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": style_prompt},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:{media_type};base64,{image_data}"
|
|
}
|
|
}
|
|
]
|
|
}
|
|
]
|
|
|
|
# Get LLM config for VLM call
|
|
from pixelle_video.config import config_manager
|
|
llm_config = config_manager.config.llm
|
|
|
|
# Create OpenAI client directly for VLM call
|
|
client = AsyncOpenAI(
|
|
api_key=llm_config.api_key,
|
|
base_url=llm_config.base_url
|
|
)
|
|
|
|
# Call VLM with multimodal message
|
|
try:
|
|
response = await client.chat.completions.create(
|
|
model=llm_config.model,
|
|
messages=messages,
|
|
temperature=0.3,
|
|
max_tokens=500
|
|
)
|
|
vlm_response = response.choices[0].message.content
|
|
logger.debug(f"VLM style extraction response: {vlm_response[:100]}...")
|
|
except Exception as e:
|
|
logger.warning(f"VLM call failed, using basic extraction: {e}")
|
|
return self._extract_basic(image_path)
|
|
|
|
# Parse response
|
|
import json
|
|
import re
|
|
|
|
try:
|
|
# Try to extract JSON from response
|
|
match = re.search(r'\{[\s\S]*\}', vlm_response)
|
|
if match:
|
|
data = json.loads(match.group())
|
|
else:
|
|
data = json.loads(vlm_response)
|
|
|
|
anchor = StyleAnchor(
|
|
art_style=data.get("art_style", ""),
|
|
color_palette=data.get("color_palette", ""),
|
|
lighting=data.get("lighting", ""),
|
|
texture=data.get("texture", ""),
|
|
style_prefix=data.get("style_prefix", ""),
|
|
reference_image=image_path,
|
|
)
|
|
|
|
logger.info(f"VLM extracted style: {anchor.style_prefix[:80]}...")
|
|
return anchor
|
|
|
|
except (json.JSONDecodeError, KeyError) as e:
|
|
logger.warning(f"Failed to parse VLM response: {e}")
|
|
# Use the raw response as style_prefix if it looks reasonable
|
|
if len(vlm_response) < 200 and len(vlm_response) > 20:
|
|
return StyleAnchor(
|
|
style_prefix=vlm_response.strip(),
|
|
reference_image=image_path,
|
|
)
|
|
return self._extract_basic(image_path)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"VLM style extraction failed: {e}")
|
|
return self._extract_basic(image_path)
|
|
|
|
def _extract_basic(self, image_path: str) -> StyleAnchor:
|
|
"""Basic style extraction without VLM - analyze filename for hints"""
|
|
import os
|
|
|
|
filename = os.path.basename(image_path).lower()
|
|
|
|
# Try to infer style from filename or path
|
|
style_hints = []
|
|
|
|
if "anime" in filename or "cartoon" in filename:
|
|
style_hints.append("anime style illustration")
|
|
elif "realistic" in filename or "photo" in filename:
|
|
style_hints.append("photorealistic style")
|
|
elif "sketch" in filename or "line" in filename:
|
|
style_hints.append("sketch style, clean lines")
|
|
else:
|
|
style_hints.append("consistent visual style, high quality")
|
|
|
|
return StyleAnchor(
|
|
style_prefix=", ".join(style_hints),
|
|
reference_image=image_path,
|
|
)
|
|
|
|
def apply_style(
|
|
self,
|
|
prompt: str,
|
|
style_anchor: Optional[StyleAnchor] = None,
|
|
) -> str:
|
|
"""
|
|
Apply style constraints to an image prompt
|
|
|
|
Args:
|
|
prompt: Original image prompt
|
|
style_anchor: Style anchor to apply (uses current if not provided)
|
|
|
|
Returns:
|
|
Modified prompt with style constraints
|
|
"""
|
|
anchor = style_anchor or self._current_anchor
|
|
|
|
if not anchor:
|
|
return prompt
|
|
|
|
style_prefix = anchor.to_prompt_prefix()
|
|
|
|
if not style_prefix:
|
|
return prompt
|
|
|
|
if self.config.prefix_position == "start":
|
|
return f"{style_prefix}, {prompt}"
|
|
else:
|
|
return f"{prompt}, {style_prefix}"
|
|
|
|
def apply_style_to_batch(
|
|
self,
|
|
prompts: List[str],
|
|
style_anchor: Optional[StyleAnchor] = None,
|
|
skip_first: bool = True,
|
|
) -> List[str]:
|
|
"""
|
|
Apply style constraints to a batch of prompts
|
|
|
|
Args:
|
|
prompts: List of image prompts
|
|
style_anchor: Style anchor to apply
|
|
skip_first: Skip first prompt (used as reference)
|
|
|
|
Returns:
|
|
List of styled prompts
|
|
"""
|
|
if not prompts:
|
|
return prompts
|
|
|
|
anchor = style_anchor or self._current_anchor
|
|
|
|
if not anchor:
|
|
return prompts
|
|
|
|
result = []
|
|
for i, prompt in enumerate(prompts):
|
|
if skip_first and i == 0:
|
|
result.append(prompt)
|
|
else:
|
|
result.append(self.apply_style(prompt, anchor))
|
|
|
|
return result
|
|
|
|
def get_consistency_prompt_suffix(self) -> str:
|
|
"""
|
|
Get a consistency prompt suffix for LLM prompt generation
|
|
|
|
This can be added to the LLM prompt when generating image prompts
|
|
to encourage consistent style descriptions.
|
|
"""
|
|
return (
|
|
"Ensure all image prompts maintain consistent visual style, "
|
|
"including similar color palette, art style, lighting, and composition. "
|
|
"Each image should feel like it belongs to the same visual narrative."
|
|
)
|
|
|
|
@property
|
|
def current_anchor(self) -> Optional[StyleAnchor]:
|
|
"""Get the current style anchor"""
|
|
return self._current_anchor
|
|
|
|
def reset(self):
|
|
"""Reset the current style anchor"""
|
|
self._current_anchor = None
|