From 35b92ba9de44bc9d007e2f1eded635134f356089 Mon Sep 17 00:00:00 2001 From: Qing Date: Sun, 12 Jun 2022 13:14:17 +0800 Subject: [PATCH] ldm add plms sampler --- lama_cleaner/app/src/adapters/inpainting.ts | 1 + .../components/Settings/HDSettingBlock.tsx | 5 + .../components/Settings/ModelSettingBlock.tsx | 20 ++ .../components/Shortcuts/ShortcutsModal.tsx | 2 +- .../src/components/shared/NumberInput.scss | 3 +- lama_cleaner/app/src/store/Atoms.tsx | 4 +- lama_cleaner/model/ddim_sampler.py | 193 +++++++++++++++ lama_cleaner/model/ldm.py | 225 ++---------------- lama_cleaner/model/plms_sampler.py | 225 ++++++++++++++++++ lama_cleaner/schema.py | 6 + lama_cleaner/server.py | 1 + 11 files changed, 478 insertions(+), 207 deletions(-) create mode 100644 lama_cleaner/model/ddim_sampler.py create mode 100644 lama_cleaner/model/plms_sampler.py diff --git a/lama_cleaner/app/src/adapters/inpainting.ts b/lama_cleaner/app/src/adapters/inpainting.ts index fedd6e3..3965d54 100644 --- a/lama_cleaner/app/src/adapters/inpainting.ts +++ b/lama_cleaner/app/src/adapters/inpainting.ts @@ -16,6 +16,7 @@ export default async function inpaint( fd.append('mask', mask) fd.append('ldmSteps', settings.ldmSteps.toString()) + fd.append('ldmSampler', settings.ldmSampler.toString()) fd.append('hdStrategy', settings.hdStrategy) fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString()) fd.append( diff --git a/lama_cleaner/app/src/components/Settings/HDSettingBlock.tsx b/lama_cleaner/app/src/components/Settings/HDSettingBlock.tsx index 565950f..57fed00 100644 --- a/lama_cleaner/app/src/components/Settings/HDSettingBlock.tsx +++ b/lama_cleaner/app/src/components/Settings/HDSettingBlock.tsx @@ -11,6 +11,11 @@ export enum HDStrategy { CROP = 'Crop', } +export enum LDMSampler { + ddim = 'ddim', + plms = 'plms', +} + function HDSettingBlock() { const [setting, setSettingState] = useRecoilState(settingState) diff --git a/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx b/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx index aac6e1b..2d7a1f9 100644 --- a/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx +++ b/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx @@ -2,6 +2,7 @@ import React, { ReactNode } from 'react' import { useRecoilState } from 'recoil' import { settingState } from '../../store/Atoms' import Selector from '../shared/Selector' +import { LDMSampler } from './HDSettingBlock' import NumberInputSetting from './NumberInputSetting' import SettingBlock from './SettingBlock' @@ -19,6 +20,12 @@ function ModelSettingBlock() { }) } + const onLDMSamplerChange = (value: LDMSampler) => { + setSettingState(old => { + return { ...old, ldmSampler: value } + }) + } + const renderModelDesc = ( name: string, paperUrl: string, @@ -65,6 +72,19 @@ function ModelSettingBlock() { }) }} /> + + onLDMSamplerChange(val as LDMSampler)} + /> + } + /> ) } diff --git a/lama_cleaner/app/src/components/Shortcuts/ShortcutsModal.tsx b/lama_cleaner/app/src/components/Shortcuts/ShortcutsModal.tsx index ef8efbb..643c47b 100644 --- a/lama_cleaner/app/src/components/Shortcuts/ShortcutsModal.tsx +++ b/lama_cleaner/app/src/components/Shortcuts/ShortcutsModal.tsx @@ -56,7 +56,7 @@ export default function ShortcutsModal() { /> - + diff --git a/lama_cleaner/app/src/components/shared/NumberInput.scss b/lama_cleaner/app/src/components/shared/NumberInput.scss index d51dc32..673cb98 100644 --- a/lama_cleaner/app/src/components/shared/NumberInput.scss +++ b/lama_cleaner/app/src/components/shared/NumberInput.scss @@ -2,8 +2,9 @@ all: unset; flex: 1 0 auto; border-radius: 0.5rem; - padding: 0.4rem 0.8rem; + padding: 0 0.8rem; outline: 1px solid var(--border-color); + height: 36px; &:focus-visible { outline: 1px solid var(--yellow-accent); diff --git a/lama_cleaner/app/src/store/Atoms.tsx b/lama_cleaner/app/src/store/Atoms.tsx index 01efff2..5317637 100644 --- a/lama_cleaner/app/src/store/Atoms.tsx +++ b/lama_cleaner/app/src/store/Atoms.tsx @@ -1,5 +1,5 @@ import { atom } from 'recoil' -import { HDStrategy } from '../components/Settings/HDSettingBlock' +import { HDStrategy, LDMSampler } from '../components/Settings/HDSettingBlock' import { AIModel } from '../components/Settings/ModelSettingBlock' import { ToastState } from '../components/shared/Toast' @@ -43,6 +43,7 @@ export interface Settings { // For LDM ldmSteps: number + ldmSampler: LDMSampler } export const settingStateDefault = { @@ -50,6 +51,7 @@ export const settingStateDefault = { runInpaintingManually: false, model: AIModel.LAMA, ldmSteps: 50, + ldmSampler: LDMSampler.plms, hdStrategy: HDStrategy.RESIZE, hdStrategyResizeLimit: 2048, hdStrategyCropTrigerSize: 2048, diff --git a/lama_cleaner/model/ddim_sampler.py b/lama_cleaner/model/ddim_sampler.py new file mode 100644 index 0000000..d1e4400 --- /dev/null +++ b/lama_cleaner/model/ddim_sampler.py @@ -0,0 +1,193 @@ +import torch +import numpy as np +from tqdm import tqdm + +from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like + +from loguru import logger + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear"): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + # array([1]) + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000]) + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample(self, steps, conditioning, batch_size, shape): + self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + # samples: 1,3,128,128 + return self.ddim_sampling( + conditioning, + size, + quantize_denoised=False, + ddim_use_original_steps=False, + noise_dropout=0, + temperature=1.0, + ) + + @torch.no_grad() + def ddim_sampling( + self, + cond, + shape, + ddim_use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + ): + device = self.model.betas.device + b = shape[0] + img = torch.randn(shape, device=device, dtype=cond.dtype) + timesteps = ( + self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + ) + + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + logger.info(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + ) + img, _ = outs + + return img + + @torch.no_grad() + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + ): + b, *_, device = *x.shape, x.device + e_t = self.model.apply_model(x, t, c) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: # 没用 + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: # 没用 + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/lama_cleaner/model/ldm.py b/lama_cleaner/model/ldm.py index 2fea7c0..7720092 100644 --- a/lama_cleaner/model/ldm.py +++ b/lama_cleaner/model/ldm.py @@ -5,17 +5,15 @@ import torch from loguru import logger from lama_cleaner.model.base import InpaintModel -from lama_cleaner.schema import Config +from lama_cleaner.model.ddim_sampler import DDIMSampler +from lama_cleaner.model.plms_sampler import PLMSSampler +from lama_cleaner.schema import Config, LDMSampler torch.manual_seed(42) import torch.nn as nn -from tqdm import tqdm from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url from lama_cleaner.model.utils import ( make_beta_schedule, - make_ddim_timesteps, - make_ddim_sampling_parameters, - noise_like, timestep_embedding, ) @@ -94,7 +92,7 @@ class DDPM(nn.Module): self.linear_start = linear_start self.linear_end = linear_end assert ( - alphas_cumprod.shape[0] == self.num_timesteps + alphas_cumprod.shape[0] == self.num_timesteps ), "alphas have to be defined for each timestep" to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device) @@ -120,7 +118,7 @@ class DDPM(nn.Module): # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = (1 - self.v_posterior) * betas * ( - 1.0 - alphas_cumprod_prev + 1.0 - alphas_cumprod_prev ) / (1.0 - alphas_cumprod) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.register_buffer("posterior_variance", to_torch(posterior_variance)) @@ -142,16 +140,16 @@ class DDPM(nn.Module): if self.parameterization == "eps": lvlb_weights = self.betas ** 2 / ( - 2 - * self.posterior_variance - * to_torch(alphas) - * (1 - self.alphas_cumprod) + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) ) elif self.parameterization == "x0": lvlb_weights = ( - 0.5 - * np.sqrt(torch.Tensor(alphas_cumprod)) - / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) ) else: raise NotImplementedError("mu not supported") @@ -221,192 +219,6 @@ class LatentDiffusion(DDPM): return x_recon -class DDIMSampler(object): - def __init__(self, model, schedule="linear"): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - - def register_buffer(self, name, attr): - setattr(self, name, attr) - - def make_schedule( - self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True - ): - self.ddim_timesteps = make_ddim_timesteps( - ddim_discr_method=ddim_discretize, - num_ddim_timesteps=ddim_num_steps, - # array([1]) - num_ddpm_timesteps=self.ddpm_num_timesteps, - verbose=verbose, - ) - alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000]) - assert ( - alphas_cumprod.shape[0] == self.ddpm_num_timesteps - ), "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - - self.register_buffer("betas", to_torch(self.model.betas)) - self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) - self.register_buffer( - "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) - ) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer( - "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) - ) - self.register_buffer( - "sqrt_one_minus_alphas_cumprod", - to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), - ) - self.register_buffer( - "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) - ) - self.register_buffer( - "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) - ) - self.register_buffer( - "sqrt_recipm1_alphas_cumprod", - to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), - ) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( - alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta, - verbose=verbose, - ) - self.register_buffer("ddim_sigmas", ddim_sigmas) - self.register_buffer("ddim_alphas", ddim_alphas) - self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) - self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) - / (1 - self.alphas_cumprod) - * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) - ) - self.register_buffer( - "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps - ) - - @torch.no_grad() - def sample(self, steps, conditioning, batch_size, shape): - self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - - # samples: 1,3,128,128 - return self.ddim_sampling( - conditioning, - size, - quantize_denoised=False, - ddim_use_original_steps=False, - noise_dropout=0, - temperature=1.0, - ) - - @torch.no_grad() - def ddim_sampling( - self, - cond, - shape, - ddim_use_original_steps=False, - quantize_denoised=False, - temperature=1.0, - noise_dropout=0.0, - ): - device = self.model.betas.device - b = shape[0] - img = torch.randn(shape, device=device, dtype=cond.dtype) - timesteps = ( - self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps - ) - - time_range = ( - reversed(range(0, timesteps)) - if ddim_use_original_steps - else np.flip(timesteps) - ) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - logger.info(f"Running DDIM Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - - outs = self.p_sample_ddim( - img, - cond, - ts, - index=index, - use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, - temperature=temperature, - noise_dropout=noise_dropout, - ) - img, _ = outs - - return img - - @torch.no_grad() - def p_sample_ddim( - self, - x, - c, - t, - index, - repeat_noise=False, - use_original_steps=False, - quantize_denoised=False, - temperature=1.0, - noise_dropout=0.0, - ): - b, *_, device = *x.shape, x.device - e_t = self.model.apply_model(x, t, c) - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = ( - self.model.alphas_cumprod_prev - if use_original_steps - else self.ddim_alphas_prev - ) - sqrt_one_minus_alphas = ( - self.model.sqrt_one_minus_alphas_cumprod - if use_original_steps - else self.ddim_sqrt_one_minus_alphas - ) - sigmas = ( - self.model.ddim_sigmas_for_original_num_steps - if use_original_steps - else self.ddim_sigmas - ) - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full( - (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device - ) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: # 没用 - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - # direction pointing to x_t - dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.0: # 没用 - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - def load_jit_model(url, device): model_path = download_model(url) logger.info(f"Load LDM model from: {model_path}") @@ -432,8 +244,7 @@ class LDM(InpaintModel): self.cond_stage_model_decode = self.cond_stage_model_decode.half() self.cond_stage_model_encode = self.cond_stage_model_encode.half() - model = LatentDiffusion(self.diffusion_model, device) - self.sampler = DDIMSampler(model) + self.model = LatentDiffusion(self.diffusion_model, device) @staticmethod def is_downloaded() -> bool: @@ -454,6 +265,13 @@ class LDM(InpaintModel): # image [1,3,512,512] float32 # mask: [1,1,512,512] float32 # masked_image: [1,3,512,512] float32 + if config.ldm_sampler == LDMSampler.ddim: + sampler = DDIMSampler(self.model) + elif config.ldm_sampler == LDMSampler.plms: + sampler = PLMSSampler(self.model) + else: + raise ValueError() + steps = config.ldm_steps image = norm_img(image) mask = norm_img(mask) @@ -465,7 +283,6 @@ class LDM(InpaintModel): mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) masked_image = (1 - mask) * image - image = self._norm(image) mask = self._norm(mask) masked_image = self._norm(masked_image) @@ -476,7 +293,7 @@ class LDM(InpaintModel): c = torch.cat((c, cc), dim=1) # 1,4,128,128 shape = (c.shape[1] - 1,) + c.shape[2:] - samples_ddim = self.sampler.sample( + samples_ddim = sampler.sample( steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape ) torch.cuda.empty_cache() diff --git a/lama_cleaner/model/plms_sampler.py b/lama_cleaner/model/plms_sampler.py new file mode 100644 index 0000000..d9c0416 --- /dev/null +++ b/lama_cleaner/model/plms_sampler.py @@ -0,0 +1,225 @@ +# From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py +import torch +import numpy as np +from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like +from tqdm import tqdm + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + steps, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=False, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + return img + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index 0d6fad8..a66afa1 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -9,8 +9,14 @@ class HDStrategy(str, Enum): CROP = 'Crop' +class LDMSampler(str, Enum): + ddim = 'ddim' + plms = 'plms' + + class Config(BaseModel): ldm_steps: int + ldm_sampler: str hd_strategy: str hd_strategy_crop_margin: int hd_strategy_crop_trigger_size: int diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 64e21e4..79b8a0b 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -93,6 +93,7 @@ def process(): config = Config( ldm_steps=form["ldmSteps"], + ldm_sampler=form["ldmSampler"], hd_strategy=form["hdStrategy"], hd_strategy_crop_margin=form["hdStrategyCropMargin"], hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"],