update
This commit is contained in:
@@ -3,7 +3,6 @@ from typing import List, Dict
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.const import DEFAULT_SD_CONTROLNET_METHOD
|
||||
from lama_cleaner.download import scan_models
|
||||
from lama_cleaner.helper import switch_mps_device
|
||||
from lama_cleaner.model import models, ControlNet, SD, SDXL
|
||||
@@ -19,16 +18,25 @@ class ModelManager:
|
||||
self.available_models: Dict[str, ModelInfo] = {}
|
||||
self.scan_models()
|
||||
|
||||
self.sd_controlnet = kwargs.get("sd_controlnet", False)
|
||||
self.sd_controlnet_method = kwargs.get(
|
||||
"sd_controlnet_method", DEFAULT_SD_CONTROLNET_METHOD
|
||||
)
|
||||
self.sd_controlnet = False
|
||||
self.sd_controlnet_method = ""
|
||||
self.model = self.init_model(name, device, **kwargs)
|
||||
|
||||
def init_model(self, name: str, device, **kwargs):
|
||||
def _map_old_name(self, name: str) -> str:
|
||||
for old_name, model_cls in models.items():
|
||||
if name == old_name and hasattr(model_cls, "model_id_or_path"):
|
||||
name = model_cls.model_id_or_path
|
||||
break
|
||||
return name
|
||||
|
||||
@property
|
||||
def current_model(self) -> Dict:
|
||||
name = self._map_old_name(self.name)
|
||||
return self.available_models[name].model_dump()
|
||||
|
||||
def init_model(self, name: str, device, **kwargs):
|
||||
name = self._map_old_name(name)
|
||||
logger.info(f"Loading model: {name}")
|
||||
if name not in self.available_models:
|
||||
raise NotImplementedError(f"Unsupported model: {name}")
|
||||
|
||||
@@ -86,6 +94,7 @@ class ModelManager:
|
||||
):
|
||||
self.sd_controlnet_method = self.available_models[new_name].controlnets[0]
|
||||
try:
|
||||
# TODO: enable/disable controlnet without reload model
|
||||
del self.model
|
||||
torch_gc()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user