ldm add plms sampler

This commit is contained in:
Qing
2022-06-12 13:14:17 +08:00
parent 55197f2209
commit 35b92ba9de
11 changed files with 478 additions and 207 deletions

View File

@@ -5,17 +5,15 @@ import torch
from loguru import logger
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config
from lama_cleaner.model.ddim_sampler import DDIMSampler
from lama_cleaner.model.plms_sampler import PLMSSampler
from lama_cleaner.schema import Config, LDMSampler
torch.manual_seed(42)
import torch.nn as nn
from tqdm import tqdm
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
from lama_cleaner.model.utils import (
make_beta_schedule,
make_ddim_timesteps,
make_ddim_sampling_parameters,
noise_like,
timestep_embedding,
)
@@ -94,7 +92,7 @@ class DDPM(nn.Module):
self.linear_start = linear_start
self.linear_end = linear_end
assert (
alphas_cumprod.shape[0] == self.num_timesteps
alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
@@ -120,7 +118,7 @@ class DDPM(nn.Module):
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (
1.0 - alphas_cumprod_prev
1.0 - alphas_cumprod_prev
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer("posterior_variance", to_torch(posterior_variance))
@@ -142,16 +140,16 @@ class DDPM(nn.Module):
if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / (
2
* self.posterior_variance
* to_torch(alphas)
* (1 - self.alphas_cumprod)
2
* self.posterior_variance
* to_torch(alphas)
* (1 - self.alphas_cumprod)
)
elif self.parameterization == "x0":
lvlb_weights = (
0.5
* np.sqrt(torch.Tensor(alphas_cumprod))
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
0.5
* np.sqrt(torch.Tensor(alphas_cumprod))
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
)
else:
raise NotImplementedError("mu not supported")
@@ -221,192 +219,6 @@ class LatentDiffusion(DDPM):
return x_recon
class DDIMSampler(object):
def __init__(self, model, schedule="linear"):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
# array([1])
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
)
@torch.no_grad()
def sample(self, steps, conditioning, batch_size, shape):
self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# samples: 1,3,128,128
return self.ddim_sampling(
conditioning,
size,
quantize_denoised=False,
ddim_use_original_steps=False,
noise_dropout=0,
temperature=1.0,
)
@torch.no_grad()
def ddim_sampling(
self,
cond,
shape,
ddim_use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
):
device = self.model.betas.device
b = shape[0]
img = torch.randn(shape, device=device, dtype=cond.dtype)
timesteps = (
self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
)
time_range = (
reversed(range(0, timesteps))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
)
img, _ = outs
return img
@torch.no_grad()
def p_sample_ddim(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
):
b, *_, device = *x.shape, x.device
e_t = self.model.apply_model(x, t, c)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: # 没用
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0: # 没用
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
def load_jit_model(url, device):
model_path = download_model(url)
logger.info(f"Load LDM model from: {model_path}")
@@ -432,8 +244,7 @@ class LDM(InpaintModel):
self.cond_stage_model_decode = self.cond_stage_model_decode.half()
self.cond_stage_model_encode = self.cond_stage_model_encode.half()
model = LatentDiffusion(self.diffusion_model, device)
self.sampler = DDIMSampler(model)
self.model = LatentDiffusion(self.diffusion_model, device)
@staticmethod
def is_downloaded() -> bool:
@@ -454,6 +265,13 @@ class LDM(InpaintModel):
# image [1,3,512,512] float32
# mask: [1,1,512,512] float32
# masked_image: [1,3,512,512] float32
if config.ldm_sampler == LDMSampler.ddim:
sampler = DDIMSampler(self.model)
elif config.ldm_sampler == LDMSampler.plms:
sampler = PLMSSampler(self.model)
else:
raise ValueError()
steps = config.ldm_steps
image = norm_img(image)
mask = norm_img(mask)
@@ -465,7 +283,6 @@ class LDM(InpaintModel):
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
masked_image = (1 - mask) * image
image = self._norm(image)
mask = self._norm(mask)
masked_image = self._norm(masked_image)
@@ -476,7 +293,7 @@ class LDM(InpaintModel):
c = torch.cat((c, cc), dim=1) # 1,4,128,128
shape = (c.shape[1] - 1,) + c.shape[2:]
samples_ddim = self.sampler.sample(
samples_ddim = sampler.sample(
steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
)
torch.cuda.empty_cache()