wip fp16 mat

This commit is contained in:
Qing
2023-03-25 21:29:13 +08:00
parent 094b3c4f69
commit 7e028c3908

View File

@@ -21,6 +21,7 @@ from lama_cleaner.model.utils import (
MinibatchStdLayer, MinibatchStdLayer,
to_2tuple, to_2tuple,
normalize_2nd_moment, normalize_2nd_moment,
set_seed,
) )
from lama_cleaner.schema import Config from lama_cleaner.schema import Config
@@ -361,6 +362,7 @@ class MappingNet(torch.nn.Module):
activation="lrelu", # Activation function: 'relu', 'lrelu', etc. activation="lrelu", # Activation function: 'relu', 'lrelu', etc.
lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track. w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track.
torch_dtype=torch.float32,
): ):
super().__init__() super().__init__()
self.z_dim = z_dim self.z_dim = z_dim
@@ -369,6 +371,7 @@ class MappingNet(torch.nn.Module):
self.num_ws = num_ws self.num_ws = num_ws
self.num_layers = num_layers self.num_layers = num_layers
self.w_avg_beta = w_avg_beta self.w_avg_beta = w_avg_beta
self.torch_dtype = torch_dtype
if embed_features is None: if embed_features is None:
embed_features = w_dim embed_features = w_dim
@@ -399,14 +402,16 @@ class MappingNet(torch.nn.Module):
def forward( def forward(
self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False
): ):
import ipdb
ipdb.set_trace()
# Embed, normalize, and concat inputs. # Embed, normalize, and concat inputs.
x = None x = None
with torch.autograd.profiler.record_function("input"): if self.z_dim > 0:
if self.z_dim > 0: x = normalize_2nd_moment(z)
x = normalize_2nd_moment(z.to(torch.float32)) if self.c_dim > 0:
if self.c_dim > 0: y = normalize_2nd_moment(self.embed(c))
y = normalize_2nd_moment(self.embed(c.to(torch.float32))) x = torch.cat([x, y], dim=1) if x is not None else y
x = torch.cat([x, y], dim=1) if x is not None else y
# Main layers. # Main layers.
for idx in range(self.num_layers): for idx in range(self.num_layers):
@@ -415,26 +420,21 @@ class MappingNet(torch.nn.Module):
# Update moving average of W. # Update moving average of W.
if self.w_avg_beta is not None and self.training and not skip_w_avg_update: if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
with torch.autograd.profiler.record_function("update_w_avg"): self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
self.w_avg.copy_(
x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)
)
# Broadcast. # Broadcast.
if self.num_ws is not None: if self.num_ws is not None:
with torch.autograd.profiler.record_function("broadcast"): x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
# Apply truncation. # Apply truncation.
if truncation_psi != 1: if truncation_psi != 1:
with torch.autograd.profiler.record_function("truncate"): assert self.w_avg_beta is not None
assert self.w_avg_beta is not None if self.num_ws is None or truncation_cutoff is None:
if self.num_ws is None or truncation_cutoff is None: x = self.w_avg.lerp(x, truncation_psi)
x = self.w_avg.lerp(x, truncation_psi) else:
else: x[:, :truncation_cutoff] = self.w_avg.lerp(
x[:, :truncation_cutoff] = self.w_avg.lerp( x[:, :truncation_cutoff], truncation_psi
x[:, :truncation_cutoff], truncation_psi )
)
return x return x
@@ -713,7 +713,6 @@ class WindowAttention(nn.Module):
attn_drop=0.0, attn_drop=0.0,
proj_drop=0.0, proj_drop=0.0,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.window_size = window_size # Wh, Ww self.window_size = window_size # Wh, Ww
@@ -1058,7 +1057,6 @@ class BasicLayer(nn.Module):
downsample=None, downsample=None,
use_checkpoint=False, use_checkpoint=False,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.input_resolution = input_resolution self.input_resolution = input_resolution
@@ -1882,14 +1880,22 @@ class MAT(InpaintModel):
def init_model(self, device, **kwargs): def init_model(self, device, **kwargs):
seed = 240 # pick up a random number seed = 240 # pick up a random number
random.seed(seed) set_seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3) self.torch_dtype = torch.float16
self.model = load_model(G, MAT_MODEL_URL, device, MAT_MODEL_MD5) G = Generator(
self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(device) # [1., 512] z_dim=512,
self.label = torch.zeros([1, self.model.c_dim], device=device) c_dim=0,
w_dim=512,
img_resolution=512,
img_channels=3,
mapping_kwargs={"torch_dtype": self.torch_dtype},
)
# fmt: off
self.model = load_model(G, MAT_MODEL_URL, device, MAT_MODEL_MD5).to(self.torch_dtype)
self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(self.torch_dtype).to(device)
self.label = torch.zeros([1, self.model.c_dim], device=device).to(self.torch_dtype)
# fmt: on
@staticmethod @staticmethod
def is_downloaded() -> bool: def is_downloaded() -> bool:
@@ -1909,8 +1915,10 @@ class MAT(InpaintModel):
mask = 255 - mask mask = 255 - mask
mask = norm_img(mask) mask = norm_img(mask)
image = torch.from_numpy(image).unsqueeze(0).to(self.device) image = (
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) torch.from_numpy(image).unsqueeze(0).to(self.torch_dtype).to(self.device)
)
mask = torch.from_numpy(mask).unsqueeze(0).to(self.torch_dtype).to(self.device)
output = self.model( output = self.model(
image, mask, self.z, self.label, truncation_psi=1, noise_mode="none" image, mask, self.z, self.label, truncation_psi=1, noise_mode="none"