From cfcaf82a21f3de4d4c1f7be7c907fd96efbdf285 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 19 Jul 2022 21:47:21 +0800 Subject: [PATCH] zits use structure_upsample_model --- .../app/src/components/Settings/Settings.scss | 2 +- lama_cleaner/helper.py | 22 +++++-- lama_cleaner/model/zits.py | 64 ++++++++++++------- 3 files changed, 58 insertions(+), 30 deletions(-) diff --git a/lama_cleaner/app/src/components/Settings/Settings.scss b/lama_cleaner/app/src/components/Settings/Settings.scss index e743959..3f8f803 100644 --- a/lama_cleaner/app/src/components/Settings/Settings.scss +++ b/lama_cleaner/app/src/components/Settings/Settings.scss @@ -8,7 +8,7 @@ background-color: var(--modal-bg); color: var(--modal-text-color); box-shadow: 0px 0px 20px rgb(0, 0, 40, 0.2); - width: 700px; + width: 600px; @include mobile { display: grid; diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index b93c61a..5d477f0 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -42,17 +42,23 @@ def load_jit_model(url_or_path, device): else: model_path = download_model(url_or_path) logger.info(f"Load model from: {model_path}") - model = torch.jit.load(model_path).to(device) + try: + model = torch.jit.load(model_path).to(device) + except: + logger.error( + f"Failed to load {model_path}, delete model and restart lama-cleaner" + ) + exit(-1) model.eval() return model def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: - data = cv2.imencode(f".{ext}", image_numpy, - [ - int(cv2.IMWRITE_JPEG_QUALITY), 100, - int(cv2.IMWRITE_PNG_COMPRESSION), 0 - ])[1] + data = cv2.imencode( + f".{ext}", + image_numpy, + [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], + )[1] image_bytes = data.tobytes() return image_bytes @@ -95,7 +101,9 @@ def resize_max_size( return np_img -def pad_img_to_modulo(img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None): +def pad_img_to_modulo( + img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None +): """ Args: diff --git a/lama_cleaner/model/zits.py b/lama_cleaner/model/zits.py index 3faf5eb..f3d3483 100644 --- a/lama_cleaner/model/zits.py +++ b/lama_cleaner/model/zits.py @@ -41,7 +41,7 @@ def resize(img, height, width, center_crop=False): side = np.minimum(imgh, imgw) j = (imgh - side) // 2 i = (imgw - side) // 2 - img = img[j: j + side, i: i + side, ...] + img = img[j : j + side, i : i + side, ...] if imgh > height and imgw > width: inter = cv2.INTER_AREA @@ -219,7 +219,9 @@ class ZITS(InpaintModel): def init_model(self, device): self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device) self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device) - # self.structure_upsample = load_jit_model(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device) + self.structure_upsample = load_jit_model( + ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device + ) self.inpaint = load_jit_model(ZITS_INPAINT_MODEL_URL, device) @staticmethod @@ -227,7 +229,7 @@ class ZITS(InpaintModel): model_paths = [ get_cache_path_by_url(ZITS_WIRE_FRAME_MODEL_URL), get_cache_path_by_url(ZITS_EDGE_LINE_MODEL_URL), - # get_cache_path_by_url(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL), + get_cache_path_by_url(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL), get_cache_path_by_url(ZITS_INPAINT_MODEL_URL), ] return all([os.path.exists(it) for it in model_paths]) @@ -272,20 +274,27 @@ class ZITS(InpaintModel): # cv2.imwrite("line_pred.jpg", np_line_pred) # exit() - # No structure_upsample_model input_size = min(items["h"], items["w"]) - edge_pred = F.interpolate( - edge_pred, - size=(input_size, input_size), - mode="bilinear", - align_corners=False, - ) - line_pred = F.interpolate( - line_pred, - size=(input_size, input_size), - mode="bilinear", - align_corners=False, - ) + if input_size != 256 and input_size > 256: + while edge_pred.shape[2] < input_size: + edge_pred = self.structure_upsample(edge_pred) + edge_pred = torch.sigmoid((edge_pred + 2) * 2) + + line_pred = self.structure_upsample(line_pred) + line_pred = torch.sigmoid((line_pred + 2) * 2) + + edge_pred = F.interpolate( + edge_pred, + size=(input_size, input_size), + mode="bilinear", + align_corners=False, + ) + line_pred = F.interpolate( + line_pred, + size=(input_size, input_size), + mode="bilinear", + align_corners=False, + ) # np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8) # cv2.imwrite("edge_pred_upsample.jpg", np_edge_pred) @@ -308,12 +317,19 @@ class ZITS(InpaintModel): self.wireframe_edge_and_line(items, config.zits_wireframe) - inpainted_image = self.inpaint(items["images"], items["masks"], - items["edge"], items["line"], - items["rel_pos"], items["direct"]) + inpainted_image = self.inpaint( + items["images"], + items["masks"], + items["edge"], + items["line"], + items["rel_pos"], + items["direct"], + ) inpainted_image = inpainted_image * 255.0 - inpainted_image = inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8) + inpainted_image = ( + inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8) + ) inpainted_image = inpainted_image[:, :, ::-1] # cv2.imwrite("inpainted.jpg", inpainted_image) @@ -362,7 +378,9 @@ class ZITS(InpaintModel): lines_tensor = torch.cat(lines_tensor, dim=0) return lines_tensor.detach().to(self.device) - def sample_edge_line_logits(self, context, mask=None, iterations=1, add_v=0, mul_v=4): + def sample_edge_line_logits( + self, context, mask=None, iterations=1, add_v=0, mul_v=4 + ): [img, edge, line] = context img = img * (1 - mask) @@ -391,7 +409,9 @@ class ZITS(InpaintModel): edge_max_probs = edge_probs.max(dim=-1)[0] + (1 - mask) * (-100) line_max_probs = line_probs.max(dim=-1)[0] + (1 - mask) * (-100) - indices = torch.sort(edge_max_probs + line_max_probs, dim=-1, descending=True)[1] + indices = torch.sort( + edge_max_probs + line_max_probs, dim=-1, descending=True + )[1] for ii in range(b): keep = int((i + 1) / iterations * torch.sum(mask[ii, ...]))