update sd inpainting pipeline

This commit is contained in:
Qing
2022-10-15 22:32:25 +08:00
parent b92e9d8da6
commit 3c87b050d9
5 changed files with 178 additions and 112 deletions

View File

@@ -4,9 +4,8 @@ import PIL.Image
import cv2
import numpy as np
import torch
from diffusers import PNDMScheduler, DDIMScheduler
from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler
from loguru import logger
from transformers import FeatureExtractionMixin, ImageFeatureExtractionMixin
from lama_cleaner.helper import norm_img
@@ -39,30 +38,6 @@ from lama_cleaner.schema import Config, SDSampler
# mask = torch.from_numpy(mask)
# return mask
class DummyFeatureExtractorOutput:
def __init__(self, pixel_values):
self.pixel_values = pixel_values
def to(self, device):
return self
class DummyFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __call__(self, *args, **kwargs):
return DummyFeatureExtractorOutput(torch.empty(0, 3))
class DummySafetyChecker:
def __init__(self, *args, **kwargs):
pass
def __call__(self, clip_input, images):
return images, False
class SD(InpaintModel):
pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505
min_size = 512
@@ -74,8 +49,7 @@ class SD(InpaintModel):
if kwargs['sd_disable_nsfw']:
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(dict(
feature_extractor=DummyFeatureExtractor(),
safety_checker=DummySafetyChecker(),
safety_checker=None,
))
self.model = StableDiffusionInpaintPipeline.from_pretrained(
@@ -94,7 +68,7 @@ class SD(InpaintModel):
self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu'), non_blocking=True)
self.model.text_encoder = self.model.text_encoder.to(torch.float32, non_blocking=True )
self.callbacks = kwargs.pop("callbacks", None)
self.callback = kwargs.pop("callback", None)
@torch.cuda.amp.autocast()
def forward(self, image, mask, config: Config):
@@ -133,6 +107,8 @@ class SD(InpaintModel):
"skip_prk_steps": True,
}
scheduler = PNDMScheduler(**PNDM_kwargs)
elif config.sd_sampler == SDSampler.k_lms:
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
else:
raise ValueError(config.sd_sampler)
@@ -156,7 +132,7 @@ class SD(InpaintModel):
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
output_type="np.array",
callbacks=self.callbacks,
callback=self.callback,
).images[0]
output = (output * 255).round().astype("uint8")