add anime segmentation
This commit is contained in:
@@ -2,6 +2,8 @@ import hashlib
|
||||
import os
|
||||
import time
|
||||
|
||||
from lama_cleaner.plugins.anime_seg import AnimeSeg
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
from pathlib import Path
|
||||
|
||||
@@ -36,6 +38,15 @@ def test_remove_bg():
|
||||
_save(res, "test_remove_bg.png")
|
||||
|
||||
|
||||
def test_anime_seg():
|
||||
model = AnimeSeg()
|
||||
img = cv2.imread(str(current_dir / "anime_test.png"))
|
||||
res = model.forward(img)
|
||||
assert len(res.shape) == 3
|
||||
assert res.shape[-1] == 4
|
||||
_save(res, "test_anime_seg.png")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
|
||||
def test_upscale(device):
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
|
||||
Reference in New Issue
Block a user