pass upscaler to GFPGAN
This commit is contained in:
@@ -8,7 +8,7 @@ from lama_cleaner.plugins.base_plugin import BasePlugin
|
|||||||
class GFPGANPlugin(BasePlugin):
|
class GFPGANPlugin(BasePlugin):
|
||||||
name = "GFPGAN"
|
name = "GFPGAN"
|
||||||
|
|
||||||
def __init__(self, device):
|
def __init__(self, device, upscaler=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
from .gfpganer import MyGFPGANer
|
from .gfpganer import MyGFPGANer
|
||||||
|
|
||||||
@@ -24,6 +24,7 @@ class GFPGANPlugin(BasePlugin):
|
|||||||
arch="clean",
|
arch="clean",
|
||||||
channel_multiplier=2,
|
channel_multiplier=2,
|
||||||
device=device,
|
device=device,
|
||||||
|
bg_upsampler=upscaler.model if upscaler is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, rgb_np_img, files, form):
|
def __call__(self, rgb_np_img, files, form):
|
||||||
|
|||||||
@@ -449,7 +449,9 @@ def build_plugins(args):
|
|||||||
)
|
)
|
||||||
if args.enable_gfpgan:
|
if args.enable_gfpgan:
|
||||||
logger.info(f"Initialize {GFPGANPlugin.name} plugin")
|
logger.info(f"Initialize {GFPGANPlugin.name} plugin")
|
||||||
plugins[GFPGANPlugin.name] = GFPGANPlugin(args.gfpgan_device)
|
plugins[GFPGANPlugin.name] = GFPGANPlugin(
|
||||||
|
args.gfpgan_device, upscaler=plugins.get(RealESRGANUpscaler.name, None)
|
||||||
|
)
|
||||||
if args.enable_gif:
|
if args.enable_gif:
|
||||||
logger.info(f"Initialize GIF plugin")
|
logger.info(f"Initialize GIF plugin")
|
||||||
plugins[MakeGIF.name] = MakeGIF()
|
plugins[MakeGIF.name] = MakeGIF()
|
||||||
|
|||||||
Reference in New Issue
Block a user