add sd1.5

This commit is contained in:
Qing
2022-10-20 21:01:14 +08:00
parent d892d9166f
commit 6ccb6cd291
6 changed files with 64 additions and 424 deletions

View File

@@ -7,8 +7,6 @@ import torch
from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler
from loguru import logger
from lama_cleaner.helper import norm_img
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config, SDSampler
@@ -38,12 +36,22 @@ from lama_cleaner.schema import Config, SDSampler
# mask = torch.from_numpy(mask)
# return mask
class CPUTextEncoderWrapper:
def __init__(self, text_encoder):
self.text_encoder = text_encoder.to(torch.device('cpu'), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
def __call__(self, x):
input_device = x.device
return [self.text_encoder(x.to(self.text_encoder.device))[0].to(input_device)]
class SD(InpaintModel):
pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505
pad_mod = 8 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505
min_size = 512
def init_model(self, device: torch.device, **kwargs):
from .sd_pipeline import StableDiffusionInpaintPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
model_kwargs = {"local_files_only": kwargs['sd_run_local']}
if kwargs['sd_disable_nsfw']:
@@ -65,8 +73,7 @@ class SD(InpaintModel):
if kwargs['sd_cpu_textencoder']:
logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu'), non_blocking=True)
self.model.text_encoder = self.model.text_encoder.to(torch.float32, non_blocking=True )
self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder)
self.callback = kwargs.pop("callback", None)
@@ -99,7 +106,6 @@ class SD(InpaintModel):
)
elif config.sd_sampler == SDSampler.pndm:
PNDM_kwargs = {
"tensor_format": "pt",
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"beta_end": 0.012,
@@ -124,15 +130,19 @@ class SD(InpaintModel):
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
_kwargs = {
self.image_key: PIL.Image.fromarray(image),
}
output = self.model(
prompt=config.prompt,
init_image=PIL.Image.fromarray(image),
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
strength=config.sd_strength,
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
output_type="np.array",
callback=self.callback,
**_kwargs
).images[0]
output = (output * 255).round().astype("uint8")
@@ -185,7 +195,9 @@ class SD(InpaintModel):
class SD14(SD):
model_id_or_path = "CompVis/stable-diffusion-v1-4"
image_key = "init_image"
class SD15(SD):
model_id_or_path = "CompVis/stable-diffusion-v1-5"
model_id_or_path = "runwayml/stable-diffusion-inpainting"
image_key = "image"