diff --git a/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx b/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx index d722628..d8f7057 100644 --- a/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx +++ b/lama_cleaner/app/src/components/Settings/ModelSettingBlock.tsx @@ -196,6 +196,8 @@ function ModelSettingBlock() { return renderFCFModelDesc() case AIModel.SD15: return undefined + case AIModel.Mange: + return undefined case AIModel.CV2: return renderOpenCV2Desc() default: @@ -241,6 +243,12 @@ function ModelSettingBlock() { 'https://ommer-lab.com/research/latent-diffusion-models/', 'https://github.com/CompVis/stable-diffusion' ) + case AIModel.Mange: + return renderModelDesc( + 'Manga Inpainting', + 'https://www.cse.cuhk.edu.hk/~ttwong/papers/mangainpaint/mangainpaint.html', + 'https://github.com/msxie92/MangaInpainting' + ) case AIModel.CV2: return renderModelDesc( 'OpenCV Image Inpainting', diff --git a/lama_cleaner/app/src/store/Atoms.tsx b/lama_cleaner/app/src/store/Atoms.tsx index 25002c6..14a57f3 100644 --- a/lama_cleaner/app/src/store/Atoms.tsx +++ b/lama_cleaner/app/src/store/Atoms.tsx @@ -11,6 +11,7 @@ export enum AIModel { FCF = 'fcf', SD15 = 'sd1.5', CV2 = 'cv2', + Mange = 'manga', } export const fileState = atom({ @@ -223,6 +224,13 @@ const defaultHDSettings: ModelsHDSettings = { hdStrategyCropMargin: 128, enabled: true, }, + [AIModel.Mange]: { + hdStrategy: HDStrategy.CROP, + hdStrategyResizeLimit: 1280, + hdStrategyCropTrigerSize: 1024, + hdStrategyCropMargin: 196, + enabled: true, + }, [AIModel.CV2]: { hdStrategy: HDStrategy.RESIZE, hdStrategyResizeLimit: 1080, diff --git a/lama_cleaner/model/lama.py b/lama_cleaner/model/lama.py index b414f1d..7f85e3b 100644 --- a/lama_cleaner/model/lama.py +++ b/lama_cleaner/model/lama.py @@ -5,7 +5,7 @@ import numpy as np import torch from loguru import logger -from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img, get_cache_path_by_url +from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url from lama_cleaner.model.base import InpaintModel from lama_cleaner.schema import Config diff --git a/lama_cleaner/model/manga.py b/lama_cleaner/model/manga.py new file mode 100644 index 0000000..2fd73a8 --- /dev/null +++ b/lama_cleaner/model/manga.py @@ -0,0 +1,120 @@ +import os + +import cv2 +import numpy as np +import torch +import time +from loguru import logger + +from lama_cleaner.helper import get_cache_path_by_url, load_jit_model +from lama_cleaner.model.base import InpaintModel +from lama_cleaner.schema import Config + +# def norm(np_img): +# return np_img / 255 * 2 - 1.0 +# +# +# @torch.no_grad() +# def run(): +# name = 'manga_1080x740.jpg' +# img_p = f'/Users/qing/code/github/MangaInpainting/examples/test/imgs/{name}' +# mask_p = f'/Users/qing/code/github/MangaInpainting/examples/test/masks/mask_{name}' +# erika_model = torch.jit.load('erika.jit') +# manga_inpaintor_model = torch.jit.load('manga_inpaintor.jit') +# +# img = cv2.imread(img_p) +# gray_img = cv2.imread(img_p, cv2.IMREAD_GRAYSCALE) +# mask = cv2.imread(mask_p, cv2.IMREAD_GRAYSCALE) +# +# kernel = np.ones((9, 9), dtype=np.uint8) +# mask = cv2.dilate(mask, kernel, 2) +# # cv2.imwrite("mask.jpg", mask) +# # cv2.imshow('dilated_mask', cv2.hconcat([mask, dilated_mask])) +# # cv2.waitKey(0) +# # exit() +# +# # img = pad(img) +# gray_img = pad(gray_img).astype(np.float32) +# mask = pad(mask) +# +# # pad_mod = 16 +# import time +# start = time.time() +# y = erika_model(torch.from_numpy(gray_img[np.newaxis, np.newaxis, :, :])) +# y = torch.clamp(y, 0, 255) +# lines = y.cpu().numpy() +# print(f"erika_model time: {time.time() - start}") +# +# cv2.imwrite('lines.png', lines[0][0]) +# +# start = time.time() +# masks = torch.from_numpy(mask[np.newaxis, np.newaxis, :, :]) +# masks = torch.where(masks > 0.5, torch.tensor(1.0), torch.tensor(0.0)) +# noise = torch.randn_like(masks) +# +# images = torch.from_numpy(norm(gray_img)[np.newaxis, np.newaxis, :, :]) +# lines = torch.from_numpy(norm(lines)) +# +# outputs = manga_inpaintor_model(images, lines, masks, noise) +# print(f"manga_inpaintor_model time: {time.time() - start}") +# +# outputs_merged = (outputs * masks) + (images * (1 - masks)) +# outputs_merged = outputs_merged * 127.5 + 127.5 +# outputs_merged = outputs_merged.permute(0, 2, 3, 1)[0].detach().cpu().numpy().astype(np.uint8) +# cv2.imwrite(f'output_{name}', outputs_merged) + + +MANGA_INPAINTOR_MODEL_URL = os.environ.get( + "MANGA_INPAINTOR_MODEL_URL", + "https://github.com/Sanster/models/releases/download/manga/manga_inpaintor.jit" +) +MANGA_LINE_MODEL_URL = os.environ.get( + "MANGA_LINE_MODEL_URL", + "https://github.com/Sanster/models/releases/download/manga/erika.jit" +) + + +class Manga(InpaintModel): + pad_mod = 16 + + def init_model(self, device, **kwargs): + self.inpaintor_model = load_jit_model(MANGA_INPAINTOR_MODEL_URL, device) + self.line_model = load_jit_model(MANGA_LINE_MODEL_URL, device) + + @staticmethod + def is_downloaded() -> bool: + model_paths = [ + get_cache_path_by_url(MANGA_INPAINTOR_MODEL_URL), + get_cache_path_by_url(MANGA_LINE_MODEL_URL), + ] + return all([os.path.exists(it) for it in model_paths]) + + def forward(self, image, mask, config: Config): + """ + image: [H, W, C] RGB + mask: [H, W, 1] + return: BGR IMAGE + """ + gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + gray_img = torch.from_numpy(gray_img[np.newaxis, np.newaxis, :, :].astype(np.float32)).to(self.device) + start = time.time() + lines = self.line_model(gray_img) + lines = torch.clamp(lines, 0, 255) + logger.info(f"erika_model time: {time.time() - start}") + + mask = torch.from_numpy(mask[np.newaxis, :, :, :]).to(self.device) + mask = mask.permute(0, 3, 1, 2) + mask = torch.where(mask > 0.5, torch.tensor(1.0), torch.tensor(0.0)) + noise = torch.randn_like(mask) + + gray_img = gray_img / 255 * 2 - 1.0 + lines = lines / 255 * 2 - 1.0 + + start = time.time() + inpainted_image = self.inpaintor_model(gray_img, lines, mask, noise) + logger.info(f"image_inpaintor_model time: {time.time() - start}") + + cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() + cur_res = (cur_res * 127.5 + 127.5).astype(np.uint8) + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_GRAY2BGR) + return cur_res diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index f3b8eda..70411e1 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -3,13 +3,14 @@ import torch from lama_cleaner.model.fcf import FcF from lama_cleaner.model.lama import LaMa from lama_cleaner.model.ldm import LDM +from lama_cleaner.model.manga import Manga from lama_cleaner.model.mat import MAT -from lama_cleaner.model.sd import SD14, SD15 +from lama_cleaner.model.sd import SD15 from lama_cleaner.model.zits import ZITS from lama_cleaner.model.opencv2 import OpenCV2 from lama_cleaner.schema import Config -models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2} +models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga} class ModelManager: diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index 8308d2f..23bb724 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -10,7 +10,7 @@ def parse_args(): parser.add_argument( "--model", default="lama", - choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2"], + choices=["lama", "ldm", "zits", "mat", "fcf", "sd1.5", "cv2", "manga"], ) parser.add_argument( "--hf_access_token", diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 56f3725..33b8079 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -76,6 +76,7 @@ input_image_path: str = None is_disable_model_switch: bool = False is_desktop: bool = False + def get_image_ext(img_bytes): w = imghdr.what("", img_bytes) if w is None: @@ -147,9 +148,13 @@ def process(): try: res_np_img = model(image, mask, config) except RuntimeError as e: - # NOTE: the string may change? + torch.cuda.empty_cache() if "CUDA out of memory. " in str(e): + # NOTE: the string may change? return "CUDA out of memory", 500 + else: + logger.exception(e) + return "Internal Server Error", 500 finally: logger.info(f"process time: {(time.time() - start) * 1000}ms") torch.cuda.empty_cache() @@ -179,6 +184,7 @@ def process(): def current_model(): return model.name, 200 + @app.route("/is_disable_model_switch") def get_is_disable_model_switch(): res = 'true' if is_disable_model_switch else 'false' @@ -189,10 +195,12 @@ def get_is_disable_model_switch(): def model_downloaded(name): return str(model.is_downloaded(name)), 200 + @app.route("/is_desktop") def get_is_desktop(): return str(is_desktop), 200 + @app.route("/model", methods=["POST"]) def switch_model(): new_name = request.form.get("name")