fix: Use OpenAI multimodal message format for VLM style extraction
This commit is contained in:
@@ -170,6 +170,7 @@ class StyleGuard:
|
||||
|
||||
import base64
|
||||
import os
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Read and encode image
|
||||
if not os.path.exists(image_path):
|
||||
@@ -183,7 +184,7 @@ class StyleGuard:
|
||||
ext = os.path.splitext(image_path)[1].lower()
|
||||
media_type = "image/png" if ext == ".png" else "image/jpeg"
|
||||
|
||||
# Call VLM to analyze style
|
||||
# Style extraction prompt
|
||||
style_prompt = """Analyze this image and extract its visual style characteristics.
|
||||
|
||||
Provide a concise style description that could be used as a prefix for image generation prompts to maintain visual consistency.
|
||||
@@ -199,16 +200,43 @@ Output format (JSON):
|
||||
|
||||
Focus on creating a specific, reproducible style_prefix that will generate visually consistent images."""
|
||||
|
||||
# Try to call LLM with vision capability
|
||||
# Build multimodal message with image
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": style_prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{media_type};base64,{image_data}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Get LLM config for VLM call
|
||||
from pixelle_video.config import config_manager
|
||||
llm_config = config_manager.config.llm
|
||||
|
||||
# Create OpenAI client directly for VLM call
|
||||
client = AsyncOpenAI(
|
||||
api_key=llm_config.api_key,
|
||||
base_url=llm_config.base_url
|
||||
)
|
||||
|
||||
# Call VLM with multimodal message
|
||||
try:
|
||||
response = await self.llm_service(
|
||||
prompt=style_prompt,
|
||||
images=[f"data:{media_type};base64,{image_data}"],
|
||||
response = await client.chat.completions.create(
|
||||
model=llm_config.model,
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
max_tokens=500
|
||||
)
|
||||
vlm_response = response.choices[0].message.content
|
||||
logger.debug(f"VLM style extraction response: {vlm_response[:100]}...")
|
||||
except Exception as e:
|
||||
# Fallback: try without image (text-only LLM)
|
||||
logger.warning(f"VLM call failed, using basic extraction: {e}")
|
||||
return self._extract_basic(image_path)
|
||||
|
||||
@@ -218,11 +246,11 @@ Focus on creating a specific, reproducible style_prefix that will generate visua
|
||||
|
||||
try:
|
||||
# Try to extract JSON from response
|
||||
match = re.search(r'\{[\s\S]*\}', response)
|
||||
match = re.search(r'\{[\s\S]*\}', vlm_response)
|
||||
if match:
|
||||
data = json.loads(match.group())
|
||||
else:
|
||||
data = json.loads(response)
|
||||
data = json.loads(vlm_response)
|
||||
|
||||
anchor = StyleAnchor(
|
||||
art_style=data.get("art_style", ""),
|
||||
@@ -239,9 +267,9 @@ Focus on creating a specific, reproducible style_prefix that will generate visua
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"Failed to parse VLM response: {e}")
|
||||
# Use the raw response as style_prefix if it looks reasonable
|
||||
if len(response) < 200 and len(response) > 20:
|
||||
if len(vlm_response) < 200 and len(vlm_response) > 20:
|
||||
return StyleAnchor(
|
||||
style_prefix=response.strip(),
|
||||
style_prefix=vlm_response.strip(),
|
||||
reference_image=image_path,
|
||||
)
|
||||
return self._extract_basic(image_path)
|
||||
|
||||
Reference in New Issue
Block a user