update diffusers to 0.9; add SD2

This commit is contained in:
Qing
2022-12-04 13:41:48 +08:00
parent 15fe87e42d
commit 6a0ffdc96e
7 changed files with 41 additions and 30 deletions

View File

@@ -5,7 +5,7 @@ import cv2
import numpy as np
import torch
from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \
EulerAncestralDiscreteScheduler
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
from loguru import logger
from lama_cleaner.model.base import InpaintModel
@@ -102,27 +102,20 @@ class SD(InpaintModel):
# image = torch.from_numpy(image).unsqueeze(0).to(self.device)
# mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
scheduler_kwargs = dict(
beta_schedule="scaled_linear",
beta_start=0.00085,
beta_end=0.012,
num_train_timesteps=1000,
)
scheduler_config = self.model.scheduler.config
if config.sd_sampler == SDSampler.ddim:
scheduler = DDIMScheduler(
**scheduler_kwargs,
clip_sample=False,
set_alpha_to_one=False,
)
scheduler = DDIMScheduler.from_config(scheduler_config)
elif config.sd_sampler == SDSampler.pndm:
scheduler = PNDMScheduler(**scheduler_kwargs, skip_prk_steps=True)
scheduler = PNDMScheduler.from_config(scheduler_config)
elif config.sd_sampler == SDSampler.k_lms:
scheduler = LMSDiscreteScheduler(**scheduler_kwargs)
scheduler = LMSDiscreteScheduler.from_config(scheduler_config)
elif config.sd_sampler == SDSampler.k_euler:
scheduler = EulerDiscreteScheduler(**scheduler_kwargs)
scheduler = EulerDiscreteScheduler.from_config(scheduler_config)
elif config.sd_sampler == SDSampler.k_euler_a:
scheduler = EulerAncestralDiscreteScheduler(**scheduler_kwargs)
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config)
elif config.sd_sampler == SDSampler.dpm_plus_plus:
scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config)
else:
raise ValueError(config.sd_sampler)
@@ -138,13 +131,10 @@ class SD(InpaintModel):
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
_kwargs = {
self.image_key: PIL.Image.fromarray(image),
}
img_h, img_w = image.shape[:2]
output = self.model(
image=PIL.Image.fromarray(image),
prompt=config.prompt,
negative_prompt=config.negative_prompt,
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
@@ -155,7 +145,6 @@ class SD(InpaintModel):
callback=self.callback,
height=img_h,
width=img_w,
**_kwargs
).images[0]
output = (output * 255).round().astype("uint8")
@@ -217,4 +206,7 @@ class SD(InpaintModel):
class SD15(SD):
model_id_or_path = "runwayml/stable-diffusion-inpainting"
image_key = "image"
class SD2(SD):
model_id_or_path = "stabilityai/stable-diffusion-2-inpainting"