add controlnet inpainting

This commit is contained in:
Qing
2023-03-19 22:40:23 +08:00
parent 61928c9861
commit 5f4c62ac18
17 changed files with 1197 additions and 186 deletions

View File

@@ -7,17 +7,35 @@ import numpy as np
import collections
from itertools import repeat
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
UniPCMultistepScheduler,
)
from lama_cleaner.schema import SDSampler
from torch import conv2d, conv_transpose2d
def make_beta_schedule(device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
def make_beta_schedule(
device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
):
if schedule == "linear":
betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
torch.linspace(
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
)
** 2
)
elif schedule == "cosine":
timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s).to(device)
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
).to(device)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2).to(device)
alphas = alphas / alphas[0]
@@ -25,9 +43,14 @@ def make_beta_schedule(device, schedule, n_timestep, linear_start=1e-4, linear_e
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64
)
elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
betas = (
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
** 0.5
)
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
@@ -39,33 +62,47 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
sigmas = eta * np.sqrt(
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
)
if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
print(f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
print(
f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
)
print(
f"For the chosen value of eta, which is {eta}, "
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
)
return sigmas, alphas, alphas_prev
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
if ddim_discr_method == 'uniform':
def make_ddim_timesteps(
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
):
if ddim_discr_method == "uniform":
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
elif ddim_discr_method == "quad":
ddim_timesteps = (
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
).astype(int)
else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1
if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}')
print(f"Selected timesteps for ddim sampler: {steps_out}")
return steps_out
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
shape[0], *((1,) * (len(shape) - 1))
)
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
@@ -81,7 +118,9 @@ def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=Fal
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=device)
args = timesteps[:, None].float() * freqs[None]
@@ -115,9 +154,8 @@ class EasyDict(dict):
del self[name]
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
"""
def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None):
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops."""
assert isinstance(x, torch.Tensor)
assert clamp is None or clamp >= 0
spec = activation_funcs[act]
@@ -147,7 +185,9 @@ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=N
return x
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='ref'):
def bias_act(
x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None, impl="ref"
):
r"""Fused bias and activation function.
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
@@ -178,8 +218,10 @@ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None,
Tensor of the same shape and datatype as `x`.
"""
assert isinstance(x, torch.Tensor)
assert impl in ['ref', 'cuda']
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
assert impl in ["ref", "cuda"]
return _bias_act_ref(
x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp
)
def _get_filter_size(f):
@@ -223,7 +265,14 @@ def _parse_padding(padding):
return padx0, padx1, pady0, pady1
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
def setup_filter(
f,
device=torch.device("cpu"),
normalize=True,
flip_filter=False,
gain=1,
separable=None,
):
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
Args:
@@ -255,7 +304,7 @@ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=Fals
# Separable?
if separable is None:
separable = (f.ndim == 1 and f.numel() >= 8)
separable = f.ndim == 1 and f.numel() >= 8
if f.ndim == 1 and not separable:
f = f.ger(f)
assert f.ndim == (1 if separable else 2)
@@ -282,27 +331,82 @@ def _ntuple(n):
to_2tuple = _ntuple(2)
activation_funcs = {
'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2,
ref='y', has_2nd_grad=False),
'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2,
def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y',
has_2nd_grad=True),
'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y',
has_2nd_grad=True),
'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y',
has_2nd_grad=True),
'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y',
has_2nd_grad=True),
'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8,
ref='y', has_2nd_grad=True),
'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x',
has_2nd_grad=True),
"linear": EasyDict(
func=lambda x, **_: x,
def_alpha=0,
def_gain=1,
cuda_idx=1,
ref="",
has_2nd_grad=False,
),
"relu": EasyDict(
func=lambda x, **_: torch.nn.functional.relu(x),
def_alpha=0,
def_gain=np.sqrt(2),
cuda_idx=2,
ref="y",
has_2nd_grad=False,
),
"lrelu": EasyDict(
func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
def_alpha=0.2,
def_gain=np.sqrt(2),
cuda_idx=3,
ref="y",
has_2nd_grad=False,
),
"tanh": EasyDict(
func=lambda x, **_: torch.tanh(x),
def_alpha=0,
def_gain=1,
cuda_idx=4,
ref="y",
has_2nd_grad=True,
),
"sigmoid": EasyDict(
func=lambda x, **_: torch.sigmoid(x),
def_alpha=0,
def_gain=1,
cuda_idx=5,
ref="y",
has_2nd_grad=True,
),
"elu": EasyDict(
func=lambda x, **_: torch.nn.functional.elu(x),
def_alpha=0,
def_gain=1,
cuda_idx=6,
ref="y",
has_2nd_grad=True,
),
"selu": EasyDict(
func=lambda x, **_: torch.nn.functional.selu(x),
def_alpha=0,
def_gain=1,
cuda_idx=7,
ref="y",
has_2nd_grad=True,
),
"softplus": EasyDict(
func=lambda x, **_: torch.nn.functional.softplus(x),
def_alpha=0,
def_gain=1,
cuda_idx=8,
ref="y",
has_2nd_grad=True,
),
"swish": EasyDict(
func=lambda x, **_: torch.sigmoid(x) * x,
def_alpha=0,
def_gain=np.sqrt(2),
cuda_idx=9,
ref="x",
has_2nd_grad=True,
),
}
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"):
r"""Pad, upsample, filter, and downsample a batch of 2D images.
Performs the following sequence of operations for each channel:
@@ -344,12 +448,13 @@ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cu
"""
# assert isinstance(x, torch.Tensor)
# assert impl in ['ref', 'cuda']
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
return _upfirdn2d_ref(
x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain
)
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
"""
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops."""
# Validate arguments.
assert isinstance(x, torch.Tensor) and x.ndim == 4
if f is None:
@@ -372,8 +477,15 @@ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
# Pad or crop.
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)]
x = torch.nn.functional.pad(
x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]
)
x = x[
:,
:,
max(-pady0, 0) : x.shape[2] - max(-pady1, 0),
max(-padx0, 0) : x.shape[3] - max(-padx1, 0),
]
# Setup filter.
f = f * (gain ** (f.ndim / 2))
@@ -394,7 +506,7 @@ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
return x
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
r"""Downsample a batch of 2D images using the given 2D FIR filter.
By default, the result is padded so that its shape is a fraction of the input.
@@ -431,10 +543,12 @@ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'
pady0 + (fh - downy + 1) // 2,
pady1 + (fh - downy) // 2,
]
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
return upfirdn2d(
x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl
)
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
r"""Upsample a batch of 2D images using the given 2D FIR filter.
By default, the result is padded so that its shape is a multiple of the input.
@@ -471,7 +585,15 @@ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
pady0 + (fh + upy - 1) // 2,
pady1 + (fh - upy) // 2,
]
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain * upx * upy, impl=impl)
return upfirdn2d(
x,
f,
up=up,
padding=p,
flip_filter=flip_filter,
gain=gain * upx * upy,
impl=impl,
)
class MinibatchStdLayer(torch.nn.Module):
@@ -482,13 +604,17 @@ class MinibatchStdLayer(torch.nn.Module):
def forward(self, x):
N, C, H, W = x.shape
G = torch.min(torch.as_tensor(self.group_size),
torch.as_tensor(N)) if self.group_size is not None else N
G = (
torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N))
if self.group_size is not None
else N
)
F = self.num_channels
c = C // F
y = x.reshape(G, -1, F, c, H,
W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
y = x.reshape(
G, -1, F, c, H, W
) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
@@ -500,17 +626,24 @@ class MinibatchStdLayer(torch.nn.Module):
class FullyConnectedLayer(torch.nn.Module):
def __init__(self,
in_features, # Number of input features.
out_features, # Number of output features.
bias=True, # Apply additive bias before the activation function?
activation='linear', # Activation function: 'relu', 'lrelu', etc.
lr_multiplier=1, # Learning rate multiplier.
bias_init=0, # Initial value for the additive bias.
):
def __init__(
self,
in_features, # Number of input features.
out_features, # Number of output features.
bias=True, # Apply additive bias before the activation function?
activation="linear", # Activation function: 'relu', 'lrelu', etc.
lr_multiplier=1, # Learning rate multiplier.
bias_init=0, # Initial value for the additive bias.
):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
self.weight = torch.nn.Parameter(
torch.randn([out_features, in_features]) / lr_multiplier
)
self.bias = (
torch.nn.Parameter(torch.full([out_features], np.float32(bias_init)))
if bias
else None
)
self.activation = activation
self.weight_gain = lr_multiplier / np.sqrt(in_features)
@@ -522,7 +655,7 @@ class FullyConnectedLayer(torch.nn.Module):
if b is not None and self.bias_gain != 1:
b = b * self.bias_gain
if self.activation == 'linear' and b is not None:
if self.activation == "linear" and b is not None:
# out = torch.addmm(b.unsqueeze(0), x, w.t())
x = x.matmul(w.t())
out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)])
@@ -532,22 +665,33 @@ class FullyConnectedLayer(torch.nn.Module):
return out
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
"""
def _conv2d_wrapper(
x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True
):
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations."""
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
# Flip weight if requested.
if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
if (
not flip_weight
): # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
w = w.flip([2, 3])
# Workaround performance pitfall in cuDNN 8.0.5, triggered when using
# 1x1 kernel + memory_format=channels_last + less than 64 channels.
if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
if (
kw == 1
and kh == 1
and stride == 1
and padding in [0, [0, 0], (0, 0)]
and not transpose
):
if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
if out_channels <= 4 and groups == 1:
in_shape = x.shape
x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
x = w.squeeze(3).squeeze(2) @ x.reshape(
[in_shape[0], in_channels_per_group, -1]
)
x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
else:
x = x.to(memory_format=torch.contiguous_format)
@@ -560,7 +704,9 @@ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_w
return op(x, w, stride=stride, padding=padding, groups=groups)
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
def conv2d_resample(
x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False
):
r"""2D convolution with optional up/downsampling.
Padding is performed only once at the beginning, not between the operations.
@@ -587,7 +733,9 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
# Validate arguments.
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
assert f is None or (
isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32
)
assert isinstance(up, int) and (up >= 1)
assert isinstance(down, int) and (down >= 1)
# assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}"
@@ -610,20 +758,31 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
if kw == 1 and kh == 1 and (down > 1 and up == 1):
x = upfirdn2d(x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
x = upfirdn2d(
x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter
)
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
return x
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
if kw == 1 and kh == 1 and (up > 1 and down == 1):
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
x = upfirdn2d(x=x, f=f, up=up, padding=[px0, px1, py0, py1], gain=up ** 2, flip_filter=flip_filter)
x = upfirdn2d(
x=x,
f=f,
up=up,
padding=[px0, px1, py0, py1],
gain=up**2,
flip_filter=flip_filter,
)
return x
# Fast path: downsampling only => use strided convolution.
if down > 1 and up == 1:
x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
x = _conv2d_wrapper(
x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight
)
return x
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
@@ -633,17 +792,31 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
else:
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
w = w.transpose(1, 2)
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
w = w.reshape(
groups * in_channels_per_group, out_channels // groups, kh, kw
)
px0 -= kw - 1
px1 -= kw - up
py0 -= kh - 1
py1 -= kh - up
pxt = max(min(-px0, -px1), 0)
pyt = max(min(-py0, -py1), 0)
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt, pxt], groups=groups, transpose=True,
flip_weight=(not flip_weight))
x = upfirdn2d(x=x, f=f, padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], gain=up ** 2,
flip_filter=flip_filter)
x = _conv2d_wrapper(
x=x,
w=w,
stride=up,
padding=[pyt, pxt],
groups=groups,
transpose=True,
flip_weight=(not flip_weight),
)
x = upfirdn2d(
x=x,
f=f,
padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt],
gain=up**2,
flip_filter=flip_filter,
)
if down > 1:
x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
return x
@@ -651,11 +824,19 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
if up == 1 and down == 1:
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
return _conv2d_wrapper(x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight)
return _conv2d_wrapper(
x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight
)
# Fallback: Generic reference implementation.
x = upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0, px1, py0, py1], gain=up ** 2,
flip_filter=flip_filter)
x = upfirdn2d(
x=x,
f=(f if up > 1 else None),
up=up,
padding=[px0, px1, py0, py1],
gain=up**2,
flip_filter=flip_filter,
)
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
if down > 1:
x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
@@ -663,50 +844,68 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
class Conv2dLayer(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels.
out_channels, # Number of output channels.
kernel_size, # Width and height of the convolution kernel.
bias=True, # Apply additive bias before the activation function?
activation='linear', # Activation function: 'relu', 'lrelu', etc.
up=1, # Integer upsampling factor.
down=1, # Integer downsampling factor.
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
channels_last=False, # Expect the input to have memory_format=channels_last?
trainable=True, # Update the weights of this layer during training?
):
def __init__(
self,
in_channels, # Number of input channels.
out_channels, # Number of output channels.
kernel_size, # Width and height of the convolution kernel.
bias=True, # Apply additive bias before the activation function?
activation="linear", # Activation function: 'relu', 'lrelu', etc.
up=1, # Integer upsampling factor.
down=1, # Integer downsampling factor.
resample_filter=[
1,
3,
3,
1,
], # Low-pass filter to apply when resampling activations.
conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
channels_last=False, # Expect the input to have memory_format=channels_last?
trainable=True, # Update the weights of this layer during training?
):
super().__init__()
self.activation = activation
self.up = up
self.down = down
self.register_buffer('resample_filter', setup_filter(resample_filter))
self.register_buffer("resample_filter", setup_filter(resample_filter))
self.conv_clamp = conv_clamp
self.padding = kernel_size // 2
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
self.act_gain = activation_funcs[activation].def_gain
memory_format = torch.channels_last if channels_last else torch.contiguous_format
weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
memory_format = (
torch.channels_last if channels_last else torch.contiguous_format
)
weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
memory_format=memory_format
)
bias = torch.zeros([out_channels]) if bias else None
if trainable:
self.weight = torch.nn.Parameter(weight)
self.bias = torch.nn.Parameter(bias) if bias is not None else None
else:
self.register_buffer('weight', weight)
self.register_buffer("weight", weight)
if bias is not None:
self.register_buffer('bias', bias)
self.register_buffer("bias", bias)
else:
self.bias = None
def forward(self, x, gain=1):
w = self.weight * self.weight_gain
x = conv2d_resample(x=x, w=w, f=self.resample_filter, up=self.up, down=self.down,
padding=self.padding)
x = conv2d_resample(
x=x,
w=w,
f=self.resample_filter,
up=self.up,
down=self.down,
padding=self.padding,
)
act_gain = self.act_gain * gain
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp)
out = bias_act(
x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp
)
return out
@@ -721,3 +920,22 @@ def set_seed(seed: int):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_scheduler(sd_sampler, scheduler_config):
if sd_sampler == SDSampler.ddim:
return DDIMScheduler.from_config(scheduler_config)
elif sd_sampler == SDSampler.pndm:
return PNDMScheduler.from_config(scheduler_config)
elif sd_sampler == SDSampler.k_lms:
return LMSDiscreteScheduler.from_config(scheduler_config)
elif sd_sampler == SDSampler.k_euler:
return EulerDiscreteScheduler.from_config(scheduler_config)
elif sd_sampler == SDSampler.k_euler_a:
return EulerAncestralDiscreteScheduler.from_config(scheduler_config)
elif sd_sampler == SDSampler.dpm_plus_plus:
return DPMSolverMultistepScheduler.from_config(scheduler_config)
elif sd_sampler == SDSampler.uni_pc:
return UniPCMultistepScheduler.from_config(scheduler_config)
else:
raise ValueError(sd_sampler)