add sam_hq
This commit is contained in:
@@ -5,7 +5,7 @@ from PIL import Image
|
||||
|
||||
from iopaint.helper import encode_pil_to_base64, gen_frontend_mask
|
||||
from iopaint.plugins.anime_seg import AnimeSeg
|
||||
from iopaint.schema import RunPluginRequest, RemoveBGModel
|
||||
from iopaint.schema import RunPluginRequest, RemoveBGModel, InteractiveSegModel
|
||||
from iopaint.tests.utils import check_device, current_dir, save_dir
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
@@ -103,10 +103,11 @@ def test_restoreformer(device):
|
||||
_save(res, f"test_restoreformer_{device}.png")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name", InteractiveSegModel.values())
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
|
||||
def test_segment_anything(device):
|
||||
def test_segment_anything(name, device):
|
||||
check_device(device)
|
||||
model = InteractiveSeg("vit_l", device)
|
||||
model = InteractiveSeg(name, device)
|
||||
new_mask = model.gen_mask(
|
||||
rgb_img,
|
||||
RunPluginRequest(
|
||||
@@ -116,5 +117,5 @@ def test_segment_anything(device):
|
||||
),
|
||||
)
|
||||
|
||||
save_name = f"test_segment_anything_{device}.png"
|
||||
save_name = f"test_segment_anything_{name}_{device}.png"
|
||||
_save(new_mask, save_name)
|
||||
|
||||
Reference in New Issue
Block a user