add powerpaint v2

This commit is contained in:
Qing
2024-04-24 20:22:29 +08:00
parent ccea072dc5
commit 911f7224b6
14 changed files with 8082 additions and 2318 deletions

View File

@@ -8,6 +8,7 @@ from iopaint.download import scan_models
from iopaint.helper import switch_mps_device
from iopaint.model import models, ControlNet, SD, SDXL
from iopaint.model.brushnet.brushnet_wrapper import BrushNetWrapper
from iopaint.model.power_paint.power_paint_v2 import PowerPaintV2
from iopaint.model.utils import torch_gc, is_local_files_only
from iopaint.schema import InpaintRequest, ModelInfo, ModelType
@@ -23,9 +24,9 @@ class ModelManager:
self.enable_controlnet = kwargs.get("enable_controlnet", False)
controlnet_method = kwargs.get("controlnet_method", None)
if (
controlnet_method is None
and name in self.available_models
and self.available_models[name].support_controlnet
controlnet_method is None
and name in self.available_models
and self.available_models[name].support_controlnet
):
controlnet_method = self.available_models[name].controlnets[0]
self.controlnet_method = controlnet_method
@@ -33,6 +34,8 @@ class ModelManager:
self.enable_brushnet = kwargs.get("enable_brushnet", False)
self.brushnet_method = kwargs.get("brushnet_method", None)
self.enable_powerpaint_v2 = kwargs.get("enable_powerpaint_v2", False)
self.model = self.init_model(name, device, **kwargs)
@property
@@ -62,6 +65,9 @@ class ModelManager:
if model_info.support_brushnet and self.enable_brushnet:
return BrushNetWrapper(device, **kwargs)
if model_info.support_powerpaint_v2 and self.enable_powerpaint_v2:
return PowerPaintV2(device, **kwargs)
if model_info.name in models:
return models[name](device, **kwargs)
@@ -91,10 +97,12 @@ class ModelManager:
Returns:
BGR image
"""
if not config.enable_brushnet:
if config.enable_controlnet:
self.switch_controlnet_method(config)
if not config.enable_controlnet:
if config.enable_brushnet:
self.switch_brushnet_method(config)
self.enable_disable_powerpaint_v2(config)
self.enable_disable_freeu(config)
self.enable_disable_lcm_lora(config)
return self.model(image, mask, config).astype(np.uint8)
@@ -113,9 +121,9 @@ class ModelManager:
self.name = new_name
if (
self.available_models[new_name].support_controlnet
and self.controlnet_method
not in self.available_models[new_name].controlnets
self.available_models[new_name].support_controlnet
and self.controlnet_method
not in self.available_models[new_name].controlnets
):
self.controlnet_method = self.available_models[new_name].controlnets[0]
try:
@@ -140,9 +148,9 @@ class ModelManager:
return
if (
self.enable_brushnet
and config.brushnet_method
and self.brushnet_method != config.brushnet_method
self.enable_brushnet
and config.brushnet_method
and self.brushnet_method != config.brushnet_method
):
old_brushnet_method = self.brushnet_method
self.brushnet_method = config.brushnet_method
@@ -180,9 +188,9 @@ class ModelManager:
return
if (
self.enable_controlnet
and config.controlnet_method
and self.controlnet_method != config.controlnet_method
self.enable_controlnet
and config.controlnet_method
and self.controlnet_method != config.controlnet_method
):
old_controlnet_method = self.controlnet_method
self.controlnet_method = config.controlnet_method
@@ -213,6 +221,25 @@ class ModelManager:
else:
logger.info(f"Enable controlnet: {config.controlnet_method}")
def enable_disable_powerpaint_v2(self, config: InpaintRequest):
if not self.available_models[self.name].support_powerpaint_v2:
return
if self.enable_powerpaint_v2 != config.enable_powerpaint_v2:
self.enable_powerpaint_v2 = config.enable_powerpaint_v2
pipe_components = {"vae": self.model.model.vae}
self.model = self.init_model(
self.name,
switch_mps_device(self.name, self.device),
pipe_components=pipe_components,
**self.kwargs,
)
if config.enable_powerpaint_v2:
logger.info("Enable PowerPaintV2")
else:
logger.info("Disable PowerPaintV2")
def enable_disable_freeu(self, config: InpaintRequest):
if str(self.model.device) == "mps":
return