add download command

This commit is contained in:
Qing
2023-11-16 21:12:06 +08:00
parent 20e660aa4a
commit 1d145d1cd6
17 changed files with 233 additions and 67 deletions

View File

@@ -8,7 +8,12 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img
from lama_cleaner.helper import (
load_model,
get_cache_path_by_url,
norm_img,
download_model,
)
from lama_cleaner.model.base import InpaintModel
from lama_cleaner.model.utils import (
setup_filter,
@@ -52,7 +57,7 @@ class ModulatedConv2d(nn.Module):
)
self.out_channels = out_channels
self.kernel_size = kernel_size
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
self.padding = self.kernel_size // 2
self.up = up
self.down = down
@@ -213,7 +218,7 @@ class DecBlockFirst(nn.Module):
super().__init__()
self.fc = FullyConnectedLayer(
in_features=in_channels * 2,
out_features=in_channels * 4 ** 2,
out_features=in_channels * 4**2,
activation=activation,
)
self.conv = StyleConv(
@@ -312,7 +317,7 @@ class DecBlock(nn.Module):
in_channels=in_channels,
out_channels=out_channels,
style_dim=style_dim,
resolution=2 ** res,
resolution=2**res,
kernel_size=3,
up=2,
use_noise=use_noise,
@@ -323,7 +328,7 @@ class DecBlock(nn.Module):
in_channels=out_channels,
out_channels=out_channels,
style_dim=style_dim,
resolution=2 ** res,
resolution=2**res,
kernel_size=3,
use_noise=use_noise,
activation=activation,
@@ -507,7 +512,7 @@ class Discriminator(torch.nn.Module):
self.img_channels = img_channels
resolution_log2 = int(np.log2(img_resolution))
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
assert img_resolution == 2**resolution_log2 and img_resolution >= 4
self.resolution_log2 = resolution_log2
def nf(stage):
@@ -543,7 +548,7 @@ class Discriminator(torch.nn.Module):
)
self.Dis = nn.Sequential(*Dis)
self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
self.fc0 = FullyConnectedLayer(nf(2) * 4**2, nf(2), activation=activation)
self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
def forward(self, images_in, masks_in, c):
@@ -562,7 +567,7 @@ class Discriminator(torch.nn.Module):
def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512):
NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512}
return NF[2 ** stage]
return NF[2**stage]
class Mlp(nn.Module):
@@ -659,7 +664,7 @@ class Conv2dLayerPartial(nn.Module):
)
self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size)
self.slide_winsize = kernel_size ** 2
self.slide_winsize = kernel_size**2
self.stride = down
self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0
@@ -715,7 +720,7 @@ class WindowAttention(nn.Module):
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.scale = qk_scale or head_dim**-0.5
self.q = FullyConnectedLayer(in_features=dim, out_features=dim)
self.k = FullyConnectedLayer(in_features=dim, out_features=dim)
@@ -1211,7 +1216,7 @@ class Encoder(nn.Module):
self.resolution = []
for idx, i in enumerate(range(res_log2, 3, -1)): # from input size to 16x16
res = 2 ** i
res = 2**i
self.resolution.append(res)
if i == res_log2:
block = EncFromRGB(img_channels * 2 + 1, nf(i), activation)
@@ -1296,7 +1301,7 @@ class DecBlockFirstV2(nn.Module):
in_channels=in_channels,
out_channels=out_channels,
style_dim=style_dim,
resolution=2 ** res,
resolution=2**res,
kernel_size=3,
use_noise=use_noise,
activation=activation,
@@ -1341,7 +1346,7 @@ class DecBlock(nn.Module):
in_channels=in_channels,
out_channels=out_channels,
style_dim=style_dim,
resolution=2 ** res,
resolution=2**res,
kernel_size=3,
up=2,
use_noise=use_noise,
@@ -1352,7 +1357,7 @@ class DecBlock(nn.Module):
in_channels=out_channels,
out_channels=out_channels,
style_dim=style_dim,
resolution=2 ** res,
resolution=2**res,
kernel_size=3,
use_noise=use_noise,
activation=activation,
@@ -1389,7 +1394,7 @@ class Decoder(nn.Module):
for res in range(5, res_log2 + 1):
setattr(
self,
"Dec_%dx%d" % (2 ** res, 2 ** res),
"Dec_%dx%d" % (2**res, 2**res),
DecBlock(
res,
nf(res - 1),
@@ -1406,7 +1411,7 @@ class Decoder(nn.Module):
def forward(self, x, ws, gs, E_features, noise_mode="random"):
x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode)
for res in range(5, self.res_log2 + 1):
block = getattr(self, "Dec_%dx%d" % (2 ** res, 2 ** res))
block = getattr(self, "Dec_%dx%d" % (2**res, 2**res))
x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode)
return img
@@ -1431,7 +1436,7 @@ class DecStyleBlock(nn.Module):
in_channels=in_channels,
out_channels=out_channels,
style_dim=style_dim,
resolution=2 ** res,
resolution=2**res,
kernel_size=3,
up=2,
use_noise=use_noise,
@@ -1442,7 +1447,7 @@ class DecStyleBlock(nn.Module):
in_channels=out_channels,
out_channels=out_channels,
style_dim=style_dim,
resolution=2 ** res,
resolution=2**res,
kernel_size=3,
use_noise=use_noise,
activation=activation,
@@ -1640,7 +1645,7 @@ class SynthesisNet(nn.Module):
):
super().__init__()
resolution_log2 = int(np.log2(img_resolution))
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
assert img_resolution == 2**resolution_log2 and img_resolution >= 4
self.num_layers = resolution_log2 * 2 - 3 * 2
self.img_resolution = img_resolution
@@ -1781,7 +1786,7 @@ class Discriminator(torch.nn.Module):
self.img_channels = img_channels
resolution_log2 = int(np.log2(img_resolution))
assert img_resolution == 2 ** resolution_log2 and img_resolution >= 4
assert img_resolution == 2**resolution_log2 and img_resolution >= 4
self.resolution_log2 = resolution_log2
if cmap_dim == None:
@@ -1812,7 +1817,7 @@ class Discriminator(torch.nn.Module):
)
self.Dis = nn.Sequential(*Dis)
self.fc0 = FullyConnectedLayer(nf(2) * 4 ** 2, nf(2), activation=activation)
self.fc0 = FullyConnectedLayer(nf(2) * 4**2, nf(2), activation=activation)
self.fc1 = FullyConnectedLayer(nf(2), 1 if cmap_dim == 0 else cmap_dim)
# for 64x64
@@ -1837,7 +1842,7 @@ class Discriminator(torch.nn.Module):
self.Dis_stg1 = nn.Sequential(*Dis_stg1)
self.fc0_stg1 = FullyConnectedLayer(
nf(2) // 2 * 4 ** 2, nf(2) // 2, activation=activation
nf(2) // 2 * 4**2, nf(2) // 2, activation=activation
)
self.fc1_stg1 = FullyConnectedLayer(
nf(2) // 2, 1 if cmap_dim == 0 else cmap_dim
@@ -1898,6 +1903,10 @@ class MAT(InpaintModel):
self.label = torch.zeros([1, self.model.c_dim], device=device).to(self.torch_dtype)
# fmt: on
@staticmethod
def download():
download_model(MAT_MODEL_URL, MAT_MODEL_MD5)
@staticmethod
def is_downloaded() -> bool:
return os.path.exists(get_cache_path_by_url(MAT_MODEL_URL))