diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index d6ea7f3..576f6a7 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -104,6 +104,11 @@ def parse_args(): type=str, choices=RealESRGANModelNameList, ) + parser.add_argument( + "--realesrgan-no-half", + action="store_true", + help="Disable half precision for RealESRGAN", + ) parser.add_argument("--enable-gfpgan", action="store_true", help=GFPGAN_HELP) parser.add_argument( "--gfpgan-device", default="cpu", type=str, choices=GFPGAN_AVAILABLE_DEVICES diff --git a/lama_cleaner/plugins/realesrgan.py b/lama_cleaner/plugins/realesrgan.py index a33546f..6dcf761 100644 --- a/lama_cleaner/plugins/realesrgan.py +++ b/lama_cleaner/plugins/realesrgan.py @@ -11,7 +11,7 @@ from lama_cleaner.plugins.base_plugin import BasePlugin class RealESRGANUpscaler(BasePlugin): name = "RealESRGAN" - def __init__(self, name, device): + def __init__(self, name, device, no_half=False): super().__init__() from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer @@ -69,7 +69,7 @@ class RealESRGANUpscaler(BasePlugin): scale=model_info["scale"], model_path=model_path, model=model_info["model"](), - half=True if "cuda" in str(device) else False, + half=True if "cuda" in str(device) and not no_half else False, tile=512, tile_pad=10, pre_pad=10, diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index b03016a..725f398 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -446,7 +446,9 @@ def build_plugins(args): f"Initialize {RealESRGANUpscaler.name} plugin: {args.realesrgan_model}, {args.realesrgan_device}" ) plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler( - args.realesrgan_model, args.realesrgan_device + args.realesrgan_model, + args.realesrgan_device, + no_half=args.realesrgan_no_half, ) if args.enable_gfpgan: logger.info(f"Initialize {GFPGANPlugin.name} plugin")