wip
This commit is contained in:
@@ -8,7 +8,7 @@ from lama_cleaner.helper import switch_mps_device
|
||||
from lama_cleaner.model import models, ControlNet, SD, SDXL
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.model_info import ModelInfo, ModelType
|
||||
from lama_cleaner.schema import Config
|
||||
from lama_cleaner.schema import InpaintRequest
|
||||
|
||||
|
||||
class ModelManager:
|
||||
@@ -31,13 +31,15 @@ class ModelManager:
|
||||
self.model = self.init_model(name, device, **kwargs)
|
||||
|
||||
@property
|
||||
def current_model(self) -> Dict:
|
||||
return self.available_models[self.name].model_dump()
|
||||
def current_model(self) -> ModelInfo:
|
||||
return self.available_models[self.name]
|
||||
|
||||
def init_model(self, name: str, device, **kwargs):
|
||||
logger.info(f"Loading model: {name}")
|
||||
if name not in self.available_models:
|
||||
raise NotImplementedError(f"Unsupported model: {name}. Available models: {self.available_models.keys()}")
|
||||
raise NotImplementedError(
|
||||
f"Unsupported model: {name}. Available models: {self.available_models.keys()}"
|
||||
)
|
||||
|
||||
model_info = self.available_models[name]
|
||||
kwargs = {
|
||||
@@ -66,7 +68,17 @@ class ModelManager:
|
||||
|
||||
raise NotImplementedError(f"Unsupported model: {name}")
|
||||
|
||||
def __call__(self, image, mask, config: Config):
|
||||
def __call__(self, image, mask, config: InpaintRequest):
|
||||
"""
|
||||
|
||||
Args:
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
config:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
self.switch_controlnet_method(config)
|
||||
self.enable_disable_freeu(config)
|
||||
self.enable_disable_lcm_lora(config)
|
||||
@@ -135,7 +147,7 @@ class ModelManager:
|
||||
else:
|
||||
logger.info(f"Enable controlnet: {config.controlnet_method}")
|
||||
|
||||
def enable_disable_freeu(self, config: Config):
|
||||
def enable_disable_freeu(self, config: InpaintRequest):
|
||||
if str(self.model.device) == "mps":
|
||||
return
|
||||
|
||||
@@ -151,7 +163,7 @@ class ModelManager:
|
||||
else:
|
||||
self.model.model.disable_freeu()
|
||||
|
||||
def enable_disable_lcm_lora(self, config: Config):
|
||||
def enable_disable_lcm_lora(self, config: InpaintRequest):
|
||||
if self.available_models[self.name].support_lcm_lora:
|
||||
if config.sd_lcm_lora:
|
||||
if not self.model.model.get_list_adapters():
|
||||
|
||||
Reference in New Issue
Block a user