add remove_bg_device

This commit is contained in:
Qing
2024-11-23 15:51:05 +08:00
parent d29fe6ecb5
commit b7699a0f26
10 changed files with 64 additions and 20 deletions

View File

@@ -3,7 +3,7 @@ from PIL import Image
from iopaint.helper import encode_pil_to_base64, gen_frontend_mask
from iopaint.plugins.anime_seg import AnimeSeg
from iopaint.schema import RunPluginRequest, RemoveBGModel, InteractiveSegModel
from iopaint.schema import Device, RunPluginRequest, RemoveBGModel, InteractiveSegModel
from iopaint.tests.utils import check_device, current_dir, save_dir
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@@ -38,24 +38,26 @@ def _save(img, name):
@pytest.mark.parametrize("model_name", RemoveBGModel.values())
def test_remove_bg(model_name):
print(f"Testing {model_name}")
model = RemoveBG(model_name)
@pytest.mark.parametrize("device", Device.values())
def test_remove_bg(model_name, device):
check_device(device)
print(f"Testing {model_name} on {device}")
model = RemoveBG(model_name, device)
rgba_np_img = model.gen_image(
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
)
res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA)
_save(res, f"test_remove_bg_{model_name}.png")
_save(res, f"test_remove_bg_{model_name}_{device}.png")
bgr_np_img = model.gen_mask(
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
)
res_mask = gen_frontend_mask(bgr_np_img)
_save(res_mask, f"test_remove_bg_frontend_mask_{model_name}.png")
_save(res_mask, f"test_remove_bg_frontend_mask_{model_name}_{device}.png")
assert len(bgr_np_img.shape) == 2
_save(bgr_np_img, f"test_remove_bg_mask_{model_name}.jpeg")
_save(bgr_np_img, f"test_remove_bg_mask_{model_name}_{device}.jpeg")
def test_anime_seg():