add more realesrgan models
This commit is contained in:
@@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.const import *
|
from lama_cleaner.const import *
|
||||||
|
from lama_cleaner.plugins.realesrgan import RealESRGANModelName, RealESRGANModelNameList
|
||||||
from lama_cleaner.runtime import dump_environment_info
|
from lama_cleaner.runtime import dump_environment_info
|
||||||
|
|
||||||
|
|
||||||
@@ -92,7 +93,13 @@ def parse_args():
|
|||||||
help="Enable realesrgan super resolution",
|
help="Enable realesrgan super resolution",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--realesrgan-device", default="cpu", type=str, choices=["cpu", "cuda"]
|
"--realesrgan-device", default="cpu", type=str, choices=["cpu", "cuda", "mps"]
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--realesrgan-model",
|
||||||
|
default=RealESRGANModelName.realesr_general_x4v3.value,
|
||||||
|
type=str,
|
||||||
|
choices=RealESRGANModelNameList,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-gif",
|
"--enable-gif",
|
||||||
|
|||||||
@@ -1,33 +1,79 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from lama_cleaner.helper import download_model
|
from lama_cleaner.helper import download_model
|
||||||
|
|
||||||
|
|
||||||
|
class RealESRGANModelName(str, Enum):
|
||||||
|
realesr_general_x4v3 = "realesr-general-x4v3"
|
||||||
|
RealESRGAN_x4plus = "RealESRGAN_x4plus"
|
||||||
|
RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B"
|
||||||
|
|
||||||
|
|
||||||
|
RealESRGANModelNameList = [e.value for e in RealESRGANModelName]
|
||||||
|
|
||||||
|
|
||||||
class RealESRGANUpscaler:
|
class RealESRGANUpscaler:
|
||||||
name = "RealESRGAN"
|
name = "RealESRGAN"
|
||||||
|
|
||||||
def __init__(self, device):
|
def __init__(self, name, device):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
|
|
||||||
scale = 4
|
REAL_ESRGAN_MODELS = {
|
||||||
model = RRDBNet(
|
RealESRGANModelName.realesr_general_x4v3: {
|
||||||
num_in_ch=3,
|
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||||
num_out_ch=3,
|
"scale": 4,
|
||||||
num_feat=64,
|
"model": lambda: SRVGGNetCompact(
|
||||||
num_block=23,
|
num_in_ch=3,
|
||||||
num_grow_ch=32,
|
num_out_ch=3,
|
||||||
scale=4,
|
num_feat=64,
|
||||||
)
|
num_conv=32,
|
||||||
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
|
upscale=4,
|
||||||
model_md5 = "99ec365d4afad750833258a1a24f44ca"
|
act_type="prelu",
|
||||||
model_path = download_model(url, model_md5)
|
),
|
||||||
|
"model_md5": "91a7644643c884ee00737db24e478156",
|
||||||
|
},
|
||||||
|
RealESRGANModelName.RealESRGAN_x4plus: {
|
||||||
|
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
|
"scale": 4,
|
||||||
|
"model": lambda: RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=4,
|
||||||
|
),
|
||||||
|
"model_md5": "99ec365d4afad750833258a1a24f44ca",
|
||||||
|
},
|
||||||
|
RealESRGANModelName.RealESRGAN_x4plus_anime_6B: {
|
||||||
|
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
|
"scale": 4,
|
||||||
|
"model": lambda: RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=6,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=4,
|
||||||
|
),
|
||||||
|
"model_md5": "d58ce384064ec1591c2ea7b79dbf47ba",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if name not in REAL_ESRGAN_MODELS:
|
||||||
|
raise ValueError(f"Unknown RealESRGAN model name: {name}")
|
||||||
|
model_info = REAL_ESRGAN_MODELS[name]
|
||||||
|
|
||||||
|
model_path = download_model(model_info["url"], model_info["model_md5"])
|
||||||
|
|
||||||
self.model = RealESRGANer(
|
self.model = RealESRGANer(
|
||||||
scale=scale,
|
scale=model_info["scale"],
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
model=model,
|
model=model_info["model"](),
|
||||||
half=True if "cuda" in str(device) else False,
|
half=True if "cuda" in str(device) else False,
|
||||||
tile=640,
|
tile=640,
|
||||||
tile_pad=10,
|
tile_pad=10,
|
||||||
|
|||||||
@@ -423,8 +423,12 @@ def build_plugins(args):
|
|||||||
logger.info(f"Initialize {RemoveBG.name} plugin")
|
logger.info(f"Initialize {RemoveBG.name} plugin")
|
||||||
plugins[RemoveBG.name] = RemoveBG()
|
plugins[RemoveBG.name] = RemoveBG()
|
||||||
if args.enable_realesrgan:
|
if args.enable_realesrgan:
|
||||||
logger.info(f"Initialize {RealESRGANUpscaler.name} plugin")
|
logger.info(
|
||||||
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(args.realesrgan_device)
|
f"Initialize {RealESRGANUpscaler.name} plugin: {args.realesrgan_model}, {args.realesrgan_device}"
|
||||||
|
)
|
||||||
|
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
|
||||||
|
args.realesrgan_model, args.realesrgan_device
|
||||||
|
)
|
||||||
if args.enable_gif:
|
if args.enable_gif:
|
||||||
logger.info(f"Initialize GIF plugin")
|
logger.info(f"Initialize GIF plugin")
|
||||||
plugins[MakeGIF.name] = MakeGIF()
|
plugins[MakeGIF.name] = MakeGIF()
|
||||||
|
|||||||
Reference in New Issue
Block a user