fix deprecate warning

This commit is contained in:
Qing
2024-11-23 18:58:34 +08:00
parent 4cec8ded64
commit 36e65f0097
3 changed files with 113 additions and 54 deletions

View File

@@ -1,8 +1,13 @@
from typing import Union, Optional, Dict, Any, Tuple from typing import Union, Optional, Dict, Any, Tuple
import torch import torch
from diffusers.models.unet_2d_condition import UNet2DConditionOutput from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from diffusers.utils import USE_PEFT_BACKEND, unscale_lora_layers, deprecate, scale_lora_layers from diffusers.utils import (
USE_PEFT_BACKEND,
unscale_lora_layers,
deprecate,
scale_lora_layers,
)
def brushnet_unet_forward( def brushnet_unet_forward(
@@ -112,7 +117,9 @@ def brushnet_unet_forward(
# convert encoder_attention_mask to a bias the same way we do for attention_mask # convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None: if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = (
1 - encoder_attention_mask.to(sample.dtype)
) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 0. center input if necessary # 0. center input if necessary
@@ -132,7 +139,9 @@ def brushnet_unet_forward(
emb = emb + class_emb emb = emb + class_emb
aug_emb = self.get_aug_embed( aug_emb = self.get_aug_embed(
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs emb=emb,
encoder_hidden_states=encoder_hidden_states,
added_cond_kwargs=added_cond_kwargs,
) )
if self.config.addition_embed_type == "image_hint": if self.config.addition_embed_type == "image_hint":
aug_emb, hint = aug_emb aug_emb, hint = aug_emb
@@ -151,25 +160,43 @@ def brushnet_unet_forward(
sample = self.conv_in(sample) sample = self.conv_in(sample)
# 2.5 GLIGEN position net # 2.5 GLIGEN position net
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: if (
cross_attention_kwargs is not None
and cross_attention_kwargs.get("gligen", None) is not None
):
cross_attention_kwargs = cross_attention_kwargs.copy() cross_attention_kwargs = cross_attention_kwargs.copy()
gligen_args = cross_attention_kwargs.pop("gligen") gligen_args = cross_attention_kwargs.pop("gligen")
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down # 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 lora_scale = (
cross_attention_kwargs.get("scale", 1.0)
if cross_attention_kwargs is not None
else 1.0
)
if USE_PEFT_BACKEND: if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer # weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale) scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_controlnet = (
mid_block_additional_residual is not None
and down_block_additional_residuals is not None
)
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
is_adapter = down_intrablock_additional_residuals is not None is_adapter = down_intrablock_additional_residuals is not None
# maintain backward compatibility for legacy usage, where # maintain backward compatibility for legacy usage, where
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
# but can only use one or the other # but can only use one or the other
is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None is_brushnet = (
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: down_block_add_samples is not None
and mid_block_add_sample is not None
and up_block_add_samples is not None
)
if (
not is_adapter
and mid_block_additional_residual is None
and down_block_additional_residuals is not None
):
deprecate( deprecate(
"T2I should not use down_block_additional_residuals", "T2I should not use down_block_additional_residuals",
"1.3.0", "1.3.0",
@@ -187,16 +214,25 @@ def brushnet_unet_forward(
sample = sample + down_block_add_samples.pop(0) sample = sample + down_block_add_samples.pop(0)
for downsample_block in self.down_blocks: for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: if (
hasattr(downsample_block, "has_cross_attention")
and downsample_block.has_cross_attention
):
# For t2i-adapter CrossAttnDownBlock2D # For t2i-adapter CrossAttnDownBlock2D
additional_residuals = {} additional_residuals = {}
if is_adapter and len(down_intrablock_additional_residuals) > 0: if is_adapter and len(down_intrablock_additional_residuals) > 0:
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) additional_residuals["additional_residuals"] = (
down_intrablock_additional_residuals.pop(0)
)
if is_brushnet and len(down_block_add_samples) > 0: if is_brushnet and len(down_block_add_samples) > 0:
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0) additional_residuals["down_block_add_samples"] = [
down_block_add_samples.pop(0)
for _ in range( for _ in range(
len(downsample_block.resnets) + (downsample_block.downsamplers != None))] len(downsample_block.resnets)
+ (downsample_block.downsamplers != None)
)
]
sample, res_samples = downsample_block( sample, res_samples = downsample_block(
hidden_states=sample, hidden_states=sample,
@@ -210,12 +246,17 @@ def brushnet_unet_forward(
else: else:
additional_residuals = {} additional_residuals = {}
if is_brushnet and len(down_block_add_samples) > 0: if is_brushnet and len(down_block_add_samples) > 0:
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0) additional_residuals["down_block_add_samples"] = [
down_block_add_samples.pop(0)
for _ in range( for _ in range(
len(downsample_block.resnets) + (downsample_block.downsamplers != None))] len(downsample_block.resnets)
+ (downsample_block.downsamplers != None)
)
]
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale, sample, res_samples = downsample_block(
**additional_residuals) hidden_states=sample, temb=emb, scale=lora_scale, **additional_residuals
)
if is_adapter and len(down_intrablock_additional_residuals) > 0: if is_adapter and len(down_intrablock_additional_residuals) > 0:
sample += down_intrablock_additional_residuals.pop(0) sample += down_intrablock_additional_residuals.pop(0)
@@ -227,14 +268,21 @@ def brushnet_unet_forward(
for down_block_res_sample, down_block_additional_residual in zip( for down_block_res_sample, down_block_additional_residual in zip(
down_block_res_samples, down_block_additional_residuals down_block_res_samples, down_block_additional_residuals
): ):
down_block_res_sample = down_block_res_sample + down_block_additional_residual down_block_res_sample = (
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_sample + down_block_additional_residual
)
new_down_block_res_samples = new_down_block_res_samples + (
down_block_res_sample,
)
down_block_res_samples = new_down_block_res_samples down_block_res_samples = new_down_block_res_samples
# 4. mid # 4. mid
if self.mid_block is not None: if self.mid_block is not None:
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: if (
hasattr(self.mid_block, "has_cross_attention")
and self.mid_block.has_cross_attention
):
sample = self.mid_block( sample = self.mid_block(
sample, sample,
emb, emb,
@@ -272,12 +320,19 @@ def brushnet_unet_forward(
if not is_final_block and forward_upsample_size: if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:] upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: if (
hasattr(upsample_block, "has_cross_attention")
and upsample_block.has_cross_attention
):
additional_residuals = {} additional_residuals = {}
if is_brushnet and len(up_block_add_samples) > 0: if is_brushnet and len(up_block_add_samples) > 0:
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0) additional_residuals["up_block_add_samples"] = [
up_block_add_samples.pop(0)
for _ in range( for _ in range(
len(upsample_block.resnets) + (upsample_block.upsamplers != None))] len(upsample_block.resnets)
+ (upsample_block.upsamplers != None)
)
]
sample = upsample_block( sample = upsample_block(
hidden_states=sample, hidden_states=sample,
@@ -293,9 +348,13 @@ def brushnet_unet_forward(
else: else:
additional_residuals = {} additional_residuals = {}
if is_brushnet and len(up_block_add_samples) > 0: if is_brushnet and len(up_block_add_samples) > 0:
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0) additional_residuals["up_block_add_samples"] = [
up_block_add_samples.pop(0)
for _ in range( for _ in range(
len(upsample_block.resnets) + (upsample_block.upsamplers != None))] len(upsample_block.resnets)
+ (upsample_block.upsamplers != None)
)
]
sample = upsample_block( sample = upsample_block(
hidden_states=sample, hidden_states=sample,

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from diffusers.models.unet_2d_blocks import ( from diffusers.models.unets.unet_2d_blocks import (
get_down_block, get_down_block,
get_mid_block, get_mid_block,
get_up_block, get_up_block,

View File

@@ -15,7 +15,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from diffusers.models.unet_2d_condition import UNet2DConditionOutput from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from diffusers.utils import ( from diffusers.utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
deprecate, deprecate,