🎨 完整的 IOPaint 项目更新

## 主要更新
-  更新所有依赖到最新稳定版本
- 📝 添加详细的项目文档和模型推荐
- 🔧 配置 VSCode Cloud Studio 预览功能
- 🐛 修复 PyTorch API 弃用警告

## 依赖更新
- diffusers: 0.27.2 → 0.35.2
- gradio: 4.21.0 → 5.46.0
- peft: 0.7.1 → 0.18.0
- Pillow: 9.5.0 → 11.3.0
- fastapi: 0.108.0 → 0.116.2

## 新增文件
- CLAUDE.md - 项目架构和开发指南
- UPGRADE_NOTES.md - 详细的升级说明
- .vscode/preview.yml - 预览配置
- .vscode/LAUNCH_GUIDE.md - 启动指南
- .gitignore - 更新的忽略规则

## 代码修复
- 修复 iopaint/model/ldm.py 中的 torch.cuda.amp.autocast() 弃用警告

## 文档更新
- README.md - 添加模型推荐和使用指南
- 完整的项目源码(iopaint/)
- Web 前端源码(web_app/)

🤖 Generated with Claude Code
This commit is contained in:
let5sne
2025-11-28 17:10:24 +00:00
parent 03b999e9ea
commit 1b87a98261
332 changed files with 77453 additions and 26 deletions

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,101 @@
from PIL import Image
import PIL.Image
import cv2
import torch
from loguru import logger
from ..base import DiffusionInpaintModel
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
from ..utils import (
handle_from_pretrained_exceptions,
get_torch_dtype,
enable_low_mem,
is_local_files_only,
)
from iopaint.schema import InpaintRequest
from .powerpaint_tokenizer import add_task_to_prompt
from ...const import POWERPAINT_NAME
class PowerPaint(DiffusionInpaintModel):
name = POWERPAINT_NAME
pad_mod = 8
min_size = 512
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
def init_model(self, device: torch.device, **kwargs):
from .pipeline_powerpaint import StableDiffusionInpaintPipeline
from .powerpaint_tokenizer import PowerPaintTokenizer
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
model_kwargs = {"local_files_only": is_local_files_only(**kwargs)}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(
dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
)
self.model = handle_from_pretrained_exceptions(
StableDiffusionInpaintPipeline.from_pretrained,
pretrained_model_name_or_path=self.name,
variant="fp16",
torch_dtype=torch_dtype,
**model_kwargs,
)
self.model.tokenizer = PowerPaintTokenizer(self.model.tokenizer)
enable_low_mem(self.model, kwargs.get("low_mem", False))
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
if kwargs["sd_cpu_textencoder"]:
logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = CPUTextEncoderWrapper(
self.model.text_encoder, torch_dtype
)
self.callback = kwargs.pop("callback", None)
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
self.set_scheduler(config)
img_h, img_w = image.shape[:2]
promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt(
config.prompt, config.negative_prompt, config.powerpaint_task
)
output = self.model(
image=PIL.Image.fromarray(image),
promptA=promptA,
promptB=promptB,
tradoff=config.fitting_degree,
tradoff_nag=config.fitting_degree,
negative_promptA=negative_promptA,
negative_promptB=negative_promptB,
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
num_inference_steps=config.sd_steps,
strength=config.sd_strength,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
callback_steps=1,
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output

View File

@@ -0,0 +1,186 @@
from itertools import chain
import PIL.Image
import cv2
import torch
from iopaint.model.original_sd_configs import get_config_files
from loguru import logger
from transformers import CLIPTextModel, CLIPTokenizer
import numpy as np
from ..base import DiffusionInpaintModel
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
from ..utils import (
get_torch_dtype,
enable_low_mem,
is_local_files_only,
handle_from_pretrained_exceptions,
)
from .powerpaint_tokenizer import task_to_prompt
from iopaint.schema import InpaintRequest, ModelType
from .v2.BrushNet_CA import BrushNetModel
from .v2.unet_2d_condition import UNet2DConditionModel_forward
from .v2.unet_2d_blocks import (
CrossAttnDownBlock2D_forward,
DownBlock2D_forward,
CrossAttnUpBlock2D_forward,
UpBlock2D_forward,
)
class PowerPaintV2(DiffusionInpaintModel):
pad_mod = 8
min_size = 512
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
hf_model_id = "Sanster/PowerPaint_v2"
def init_model(self, device: torch.device, **kwargs):
from .v2.pipeline_PowerPaint_Brushnet_CA import (
StableDiffusionPowerPaintBrushNetPipeline,
)
from .powerpaint_tokenizer import PowerPaintTokenizer
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
model_kwargs = {"local_files_only": is_local_files_only(**kwargs)}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(
dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
)
text_encoder_brushnet = CLIPTextModel.from_pretrained(
self.hf_model_id,
subfolder="text_encoder_brushnet",
variant="fp16",
torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"],
)
brushnet = BrushNetModel.from_pretrained(
self.hf_model_id,
subfolder="PowerPaint_Brushnet",
variant="fp16",
torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"],
)
if self.model_info.is_single_file_diffusers:
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
model_kwargs["num_in_channels"] = 4
else:
model_kwargs["num_in_channels"] = 9
pipe = StableDiffusionPowerPaintBrushNetPipeline.from_single_file(
self.model_id_or_path,
torch_dtype=torch_dtype,
load_safety_checker=False,
original_config_file=get_config_files()["v1"],
brushnet=brushnet,
text_encoder_brushnet=text_encoder_brushnet,
**model_kwargs,
)
else:
pipe = handle_from_pretrained_exceptions(
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path,
torch_dtype=torch_dtype,
brushnet=brushnet,
text_encoder_brushnet=text_encoder_brushnet,
variant="fp16",
**model_kwargs,
)
pipe.tokenizer = PowerPaintTokenizer(
CLIPTokenizer.from_pretrained(self.hf_model_id, subfolder="tokenizer")
)
self.model = pipe
enable_low_mem(self.model, kwargs.get("low_mem", False))
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
if kwargs["sd_cpu_textencoder"]:
logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = CPUTextEncoderWrapper(
self.model.text_encoder, torch_dtype
)
self.callback = kwargs.pop("callback", None)
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
self.model.unet.forward = UNet2DConditionModel_forward.__get__(
self.model.unet, self.model.unet.__class__
)
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
for down_block in chain(
self.model.unet.down_blocks, self.model.brushnet.down_blocks
):
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
down_block.forward = CrossAttnDownBlock2D_forward.__get__(
down_block, down_block.__class__
)
else:
down_block.forward = DownBlock2D_forward.__get__(
down_block, down_block.__class__
)
for up_block in chain(self.model.unet.up_blocks, self.model.brushnet.up_blocks):
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
up_block.forward = CrossAttnUpBlock2D_forward.__get__(
up_block, up_block.__class__
)
else:
up_block.forward = UpBlock2D_forward.__get__(
up_block, up_block.__class__
)
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
self.set_scheduler(config)
image = image * (1 - mask / 255.0)
img_h, img_w = image.shape[:2]
image = PIL.Image.fromarray(image.astype(np.uint8))
mask = PIL.Image.fromarray(mask[:, :, -1], mode="L").convert("RGB")
promptA, promptB, negative_promptA, negative_promptB = task_to_prompt(
config.powerpaint_task
)
output = self.model(
image=image,
mask=mask,
promptA=promptA,
promptB=promptB,
promptU=config.prompt,
tradoff=config.fitting_degree,
tradoff_nag=config.fitting_degree,
negative_promptA=negative_promptA,
negative_promptB=negative_promptB,
negative_promptU=config.negative_prompt,
num_inference_steps=config.sd_steps,
# strength=config.sd_strength,
brushnet_conditioning_scale=1.0,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback_on_step_end=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output

View File

@@ -0,0 +1,254 @@
import copy
import random
from typing import Any, List, Union
from transformers import CLIPTokenizer
from iopaint.schema import PowerPaintTask
def add_task_to_prompt(prompt, negative_prompt, task: PowerPaintTask):
if task == PowerPaintTask.object_remove:
promptA = prompt + " P_ctxt"
promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt + " P_obj"
negative_promptB = negative_prompt + " P_obj"
elif task == PowerPaintTask.context_aware:
promptA = prompt + " P_ctxt"
promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt
negative_promptB = negative_prompt
elif task == PowerPaintTask.shape_guided:
promptA = prompt + " P_shape"
promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt
negative_promptB = negative_prompt
elif task == PowerPaintTask.outpainting:
promptA = prompt + " P_ctxt"
promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt + " P_obj"
negative_promptB = negative_prompt + " P_obj"
else:
promptA = prompt + " P_obj"
promptB = prompt + " P_obj"
negative_promptA = negative_prompt
negative_promptB = negative_prompt
return promptA, promptB, negative_promptA, negative_promptB
def task_to_prompt(task: PowerPaintTask):
promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt(
"", "", task
)
return (
promptA.strip(),
promptB.strip(),
negative_promptA.strip(),
negative_promptB.strip(),
)
class PowerPaintTokenizer:
def __init__(self, tokenizer: CLIPTokenizer):
self.wrapped = tokenizer
self.token_map = {}
placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"]
num_vec_per_token = 10
for placeholder_token in placeholder_tokens:
output = []
for i in range(num_vec_per_token):
ith_token = placeholder_token + f"_{i}"
output.append(ith_token)
self.token_map[placeholder_token] = output
def __getattr__(self, name: str) -> Any:
if name == "wrapped":
return super().__getattr__("wrapped")
try:
return getattr(self.wrapped, name)
except AttributeError:
try:
return super().__getattr__(name)
except AttributeError:
raise AttributeError(
"'name' cannot be found in both "
f"'{self.__class__.__name__}' and "
f"'{self.__class__.__name__}.tokenizer'."
)
def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
"""Attempt to add tokens to the tokenizer.
Args:
tokens (Union[str, List[str]]): The tokens to be added.
"""
num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
assert num_added_tokens != 0, (
f"The tokenizer already contains the token {tokens}. Please pass "
"a different `placeholder_token` that is not already in the "
"tokenizer."
)
def get_token_info(self, token: str) -> dict:
"""Get the information of a token, including its start and end index in
the current tokenizer.
Args:
token (str): The token to be queried.
Returns:
dict: The information of the token, including its start and end
index in current tokenizer.
"""
token_ids = self.__call__(token).input_ids
start, end = token_ids[1], token_ids[-2] + 1
return {"name": token, "start": start, "end": end}
def add_placeholder_token(
self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs
):
"""Add placeholder tokens to the tokenizer.
Args:
placeholder_token (str): The placeholder token to be added.
num_vec_per_token (int, optional): The number of vectors of
the added placeholder token.
*args, **kwargs: The arguments for `self.wrapped.add_tokens`.
"""
output = []
if num_vec_per_token == 1:
self.try_adding_tokens(placeholder_token, *args, **kwargs)
output.append(placeholder_token)
else:
output = []
for i in range(num_vec_per_token):
ith_token = placeholder_token + f"_{i}"
self.try_adding_tokens(ith_token, *args, **kwargs)
output.append(ith_token)
for token in self.token_map:
if token in placeholder_token:
raise ValueError(
f"The tokenizer already has placeholder token {token} "
f"that can get confused with {placeholder_token} "
"keep placeholder tokens independent"
)
self.token_map[placeholder_token] = output
def replace_placeholder_tokens_in_text(
self,
text: Union[str, List[str]],
vector_shuffle: bool = False,
prop_tokens_to_load: float = 1.0,
) -> Union[str, List[str]]:
"""Replace the keywords in text with placeholder tokens. This function
will be called in `self.__call__` and `self.encode`.
Args:
text (Union[str, List[str]]): The text to be processed.
vector_shuffle (bool, optional): Whether to shuffle the vectors.
Defaults to False.
prop_tokens_to_load (float, optional): The proportion of tokens to
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
Returns:
Union[str, List[str]]: The processed text.
"""
if isinstance(text, list):
output = []
for i in range(len(text)):
output.append(
self.replace_placeholder_tokens_in_text(
text[i], vector_shuffle=vector_shuffle
)
)
return output
for placeholder_token in self.token_map:
if placeholder_token in text:
tokens = self.token_map[placeholder_token]
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
if vector_shuffle:
tokens = copy.copy(tokens)
random.shuffle(tokens)
text = text.replace(placeholder_token, " ".join(tokens))
return text
def replace_text_with_placeholder_tokens(
self, text: Union[str, List[str]]
) -> Union[str, List[str]]:
"""Replace the placeholder tokens in text with the original keywords.
This function will be called in `self.decode`.
Args:
text (Union[str, List[str]]): The text to be processed.
Returns:
Union[str, List[str]]: The processed text.
"""
if isinstance(text, list):
output = []
for i in range(len(text)):
output.append(self.replace_text_with_placeholder_tokens(text[i]))
return output
for placeholder_token, tokens in self.token_map.items():
merged_tokens = " ".join(tokens)
if merged_tokens in text:
text = text.replace(merged_tokens, placeholder_token)
return text
def __call__(
self,
text: Union[str, List[str]],
*args,
vector_shuffle: bool = False,
prop_tokens_to_load: float = 1.0,
**kwargs,
):
"""The call function of the wrapper.
Args:
text (Union[str, List[str]]): The text to be tokenized.
vector_shuffle (bool, optional): Whether to shuffle the vectors.
Defaults to False.
prop_tokens_to_load (float, optional): The proportion of tokens to
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
*args, **kwargs: The arguments for `self.wrapped.__call__`.
"""
replaced_text = self.replace_placeholder_tokens_in_text(
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
)
return self.wrapped.__call__(replaced_text, *args, **kwargs)
def encode(self, text: Union[str, List[str]], *args, **kwargs):
"""Encode the passed text to token index.
Args:
text (Union[str, List[str]]): The text to be encode.
*args, **kwargs: The arguments for `self.wrapped.__call__`.
"""
replaced_text = self.replace_placeholder_tokens_in_text(text)
return self.wrapped(replaced_text, *args, **kwargs)
def decode(
self, token_ids, return_raw: bool = False, *args, **kwargs
) -> Union[str, List[str]]:
"""Decode the token index to text.
Args:
token_ids: The token index to be decoded.
return_raw: Whether keep the placeholder token in the text.
Defaults to False.
*args, **kwargs: The arguments for `self.wrapped.decode`.
Returns:
Union[str, List[str]]: The decoded text.
"""
text = self.wrapped.decode(token_ids, *args, **kwargs)
if return_raw:
return text
replaced_text = self.replace_text_with_placeholder_tokens(text)
return replaced_text

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,342 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
import torch
from diffusers.utils import is_torch_version, logging
from diffusers.utils.torch_utils import apply_freeu
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def CrossAttnDownBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
additional_residuals: Optional[torch.FloatTensor] = None,
down_block_add_samples: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
lora_scale = (
cross_attention_kwargs.get("scale", 1.0)
if cross_attention_kwargs is not None
else 1.0
)
blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
# apply additional residuals to the output of the last pair of resnet and attention blocks
if i == len(blocks) - 1 and additional_residuals is not None:
hidden_states = hidden_states + additional_residuals
if down_block_add_samples is not None:
hidden_states = hidden_states + down_block_add_samples.pop(0)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale=lora_scale)
if down_block_add_samples is not None:
hidden_states = hidden_states + down_block_add_samples.pop(
0
) # todo: add before or after
output_states = output_states + (hidden_states,)
return hidden_states, output_states
def DownBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
down_block_add_samples: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
if down_block_add_samples is not None:
hidden_states = hidden_states + down_block_add_samples.pop(0)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale=scale)
if down_block_add_samples is not None:
hidden_states = hidden_states + down_block_add_samples.pop(
0
) # todo: add before or after
output_states = output_states + (hidden_states,)
return hidden_states, output_states
def CrossAttnUpBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
return_res_samples: Optional[bool] = False,
up_block_add_samples: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
lora_scale = (
cross_attention_kwargs.get("scale", 1.0)
if cross_attention_kwargs is not None
else 1.0
)
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
if return_res_samples:
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if return_res_samples:
output_states = output_states + (hidden_states,)
if up_block_add_samples is not None:
hidden_states = hidden_states + up_block_add_samples.pop(0)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
if return_res_samples:
output_states = output_states + (hidden_states,)
if up_block_add_samples is not None:
hidden_states = hidden_states + up_block_add_samples.pop(0)
if return_res_samples:
return hidden_states, output_states
else:
return hidden_states
def UpBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
scale: float = 1.0,
return_res_samples: Optional[bool] = False,
up_block_add_samples: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
if return_res_samples:
output_states = ()
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
use_reentrant=False,
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
if return_res_samples:
output_states = output_states + (hidden_states,)
if up_block_add_samples is not None:
hidden_states = hidden_states + up_block_add_samples.pop(
0
) # todo: add before or after
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
if return_res_samples:
output_states = output_states + (hidden_states,)
if up_block_add_samples is not None:
hidden_states = hidden_states + up_block_add_samples.pop(
0
) # todo: add before or after
if return_res_samples:
return hidden_states, output_states
else:
return hidden_states

