This commit is contained in:
Qing
2023-12-30 23:36:44 +08:00
parent 85c3397b97
commit c4abda3942
35 changed files with 969 additions and 854 deletions

View File

@@ -8,7 +8,7 @@ from lama_cleaner.helper import switch_mps_device
from lama_cleaner.model import models, ControlNet, SD, SDXL
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_info import ModelInfo, ModelType
from lama_cleaner.schema import Config
from lama_cleaner.schema import InpaintRequest
class ModelManager:
@@ -31,13 +31,15 @@ class ModelManager:
self.model = self.init_model(name, device, **kwargs)
@property
def current_model(self) -> Dict:
return self.available_models[self.name].model_dump()
def current_model(self) -> ModelInfo:
return self.available_models[self.name]
def init_model(self, name: str, device, **kwargs):
logger.info(f"Loading model: {name}")
if name not in self.available_models:
raise NotImplementedError(f"Unsupported model: {name}. Available models: {self.available_models.keys()}")
raise NotImplementedError(
f"Unsupported model: {name}. Available models: {self.available_models.keys()}"
)
model_info = self.available_models[name]
kwargs = {
@@ -66,7 +68,17 @@ class ModelManager:
raise NotImplementedError(f"Unsupported model: {name}")
def __call__(self, image, mask, config: Config):
def __call__(self, image, mask, config: InpaintRequest):
"""
Args:
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
config:
Returns:
"""
self.switch_controlnet_method(config)
self.enable_disable_freeu(config)
self.enable_disable_lcm_lora(config)
@@ -135,7 +147,7 @@ class ModelManager:
else:
logger.info(f"Enable controlnet: {config.controlnet_method}")
def enable_disable_freeu(self, config: Config):
def enable_disable_freeu(self, config: InpaintRequest):
if str(self.model.device) == "mps":
return
@@ -151,7 +163,7 @@ class ModelManager:
else:
self.model.model.disable_freeu()
def enable_disable_lcm_lora(self, config: Config):
def enable_disable_lcm_lora(self, config: InpaintRequest):
if self.available_models[self.name].support_lcm_lora:
if config.sd_lcm_lora:
if not self.model.model.get_list_adapters():