wip
This commit is contained in:
@@ -14,17 +14,17 @@ class InpaintModel:
|
||||
pad_mod = 8
|
||||
pad_to_square = False
|
||||
|
||||
def __init__(self, device):
|
||||
def __init__(self, device, **kwargs):
|
||||
"""
|
||||
|
||||
Args:
|
||||
device:
|
||||
"""
|
||||
self.device = device
|
||||
self.init_model(device)
|
||||
self.init_model(device, **kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def init_model(self, device):
|
||||
def init_model(self, device, **kwargs):
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@@ -36,15 +36,19 @@ class InpaintModel:
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input images and output images have same size
|
||||
images: [H, W, C] RGB
|
||||
masks: [H, W] 255 为 masks 区域
|
||||
masks: [H, W, 1] 255 为 masks 区域
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
...
|
||||
|
||||
def _pad_forward(self, image, mask, config: Config):
|
||||
origin_height, origin_width = image.shape[:2]
|
||||
pad_image = pad_img_to_modulo(image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
|
||||
pad_mask = pad_img_to_modulo(mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
|
||||
pad_image = pad_img_to_modulo(
|
||||
image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
||||
)
|
||||
pad_mask = pad_img_to_modulo(
|
||||
mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
|
||||
)
|
||||
|
||||
logger.info(f"final forward pad size: {pad_image.shape}")
|
||||
|
||||
@@ -81,18 +85,30 @@ class InpaintModel:
|
||||
elif config.hd_strategy == HDStrategy.RESIZE:
|
||||
if max(image.shape) > config.hd_strategy_resize_limit:
|
||||
origin_size = image.shape[:2]
|
||||
downsize_image = resize_max_size(image, size_limit=config.hd_strategy_resize_limit)
|
||||
downsize_mask = resize_max_size(mask, size_limit=config.hd_strategy_resize_limit)
|
||||
downsize_image = resize_max_size(
|
||||
image, size_limit=config.hd_strategy_resize_limit
|
||||
)
|
||||
downsize_mask = resize_max_size(
|
||||
mask, size_limit=config.hd_strategy_resize_limit
|
||||
)
|
||||
|
||||
logger.info(f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}")
|
||||
inpaint_result = self._pad_forward(downsize_image, downsize_mask, config)
|
||||
logger.info(
|
||||
f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}"
|
||||
)
|
||||
inpaint_result = self._pad_forward(
|
||||
downsize_image, downsize_mask, config
|
||||
)
|
||||
|
||||
# only paste masked area result
|
||||
inpaint_result = cv2.resize(inpaint_result,
|
||||
(origin_size[1], origin_size[0]),
|
||||
interpolation=cv2.INTER_CUBIC)
|
||||
inpaint_result = cv2.resize(
|
||||
inpaint_result,
|
||||
(origin_size[1], origin_size[0]),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
original_pixel_indices = mask < 127
|
||||
inpaint_result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices]
|
||||
inpaint_result[original_pixel_indices] = image[:, :, ::-1][
|
||||
original_pixel_indices
|
||||
]
|
||||
|
||||
if inpaint_result is None:
|
||||
inpaint_result = self._pad_forward(image, mask, config)
|
||||
@@ -133,11 +149,11 @@ class InpaintModel:
|
||||
if _l < 0:
|
||||
r += abs(_l)
|
||||
if _r > img_w:
|
||||
l -= (_r - img_w)
|
||||
l -= _r - img_w
|
||||
if _t < 0:
|
||||
b += abs(_t)
|
||||
if _b > img_h:
|
||||
t -= (_b - img_h)
|
||||
t -= _b - img_h
|
||||
|
||||
l = max(l, 0)
|
||||
r = min(r, img_w)
|
||||
|
||||
@@ -1135,7 +1135,7 @@ class FcF(InpaintModel):
|
||||
pad_mod = 512
|
||||
pad_to_square = True
|
||||
|
||||
def init_model(self, device):
|
||||
def init_model(self, device, **kwargs):
|
||||
seed = 0
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
@@ -18,16 +18,7 @@ LAMA_MODEL_URL = os.environ.get(
|
||||
class LaMa(InpaintModel):
|
||||
pad_mod = 8
|
||||
|
||||
def __init__(self, device):
|
||||
"""
|
||||
|
||||
Args:
|
||||
device:
|
||||
"""
|
||||
super().__init__(device)
|
||||
self.device = device
|
||||
|
||||
def init_model(self, device):
|
||||
def init_model(self, device, **kwargs):
|
||||
if os.environ.get("LAMA_MODEL"):
|
||||
model_path = os.environ.get("LAMA_MODEL")
|
||||
if not os.path.exists(model_path):
|
||||
|
||||
@@ -227,7 +227,7 @@ class LDM(InpaintModel):
|
||||
super().__init__(device)
|
||||
self.device = device
|
||||
|
||||
def init_model(self, device):
|
||||
def init_model(self, device, **kwargs):
|
||||
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
|
||||
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
|
||||
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
|
||||
|
||||
@@ -1405,7 +1405,7 @@ class MAT(InpaintModel):
|
||||
pad_mod = 512
|
||||
pad_to_square = True
|
||||
|
||||
def init_model(self, device):
|
||||
def init_model(self, device, **kwargs):
|
||||
seed = 240 # pick up a random number
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
152
lama_cleaner/model/sd.py
Normal file
152
lama_cleaner/model/sd.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import random
|
||||
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import norm_img
|
||||
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
|
||||
#
|
||||
#
|
||||
# def preprocess_image(image):
|
||||
# w, h = image.size
|
||||
# w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
# image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||
# image = np.array(image).astype(np.float32) / 255.0
|
||||
# image = image[None].transpose(0, 3, 1, 2)
|
||||
# image = torch.from_numpy(image)
|
||||
# # [-1, 1]
|
||||
# return 2.0 * image - 1.0
|
||||
#
|
||||
#
|
||||
# def preprocess_mask(mask):
|
||||
# mask = mask.convert("L")
|
||||
# w, h = mask.size
|
||||
# w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
# mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
|
||||
# mask = np.array(mask).astype(np.float32) / 255.0
|
||||
# mask = np.tile(mask, (4, 1, 1))
|
||||
# mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
||||
# mask = 1 - mask # repaint white, keep black
|
||||
# mask = torch.from_numpy(mask)
|
||||
# return mask
|
||||
|
||||
|
||||
class SD(InpaintModel):
|
||||
pad_mod = 32
|
||||
min_size = 512
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
# return
|
||||
from .sd_pipeline import StableDiffusionInpaintPipeline
|
||||
|
||||
self.model = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
self.model_id_or_path,
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
use_auth_token=kwargs["hf_access_token"],
|
||||
)
|
||||
# https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||
self.model.enable_attention_slicing()
|
||||
self.model = self.model.to(device)
|
||||
self.callbacks = kwargs.pop("callbacks", None)
|
||||
|
||||
@torch.cuda.amp.autocast()
|
||||
def forward(self, image, mask, config: Config):
|
||||
# return image
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
|
||||
# image = norm_img(image) # [0, 1]
|
||||
# image = image * 2 - 1 # [0, 1] -> [-1, 1]
|
||||
|
||||
# resize to latent feature map size
|
||||
# h, w = mask.shape[:2]
|
||||
# mask = cv2.resize(mask, (h // 8, w // 8), interpolation=cv2.INTER_AREA)
|
||||
# mask = norm_img(mask)
|
||||
#
|
||||
# image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||
# mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||
# import time
|
||||
# time.sleep(2)
|
||||
# return image
|
||||
seed = config.sd_seed
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
output = self.model(
|
||||
prompt=config.prompt,
|
||||
init_image=PIL.Image.fromarray(image),
|
||||
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
||||
strength=config.sd_strength,
|
||||
num_inference_steps=config.sd_steps,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np.array",
|
||||
callbacks=self.callbacks,
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask, config: Config):
|
||||
"""
|
||||
images: [H, W, C] RGB, not normalized
|
||||
masks: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
img_h, img_w = image.shape[:2]
|
||||
|
||||
# boxes = boxes_from_mask(mask)
|
||||
if config.use_croper:
|
||||
logger.info("use croper")
|
||||
l, t, w, h = (
|
||||
config.croper_x,
|
||||
config.croper_y,
|
||||
config.croper_width,
|
||||
config.croper_height,
|
||||
)
|
||||
r = l + w
|
||||
b = t + h
|
||||
|
||||
l = max(l, 0)
|
||||
r = min(r, img_w)
|
||||
t = max(t, 0)
|
||||
b = min(b, img_h)
|
||||
|
||||
crop_img = image[t:b, l:r, :]
|
||||
crop_mask = mask[t:b, l:r]
|
||||
|
||||
crop_image = self._pad_forward(crop_img, crop_mask, config)
|
||||
|
||||
inpaint_result = image[:, :, ::-1]
|
||||
inpaint_result[t:b, l:r, :] = crop_image
|
||||
else:
|
||||
inpaint_result = self._pad_forward(image, mask, config)
|
||||
|
||||
return inpaint_result
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
# model will be downloaded when app start, and can't switch in frontend settings
|
||||
return True
|
||||
|
||||
|
||||
class SD14(SD):
|
||||
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
|
||||
class SD15(SD):
|
||||
model_id_or_path = "CompVis/stable-diffusion-v1-5"
|
||||
309
lama_cleaner/model/sd_pipeline.py
Normal file
309
lama_cleaner/model/sd_pipeline.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import inspect
|
||||
from typing import List, Optional, Union, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, PNDMScheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
|
||||
from diffusers.utils import logging
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess_image(image):
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
|
||||
|
||||
def preprocess_mask(mask):
|
||||
mask = mask.convert("L")
|
||||
w, h = mask.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
|
||||
mask = np.array(mask).astype(np.float32) / 255.0
|
||||
mask = np.tile(mask, (4, 1, 1))
|
||||
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
||||
mask = 1 - mask # repaint white, keep black
|
||||
mask = torch.from_numpy(mask)
|
||||
return mask
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offsensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `set_attention_slice`
|
||||
self.enable_attention_slice(None)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callbacks: List[Callable[[int], None]] = None
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
||||
process. This is the image whose masked region will be inpainted.
|
||||
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
|
||||
replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be
|
||||
converted to a single channel (luminance) before use.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
|
||||
is 1, the denoising process will be run on the masked area for the full number of iterations specified
|
||||
in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
|
||||
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
|
||||
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
# set timesteps
|
||||
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
|
||||
extra_set_kwargs = {}
|
||||
offset = 0
|
||||
if accepts_offset:
|
||||
offset = 1
|
||||
extra_set_kwargs["offset"] = 1
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
|
||||
# preprocess image
|
||||
init_image = preprocess_image(init_image).to(self.device)
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
# Expand init_latents for batch_size
|
||||
init_latents = torch.cat([init_latents] * batch_size)
|
||||
init_latents_orig = init_latents
|
||||
|
||||
# preprocess mask
|
||||
mask = preprocess_mask(mask_image).to(self.device)
|
||||
mask = torch.cat([mask] * batch_size)
|
||||
|
||||
# check sizes
|
||||
if not mask.shape == init_latents.shape:
|
||||
raise ValueError("The mask and init_image should be the same size!")
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
latents = init_latents
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
if callbacks is not None:
|
||||
for callback in callbacks:
|
||||
callback(i)
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
# run safety checker
|
||||
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
@@ -216,7 +216,7 @@ class ZITS(InpaintModel):
|
||||
self.device = device
|
||||
self.sample_edge_line_iterations = 1
|
||||
|
||||
def init_model(self, device):
|
||||
def init_model(self, device, **kwargs):
|
||||
self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device)
|
||||
self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device)
|
||||
self.structure_upsample = load_jit_model(
|
||||
|
||||
Reference in New Issue
Block a user