@@ -30,6 +30,9 @@ DIFFUSION_MODELS = [
|
|||||||
"Sanster/anything-4.0-inpainting",
|
"Sanster/anything-4.0-inpainting",
|
||||||
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||||
"Fantasy-Studio/Paint-by-Example",
|
"Fantasy-Studio/Paint-by-Example",
|
||||||
|
"RunDiffusion/Juggernaut-XI-v11",
|
||||||
|
"SG161222/RealVisXL_V5.0",
|
||||||
|
"eienmojiki/Anything-XL",
|
||||||
POWERPAINT_NAME,
|
POWERPAINT_NAME,
|
||||||
ANYTEXT_NAME,
|
ANYTEXT_NAME,
|
||||||
]
|
]
|
||||||
@@ -83,6 +86,10 @@ SDXL_CONTROLNET_CHOICES = [
|
|||||||
"diffusers/controlnet-depth-sdxl-1.0-small",
|
"diffusers/controlnet-depth-sdxl-1.0-small",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
SDXL_BRUSHNET_CHOICES = [
|
||||||
|
"Regulus0725/random_mask_brushnet_ckpt_sdxl_regulus_v1"
|
||||||
|
]
|
||||||
|
|
||||||
LOCAL_FILES_ONLY_HELP = """
|
LOCAL_FILES_ONLY_HELP = """
|
||||||
When loading diffusion models, using local files only, not connect to HuggingFace server.
|
When loading diffusion models, using local files only, not connect to HuggingFace server.
|
||||||
"""
|
"""
|
||||||
|
|||||||
181
iopaint/model/brushnet/brushnet_xl_wrapper.py
Normal file
181
iopaint/model/brushnet/brushnet_xl_wrapper.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
import PIL.Image
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..base import DiffusionInpaintModel
|
||||||
|
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
|
from ..original_sd_configs import get_config_files
|
||||||
|
from ..utils import (
|
||||||
|
handle_from_pretrained_exceptions,
|
||||||
|
get_torch_dtype,
|
||||||
|
enable_low_mem,
|
||||||
|
is_local_files_only,
|
||||||
|
)
|
||||||
|
from .brushnet import BrushNetModel
|
||||||
|
from .brushnet_unet_forward import brushnet_unet_forward
|
||||||
|
from .unet_2d_blocks import (
|
||||||
|
CrossAttnDownBlock2D_forward,
|
||||||
|
DownBlock2D_forward,
|
||||||
|
CrossAttnUpBlock2D_forward,
|
||||||
|
UpBlock2D_forward,
|
||||||
|
)
|
||||||
|
from ...schema import InpaintRequest, ModelType
|
||||||
|
from ...const import SDXL_BRUSHNET_CHOICES
|
||||||
|
|
||||||
|
|
||||||
|
class BrushNetXLWrapper(DiffusionInpaintModel):
|
||||||
|
name = "RunDiffusion/Juggernaut-XI-v11"
|
||||||
|
pad_mod = 8
|
||||||
|
min_size = 1024
|
||||||
|
model_id_or_path = "RunDiffusion/Juggernaut-XI-v11"
|
||||||
|
support_brushnet = True
|
||||||
|
support_lcm_lora = False
|
||||||
|
|
||||||
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
|
from .pipeline_brushnet_sd_xl import StableDiffusionXLBrushNetPipeline
|
||||||
|
|
||||||
|
self.model_info = kwargs["model_info"]
|
||||||
|
self.brushnet_xl_method = SDXL_BRUSHNET_CHOICES[0]
|
||||||
|
# self.brushnet_xl_method = kwargs["brushnet_xl_method"]
|
||||||
|
|
||||||
|
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
|
||||||
|
model_kwargs = {
|
||||||
|
**kwargs.get("pipe_components", {}),
|
||||||
|
"local_files_only": is_local_files_only(**kwargs),
|
||||||
|
}
|
||||||
|
self.local_files_only = model_kwargs["local_files_only"]
|
||||||
|
|
||||||
|
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
|
||||||
|
"cpu_offload", False
|
||||||
|
)
|
||||||
|
if disable_nsfw_checker:
|
||||||
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
|
model_kwargs.update(
|
||||||
|
dict(
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Loading BrushNet model from {self.brushnet_xl_method}")
|
||||||
|
brushnet = BrushNetModel.from_pretrained(
|
||||||
|
self.brushnet_xl_method, torch_dtype=torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.model_info.is_single_file_diffusers:
|
||||||
|
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||||
|
model_kwargs["num_in_channels"] = 4
|
||||||
|
else:
|
||||||
|
model_kwargs["num_in_channels"] = 9
|
||||||
|
|
||||||
|
self.model = StableDiffusionXLBrushNetPipeline.from_single_file(
|
||||||
|
self.model_id_or_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
load_safety_checker=not disable_nsfw_checker,
|
||||||
|
original_config_file=get_config_files()["v1"],
|
||||||
|
brushnet=brushnet,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model = handle_from_pretrained_exceptions(
|
||||||
|
StableDiffusionXLBrushNetPipeline.from_pretrained,
|
||||||
|
pretrained_model_name_or_path=self.model_id_or_path,
|
||||||
|
variant="fp16",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
brushnet=brushnet,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
enable_low_mem(self.model, kwargs.get("low_mem", False))
|
||||||
|
|
||||||
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
|
logger.info("Enable sequential cpu offload")
|
||||||
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
|
else:
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
if kwargs["sd_cpu_textencoder"]:
|
||||||
|
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||||
|
self.model.text_encoder = CPUTextEncoderWrapper(
|
||||||
|
self.model.text_encoder, torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
|
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
|
||||||
|
self.model.unet.forward = brushnet_unet_forward.__get__(
|
||||||
|
self.model.unet, self.model.unet.__class__
|
||||||
|
)
|
||||||
|
|
||||||
|
for down_block in self.model.brushnet.down_blocks:
|
||||||
|
down_block.forward = DownBlock2D_forward.__get__(
|
||||||
|
down_block, down_block.__class__
|
||||||
|
)
|
||||||
|
for up_block in self.model.brushnet.up_blocks:
|
||||||
|
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||||
|
|
||||||
|
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
|
||||||
|
for down_block in self.model.unet.down_blocks:
|
||||||
|
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
|
||||||
|
down_block.forward = CrossAttnDownBlock2D_forward.__get__(
|
||||||
|
down_block, down_block.__class__
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
down_block.forward = DownBlock2D_forward.__get__(
|
||||||
|
down_block, down_block.__class__
|
||||||
|
)
|
||||||
|
|
||||||
|
for up_block in self.model.unet.up_blocks:
|
||||||
|
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
||||||
|
up_block.forward = CrossAttnUpBlock2D_forward.__get__(
|
||||||
|
up_block, up_block.__class__
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
up_block.forward = UpBlock2D_forward.__get__(
|
||||||
|
up_block, up_block.__class__
|
||||||
|
)
|
||||||
|
|
||||||
|
def switch_brushnet_method(self, new_method: str):
|
||||||
|
self.brushnet_method = new_method
|
||||||
|
brushnet_xl = BrushNetModel.from_pretrained(
|
||||||
|
new_method,
|
||||||
|
local_files_only=self.local_files_only,
|
||||||
|
torch_dtype=self.torch_dtype,
|
||||||
|
).to(self.model.device)
|
||||||
|
self.model.brushnet = brushnet_xl
|
||||||
|
|
||||||
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
|
"""Input image and output image have same size
|
||||||
|
image: [H, W, C] RGB
|
||||||
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
|
return: BGR IMAGE
|
||||||
|
"""
|
||||||
|
self.set_scheduler(config)
|
||||||
|
|
||||||
|
img_h, img_w = image.shape[:2]
|
||||||
|
normalized_mask = mask[:, :].astype("float32") / 255.0
|
||||||
|
image = image * (1 - normalized_mask)
|
||||||
|
image = image.astype(np.uint8)
|
||||||
|
output = self.model(
|
||||||
|
image=PIL.Image.fromarray(image),
|
||||||
|
prompt=config.prompt,
|
||||||
|
negative_prompt=config.negative_prompt,
|
||||||
|
mask=PIL.Image.fromarray(mask[:, :, -1], mode="L").convert("RGB"),
|
||||||
|
num_inference_steps=config.sd_steps,
|
||||||
|
# strength=config.sd_strength,
|
||||||
|
guidance_scale=config.sd_guidance_scale,
|
||||||
|
output_type="np",
|
||||||
|
callback_on_step_end=self.callback,
|
||||||
|
height=img_h,
|
||||||
|
width=img_w,
|
||||||
|
generator=torch.manual_seed(config.sd_seed),
|
||||||
|
brushnet_conditioning_scale=config.brushnet_conditioning_scale,
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
output = (output * 255).round().astype("uint8")
|
||||||
|
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||||
|
return output
|
||||||
1535
iopaint/model/brushnet/pipeline_brushnet_sd_xl.py
Normal file
1535
iopaint/model/brushnet/pipeline_brushnet_sd_xl.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -29,7 +29,7 @@ class SD(DiffusionInpaintModel):
|
|||||||
**kwargs.get("pipe_components", {}),
|
**kwargs.get("pipe_components", {}),
|
||||||
"local_files_only": is_local_files_only(**kwargs),
|
"local_files_only": is_local_files_only(**kwargs),
|
||||||
}
|
}
|
||||||
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
|
disable_nsfw_checker = kwargs.get("disable_nsfw", False) or kwargs.get(
|
||||||
"cpu_offload", False
|
"cpu_offload", False
|
||||||
)
|
)
|
||||||
if disable_nsfw_checker:
|
if disable_nsfw_checker:
|
||||||
@@ -71,7 +71,7 @@ class SD(DiffusionInpaintModel):
|
|||||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
else:
|
else:
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
if kwargs["sd_cpu_textencoder"]:
|
if kwargs.get("sd_cpu_textencoder", False):
|
||||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||||
self.model.text_encoder = CPUTextEncoderWrapper(
|
self.model.text_encoder = CPUTextEncoderWrapper(
|
||||||
self.model.text_encoder, torch_dtype
|
self.model.text_encoder, torch_dtype
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from iopaint.download import scan_models
|
|||||||
from iopaint.helper import switch_mps_device
|
from iopaint.helper import switch_mps_device
|
||||||
from iopaint.model import models, ControlNet, SD, SDXL
|
from iopaint.model import models, ControlNet, SD, SDXL
|
||||||
from iopaint.model.brushnet.brushnet_wrapper import BrushNetWrapper
|
from iopaint.model.brushnet.brushnet_wrapper import BrushNetWrapper
|
||||||
|
from iopaint.model.brushnet.brushnet_xl_wrapper import BrushNetXLWrapper
|
||||||
from iopaint.model.power_paint.power_paint_v2 import PowerPaintV2
|
from iopaint.model.power_paint.power_paint_v2 import PowerPaintV2
|
||||||
from iopaint.model.utils import torch_gc, is_local_files_only
|
from iopaint.model.utils import torch_gc, is_local_files_only
|
||||||
from iopaint.schema import InpaintRequest, ModelInfo, ModelType
|
from iopaint.schema import InpaintRequest, ModelInfo, ModelType
|
||||||
@@ -63,7 +64,10 @@ class ModelManager:
|
|||||||
return ControlNet(device, **kwargs)
|
return ControlNet(device, **kwargs)
|
||||||
|
|
||||||
if model_info.support_brushnet and self.enable_brushnet:
|
if model_info.support_brushnet and self.enable_brushnet:
|
||||||
|
if model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||||
return BrushNetWrapper(device, **kwargs)
|
return BrushNetWrapper(device, **kwargs)
|
||||||
|
elif model_info.model_type == ModelType.DIFFUSERS_SDXL:
|
||||||
|
return BrushNetXLWrapper(device, **kwargs)
|
||||||
|
|
||||||
if model_info.support_powerpaint_v2 and self.enable_powerpaint_v2:
|
if model_info.support_powerpaint_v2 and self.enable_powerpaint_v2:
|
||||||
return PowerPaintV2(device, **kwargs)
|
return PowerPaintV2(device, **kwargs)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from iopaint.const import (
|
|||||||
SD2_CONTROLNET_CHOICES,
|
SD2_CONTROLNET_CHOICES,
|
||||||
SD_CONTROLNET_CHOICES,
|
SD_CONTROLNET_CHOICES,
|
||||||
SD_BRUSHNET_CHOICES,
|
SD_BRUSHNET_CHOICES,
|
||||||
|
SDXL_BRUSHNET_CHOICES
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, Field, computed_field, model_validator
|
from pydantic import BaseModel, Field, computed_field, model_validator
|
||||||
|
|
||||||
@@ -70,6 +71,8 @@ class ModelInfo(BaseModel):
|
|||||||
def brushnets(self) -> List[str]:
|
def brushnets(self) -> List[str]:
|
||||||
if self.model_type in [ModelType.DIFFUSERS_SD]:
|
if self.model_type in [ModelType.DIFFUSERS_SD]:
|
||||||
return SD_BRUSHNET_CHOICES
|
return SD_BRUSHNET_CHOICES
|
||||||
|
if self.model_type in [ModelType.DIFFUSERS_SDXL]:
|
||||||
|
return SDXL_BRUSHNET_CHOICES
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@@ -117,6 +120,7 @@ class ModelInfo(BaseModel):
|
|||||||
def support_brushnet(self) -> bool:
|
def support_brushnet(self) -> bool:
|
||||||
return self.model_type in [
|
return self.model_type in [
|
||||||
ModelType.DIFFUSERS_SD,
|
ModelType.DIFFUSERS_SD,
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
]
|
]
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
|
|||||||
@@ -142,14 +142,14 @@ const DiffusionOptions = () => {
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</RowContainer>
|
</RowContainer>
|
||||||
{/* <RowContainer>
|
<RowContainer>
|
||||||
<Slider
|
<Slider
|
||||||
defaultValue={[100]}
|
defaultValue={[100]}
|
||||||
className="w-[180px]"
|
className="w-[180px]"
|
||||||
min={1}
|
min={1}
|
||||||
max={100}
|
max={100}
|
||||||
step={1}
|
step={1}
|
||||||
disabled={!settings.enableBrushNet || disable}
|
disabled={!settings.enableBrushNet}
|
||||||
value={[Math.floor(settings.brushnetConditioningScale * 100)]}
|
value={[Math.floor(settings.brushnetConditioningScale * 100)]}
|
||||||
onValueChange={(vals) =>
|
onValueChange={(vals) =>
|
||||||
updateSettings({ brushnetConditioningScale: vals[0] / 100 })
|
updateSettings({ brushnetConditioningScale: vals[0] / 100 })
|
||||||
@@ -159,12 +159,12 @@ const DiffusionOptions = () => {
|
|||||||
id="brushnet-weight"
|
id="brushnet-weight"
|
||||||
className="w-[50px] rounded-full"
|
className="w-[50px] rounded-full"
|
||||||
numberValue={settings.brushnetConditioningScale}
|
numberValue={settings.brushnetConditioningScale}
|
||||||
allowFloat={false}
|
allowFloat
|
||||||
onNumberValueChange={(val) => {
|
onNumberValueChange={(val) => {
|
||||||
updateSettings({ brushnetConditioningScale: val })
|
updateSettings({ brushnetConditioningScale: val })
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</RowContainer> */}
|
</RowContainer>
|
||||||
|
|
||||||
<RowContainer>
|
<RowContainer>
|
||||||
<Select
|
<Select
|
||||||
@@ -240,7 +240,7 @@ const DiffusionOptions = () => {
|
|||||||
className="w-[50px] rounded-full"
|
className="w-[50px] rounded-full"
|
||||||
disabled={!settings.enableControlnet}
|
disabled={!settings.enableControlnet}
|
||||||
numberValue={settings.controlnetConditioningScale}
|
numberValue={settings.controlnetConditioningScale}
|
||||||
allowFloat={false}
|
allowFloat
|
||||||
onNumberValueChange={(val) => {
|
onNumberValueChange={(val) => {
|
||||||
updateSettings({ controlnetConditioningScale: val })
|
updateSettings({ controlnetConditioningScale: val })
|
||||||
}}
|
}}
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ const SidePanel = () => {
|
|||||||
</SheetTrigger>
|
</SheetTrigger>
|
||||||
<SheetContent
|
<SheetContent
|
||||||
side="right"
|
side="right"
|
||||||
className="w-[286px] mt-[60px] outline-none pl-3 pr-1"
|
className="min-w-[286px] max-w-full mt-[60px] outline-none pl-3 pr-1"
|
||||||
onOpenAutoFocus={(event) => event.preventDefault()}
|
onOpenAutoFocus={(event) => event.preventDefault()}
|
||||||
onPointerDownOutside={(event) => event.preventDefault()}
|
onPointerDownOutside={(event) => event.preventDefault()}
|
||||||
>
|
>
|
||||||
|
|||||||
Reference in New Issue
Block a user