update sd inpainting pipeline
This commit is contained in:
@@ -4,9 +4,8 @@ import PIL.Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import PNDMScheduler, DDIMScheduler
|
||||
from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler
|
||||
from loguru import logger
|
||||
from transformers import FeatureExtractionMixin, ImageFeatureExtractionMixin
|
||||
|
||||
from lama_cleaner.helper import norm_img
|
||||
|
||||
@@ -39,30 +38,6 @@ from lama_cleaner.schema import Config, SDSampler
|
||||
# mask = torch.from_numpy(mask)
|
||||
# return mask
|
||||
|
||||
class DummyFeatureExtractorOutput:
|
||||
def __init__(self, pixel_values):
|
||||
self.pixel_values = pixel_values
|
||||
|
||||
def to(self, device):
|
||||
return self
|
||||
|
||||
|
||||
class DummyFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return DummyFeatureExtractorOutput(torch.empty(0, 3))
|
||||
|
||||
|
||||
class DummySafetyChecker:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, clip_input, images):
|
||||
return images, False
|
||||
|
||||
|
||||
class SD(InpaintModel):
|
||||
pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505
|
||||
min_size = 512
|
||||
@@ -74,8 +49,7 @@ class SD(InpaintModel):
|
||||
if kwargs['sd_disable_nsfw']:
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(dict(
|
||||
feature_extractor=DummyFeatureExtractor(),
|
||||
safety_checker=DummySafetyChecker(),
|
||||
safety_checker=None,
|
||||
))
|
||||
|
||||
self.model = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
@@ -94,7 +68,7 @@ class SD(InpaintModel):
|
||||
self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu'), non_blocking=True)
|
||||
self.model.text_encoder = self.model.text_encoder.to(torch.float32, non_blocking=True )
|
||||
|
||||
self.callbacks = kwargs.pop("callbacks", None)
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
@torch.cuda.amp.autocast()
|
||||
def forward(self, image, mask, config: Config):
|
||||
@@ -133,6 +107,8 @@ class SD(InpaintModel):
|
||||
"skip_prk_steps": True,
|
||||
}
|
||||
scheduler = PNDMScheduler(**PNDM_kwargs)
|
||||
elif config.sd_sampler == SDSampler.k_lms:
|
||||
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
else:
|
||||
raise ValueError(config.sd_sampler)
|
||||
|
||||
@@ -156,7 +132,7 @@ class SD(InpaintModel):
|
||||
num_inference_steps=config.sd_steps,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np.array",
|
||||
callbacks=self.callbacks,
|
||||
callback=self.callback,
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
|
||||
Reference in New Issue
Block a user