backend add freeu
This commit is contained in:
@@ -3,7 +3,7 @@ import gc
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.const import SD15_MODELS
|
||||
from lama_cleaner.const import SD15_MODELS, MODELS_SUPPORT_FREEU
|
||||
from lama_cleaner.helper import switch_mps_device
|
||||
from lama_cleaner.model.controlnet import ControlNet
|
||||
from lama_cleaner.model.fcf import FcF
|
||||
@@ -65,6 +65,7 @@ class ModelManager:
|
||||
|
||||
def __call__(self, image, mask, config: Config):
|
||||
self.switch_controlnet_method(control_method=config.controlnet_method)
|
||||
self.enable_disable_freeu(config)
|
||||
return self.model(image, mask, config)
|
||||
|
||||
def switch(self, new_name: str, **kwargs):
|
||||
@@ -120,3 +121,19 @@ class ModelManager:
|
||||
self.name, switch_mps_device(self.name, self.device), **self.kwargs
|
||||
)
|
||||
logger.info(f"Switch ControlNet method from {old_method} to {control_method}")
|
||||
|
||||
def enable_disable_freeu(self, config: Config):
|
||||
if str(self.model.device) == "mps":
|
||||
return
|
||||
|
||||
if self.name in MODELS_SUPPORT_FREEU:
|
||||
if config.sd_freeu:
|
||||
freeu_config = config.sd_freeu_config
|
||||
self.model.model.enable_freeu(
|
||||
s1=freeu_config.s1,
|
||||
s2=freeu_config.s2,
|
||||
b1=freeu_config.b1,
|
||||
b2=freeu_config.b2,
|
||||
)
|
||||
else:
|
||||
self.model.model.disable_freeu()
|
||||
|
||||
Reference in New Issue
Block a user