update
This commit is contained in:
@@ -12,8 +12,9 @@ from lama_cleaner.model.utils import torch_gc, get_scheduler
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
|
||||
class CPUTextEncoderWrapper:
|
||||
class CPUTextEncoderWrapper(torch.nn.Module):
|
||||
def __init__(self, text_encoder, torch_dtype):
|
||||
super().__init__()
|
||||
self.config = text_encoder.config
|
||||
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
|
||||
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
||||
|
||||
@@ -285,6 +285,28 @@ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline
|
||||
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
||||
return mask, masked_image_latents
|
||||
|
||||
def _default_height_width(self, height, width, image):
|
||||
if isinstance(image, list):
|
||||
image = image[0]
|
||||
|
||||
if height is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
height = image.height
|
||||
elif isinstance(image, torch.Tensor):
|
||||
height = image.shape[3]
|
||||
|
||||
height = (height // 8) * 8 # round down to nearest multiple of 8
|
||||
|
||||
if width is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
width = image.width
|
||||
elif isinstance(image, torch.Tensor):
|
||||
width = image.shape[2]
|
||||
|
||||
width = (width // 8) * 8 # round down to nearest multiple of 8
|
||||
|
||||
return height, width
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -402,14 +424,11 @@ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
control_image,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt=prompt,
|
||||
image=control_image,
|
||||
callback_steps=callback_steps,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
|
||||
@@ -7,12 +7,13 @@ import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||
from lama_cleaner.model.utils import torch_gc, get_scheduler
|
||||
from lama_cleaner.schema import Config, SDSampler
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
|
||||
class CPUTextEncoderWrapper:
|
||||
class CPUTextEncoderWrapper(torch.nn.Module):
|
||||
def __init__(self, text_encoder, torch_dtype):
|
||||
super().__init__()
|
||||
self.config = text_encoder.config
|
||||
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
|
||||
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
||||
|
||||
Reference in New Issue
Block a user