This commit is contained in:
Qing
2023-11-16 11:08:34 +08:00
parent 8f942e27c4
commit 0cfec489b7
7 changed files with 63 additions and 28 deletions

View File

@@ -285,6 +285,28 @@ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
def _default_height_width(self, height, width, image):
if isinstance(image, list):
image = image[0]
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[3]
height = (height // 8) * 8 # round down to nearest multiple of 8
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[2]
width = (width // 8) * 8 # round down to nearest multiple of 8
return height, width
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -402,14 +424,11 @@ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
control_image,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt=prompt,
image=control_image,
callback_steps=callback_steps,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 2. Define call parameters