This commit is contained in:
Qing
2023-11-16 11:08:34 +08:00
parent 8f942e27c4
commit 0cfec489b7
7 changed files with 63 additions and 28 deletions

View File

@@ -12,8 +12,9 @@ from lama_cleaner.model.utils import torch_gc, get_scheduler
from lama_cleaner.schema import Config
class CPUTextEncoderWrapper:
class CPUTextEncoderWrapper(torch.nn.Module):
def __init__(self, text_encoder, torch_dtype):
super().__init__()
self.config = text_encoder.config
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)

View File

@@ -285,6 +285,28 @@ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
def _default_height_width(self, height, width, image):
if isinstance(image, list):
image = image[0]
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[3]
height = (height // 8) * 8 # round down to nearest multiple of 8
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[2]
width = (width // 8) * 8 # round down to nearest multiple of 8
return height, width
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -402,14 +424,11 @@ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
control_image,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt=prompt,
image=control_image,
callback_steps=callback_steps,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 2. Define call parameters

View File

@@ -7,12 +7,13 @@ import torch
from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import torch_gc, get_scheduler
from lama_cleaner.schema import Config, SDSampler
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.schema import Config
class CPUTextEncoderWrapper:
class CPUTextEncoderWrapper(torch.nn.Module):
def __init__(self, text_encoder, torch_dtype):
super().__init__()
self.config = text_encoder.config
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)