View File

@@ -0,0 +1,402 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
scale_lora_layers,
unscale_lora_layers,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def UNet2DConditionModel_forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
The [`UNet2DConditionModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
through the `self.time_embedding` layer to obtain the timestep embeddings.
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
A tuple of tensors that if specified are added to the residuals of down unet blocks.
mid_block_additional_residual: (`torch.Tensor`, *optional*):
A tensor that if specified is added to the residual of the middle unet block.
encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
example from ControlNet side model(s)
mid_block_additional_residual (`torch.Tensor`, *optional*):
additional residual to be added to UNet mid block output, for example from ControlNet side model
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
Returns:
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
break
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None:
encoder_attention_mask = (
1 - encoder_attention_mask.to(sample.dtype)
) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
if class_emb is not None:
if self.config.class_embeddings_concat:
emb = torch.cat([emb, class_emb], dim=-1)
else:
emb = emb + class_emb
aug_emb = self.get_aug_embed(
emb=emb,
encoder_hidden_states=encoder_hidden_states,
added_cond_kwargs=added_cond_kwargs,
)
if self.config.addition_embed_type == "image_hint":
aug_emb, hint = aug_emb
sample = torch.cat([sample, hint], dim=1)
emb = emb + aug_emb if aug_emb is not None else emb
if self.time_embed_act is not None:
emb = self.time_embed_act(emb)
encoder_hidden_states = self.process_encoder_hidden_states(
encoder_hidden_states=encoder_hidden_states,
added_cond_kwargs=added_cond_kwargs,
)
# 2. pre-process
sample = self.conv_in(sample)
# 2.5 GLIGEN position net
if (
cross_attention_kwargs is not None
and cross_attention_kwargs.get("gligen", None) is not None
):
cross_attention_kwargs = cross_attention_kwargs.copy()
gligen_args = cross_attention_kwargs.pop("gligen")
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down
lora_scale = (
cross_attention_kwargs.get("scale", 1.0)
if cross_attention_kwargs is not None
else 1.0
)
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
is_controlnet = (
mid_block_additional_residual is not None
and down_block_additional_residuals is not None
)
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
is_adapter = down_intrablock_additional_residuals is not None
# maintain backward compatibility for legacy usage, where
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
# but can only use one or the other
is_brushnet = (
down_block_add_samples is not None
and mid_block_add_sample is not None
and up_block_add_samples is not None
)
if (
not is_adapter
and mid_block_additional_residual is None
and down_block_additional_residuals is not None
):
deprecate(
"T2I should not use down_block_additional_residuals",
"1.3.0",
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
standard_warn=False,
)
down_intrablock_additional_residuals = down_block_additional_residuals
is_adapter = True
down_block_res_samples = (sample,)
if is_brushnet:
sample = sample + down_block_add_samples.pop(0)
for downsample_block in self.down_blocks:
if (
hasattr(downsample_block, "has_cross_attention")
and downsample_block.has_cross_attention
):
# For t2i-adapter CrossAttnDownBlock2D
additional_residuals = {}
if is_adapter and len(down_intrablock_additional_residuals) > 0:
additional_residuals["additional_residuals"] = (
down_intrablock_additional_residuals.pop(0)
)
if is_brushnet and len(down_block_add_samples) > 0:
additional_residuals["down_block_add_samples"] = [
down_block_add_samples.pop(0)
for _ in range(
len(downsample_block.resnets)
+ (downsample_block.downsamplers != None)
)
]
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
**additional_residuals,
)
else:
additional_residuals = {}
if is_brushnet and len(down_block_add_samples) > 0:
additional_residuals["down_block_add_samples"] = [
down_block_add_samples.pop(0)
for _ in range(
len(downsample_block.resnets)
+ (downsample_block.downsamplers != None)
)
]
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
scale=lora_scale,
**additional_residuals,
)
if is_adapter and len(down_intrablock_additional_residuals) > 0:
sample += down_intrablock_additional_residuals.pop(0)
down_block_res_samples += res_samples
if is_controlnet:
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
down_block_res_samples, down_block_additional_residuals
):
down_block_res_sample = (
down_block_res_sample + down_block_additional_residual
)
new_down_block_res_samples = new_down_block_res_samples + (
down_block_res_sample,
)
down_block_res_samples = new_down_block_res_samples
# 4. mid
if self.mid_block is not None:
if (
hasattr(self.mid_block, "has_cross_attention")
and self.mid_block.has_cross_attention
):
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample = self.mid_block(sample, emb)
# To support T2I-Adapter-XL
if (
is_adapter
and len(down_intrablock_additional_residuals) > 0
and sample.shape == down_intrablock_additional_residuals[0].shape
):
sample += down_intrablock_additional_residuals.pop(0)
if is_controlnet:
sample = sample + mid_block_additional_residual
if is_brushnet:
sample = sample + mid_block_add_sample
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if (
hasattr(upsample_block, "has_cross_attention")
and upsample_block.has_cross_attention
):
additional_residuals = {}
if is_brushnet and len(up_block_add_samples) > 0:
additional_residuals["up_block_add_samples"] = [
up_block_add_samples.pop(0)
for _ in range(
len(upsample_block.resnets)
+ (upsample_block.upsamplers != None)
)
]
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
**additional_residuals,
)
else:
additional_residuals = {}
if is_brushnet and len(up_block_add_samples) > 0:
additional_residuals["up_block_add_samples"] = [
up_block_add_samples.pop(0)
for _ in range(
len(upsample_block.resnets)
+ (upsample_block.upsamplers != None)
)
]
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
scale=lora_scale,
**additional_residuals,
)
# 6. post-process
if self.conv_norm_out:
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (sample,)
return UNet2DConditionOutput(sample=sample)