249 lines
8.8 KiB
Python
249 lines
8.8 KiB
Python
# Copyright (C) 2025 AIDC-AI
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Unit tests for CharacterMemory module
|
|
"""
|
|
|
|
import pytest
|
|
from pixelle_video.services.quality import CharacterMemory, Character, CharacterType, CharacterMemoryConfig
|
|
|
|
|
|
class TestCharacter:
|
|
"""Tests for Character dataclass"""
|
|
|
|
def test_character_creation(self):
|
|
"""Test basic character creation"""
|
|
char = Character(
|
|
id="char_0_test",
|
|
name="TestChar",
|
|
appearance_description="young man with black hair",
|
|
clothing_description="blue shirt"
|
|
)
|
|
|
|
assert char.id == "char_0_test"
|
|
assert char.name == "TestChar"
|
|
assert char.appearance_description == "young man with black hair"
|
|
assert char.clothing_description == "blue shirt"
|
|
assert char.is_active == True
|
|
assert char.character_type == CharacterType.PERSON
|
|
|
|
def test_character_prompt_injection(self):
|
|
"""Test prompt injection generation"""
|
|
char = Character(
|
|
id="char_1_hero",
|
|
name="Hero",
|
|
appearance_description="tall warrior with long hair",
|
|
clothing_description="silver armor"
|
|
)
|
|
|
|
injection = char.get_prompt_injection()
|
|
|
|
assert "Hero" in injection
|
|
assert "tall warrior" in injection
|
|
assert "silver armor" in injection
|
|
|
|
def test_character_matches_name(self):
|
|
"""Test name matching with aliases"""
|
|
char = Character(
|
|
id="char_2_xiaoming",
|
|
name="小明",
|
|
aliases=["Ming", "小明同学"]
|
|
)
|
|
|
|
assert char.matches_name("小明") == True
|
|
assert char.matches_name("Ming") == True
|
|
assert char.matches_name("小明同学") == True
|
|
assert char.matches_name("小红") == False
|
|
|
|
def test_character_reference_image(self):
|
|
"""Test adding reference images"""
|
|
char = Character(id="char_3_test", name="Test")
|
|
|
|
char.add_reference_image("/path/to/image1.png", set_as_primary=True)
|
|
char.add_reference_image("/path/to/image2.png")
|
|
|
|
assert len(char.reference_images) == 2
|
|
assert char.primary_reference == "/path/to/image1.png"
|
|
|
|
def test_character_is_active_default(self):
|
|
"""Test is_active defaults to True"""
|
|
char = Character(id="char_4_active", name="ActiveChar")
|
|
assert char.is_active == True
|
|
|
|
|
|
class TestCharacterMemory:
|
|
"""Tests for CharacterMemory class"""
|
|
|
|
def test_memory_initialization(self):
|
|
"""Test memory initialization"""
|
|
memory = CharacterMemory()
|
|
|
|
assert memory.characters == []
|
|
assert memory.config is not None
|
|
|
|
def test_register_character(self):
|
|
"""Test character registration"""
|
|
memory = CharacterMemory()
|
|
|
|
char = memory.register_character(
|
|
name="小美",
|
|
appearance_description="young woman with long black hair",
|
|
clothing_description="red dress",
|
|
distinctive_features=["glasses", "necklace"]
|
|
)
|
|
|
|
assert char.name == "小美"
|
|
assert len(memory.characters) == 1
|
|
assert memory.get_character("小美") == char
|
|
|
|
def test_get_character_by_name(self):
|
|
"""Test getting character by name"""
|
|
memory = CharacterMemory()
|
|
memory.register_character(name="Hero", appearance_description="tall")
|
|
memory.register_character(name="Villain", appearance_description="dark")
|
|
|
|
hero = memory.get_character("Hero")
|
|
villain = memory.get_character("Villain")
|
|
unknown = memory.get_character("Unknown")
|
|
|
|
assert hero is not None
|
|
assert hero.name == "Hero"
|
|
assert villain is not None
|
|
assert unknown is None
|
|
|
|
def test_apply_to_prompt(self):
|
|
"""Test applying character consistency to prompt"""
|
|
memory = CharacterMemory()
|
|
memory.register_character(
|
|
name="Alice",
|
|
appearance_description="blonde girl",
|
|
clothing_description="blue dress"
|
|
)
|
|
|
|
original_prompt = "A girl walking in the park"
|
|
enhanced_prompt = memory.apply_to_prompt(original_prompt)
|
|
|
|
assert "Alice" in enhanced_prompt
|
|
assert "blonde girl" in enhanced_prompt
|
|
assert "walking in the park" in enhanced_prompt
|
|
|
|
def test_apply_to_prompt_with_specific_characters(self):
|
|
"""Test applying only specific characters to prompt"""
|
|
memory = CharacterMemory()
|
|
memory.register_character(name="Bob", appearance_description="tall man")
|
|
memory.register_character(name="Charlie", appearance_description="short man")
|
|
|
|
prompt = "Two people talking"
|
|
enhanced = memory.apply_to_prompt(prompt, character_names=["Bob"])
|
|
|
|
assert "Bob" in enhanced
|
|
assert "Charlie" not in enhanced
|
|
|
|
def test_apply_to_prompt_disabled(self):
|
|
"""Test that prompt injection can be disabled"""
|
|
config = CharacterMemoryConfig(inject_character_prompts=False)
|
|
memory = CharacterMemory(config=config)
|
|
memory.register_character(name="Test", appearance_description="test desc")
|
|
|
|
original = "Original prompt"
|
|
result = memory.apply_to_prompt(original)
|
|
|
|
assert result == original
|
|
|
|
def test_get_reference_images(self):
|
|
"""Test getting reference images"""
|
|
memory = CharacterMemory()
|
|
char = memory.register_character(name="Photo")
|
|
char.add_reference_image("/path/ref1.png", set_as_primary=True)
|
|
|
|
images = memory.get_reference_images()
|
|
|
|
assert len(images) == 1
|
|
assert images[0] == "/path/ref1.png"
|
|
|
|
def test_reset_memory(self):
|
|
"""Test clearing all characters"""
|
|
memory = CharacterMemory()
|
|
memory.register_character(name="Temp1")
|
|
memory.register_character(name="Temp2")
|
|
|
|
assert len(memory.characters) == 2
|
|
|
|
memory.reset()
|
|
|
|
assert len(memory.characters) == 0
|
|
|
|
def test_consistency_summary(self):
|
|
"""Test consistency summary generation"""
|
|
memory = CharacterMemory()
|
|
memory.register_character(name="Char1")
|
|
memory.register_character(name="Char2")
|
|
|
|
summary = memory.get_consistency_summary()
|
|
|
|
assert "Characters (2):" in summary
|
|
assert "Char1" in summary
|
|
assert "Char2" in summary
|
|
|
|
|
|
class TestCharacterMemoryConfig:
|
|
"""Tests for CharacterMemoryConfig"""
|
|
|
|
def test_default_config(self):
|
|
"""Test default configuration values"""
|
|
config = CharacterMemoryConfig()
|
|
|
|
assert config.auto_detect_characters == True
|
|
assert config.use_llm_detection == True
|
|
assert config.inject_character_prompts == True
|
|
assert config.prompt_injection_position == "start"
|
|
|
|
def test_custom_config(self):
|
|
"""Test custom configuration"""
|
|
config = CharacterMemoryConfig(
|
|
auto_detect_characters=False,
|
|
inject_character_prompts=False,
|
|
prompt_injection_position="end"
|
|
)
|
|
|
|
assert config.auto_detect_characters == False
|
|
assert config.inject_character_prompts == False
|
|
assert config.prompt_injection_position == "end"
|
|
|
|
|
|
# Async tests for LLM-based detection
|
|
class TestCharacterMemoryAsync:
|
|
"""Async tests for CharacterMemory"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_detect_characters_disabled(self):
|
|
"""Test that detection is skipped when disabled"""
|
|
config = CharacterMemoryConfig(auto_detect_characters=False)
|
|
memory = CharacterMemory(config=config)
|
|
|
|
result = await memory.detect_characters_from_narration("小明在公园散步")
|
|
|
|
assert result == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_basic_detection_without_llm(self):
|
|
"""Test basic detection without LLM service"""
|
|
config = CharacterMemoryConfig(auto_detect_characters=True, use_llm_detection=False)
|
|
memory = CharacterMemory(config=config)
|
|
|
|
# This will use _detect_basic which looks for patterns
|
|
result = await memory.detect_characters_from_narration("小明走在路上")
|
|
|
|
# Basic detection should find "小明" pattern
|
|
assert isinstance(result, list)
|