add euler sampler

This commit is contained in:
Qing
2022-11-15 21:09:51 +08:00
parent 6503d7ec32
commit d7c3149f67
5 changed files with 144 additions and 86 deletions

View File

@@ -4,7 +4,8 @@ import PIL.Image
import cv2
import numpy as np
import torch
from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler
from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \
EulerAncestralDiscreteScheduler
from loguru import logger
from lama_cleaner.model.base import InpaintModel
@@ -98,25 +99,27 @@ class SD(InpaintModel):
# image = torch.from_numpy(image).unsqueeze(0).to(self.device)
# mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
scheduler_kwargs = dict(
beta_schedule="scaled_linear",
beta_start=0.00085,
beta_end=0.012,
num_train_timesteps=1000,
)
if config.sd_sampler == SDSampler.ddim:
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
**scheduler_kwargs,
clip_sample=False,
set_alpha_to_one=False,
)
elif config.sd_sampler == SDSampler.pndm:
PNDM_kwargs = {
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"beta_end": 0.012,
"num_train_timesteps": 1000,
"skip_prk_steps": True,
}
scheduler = PNDMScheduler(**PNDM_kwargs)
scheduler = PNDMScheduler(**scheduler_kwargs, skip_prk_steps=True)
elif config.sd_sampler == SDSampler.k_lms:
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
scheduler = LMSDiscreteScheduler(**scheduler_kwargs)
elif config.sd_sampler == SDSampler.k_euler:
scheduler = EulerDiscreteScheduler(**scheduler_kwargs)
elif config.sd_sampler == SDSampler.k_euler_a:
scheduler = EulerAncestralDiscreteScheduler(**scheduler_kwargs)
else:
raise ValueError(config.sd_sampler)