This commit is contained in:
Qing
2022-09-15 22:21:27 +08:00
parent 3ac6ee7f44
commit 32854d40da
52 changed files with 2258 additions and 205 deletions

View File

@@ -2,27 +2,23 @@ from lama_cleaner.model.fcf import FcF
from lama_cleaner.model.lama import LaMa
from lama_cleaner.model.ldm import LDM
from lama_cleaner.model.mat import MAT
from lama_cleaner.model.sd import SD14
from lama_cleaner.model.zits import ZITS
from lama_cleaner.schema import Config
models = {
'lama': LaMa,
'ldm': LDM,
'zits': ZITS,
'mat': MAT,
'fcf': FcF
}
models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.4": SD14}
class ModelManager:
def __init__(self, name: str, device):
def __init__(self, name: str, device, **kwargs):
self.name = name
self.device = device
self.model = self.init_model(name, device)
self.kwargs = kwargs
self.model = self.init_model(name, device, **kwargs)
def init_model(self, name: str, device):
def init_model(self, name: str, device, **kwargs):
if name in models:
model = models[name](device)
model = models[name](device, **kwargs)
else:
raise NotImplementedError(f"Not supported model: {name}")
return model
@@ -40,7 +36,7 @@ class ModelManager:
if new_name == self.name:
return
try:
self.model = self.init_model(new_name, self.device)
self.model = self.init_model(new_name, self.device, **self.kwargs)
self.name = new_name
except NotImplementedError as e:
raise e