RemoveBg plugin add birefnet models, require rembg>=2.0.59
This commit is contained in:
@@ -17,6 +17,14 @@ class RemoveBG(BasePlugin):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
|
if model_name.startswith("birefnet"):
|
||||||
|
import rembg
|
||||||
|
|
||||||
|
if rembg.__version__ < "2.0.59":
|
||||||
|
raise ValueError(
|
||||||
|
"To use birefnet models, please upgrade rembg to >= 2.0.59. pip install -U rembg"
|
||||||
|
)
|
||||||
|
|
||||||
hub_dir = get_dir()
|
hub_dir = get_dir()
|
||||||
model_dir = os.path.join(hub_dir, "checkpoints")
|
model_dir = os.path.join(hub_dir, "checkpoints")
|
||||||
os.environ["U2NET_HOME"] = model_dir
|
os.environ["U2NET_HOME"] = model_dir
|
||||||
@@ -66,6 +74,4 @@ class RemoveBG(BasePlugin):
|
|||||||
try:
|
try:
|
||||||
import rembg
|
import rembg
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return (
|
return "RemoveBG is not installed, please install it first. pip install -U rembg"
|
||||||
"RemoveBG is not installed, please install it first. pip install rembg"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -144,13 +144,21 @@ class RealESRGANModel(Choices):
|
|||||||
|
|
||||||
|
|
||||||
class RemoveBGModel(Choices):
|
class RemoveBGModel(Choices):
|
||||||
|
briaai_rmbg_1_4 = "briaai/RMBG-1.4"
|
||||||
|
# models from https://github.com/danielgatis/rembg
|
||||||
u2net = "u2net"
|
u2net = "u2net"
|
||||||
u2netp = "u2netp"
|
u2netp = "u2netp"
|
||||||
u2net_human_seg = "u2net_human_seg"
|
u2net_human_seg = "u2net_human_seg"
|
||||||
u2net_cloth_seg = "u2net_cloth_seg"
|
u2net_cloth_seg = "u2net_cloth_seg"
|
||||||
silueta = "silueta"
|
silueta = "silueta"
|
||||||
isnet_general_use = "isnet-general-use"
|
isnet_general_use = "isnet-general-use"
|
||||||
briaai_rmbg_1_4 = "briaai/RMBG-1.4"
|
birefnet_general = "birefnet-general"
|
||||||
|
birefnet_general_lite = "birefnet-general-lite"
|
||||||
|
birefnet_portrait = "birefnet-portrait"
|
||||||
|
birefnet_dis = "birefnet-dis"
|
||||||
|
birefnet_hrsod = "birefnet-hrsod"
|
||||||
|
birefnet_cod = "birefnet-cod"
|
||||||
|
birefnet_massive = "birefnet-massive"
|
||||||
|
|
||||||
|
|
||||||
class Device(Choices):
|
class Device(Choices):
|
||||||
|
|||||||
@@ -36,23 +36,25 @@ def _save(img, name):
|
|||||||
cv2.imwrite(str(save_dir / name), img)
|
cv2.imwrite(str(save_dir / name), img)
|
||||||
|
|
||||||
|
|
||||||
def test_remove_bg():
|
@pytest.mark.parametrize("model_name", RemoveBGModel.values())
|
||||||
model = RemoveBG(RemoveBGModel.briaai_rmbg_1_4)
|
def test_remove_bg(model_name):
|
||||||
|
print(f"Testing {model_name}")
|
||||||
|
model = RemoveBG(model_name)
|
||||||
rgba_np_img = model.gen_image(
|
rgba_np_img = model.gen_image(
|
||||||
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
|
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
|
||||||
)
|
)
|
||||||
res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA)
|
res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA)
|
||||||
_save(res, "test_remove_bg.png")
|
_save(res, f"test_remove_bg_{model_name}.png")
|
||||||
|
|
||||||
bgr_np_img = model.gen_mask(
|
bgr_np_img = model.gen_mask(
|
||||||
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
|
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
|
||||||
)
|
)
|
||||||
|
|
||||||
res_mask = gen_frontend_mask(bgr_np_img)
|
res_mask = gen_frontend_mask(bgr_np_img)
|
||||||
_save(res_mask, "test_remove_bg_frontend_mask.png")
|
_save(res_mask, f"test_remove_bg_frontend_mask_{model_name}.png")
|
||||||
|
|
||||||
assert len(bgr_np_img.shape) == 2
|
assert len(bgr_np_img.shape) == 2
|
||||||
_save(bgr_np_img, "test_remove_bg_mask.jpeg")
|
_save(bgr_np_img, f"test_remove_bg_mask_{model_name}.jpeg")
|
||||||
|
|
||||||
|
|
||||||
def test_anime_seg():
|
def test_anime_seg():
|
||||||
|
|||||||
Reference in New Issue
Block a user