add anime segmentation

This commit is contained in:
Qing
2023-05-09 19:07:12 +08:00
parent 7fcce78e40
commit e5ac6a105a
11 changed files with 510 additions and 5 deletions

View File

@@ -2,6 +2,8 @@ import hashlib
import os
import time
from lama_cleaner.plugins.anime_seg import AnimeSeg
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from pathlib import Path
@@ -36,6 +38,15 @@ def test_remove_bg():
_save(res, "test_remove_bg.png")
def test_anime_seg():
model = AnimeSeg()
img = cv2.imread(str(current_dir / "anime_test.png"))
res = model.forward(img)
assert len(res.shape) == 3
assert res.shape[-1] == 4
_save(res, "test_anime_seg.png")
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
def test_upscale(device):
if device == "cuda" and not torch.cuda.is_available():