From 4d1809e9082e831fff0c261a585c6dacfa98f847 Mon Sep 17 00:00:00 2001 From: Qing Date: Sun, 26 Mar 2023 20:52:06 +0800 Subject: [PATCH] pass upscaler to GFPGAN --- lama_cleaner/plugins/gfpgan_plugin.py | 3 ++- lama_cleaner/server.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lama_cleaner/plugins/gfpgan_plugin.py b/lama_cleaner/plugins/gfpgan_plugin.py index 0d465e0..cee1834 100644 --- a/lama_cleaner/plugins/gfpgan_plugin.py +++ b/lama_cleaner/plugins/gfpgan_plugin.py @@ -8,7 +8,7 @@ from lama_cleaner.plugins.base_plugin import BasePlugin class GFPGANPlugin(BasePlugin): name = "GFPGAN" - def __init__(self, device): + def __init__(self, device, upscaler=None): super().__init__() from .gfpganer import MyGFPGANer @@ -24,6 +24,7 @@ class GFPGANPlugin(BasePlugin): arch="clean", channel_multiplier=2, device=device, + bg_upsampler=upscaler.model if upscaler is not None else None, ) def __call__(self, rgb_np_img, files, form): diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index cb79b8e..8ea10b9 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -449,7 +449,9 @@ def build_plugins(args): ) if args.enable_gfpgan: logger.info(f"Initialize {GFPGANPlugin.name} plugin") - plugins[GFPGANPlugin.name] = GFPGANPlugin(args.gfpgan_device) + plugins[GFPGANPlugin.name] = GFPGANPlugin( + args.gfpgan_device, upscaler=plugins.get(RealESRGANUpscaler.name, None) + ) if args.enable_gif: logger.info(f"Initialize GIF plugin") plugins[MakeGIF.name] = MakeGIF()