make brushnet work
This commit is contained in:
@@ -7,6 +7,7 @@ import numpy as np
|
||||
from iopaint.download import scan_models
|
||||
from iopaint.helper import switch_mps_device
|
||||
from iopaint.model import models, ControlNet, SD, SDXL
|
||||
from iopaint.model.brushnet.brushnet_wrapper import BrushNetWrapper
|
||||
from iopaint.model.utils import torch_gc, is_local_files_only
|
||||
from iopaint.schema import InpaintRequest, ModelInfo, ModelType
|
||||
|
||||
@@ -22,12 +23,16 @@ class ModelManager:
|
||||
self.enable_controlnet = kwargs.get("enable_controlnet", False)
|
||||
controlnet_method = kwargs.get("controlnet_method", None)
|
||||
if (
|
||||
controlnet_method is None
|
||||
and name in self.available_models
|
||||
and self.available_models[name].support_controlnet
|
||||
controlnet_method is None
|
||||
and name in self.available_models
|
||||
and self.available_models[name].support_controlnet
|
||||
):
|
||||
controlnet_method = self.available_models[name].controlnets[0]
|
||||
self.controlnet_method = controlnet_method
|
||||
|
||||
self.enable_brushnet = kwargs.get("enable_brushnet", False)
|
||||
self.brushnet_method = kwargs.get("brushnet_method", None)
|
||||
|
||||
self.model = self.init_model(name, device, **kwargs)
|
||||
|
||||
@property
|
||||
@@ -47,24 +52,30 @@ class ModelManager:
|
||||
"model_info": model_info,
|
||||
"enable_controlnet": self.enable_controlnet,
|
||||
"controlnet_method": self.controlnet_method,
|
||||
"enable_brushnet": self.enable_brushnet,
|
||||
"brushnet_method": self.brushnet_method,
|
||||
}
|
||||
|
||||
if model_info.support_controlnet and self.enable_controlnet:
|
||||
return ControlNet(device, **kwargs)
|
||||
elif model_info.name in models:
|
||||
return models[name](device, **kwargs)
|
||||
else:
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SD,
|
||||
]:
|
||||
return SD(device, **kwargs)
|
||||
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
]:
|
||||
return SDXL(device, **kwargs)
|
||||
if model_info.support_brushnet and self.enable_brushnet:
|
||||
return BrushNetWrapper(device, **kwargs)
|
||||
|
||||
if model_info.name in models:
|
||||
return models[name](device, **kwargs)
|
||||
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SD,
|
||||
]:
|
||||
return SD(device, **kwargs)
|
||||
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
]:
|
||||
return SDXL(device, **kwargs)
|
||||
|
||||
raise NotImplementedError(f"Unsupported model: {name}")
|
||||
|
||||
@@ -80,7 +91,10 @@ class ModelManager:
|
||||
Returns:
|
||||
BGR image
|
||||
"""
|
||||
self.switch_controlnet_method(config)
|
||||
if not config.enable_brushnet:
|
||||
self.switch_controlnet_method(config)
|
||||
if not config.enable_controlnet:
|
||||
self.switch_brushnet_method(config)
|
||||
self.enable_disable_freeu(config)
|
||||
self.enable_disable_lcm_lora(config)
|
||||
return self.model(image, mask, config).astype(np.uint8)
|
||||
@@ -99,9 +113,9 @@ class ModelManager:
|
||||
self.name = new_name
|
||||
|
||||
if (
|
||||
self.available_models[new_name].support_controlnet
|
||||
and self.controlnet_method
|
||||
not in self.available_models[new_name].controlnets
|
||||
self.available_models[new_name].support_controlnet
|
||||
and self.controlnet_method
|
||||
not in self.available_models[new_name].controlnets
|
||||
):
|
||||
self.controlnet_method = self.available_models[new_name].controlnets[0]
|
||||
try:
|
||||
@@ -121,14 +135,54 @@ class ModelManager:
|
||||
)
|
||||
raise e
|
||||
|
||||
def switch_brushnet_method(self, config):
|
||||
if not self.available_models[self.name].support_brushnet:
|
||||
return
|
||||
|
||||
if (
|
||||
self.enable_brushnet
|
||||
and config.brushnet_method
|
||||
and self.brushnet_method != config.brushnet_method
|
||||
):
|
||||
old_brushnet_method = self.brushnet_method
|
||||
self.brushnet_method = config.brushnet_method
|
||||
self.model.switch_brushnet_method(config.brushnet_method)
|
||||
logger.info(
|
||||
f"Switch Brushnet method from {old_brushnet_method} to {config.brushnet_method}"
|
||||
)
|
||||
|
||||
elif self.enable_brushnet != config.enable_brushnet:
|
||||
self.enable_brushnet = config.enable_brushnet
|
||||
self.brushnet_method = config.brushnet_method
|
||||
|
||||
pipe_components = {
|
||||
"vae": self.model.model.vae,
|
||||
"text_encoder": self.model.model.text_encoder,
|
||||
"unet": self.model.model.unet,
|
||||
}
|
||||
if hasattr(self.model.model, "text_encoder_2"):
|
||||
pipe_components["text_encoder_2"] = self.model.model.text_encoder_2
|
||||
|
||||
self.model = self.init_model(
|
||||
self.name,
|
||||
switch_mps_device(self.name, self.device),
|
||||
pipe_components=pipe_components,
|
||||
**self.kwargs,
|
||||
)
|
||||
|
||||
if not config.enable_brushnet:
|
||||
logger.info("BrushNet Disabled")
|
||||
else:
|
||||
logger.info("BrushNet Enabled")
|
||||
|
||||
def switch_controlnet_method(self, config):
|
||||
if not self.available_models[self.name].support_controlnet:
|
||||
return
|
||||
|
||||
if (
|
||||
self.enable_controlnet
|
||||
and config.controlnet_method
|
||||
and self.controlnet_method != config.controlnet_method
|
||||
self.enable_controlnet
|
||||
and config.controlnet_method
|
||||
and self.controlnet_method != config.controlnet_method
|
||||
):
|
||||
old_controlnet_method = self.controlnet_method
|
||||
self.controlnet_method = config.controlnet_method
|
||||
@@ -155,7 +209,7 @@ class ModelManager:
|
||||
**self.kwargs,
|
||||
)
|
||||
if not config.enable_controlnet:
|
||||
logger.info(f"Disable controlnet")
|
||||
logger.info("Disable controlnet")
|
||||
else:
|
||||
logger.info(f"Enable controlnet: {config.controlnet_method}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user