add Segment Anything

This commit is contained in:
Qing
2023-04-06 21:55:20 +08:00
parent ed36744339
commit a6aec566d9
20 changed files with 1885 additions and 299 deletions

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env python3
import hashlib
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@@ -315,6 +316,10 @@ def run_plugin():
start = time.time()
try:
form = dict(form)
if name == InteractiveSeg.name:
img_md5 = hashlib.md5(origin_image_bytes).hexdigest()
form["img_md5"] = img_md5
bgr_res = plugins[name](rgb_np_img, files, form)
except RuntimeError as e:
torch.cuda.empty_cache()
@@ -437,7 +442,9 @@ def build_plugins(args):
global plugins
if args.enable_interactive_seg:
logger.info(f"Initialize {InteractiveSeg.name} plugin")
plugins[InteractiveSeg.name] = InteractiveSeg()
plugins[InteractiveSeg.name] = InteractiveSeg(
args.interactive_seg_model, args.interactive_seg_device
)
if args.enable_remove_bg:
logger.info(f"Initialize {RemoveBG.name} plugin")
plugins[RemoveBG.name] = RemoveBG()
@@ -452,6 +459,12 @@ def build_plugins(args):
)
if args.enable_gfpgan:
logger.info(f"Initialize {GFPGANPlugin.name} plugin")
if args.enable_realesrgan:
logger.info("Use realesrgan as GFPGAN background upscaler")
else:
logger.info(
f"GFPGAN no background upscaler, use --enable-realesrgan to enable it"
)
plugins[GFPGANPlugin.name] = GFPGANPlugin(
args.gfpgan_device, upscaler=plugins.get(RealESRGANUpscaler.name, None)
)