add euler sampler
This commit is contained in:
@@ -4,7 +4,8 @@ import PIL.Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler
|
||||
from diffusers import PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \
|
||||
EulerAncestralDiscreteScheduler
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
@@ -98,25 +99,27 @@ class SD(InpaintModel):
|
||||
# image = torch.from_numpy(image).unsqueeze(0).to(self.device)
|
||||
# mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
|
||||
|
||||
scheduler_kwargs = dict(
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
|
||||
if config.sd_sampler == SDSampler.ddim:
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
**scheduler_kwargs,
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
elif config.sd_sampler == SDSampler.pndm:
|
||||
PNDM_kwargs = {
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"beta_end": 0.012,
|
||||
"num_train_timesteps": 1000,
|
||||
"skip_prk_steps": True,
|
||||
}
|
||||
scheduler = PNDMScheduler(**PNDM_kwargs)
|
||||
scheduler = PNDMScheduler(**scheduler_kwargs, skip_prk_steps=True)
|
||||
elif config.sd_sampler == SDSampler.k_lms:
|
||||
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
scheduler = LMSDiscreteScheduler(**scheduler_kwargs)
|
||||
elif config.sd_sampler == SDSampler.k_euler:
|
||||
scheduler = EulerDiscreteScheduler(**scheduler_kwargs)
|
||||
elif config.sd_sampler == SDSampler.k_euler_a:
|
||||
scheduler = EulerAncestralDiscreteScheduler(**scheduler_kwargs)
|
||||
else:
|
||||
raise ValueError(config.sd_sampler)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user