add Segment Anything
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
@@ -315,6 +316,10 @@ def run_plugin():
|
||||
|
||||
start = time.time()
|
||||
try:
|
||||
form = dict(form)
|
||||
if name == InteractiveSeg.name:
|
||||
img_md5 = hashlib.md5(origin_image_bytes).hexdigest()
|
||||
form["img_md5"] = img_md5
|
||||
bgr_res = plugins[name](rgb_np_img, files, form)
|
||||
except RuntimeError as e:
|
||||
torch.cuda.empty_cache()
|
||||
@@ -437,7 +442,9 @@ def build_plugins(args):
|
||||
global plugins
|
||||
if args.enable_interactive_seg:
|
||||
logger.info(f"Initialize {InteractiveSeg.name} plugin")
|
||||
plugins[InteractiveSeg.name] = InteractiveSeg()
|
||||
plugins[InteractiveSeg.name] = InteractiveSeg(
|
||||
args.interactive_seg_model, args.interactive_seg_device
|
||||
)
|
||||
if args.enable_remove_bg:
|
||||
logger.info(f"Initialize {RemoveBG.name} plugin")
|
||||
plugins[RemoveBG.name] = RemoveBG()
|
||||
@@ -452,6 +459,12 @@ def build_plugins(args):
|
||||
)
|
||||
if args.enable_gfpgan:
|
||||
logger.info(f"Initialize {GFPGANPlugin.name} plugin")
|
||||
if args.enable_realesrgan:
|
||||
logger.info("Use realesrgan as GFPGAN background upscaler")
|
||||
else:
|
||||
logger.info(
|
||||
f"GFPGAN no background upscaler, use --enable-realesrgan to enable it"
|
||||
)
|
||||
plugins[GFPGANPlugin.name] = GFPGANPlugin(
|
||||
args.gfpgan_device, upscaler=plugins.get(RealESRGANUpscaler.name, None)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user