add ZITS
This commit is contained in:
@@ -1,31 +1,31 @@
|
||||
from lama_cleaner.model.lama import LaMa
|
||||
from lama_cleaner.model.ldm import LDM
|
||||
from lama_cleaner.model.zits import ZITS
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
models = {
|
||||
'lama': LaMa,
|
||||
'ldm': LDM,
|
||||
'zits': ZITS
|
||||
}
|
||||
|
||||
|
||||
class ModelManager:
|
||||
LAMA = 'lama'
|
||||
LDM = 'ldm'
|
||||
|
||||
def __init__(self, name: str, device):
|
||||
self.name = name
|
||||
self.device = device
|
||||
self.model = self.init_model(name, device)
|
||||
|
||||
def init_model(self, name: str, device):
|
||||
if name == self.LAMA:
|
||||
model = LaMa(device)
|
||||
elif name == self.LDM:
|
||||
model = LDM(device)
|
||||
if name in models:
|
||||
model = models[name](device)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported model: {name}")
|
||||
return model
|
||||
|
||||
def is_downloaded(self, name: str) -> bool:
|
||||
if name == self.LAMA:
|
||||
return LaMa.is_downloaded()
|
||||
elif name == self.LDM:
|
||||
return LDM.is_downloaded()
|
||||
if name in models:
|
||||
return models[name].is_downloaded()
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported model: {name}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user