lots update

This commit is contained in:
Qing
2023-12-27 22:00:07 +08:00
parent 0ba6c121e0
commit f0b852725f
33 changed files with 4085 additions and 1000 deletions

View File

@@ -7,7 +7,8 @@ from lama_cleaner.download import scan_models
from lama_cleaner.helper import switch_mps_device
from lama_cleaner.model import models, ControlNet, SD, SDXL
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.schema import Config, ModelInfo, ModelType
from lama_cleaner.model_info import ModelInfo, ModelType
from lama_cleaner.schema import Config
class ModelManager:
@@ -18,13 +19,20 @@ class ModelManager:
self.available_models: Dict[str, ModelInfo] = {}
self.scan_models()
self.sd_controlnet = False
self.sd_controlnet_method = ""
self.enable_controlnet = kwargs.get("enable_controlnet", False)
controlnet_method = kwargs.get("controlnet_method", None)
if (
controlnet_method is None
and name in self.available_models
and self.available_models[name].support_controlnet
):
controlnet_method = self.available_models[name].controlnets[0]
self.controlnet_method = controlnet_method
self.model = self.init_model(name, device, **kwargs)
@property
def current_model(self) -> Dict:
return self.available_models[name].model_dump()
return self.available_models[self.name].model_dump()
def init_model(self, name: str, device, **kwargs):
logger.info(f"Loading model: {name}")
@@ -35,15 +43,14 @@ class ModelManager:
kwargs = {
**kwargs,
"model_info": model_info,
"sd_controlnet": self.sd_controlnet,
"sd_controlnet_method": self.sd_controlnet_method,
"enable_controlnet": self.enable_controlnet,
"controlnet_method": self.controlnet_method,
}
if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]:
return models[name](device, **kwargs)
if self.sd_controlnet:
if model_info.support_controlnet and self.enable_controlnet:
return ControlNet(device, **kwargs)
elif model_info.name in models:
return models[name](device, **kwargs)
else:
if model_info.model_type in [
ModelType.DIFFUSERS_SD_INPAINT,
@@ -75,15 +82,15 @@ class ModelManager:
return
old_name = self.name
old_sd_controlnet_method = self.sd_controlnet_method
old_controlnet_method = self.controlnet_method
self.name = new_name
if (
self.available_models[new_name].support_controlnet
and self.sd_controlnet_method
and self.controlnet_method
not in self.available_models[new_name].controlnets
):
self.sd_controlnet_method = self.available_models[new_name].controlnets[0]
self.controlnet_method = self.available_models[new_name].controlnets[0]
try:
# TODO: enable/disable controlnet without reload model
del self.model
@@ -94,7 +101,7 @@ class ModelManager:
)
except Exception as e:
self.name = old_name
self.sd_controlnet_method = old_sd_controlnet_method
self.controlnet_method = old_controlnet_method
logger.info(f"Switch model from {old_name} to {new_name} failed, rollback")
self.model = self.init_model(
old_name, switch_mps_device(old_name, self.device), **self.kwargs
@@ -106,24 +113,24 @@ class ModelManager:
return
if (
self.sd_controlnet
self.enable_controlnet
and config.controlnet_method
and self.sd_controlnet_method != config.controlnet_method
and self.controlnet_method != config.controlnet_method
):
old_sd_controlnet_method = self.sd_controlnet_method
self.sd_controlnet_method = config.controlnet_method
old_controlnet_method = self.controlnet_method
self.controlnet_method = config.controlnet_method
self.model.switch_controlnet_method(config.controlnet_method)
logger.info(
f"Switch Controlnet method from {old_sd_controlnet_method} to {config.controlnet_method}"
f"Switch Controlnet method from {old_controlnet_method} to {config.controlnet_method}"
)
elif self.sd_controlnet != config.controlnet_enabled:
self.sd_controlnet = config.controlnet_enabled
self.sd_controlnet_method = config.controlnet_method
elif self.enable_controlnet != config.enable_controlnet:
self.enable_controlnet = config.enable_controlnet
self.controlnet_method = config.controlnet_method
self.model = self.init_model(
self.name, switch_mps_device(self.name, self.device), **self.kwargs
)
if not config.controlnet_enabled:
if not config.enable_controlnet:
logger.info(f"Disable controlnet")
else:
logger.info(f"Enable controlnet: {config.controlnet_method}")