update
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
from itertools import chain
|
||||
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import torch
|
||||
from iopaint.model.original_sd_configs import get_config_files
|
||||
from loguru import logger
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
import numpy as np
|
||||
@@ -14,9 +17,15 @@ from ..utils import (
|
||||
handle_from_pretrained_exceptions,
|
||||
)
|
||||
from .powerpaint_tokenizer import task_to_prompt
|
||||
from iopaint.schema import InpaintRequest
|
||||
from iopaint.schema import InpaintRequest, ModelType
|
||||
from .v2.BrushNet_CA import BrushNetModel
|
||||
from .v2.unet_2d_condition import UNet2DConditionModel
|
||||
from .v2.unet_2d_condition import UNet2DConditionModel_forward
|
||||
from .v2.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D_forward,
|
||||
DownBlock2D_forward,
|
||||
CrossAttnUpBlock2D_forward,
|
||||
UpBlock2D_forward,
|
||||
)
|
||||
|
||||
|
||||
class PowerPaintV2(DiffusionInpaintModel):
|
||||
@@ -50,14 +59,7 @@ class PowerPaintV2(DiffusionInpaintModel):
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=model_kwargs["local_files_only"],
|
||||
)
|
||||
unet = handle_from_pretrained_exceptions(
|
||||
UNet2DConditionModel.from_pretrained,
|
||||
pretrained_model_name_or_path=self.model_id_or_path,
|
||||
subfolder="unet",
|
||||
variant="fp16",
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=model_kwargs["local_files_only"],
|
||||
)
|
||||
|
||||
brushnet = BrushNetModel.from_pretrained(
|
||||
self.hf_model_id,
|
||||
subfolder="PowerPaint_Brushnet",
|
||||
@@ -65,16 +67,32 @@ class PowerPaintV2(DiffusionInpaintModel):
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=model_kwargs["local_files_only"],
|
||||
)
|
||||
pipe = handle_from_pretrained_exceptions(
|
||||
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
|
||||
pretrained_model_name_or_path=self.model_id_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
unet=unet,
|
||||
brushnet=brushnet,
|
||||
text_encoder_brushnet=text_encoder_brushnet,
|
||||
variant="fp16",
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if self.model_info.is_single_file_diffusers:
|
||||
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||
model_kwargs["num_in_channels"] = 4
|
||||
else:
|
||||
model_kwargs["num_in_channels"] = 9
|
||||
|
||||
pipe = StableDiffusionPowerPaintBrushNetPipeline.from_single_file(
|
||||
self.model_id_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
load_safety_checker=False,
|
||||
original_config_file=get_config_files()["v1"],
|
||||
brushnet=brushnet,
|
||||
text_encoder_brushnet=text_encoder_brushnet,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
pipe = handle_from_pretrained_exceptions(
|
||||
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
|
||||
pretrained_model_name_or_path=self.model_id_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
brushnet=brushnet,
|
||||
text_encoder_brushnet=text_encoder_brushnet,
|
||||
variant="fp16",
|
||||
**model_kwargs,
|
||||
)
|
||||
pipe.tokenizer = PowerPaintTokenizer(
|
||||
CLIPTokenizer.from_pretrained(self.hf_model_id, subfolder="tokenizer")
|
||||
)
|
||||
@@ -95,6 +113,34 @@ class PowerPaintV2(DiffusionInpaintModel):
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
|
||||
self.model.unet.forward = UNet2DConditionModel_forward.__get__(
|
||||
self.model.unet, self.model.unet.__class__
|
||||
)
|
||||
|
||||
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
|
||||
for down_block in chain(
|
||||
self.model.unet.down_blocks, self.model.brushnet.down_blocks
|
||||
):
|
||||
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
|
||||
down_block.forward = CrossAttnDownBlock2D_forward.__get__(
|
||||
down_block, down_block.__class__
|
||||
)
|
||||
else:
|
||||
down_block.forward = DownBlock2D_forward.__get__(
|
||||
down_block, down_block.__class__
|
||||
)
|
||||
|
||||
for up_block in chain(self.model.unet.up_blocks, self.model.brushnet.up_blocks):
|
||||
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
||||
up_block.forward = CrossAttnUpBlock2D_forward.__get__(
|
||||
up_block, up_block.__class__
|
||||
)
|
||||
else:
|
||||
up_block.forward = UpBlock2D_forward.__get__(
|
||||
up_block, up_block.__class__
|
||||
)
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
@@ -129,11 +175,10 @@ class PowerPaintV2(DiffusionInpaintModel):
|
||||
brushnet_conditioning_scale=1.0,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
callback=self.callback,
|
||||
callback_on_step_end=self.callback,
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
callback_steps=1,
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
|
||||
@@ -2,6 +2,14 @@ from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
get_down_block,
|
||||
get_mid_block,
|
||||
get_up_block,
|
||||
CrossAttnDownBlock2D,
|
||||
DownBlock2D,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
@@ -13,18 +21,14 @@ from diffusers.models.attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, \
|
||||
TimestepEmbedding, Timesteps
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
DownBlock2D,
|
||||
get_down_block,
|
||||
get_mid_block,
|
||||
get_up_block
|
||||
from diffusers.models.embeddings import (
|
||||
TextImageProjection,
|
||||
TextImageTimeEmbedding,
|
||||
TextTimeEmbedding,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -132,47 +136,55 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 5,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
brushnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 5,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
brushnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
|
||||
16,
|
||||
32,
|
||||
96,
|
||||
256,
|
||||
),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -195,25 +207,33 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||
if not isinstance(only_cross_attention, bool) and len(
|
||||
only_cross_attention
|
||||
) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
|
||||
down_block_types
|
||||
):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(
|
||||
down_block_types
|
||||
)
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in_condition = nn.Conv2d(
|
||||
in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel,
|
||||
padding=conv_in_padding
|
||||
in_channels + conditioning_channels,
|
||||
block_out_channels[0],
|
||||
kernel_size=conv_in_kernel,
|
||||
padding=conv_in_padding,
|
||||
)
|
||||
|
||||
# time
|
||||
@@ -229,7 +249,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
||||
encoder_hid_dim_type = "text_proj"
|
||||
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
||||
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
||||
logger.info(
|
||||
"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
|
||||
)
|
||||
|
||||
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
@@ -274,7 +296,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
self.class_embedding = TimestepEmbedding(
|
||||
projection_class_embeddings_input_dim, time_embed_dim
|
||||
)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
@@ -285,21 +309,31 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
text_time_embedding_from_dim = cross_attention_dim
|
||||
|
||||
self.add_embedding = TextTimeEmbedding(
|
||||
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||
text_time_embedding_from_dim,
|
||||
time_embed_dim,
|
||||
num_heads=addition_embed_type_num_heads,
|
||||
)
|
||||
elif addition_embed_type == "text_image":
|
||||
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
||||
self.add_embedding = TextImageTimeEmbedding(
|
||||
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
||||
text_embed_dim=cross_attention_dim,
|
||||
image_embed_dim=cross_attention_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
)
|
||||
elif addition_embed_type == "text_time":
|
||||
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
||||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
self.add_time_proj = Timesteps(
|
||||
addition_time_embed_dim, flip_sin_to_cos, freq_shift
|
||||
)
|
||||
self.add_embedding = TimestepEmbedding(
|
||||
projection_class_embeddings_input_dim, time_embed_dim
|
||||
)
|
||||
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
||||
raise ValueError(
|
||||
f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.brushnet_down_blocks = nn.ModuleList([])
|
||||
@@ -338,7 +372,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
attention_head_dim=attention_head_dim[i]
|
||||
if attention_head_dim[i] is not None
|
||||
else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
@@ -348,12 +384,16 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
for _ in range(layers_per_block):
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1
|
||||
)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_down_blocks.append(brushnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1
|
||||
)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_down_blocks.append(brushnet_block)
|
||||
|
||||
@@ -386,7 +426,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
|
||||
reversed_transformer_layers_per_block = list(
|
||||
reversed(transformer_layers_per_block)
|
||||
)
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
@@ -399,7 +441,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
input_channel = reversed_block_out_channels[
|
||||
min(i + 1, len(block_out_channels) - 1)
|
||||
]
|
||||
|
||||
# add upsample block for all BUT final layer
|
||||
if not is_final_block:
|
||||
@@ -427,29 +471,40 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
attention_head_dim=attention_head_dim[i]
|
||||
if attention_head_dim[i] is not None
|
||||
else output_channel,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
for _ in range(layers_per_block + 1):
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1
|
||||
)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_up_blocks.append(brushnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1
|
||||
)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_up_blocks.append(brushnet_block)
|
||||
|
||||
@classmethod
|
||||
def from_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
brushnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
conditioning_channels: int = 5,
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
brushnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
|
||||
16,
|
||||
32,
|
||||
96,
|
||||
256,
|
||||
),
|
||||
load_weights_from_unet: bool = True,
|
||||
conditioning_channels: int = 5,
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
|
||||
@@ -460,13 +515,27 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
where applicable.
|
||||
"""
|
||||
transformer_layers_per_block = (
|
||||
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
||||
unet.config.transformer_layers_per_block
|
||||
if "transformer_layers_per_block" in unet.config
|
||||
else 1
|
||||
)
|
||||
encoder_hid_dim = (
|
||||
unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||
)
|
||||
encoder_hid_dim_type = (
|
||||
unet.config.encoder_hid_dim_type
|
||||
if "encoder_hid_dim_type" in unet.config
|
||||
else None
|
||||
)
|
||||
addition_embed_type = (
|
||||
unet.config.addition_embed_type
|
||||
if "addition_embed_type" in unet.config
|
||||
else None
|
||||
)
|
||||
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
||||
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
||||
addition_time_embed_dim = (
|
||||
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
||||
unet.config.addition_time_embed_dim
|
||||
if "addition_time_embed_dim" in unet.config
|
||||
else None
|
||||
)
|
||||
|
||||
brushnet = cls(
|
||||
@@ -475,14 +544,21 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||
freq_shift=unet.config.freq_shift,
|
||||
# down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
|
||||
down_block_types=["CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D", ],
|
||||
down_block_types=[
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
],
|
||||
# mid_block_type='MidBlock2D',
|
||||
mid_block_type="UNetMidBlock2DCrossAttn",
|
||||
# up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
|
||||
up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
|
||||
up_block_types=[
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
],
|
||||
only_cross_attention=unet.config.only_cross_attention,
|
||||
block_out_channels=unet.config.block_out_channels,
|
||||
layers_per_block=unet.config.layers_per_block,
|
||||
@@ -510,21 +586,33 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if load_weights_from_unet:
|
||||
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
|
||||
conv_in_condition_weight = torch.zeros_like(
|
||||
brushnet.conv_in_condition.weight
|
||||
)
|
||||
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
|
||||
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
|
||||
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
|
||||
brushnet.conv_in_condition.weight = torch.nn.Parameter(
|
||||
conv_in_condition_weight
|
||||
)
|
||||
brushnet.conv_in_condition.bias = unet.conv_in.bias
|
||||
|
||||
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||
|
||||
if brushnet.class_embedding:
|
||||
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
||||
brushnet.class_embedding.load_state_dict(
|
||||
unet.class_embedding.state_dict()
|
||||
)
|
||||
|
||||
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
||||
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
||||
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
|
||||
brushnet.down_blocks.load_state_dict(
|
||||
unet.down_blocks.state_dict(), strict=False
|
||||
)
|
||||
brushnet.mid_block.load_state_dict(
|
||||
unet.mid_block.state_dict(), strict=False
|
||||
)
|
||||
brushnet.up_blocks.load_state_dict(
|
||||
unet.up_blocks.state_dict(), strict=False
|
||||
)
|
||||
|
||||
return brushnet.to(unet.dtype)
|
||||
|
||||
@@ -539,9 +627,15 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(
|
||||
name: str,
|
||||
module: torch.nn.Module,
|
||||
processors: Dict[str, AttentionProcessor],
|
||||
):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
processors[f"{name}.processor"] = module.get_processor(
|
||||
return_deprecated_lora=True
|
||||
)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
@@ -554,7 +648,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -593,9 +689,15 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
if all(
|
||||
proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
|
||||
for proc in self.attn_processors.values()
|
||||
):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
elif all(
|
||||
proc.__class__ in CROSS_ATTENTION_PROCESSORS
|
||||
for proc in self.attn_processors.values()
|
||||
):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -642,7 +744,11 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# make smallest slice possible
|
||||
slice_size = num_sliceable_layers * [1]
|
||||
|
||||
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||
slice_size = (
|
||||
num_sliceable_layers * [slice_size]
|
||||
if not isinstance(slice_size, list)
|
||||
else slice_size
|
||||
)
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
@@ -659,7 +765,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
def fn_recursive_set_attention_slice(
|
||||
module: torch.nn.Module, slice_size: List[int]
|
||||
):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
@@ -675,19 +783,19 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
brushnet_cond: torch.FloatTensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
brushnet_cond: torch.FloatTensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
||||
"""
|
||||
The [`BrushNetModel`] forward method.
|
||||
@@ -737,7 +845,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
elif channel_order == "bgr":
|
||||
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
|
||||
else:
|
||||
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
|
||||
raise ValueError(
|
||||
f"unknown `brushnet_conditioning_channel_order`: {channel_order}"
|
||||
)
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
@@ -773,7 +883,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
raise ValueError(
|
||||
"class_labels should be provided when num_class_embeds > 0"
|
||||
)
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
@@ -812,7 +924,10 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
if (
|
||||
hasattr(downsample_block, "has_cross_attention")
|
||||
and downsample_block.has_cross_attention
|
||||
):
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
@@ -827,13 +942,20 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 4. PaintingNet down blocks
|
||||
brushnet_down_block_res_samples = ()
|
||||
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
|
||||
for down_block_res_sample, brushnet_down_block in zip(
|
||||
down_block_res_samples, self.brushnet_down_blocks
|
||||
):
|
||||
down_block_res_sample = brushnet_down_block(down_block_res_sample)
|
||||
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
|
||||
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (
|
||||
down_block_res_sample,
|
||||
)
|
||||
|
||||
# 5. mid
|
||||
if self.mid_block is not None:
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
if (
|
||||
hasattr(self.mid_block, "has_cross_attention")
|
||||
and self.mid_block.has_cross_attention
|
||||
):
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
@@ -852,15 +974,20 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[
|
||||
: -len(upsample_block.resnets)
|
||||
]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||
if (
|
||||
hasattr(upsample_block, "has_cross_attention")
|
||||
and upsample_block.has_cross_attention
|
||||
):
|
||||
sample, up_res_samples = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
@@ -869,7 +996,7 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
return_res_samples=True
|
||||
return_res_samples=True,
|
||||
)
|
||||
else:
|
||||
sample, up_res_samples = upsample_block(
|
||||
@@ -877,53 +1004,87 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size,
|
||||
return_res_samples=True
|
||||
return_res_samples=True,
|
||||
)
|
||||
|
||||
up_block_res_samples += up_res_samples
|
||||
|
||||
# 8. BrushNet up blocks
|
||||
brushnet_up_block_res_samples = ()
|
||||
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
|
||||
for up_block_res_sample, brushnet_up_block in zip(
|
||||
up_block_res_samples, self.brushnet_up_blocks
|
||||
):
|
||||
up_block_res_sample = brushnet_up_block(up_block_res_sample)
|
||||
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
|
||||
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (
|
||||
up_block_res_sample,
|
||||
)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0,
|
||||
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
|
||||
device=sample.device) # 0.1 to 1.0
|
||||
scales = torch.logspace(
|
||||
-1,
|
||||
0,
|
||||
len(brushnet_down_block_res_samples)
|
||||
+ 1
|
||||
+ len(brushnet_up_block_res_samples),
|
||||
device=sample.device,
|
||||
) # 0.1 to 1.0
|
||||
scales = scales * conditioning_scale
|
||||
|
||||
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples,
|
||||
scales[:len(
|
||||
brushnet_down_block_res_samples)])]
|
||||
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
|
||||
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples,
|
||||
scales[
|
||||
len(brushnet_down_block_res_samples) + 1:])]
|
||||
brushnet_down_block_res_samples = [
|
||||
sample * scale
|
||||
for sample, scale in zip(
|
||||
brushnet_down_block_res_samples,
|
||||
scales[: len(brushnet_down_block_res_samples)],
|
||||
)
|
||||
]
|
||||
brushnet_mid_block_res_sample = (
|
||||
brushnet_mid_block_res_sample
|
||||
* scales[len(brushnet_down_block_res_samples)]
|
||||
)
|
||||
brushnet_up_block_res_samples = [
|
||||
sample * scale
|
||||
for sample, scale in zip(
|
||||
brushnet_up_block_res_samples,
|
||||
scales[len(brushnet_down_block_res_samples) + 1 :],
|
||||
)
|
||||
]
|
||||
else:
|
||||
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in
|
||||
brushnet_down_block_res_samples]
|
||||
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
|
||||
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
|
||||
brushnet_down_block_res_samples = [
|
||||
sample * conditioning_scale
|
||||
for sample in brushnet_down_block_res_samples
|
||||
]
|
||||
brushnet_mid_block_res_sample = (
|
||||
brushnet_mid_block_res_sample * conditioning_scale
|
||||
)
|
||||
brushnet_up_block_res_samples = [
|
||||
sample * conditioning_scale for sample in brushnet_up_block_res_samples
|
||||
]
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
brushnet_down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True)
|
||||
for sample in brushnet_down_block_res_samples
|
||||
]
|
||||
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
brushnet_mid_block_res_sample = torch.mean(
|
||||
brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True
|
||||
)
|
||||
brushnet_up_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True)
|
||||
for sample in brushnet_up_block_res_samples
|
||||
]
|
||||
|
||||
if not return_dict:
|
||||
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
|
||||
return (
|
||||
brushnet_down_block_res_samples,
|
||||
brushnet_mid_block_res_sample,
|
||||
brushnet_up_block_res_samples,
|
||||
)
|
||||
|
||||
return BrushNetOutput(
|
||||
down_block_res_samples=brushnet_down_block_res_samples,
|
||||
mid_block_res_sample=brushnet_mid_block_res_sample,
|
||||
up_block_res_samples=brushnet_up_block_res_samples
|
||||
up_block_res_samples=brushnet_up_block_res_samples,
|
||||
)
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user