From a5997e933213d4be1b6501aef0ad850f6742ae94 Mon Sep 17 00:00:00 2001 From: Qing Date: Sat, 25 Mar 2023 10:15:44 +0800 Subject: [PATCH] add more realesrgan models --- lama_cleaner/parse_args.py | 9 +++- lama_cleaner/plugins/realesrgan.py | 76 ++++++++++++++++++++++++------ lama_cleaner/server.py | 8 +++- 3 files changed, 75 insertions(+), 18 deletions(-) diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index 7c7d4bd..a4eefd7 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -6,6 +6,7 @@ from pathlib import Path from loguru import logger from lama_cleaner.const import * +from lama_cleaner.plugins.realesrgan import RealESRGANModelName, RealESRGANModelNameList from lama_cleaner.runtime import dump_environment_info @@ -92,7 +93,13 @@ def parse_args(): help="Enable realesrgan super resolution", ) parser.add_argument( - "--realesrgan-device", default="cpu", type=str, choices=["cpu", "cuda"] + "--realesrgan-device", default="cpu", type=str, choices=["cpu", "cuda", "mps"] + ) + parser.add_argument( + "--realesrgan-model", + default=RealESRGANModelName.realesr_general_x4v3.value, + type=str, + choices=RealESRGANModelNameList, ) parser.add_argument( "--enable-gif", diff --git a/lama_cleaner/plugins/realesrgan.py b/lama_cleaner/plugins/realesrgan.py index 2758291..9000ee1 100644 --- a/lama_cleaner/plugins/realesrgan.py +++ b/lama_cleaner/plugins/realesrgan.py @@ -1,33 +1,79 @@ +from enum import Enum + import cv2 from lama_cleaner.helper import download_model +class RealESRGANModelName(str, Enum): + realesr_general_x4v3 = "realesr-general-x4v3" + RealESRGAN_x4plus = "RealESRGAN_x4plus" + RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B" + + +RealESRGANModelNameList = [e.value for e in RealESRGANModelName] + + class RealESRGANUpscaler: name = "RealESRGAN" - def __init__(self, device): + def __init__(self, name, device): super().__init__() from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer + from realesrgan.archs.srvgg_arch import SRVGGNetCompact - scale = 4 - model = RRDBNet( - num_in_ch=3, - num_out_ch=3, - num_feat=64, - num_block=23, - num_grow_ch=32, - scale=4, - ) - url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth" - model_md5 = "99ec365d4afad750833258a1a24f44ca" - model_path = download_model(url, model_md5) + REAL_ESRGAN_MODELS = { + RealESRGANModelName.realesr_general_x4v3: { + "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", + "scale": 4, + "model": lambda: SRVGGNetCompact( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_conv=32, + upscale=4, + act_type="prelu", + ), + "model_md5": "91a7644643c884ee00737db24e478156", + }, + RealESRGANModelName.RealESRGAN_x4plus: { + "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", + "scale": 4, + "model": lambda: RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4, + ), + "model_md5": "99ec365d4afad750833258a1a24f44ca", + }, + RealESRGANModelName.RealESRGAN_x4plus_anime_6B: { + "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", + "scale": 4, + "model": lambda: RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=6, + num_grow_ch=32, + scale=4, + ), + "model_md5": "d58ce384064ec1591c2ea7b79dbf47ba", + }, + } + if name not in REAL_ESRGAN_MODELS: + raise ValueError(f"Unknown RealESRGAN model name: {name}") + model_info = REAL_ESRGAN_MODELS[name] + + model_path = download_model(model_info["url"], model_info["model_md5"]) self.model = RealESRGANer( - scale=scale, + scale=model_info["scale"], model_path=model_path, - model=model, + model=model_info["model"](), half=True if "cuda" in str(device) else False, tile=640, tile_pad=10, diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index d6b49f0..53a2738 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -423,8 +423,12 @@ def build_plugins(args): logger.info(f"Initialize {RemoveBG.name} plugin") plugins[RemoveBG.name] = RemoveBG() if args.enable_realesrgan: - logger.info(f"Initialize {RealESRGANUpscaler.name} plugin") - plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(args.realesrgan_device) + logger.info( + f"Initialize {RealESRGANUpscaler.name} plugin: {args.realesrgan_model}, {args.realesrgan_device}" + ) + plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler( + args.realesrgan_model, args.realesrgan_device + ) if args.enable_gif: logger.info(f"Initialize GIF plugin") plugins[MakeGIF.name] = MakeGIF()