add sam_hq

This commit is contained in:
Qing
2024-02-20 09:03:11 +08:00
parent 6447e821cb
commit b358e6cbce
10 changed files with 1281 additions and 19 deletions

View File

@@ -5,7 +5,7 @@ from PIL import Image
from iopaint.helper import encode_pil_to_base64, gen_frontend_mask
from iopaint.plugins.anime_seg import AnimeSeg
from iopaint.schema import RunPluginRequest, RemoveBGModel
from iopaint.schema import RunPluginRequest, RemoveBGModel, InteractiveSegModel
from iopaint.tests.utils import check_device, current_dir, save_dir
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@@ -103,10 +103,11 @@ def test_restoreformer(device):
_save(res, f"test_restoreformer_{device}.png")
@pytest.mark.parametrize("name", InteractiveSegModel.values())
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
def test_segment_anything(device):
def test_segment_anything(name, device):
check_device(device)
model = InteractiveSeg("vit_l", device)
model = InteractiveSeg(name, device)
new_mask = model.gen_mask(
rgb_img,
RunPluginRequest(
@@ -116,5 +117,5 @@ def test_segment_anything(device):
),
)
save_name = f"test_segment_anything_{device}.png"
save_name = f"test_segment_anything_{name}_{device}.png"
_save(new_mask, save_name)