add download command
This commit is contained in:
@@ -8,7 +8,12 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img
|
||||
from lama_cleaner.helper import (
|
||||
load_model,
|
||||
get_cache_path_by_url,
|
||||
norm_img,
|
||||
download_model,
|
||||
)
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.model.utils import (
|
||||
setup_filter,
|
||||
@@ -52,7 +57,7 @@ class ModulatedConv2d(nn.Module):
|
||||
)
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
||||
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
|
||||
self.padding = self.kernel_size // 2
|
||||
self.up = up
|
||||
self.down = down
|
||||
@@ -213,7 +218,7 @@ class DecBlockFirst(nn.Module):
|
||||
super().__init__()
|
||||
self.fc = FullyConnectedLayer(
|
||||
in_features=in_channels * 2,
|
||||
out_features=in_channels * 4 ** 2,
|
||||
out_features=in_channels * 4**2,
|
||||
activation=activation,
|
||||
)
|
||||
self.conv = StyleConv(
|
||||
@@ -312,7 +317,7 @@ class DecBlock(nn.Module):
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
style_dim=style_dim,
|
||||
resolution=2 ** res,
|
||||
resolution=2**res,
|
||||
kernel_size=3,
|
||||
up=2,
|
||||
use_noise=use_noise,
|
||||
@@ -323,7 +328,7 @@ class DecBlock(nn.Module):
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
style_dim=style_dim,
|
||||
resolution=2 ** res,
|
||||
resolution=2**res,
|
||||
kernel_size=3,
|
||||
use_noise=use_noise,
|
||||
activation=activation,
|
||||
@@ -507,7 +512,7 @@ class Discriminator(torch.nn.Module):
|
||||
self.img_channels = img_channels
|
||||
|
||||
resolution_log2 = int(np.log2(img_resolution))
|
||||
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
|
||||
assert img_resolution == 2**resolution_log2 and img_resolution >= 4
|
||||
self.resolution_log2 = resolution_log2
|
||||
|
||||
def nf(stage):
|
||||
@@ -543,7 +548,7 @@ class Discriminator(torch.nn.Module):
|
||||
)
|
||||
self.Dis = nn.Sequential(*Dis)
|
||||
|
||||
self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
|
||||
self.fc0 = FullyConnectedLayer(nf(2) * 4**2, nf(2), activation=activation)
|
||||
self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
|
||||
|
||||
def forward(self, images_in, masks_in, c):
|
||||
@@ -562,7 +567,7 @@ class Discriminator(torch.nn.Module):
|
||||
|
||||
def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512):
|
||||
NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512}
|
||||
return NF[2 ** stage]
|
||||
return NF[2**stage]
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
@@ -659,7 +664,7 @@ class Conv2dLayerPartial(nn.Module):
|
||||
)
|
||||
|
||||
self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size)
|
||||
self.slide_winsize = kernel_size ** 2
|
||||
self.slide_winsize = kernel_size**2
|
||||
self.stride = down
|
||||
self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0
|
||||
|
||||
@@ -715,7 +720,7 @@ class WindowAttention(nn.Module):
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.q = FullyConnectedLayer(in_features=dim, out_features=dim)
|
||||
self.k = FullyConnectedLayer(in_features=dim, out_features=dim)
|
||||
@@ -1211,7 +1216,7 @@ class Encoder(nn.Module):
|
||||
self.resolution = []
|
||||
|
||||
for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16
|
||||
res = 2 ** i
|
||||
res = 2**i
|
||||
self.resolution.append(res)
|
||||
if i == res_log2:
|
||||
block = EncFromRGB(img_channels * 2 + 1, nf(i), activation)
|
||||
@@ -1296,7 +1301,7 @@ class DecBlockFirstV2(nn.Module):
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
style_dim=style_dim,
|
||||
resolution=2 ** res,
|
||||
resolution=2**res,
|
||||
kernel_size=3,
|
||||
use_noise=use_noise,
|
||||
activation=activation,
|
||||
@@ -1341,7 +1346,7 @@ class DecBlock(nn.Module):
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
style_dim=style_dim,
|
||||
resolution=2 ** res,
|
||||
resolution=2**res,
|
||||
kernel_size=3,
|
||||
up=2,
|
||||
use_noise=use_noise,
|
||||
@@ -1352,7 +1357,7 @@ class DecBlock(nn.Module):
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
style_dim=style_dim,
|
||||
resolution=2 ** res,
|
||||
resolution=2**res,
|
||||
kernel_size=3,
|
||||
use_noise=use_noise,
|
||||
activation=activation,
|
||||
@@ -1389,7 +1394,7 @@ class Decoder(nn.Module):
|
||||
for res in range(5, res_log2 + 1):
|
||||
setattr(
|
||||
self,
|
||||
"Dec_%dx%d" % (2 ** res, 2 ** res),
|
||||
"Dec_%dx%d" % (2**res, 2**res),
|
||||
DecBlock(
|
||||
res,
|
||||
nf(res - 1),
|
||||
@@ -1406,7 +1411,7 @@ class Decoder(nn.Module):
|
||||
def forward(self, x, ws, gs, E_features, noise_mode="random"):
|
||||
x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode)
|
||||
for res in range(5, self.res_log2 + 1):
|
||||
block = getattr(self, "Dec_%dx%d" % (2 ** res, 2 ** res))
|
||||
block = getattr(self, "Dec_%dx%d" % (2**res, 2**res))
|
||||
x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode)
|
||||
|
||||
return img
|
||||
@@ -1431,7 +1436,7 @@ class DecStyleBlock(nn.Module):
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
style_dim=style_dim,
|
||||
resolution=2 ** res,
|
||||
resolution=2**res,
|
||||
kernel_size=3,
|
||||
up=2,
|
||||
use_noise=use_noise,
|
||||
@@ -1442,7 +1447,7 @@ class DecStyleBlock(nn.Module):
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
style_dim=style_dim,
|
||||
resolution=2 ** res,
|
||||
resolution=2**res,
|
||||
kernel_size=3,
|
||||
use_noise=use_noise,
|
||||
activation=activation,
|
||||
@@ -1640,7 +1645,7 @@ class SynthesisNet(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
resolution_log2 = int(np.log2(img_resolution))
|
||||
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
|
||||
assert img_resolution == 2**resolution_log2 and img_resolution >= 4
|
||||
|
||||
self.num_layers = resolution_log2 * 2 - 3 * 2
|
||||
self.img_resolution = img_resolution
|
||||
@@ -1781,7 +1786,7 @@ class Discriminator(torch.nn.Module):
|
||||
self.img_channels = img_channels
|
||||
|
||||
resolution_log2 = int(np.log2(img_resolution))
|
||||
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
|
||||
assert img_resolution == 2**resolution_log2 and img_resolution >= 4
|
||||
self.resolution_log2 = resolution_log2
|
||||
|
||||
if cmap_dim == None:
|
||||
@@ -1812,7 +1817,7 @@ class Discriminator(torch.nn.Module):
|
||||
)
|
||||
self.Dis = nn.Sequential(*Dis)
|
||||
|
||||
self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
|
||||
self.fc0 = FullyConnectedLayer(nf(2) * 4**2, nf(2), activation=activation)
|
||||
self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
|
||||
|
||||
# for 64x64
|
||||
@@ -1837,7 +1842,7 @@ class Discriminator(torch.nn.Module):
|
||||
self.Dis_stg1 = nn.Sequential(*Dis_stg1)
|
||||
|
||||
self.fc0_stg1 = FullyConnectedLayer(
|
||||
nf(2) // 2 * 4 ** 2, nf(2) // 2, activation=activation
|
||||
nf(2) // 2 * 4**2, nf(2) // 2, activation=activation
|
||||
)
|
||||
self.fc1_stg1 = FullyConnectedLayer(
|
||||
nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim
|
||||
@@ -1898,6 +1903,10 @@ class MAT(InpaintModel):
|
||||
self.label = torch.zeros([1, self.model.c_dim], device=device).to(self.torch_dtype)
|
||||
# fmt: on
|
||||
|
||||
@staticmethod
|
||||
def download():
|
||||
download_model(MAT_MODEL_URL, MAT_MODEL_MD5)
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL))
|
||||
|
||||
Reference in New Issue
Block a user