get samplers from backend
This commit is contained in:
@@ -35,9 +35,7 @@ class Kandinsky(DiffusionInpaintModel):
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
scheduler_config = self.model.scheduler.config
|
||||
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
|
||||
self.model.scheduler = scheduler
|
||||
self.set_scheduler(config)
|
||||
|
||||
generator = torch.manual_seed(config.sd_seed)
|
||||
mask = mask.astype(np.float32) / 255
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import gc
|
||||
import math
|
||||
import random
|
||||
@@ -18,10 +19,13 @@ from diffusers import (
|
||||
DPMSolverMultistepScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
LCMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from huggingface_hub.utils import RevisionNotFoundError
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from loguru import logger
|
||||
from requests import HTTPError
|
||||
|
||||
from lama_cleaner.schema import SDSampler
|
||||
from torch import conv2d, conv_transpose2d
|
||||
@@ -930,22 +934,41 @@ def set_seed(seed: int):
|
||||
|
||||
|
||||
def get_scheduler(sd_sampler, scheduler_config):
|
||||
if sd_sampler == SDSampler.ddim:
|
||||
return DDIMScheduler.from_config(scheduler_config)
|
||||
elif sd_sampler == SDSampler.pndm:
|
||||
return PNDMScheduler.from_config(scheduler_config)
|
||||
elif sd_sampler == SDSampler.k_lms:
|
||||
return LMSDiscreteScheduler.from_config(scheduler_config)
|
||||
elif sd_sampler == SDSampler.k_euler:
|
||||
return EulerDiscreteScheduler.from_config(scheduler_config)
|
||||
elif sd_sampler == SDSampler.k_euler_a:
|
||||
return EulerAncestralDiscreteScheduler.from_config(scheduler_config)
|
||||
elif sd_sampler == SDSampler.dpm_plus_plus:
|
||||
return DPMSolverMultistepScheduler.from_config(scheduler_config)
|
||||
elif sd_sampler == SDSampler.uni_pc:
|
||||
return UniPCMultistepScheduler.from_config(scheduler_config)
|
||||
elif sd_sampler == SDSampler.lcm:
|
||||
return LCMScheduler.from_config(scheduler_config)
|
||||
# https://github.com/huggingface/diffusers/issues/4167
|
||||
keys_to_pop = ["use_karras_sigmas", "algorithm_type"]
|
||||
scheduler_config = dict(scheduler_config)
|
||||
for it in keys_to_pop:
|
||||
scheduler_config.pop(it, None)
|
||||
|
||||
# fmt: off
|
||||
samplers = {
|
||||
SDSampler.dpm_plus_plus_2m: [DPMSolverMultistepScheduler],
|
||||
SDSampler.dpm_plus_plus_2m_karras: [DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)],
|
||||
SDSampler.dpm_plus_plus_2m_sde: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++")],
|
||||
SDSampler.dpm_plus_plus_2m_sde_karras: [DPMSolverMultistepScheduler, dict(algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)],
|
||||
SDSampler.dpm_plus_plus_sde: [DPMSolverSinglestepScheduler],
|
||||
SDSampler.dpm_plus_plus_sde_karras: [DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)],
|
||||
SDSampler.dpm2: [KDPM2DiscreteScheduler],
|
||||
SDSampler.dpm2_karras: [KDPM2DiscreteScheduler, dict(use_karras_sigmas=True)],
|
||||
SDSampler.dpm2_a: [KDPM2AncestralDiscreteScheduler],
|
||||
SDSampler.dpm2_a_karras: [KDPM2AncestralDiscreteScheduler, dict(use_karras_sigmas=True)],
|
||||
SDSampler.euler: [EulerDiscreteScheduler],
|
||||
SDSampler.euler_a: [EulerAncestralDiscreteScheduler],
|
||||
SDSampler.heun: [HeunDiscreteScheduler],
|
||||
SDSampler.lms: [LMSDiscreteScheduler],
|
||||
SDSampler.lms_karras: [LMSDiscreteScheduler, dict(use_karras_sigmas=True)],
|
||||
SDSampler.ddim: [DDIMScheduler],
|
||||
SDSampler.pndm: [PNDMScheduler],
|
||||
SDSampler.uni_pc: [UniPCMultistepScheduler],
|
||||
SDSampler.lcm: [LCMScheduler],
|
||||
}
|
||||
# fmt: on
|
||||
if sd_sampler in samplers:
|
||||
if len(samplers[sd_sampler]) == 2:
|
||||
scheduler_cls, kwargs = samplers[sd_sampler]
|
||||
else:
|
||||
scheduler_cls, kwargs = samplers[sd_sampler][0], {}
|
||||
return scheduler_cls.from_config(scheduler_config, **kwargs)
|
||||
else:
|
||||
raise ValueError(sd_sampler)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user