diff --git a/iopaint/plugins/remove_bg.py b/iopaint/plugins/remove_bg.py index 64bf785..e066662 100644 --- a/iopaint/plugins/remove_bg.py +++ b/iopaint/plugins/remove_bg.py @@ -17,6 +17,14 @@ class RemoveBG(BasePlugin): super().__init__() self.model_name = model_name + if model_name.startswith("birefnet"): + import rembg + + if rembg.__version__ < "2.0.59": + raise ValueError( + "To use birefnet models, please upgrade rembg to >= 2.0.59. pip install -U rembg" + ) + hub_dir = get_dir() model_dir = os.path.join(hub_dir, "checkpoints") os.environ["U2NET_HOME"] = model_dir @@ -66,6 +74,4 @@ class RemoveBG(BasePlugin): try: import rembg except ImportError: - return ( - "RemoveBG is not installed, please install it first. pip install rembg" - ) + return "RemoveBG is not installed, please install it first. pip install -U rembg" diff --git a/iopaint/schema.py b/iopaint/schema.py index 3150ab4..0b01a93 100644 --- a/iopaint/schema.py +++ b/iopaint/schema.py @@ -144,13 +144,21 @@ class RealESRGANModel(Choices): class RemoveBGModel(Choices): + briaai_rmbg_1_4 = "briaai/RMBG-1.4" + # models from https://github.com/danielgatis/rembg u2net = "u2net" u2netp = "u2netp" u2net_human_seg = "u2net_human_seg" u2net_cloth_seg = "u2net_cloth_seg" silueta = "silueta" isnet_general_use = "isnet-general-use" - briaai_rmbg_1_4 = "briaai/RMBG-1.4" + birefnet_general = "birefnet-general" + birefnet_general_lite = "birefnet-general-lite" + birefnet_portrait = "birefnet-portrait" + birefnet_dis = "birefnet-dis" + birefnet_hrsod = "birefnet-hrsod" + birefnet_cod = "birefnet-cod" + birefnet_massive = "birefnet-massive" class Device(Choices): diff --git a/iopaint/tests/test_plugins.py b/iopaint/tests/test_plugins.py index dd1eafd..ffa35e6 100644 --- a/iopaint/tests/test_plugins.py +++ b/iopaint/tests/test_plugins.py @@ -36,23 +36,25 @@ def _save(img, name): cv2.imwrite(str(save_dir / name), img) -def test_remove_bg(): - model = RemoveBG(RemoveBGModel.briaai_rmbg_1_4) +@pytest.mark.parametrize("model_name", RemoveBGModel.values()) +def test_remove_bg(model_name): + print(f"Testing {model_name}") + model = RemoveBG(model_name) rgba_np_img = model.gen_image( rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64) ) res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA) - _save(res, "test_remove_bg.png") + _save(res, f"test_remove_bg_{model_name}.png") bgr_np_img = model.gen_mask( rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64) ) res_mask = gen_frontend_mask(bgr_np_img) - _save(res_mask, "test_remove_bg_frontend_mask.png") + _save(res_mask, f"test_remove_bg_frontend_mask_{model_name}.png") assert len(bgr_np_img.shape) == 2 - _save(bgr_np_img, "test_remove_bg_mask.jpeg") + _save(bgr_np_img, f"test_remove_bg_mask_{model_name}.jpeg") def test_anime_seg():