From 7e028c39082cef45fcb8ce23cb871c121c763dd6 Mon Sep 17 00:00:00 2001 From: Qing Date: Sat, 25 Mar 2023 21:29:13 +0800 Subject: [PATCH] wip fp16 mat --- lama_cleaner/model/mat.py | 70 ++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 31 deletions(-) diff --git a/lama_cleaner/model/mat.py b/lama_cleaner/model/mat.py index 5c6de9c..ec6aed4 100644 --- a/lama_cleaner/model/mat.py +++ b/lama_cleaner/model/mat.py @@ -21,6 +21,7 @@ from lama_cleaner.model.utils import ( MinibatchStdLayer, to_2tuple, normalize_2nd_moment, + set_seed, ) from lama_cleaner.schema import Config @@ -361,6 +362,7 @@ class MappingNet(torch.nn.Module): activation="lrelu", # Activation function: 'relu', 'lrelu', etc. lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track. + torch_dtype=torch.float32, ): super().__init__() self.z_dim = z_dim @@ -369,6 +371,7 @@ class MappingNet(torch.nn.Module): self.num_ws = num_ws self.num_layers = num_layers self.w_avg_beta = w_avg_beta + self.torch_dtype = torch_dtype if embed_features is None: embed_features = w_dim @@ -399,14 +402,16 @@ class MappingNet(torch.nn.Module): def forward( self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False ): + import ipdb + + ipdb.set_trace() # Embed, normalize, and concat inputs. x = None - with torch.autograd.profiler.record_function("input"): - if self.z_dim > 0: - x = normalize_2nd_moment(z.to(torch.float32)) - if self.c_dim > 0: - y = normalize_2nd_moment(self.embed(c.to(torch.float32))) - x = torch.cat([x, y], dim=1) if x is not None else y + if self.z_dim > 0: + x = normalize_2nd_moment(z) + if self.c_dim > 0: + y = normalize_2nd_moment(self.embed(c)) + x = torch.cat([x, y], dim=1) if x is not None else y # Main layers. for idx in range(self.num_layers): @@ -415,26 +420,21 @@ class MappingNet(torch.nn.Module): # Update moving average of W. if self.w_avg_beta is not None and self.training and not skip_w_avg_update: - with torch.autograd.profiler.record_function("update_w_avg"): - self.w_avg.copy_( - x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta) - ) + self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) # Broadcast. if self.num_ws is not None: - with torch.autograd.profiler.record_function("broadcast"): - x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) # Apply truncation. if truncation_psi != 1: - with torch.autograd.profiler.record_function("truncate"): - assert self.w_avg_beta is not None - if self.num_ws is None or truncation_cutoff is None: - x = self.w_avg.lerp(x, truncation_psi) - else: - x[:, :truncation_cutoff] = self.w_avg.lerp( - x[:, :truncation_cutoff], truncation_psi - ) + assert self.w_avg_beta is not None + if self.num_ws is None or truncation_cutoff is None: + x = self.w_avg.lerp(x, truncation_psi) + else: + x[:, :truncation_cutoff] = self.w_avg.lerp( + x[:, :truncation_cutoff], truncation_psi + ) return x @@ -713,7 +713,6 @@ class WindowAttention(nn.Module): attn_drop=0.0, proj_drop=0.0, ): - super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww @@ -1058,7 +1057,6 @@ class BasicLayer(nn.Module): downsample=None, use_checkpoint=False, ): - super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -1882,14 +1880,22 @@ class MAT(InpaintModel): def init_model(self, device, **kwargs): seed = 240 # pick up a random number - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) + set_seed(seed) - G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3) - self.model = load_model(G, MAT_MODEL_URL, device, MAT_MODEL_MD5) - self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(device) # [1., 512] - self.label = torch.zeros([1, self.model.c_dim], device=device) + self.torch_dtype = torch.float16 + G = Generator( + z_dim=512, + c_dim=0, + w_dim=512, + img_resolution=512, + img_channels=3, + mapping_kwargs={"torch_dtype": self.torch_dtype}, + ) + # fmt: off + self.model = load_model(G, MAT_MODEL_URL, device, MAT_MODEL_MD5).to(self.torch_dtype) + self.z = torch.from_numpy(np.random.randn(1, G.z_dim)).to(self.torch_dtype).to(device) + self.label = torch.zeros([1, self.model.c_dim], device=device).to(self.torch_dtype) + # fmt: on @staticmethod def is_downloaded() -> bool: @@ -1909,8 +1915,10 @@ class MAT(InpaintModel): mask = 255 - mask mask = norm_img(mask) - image = torch.from_numpy(image).unsqueeze(0).to(self.device) - mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) + image = ( + torch.from_numpy(image).unsqueeze(0).to(self.torch_dtype).to(self.device) + ) + mask = torch.from_numpy(mask).unsqueeze(0).to(self.torch_dtype).to(self.device) output = self.model( image, mask, self.z, self.label, truncation_psi=1, noise_mode="none"