wip: controlnet

This commit is contained in:
Qing
2023-05-11 21:51:58 +08:00
parent e5ac6a105a
commit 87f54bb87e
10 changed files with 117 additions and 29 deletions

View File

@@ -12,6 +12,7 @@ from lama_cleaner.model.mat import MAT
from lama_cleaner.model.paint_by_example import PaintByExample
from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model.zits import ZITS
from lama_cleaner.model.opencv2 import OpenCV2
from lama_cleaner.schema import Config
@@ -59,7 +60,7 @@ class ModelManager:
def __call__(self, image, mask, config: Config):
return self.model(image, mask, config)
def switch(self, new_name: str):
def switch(self, new_name: str, **kwargs):
if new_name == self.name:
return
try:
@@ -75,3 +76,17 @@ class ModelManager:
self.name = new_name
except NotImplementedError as e:
raise e
def switch_controlnet_method(self, control_method: str):
if not self.kwargs.get("sd_controlnet"):
return
if self.kwargs["sd_controlnet_method"] == control_method:
return
del self.model
torch_gc()
self.kwargs["sd_controlnet_method"] = control_method
self.model = self.init_model(
self.name, switch_mps_device(self.name, self.device), **self.kwargs
)