sd make change sampler work

This commit is contained in:
Qing
2022-09-22 12:38:32 +08:00
parent 047474ab84
commit e1fb0030d1
3 changed files with 46 additions and 19 deletions

View File

@@ -4,12 +4,13 @@ import PIL.Image
import cv2
import numpy as np
import torch
from diffusers import PNDMScheduler, DDIMScheduler
from loguru import logger
from lama_cleaner.helper import norm_img
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
from lama_cleaner.schema import Config, SDSampler
#
@@ -43,13 +44,12 @@ class SD(InpaintModel):
min_size = 512
def init_model(self, device: torch.device, **kwargs):
# return
from .sd_pipeline import StableDiffusionInpaintPipeline
self.model = StableDiffusionInpaintPipeline.from_pretrained(
self.model_id_or_path,
revision="fp16",
torch_dtype=torch.float16,
revision="fp16" if torch.cuda.is_available() else 'main',
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_auth_token=kwargs["hf_access_token"],
)
# https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
@@ -59,7 +59,6 @@ class SD(InpaintModel):
@torch.cuda.amp.autocast()
def forward(self, image, mask, config: Config):
# return image
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
@@ -76,9 +75,30 @@ class SD(InpaintModel):
#
# image = torch.from_numpy(image).unsqueeze(0).to(self.device)
# mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
# import time
# time.sleep(2)
# return image
if config.sd_sampler == SDSampler.ddim:
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
elif config.sd_sampler == SDSampler.pndm:
PNDM_kwargs = {
"tensor_format": "pt",
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"beta_end": 0.012,
"num_train_timesteps": 1000,
"skip_prk_steps": True
}
scheduler = PNDMScheduler(**PNDM_kwargs)
else:
raise ValueError(config.sd_sampler)
self.model.scheduler = scheduler
seed = config.sd_seed
random.seed(seed)
np.random.seed(seed)