This commit is contained in:
Qing
2023-12-24 15:32:27 +08:00
parent 0e5e16ba20
commit 371db2d771
31 changed files with 441 additions and 439 deletions

View File

@@ -3,7 +3,6 @@ from typing import List, Dict
import torch
from loguru import logger
from lama_cleaner.const import DEFAULT_SD_CONTROLNET_METHOD
from lama_cleaner.download import scan_models
from lama_cleaner.helper import switch_mps_device
from lama_cleaner.model import models, ControlNet, SD, SDXL
@@ -19,16 +18,25 @@ class ModelManager:
self.available_models: Dict[str, ModelInfo] = {}
self.scan_models()
self.sd_controlnet = kwargs.get("sd_controlnet", False)
self.sd_controlnet_method = kwargs.get(
"sd_controlnet_method", DEFAULT_SD_CONTROLNET_METHOD
)
self.sd_controlnet = False
self.sd_controlnet_method = ""
self.model = self.init_model(name, device, **kwargs)
def init_model(self, name: str, device, **kwargs):
def _map_old_name(self, name: str) -> str:
for old_name, model_cls in models.items():
if name == old_name and hasattr(model_cls, "model_id_or_path"):
name = model_cls.model_id_or_path
break
return name
@property
def current_model(self) -> Dict:
name = self._map_old_name(self.name)
return self.available_models[name].model_dump()
def init_model(self, name: str, device, **kwargs):
name = self._map_old_name(name)
logger.info(f"Loading model: {name}")
if name not in self.available_models:
raise NotImplementedError(f"Unsupported model: {name}")
@@ -86,6 +94,7 @@ class ModelManager:
):
self.sd_controlnet_method = self.available_models[new_name].controlnets[0]
try:
# TODO: enable/disable controlnet without reload model
del self.model
torch_gc()