wip: controlnet
This commit is contained in:
@@ -12,6 +12,7 @@ from lama_cleaner.model.mat import MAT
|
||||
from lama_cleaner.model.paint_by_example import PaintByExample
|
||||
from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
|
||||
from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.model.zits import ZITS
|
||||
from lama_cleaner.model.opencv2 import OpenCV2
|
||||
from lama_cleaner.schema import Config
|
||||
@@ -59,7 +60,7 @@ class ModelManager:
|
||||
def __call__(self, image, mask, config: Config):
|
||||
return self.model(image, mask, config)
|
||||
|
||||
def switch(self, new_name: str):
|
||||
def switch(self, new_name: str, **kwargs):
|
||||
if new_name == self.name:
|
||||
return
|
||||
try:
|
||||
@@ -75,3 +76,17 @@ class ModelManager:
|
||||
self.name = new_name
|
||||
except NotImplementedError as e:
|
||||
raise e
|
||||
|
||||
def switch_controlnet_method(self, control_method: str):
|
||||
if not self.kwargs.get("sd_controlnet"):
|
||||
return
|
||||
if self.kwargs["sd_controlnet_method"] == control_method:
|
||||
return
|
||||
|
||||
del self.model
|
||||
torch_gc()
|
||||
|
||||
self.kwargs["sd_controlnet_method"] = control_method
|
||||
self.model = self.init_model(
|
||||
self.name, switch_mps_device(self.name, self.device), **self.kwargs
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user