From 07ae89b7c092572fb8f56ef3f58f5cabe4d092ef Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 30 Mar 2023 21:06:07 +0800 Subject: [PATCH] update test --- lama_cleaner/tests/test_plugins.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/lama_cleaner/tests/test_plugins.py b/lama_cleaner/tests/test_plugins.py index 2f41ed8..cf438b7 100644 --- a/lama_cleaner/tests/test_plugins.py +++ b/lama_cleaner/tests/test_plugins.py @@ -4,7 +4,12 @@ import cv2 import pytest import torch.cuda -from lama_cleaner.plugins import RemoveBG, RealESRGANUpscaler, GFPGANPlugin +from lama_cleaner.plugins import ( + RemoveBG, + RealESRGANUpscaler, + GFPGANPlugin, + RestoreFormerPlugin, +) current_dir = Path(__file__).parent.absolute().resolve() save_dir = current_dir / "result" @@ -48,3 +53,14 @@ def test_gfpgan(device): model = GFPGANPlugin(device) res = model(rgb_img, None, None) _save(res, f"test_gfpgan_{device}.png") + + +@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) +def test_restoreformer(device): + if device == "cuda" and not torch.cuda.is_available(): + return + if device == "mps" and not torch.backends.mps.is_available(): + return + model = RestoreFormerPlugin(device) + res = model(rgb_img, None, None) + _save(res, f"test_restoreformer_{device}.png")