This commit is contained in:
Qing
2023-11-16 11:08:34 +08:00
parent 8f942e27c4
commit 0cfec489b7
7 changed files with 63 additions and 28 deletions

View File

@@ -13,11 +13,9 @@ from lama_cleaner.tests.test_model import get_config, assert_equal
current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / "result"
save_dir.mkdir(exist_ok=True, parents=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
@pytest.mark.parametrize("sd_device", ["cuda"])
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
@pytest.mark.parametrize("cpu_textencoder", [True, False])
@@ -56,7 +54,7 @@ def test_runway_sd_1_5_ddim(
)
@pytest.mark.parametrize("sd_device", ["cuda"])
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize(
"sampler", [SDSampler.pndm, SDSampler.k_lms, SDSampler.k_euler, SDSampler.k_euler_a]
@@ -95,7 +93,7 @@ def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_ns
)
@pytest.mark.parametrize("sd_device", ["mps"])
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
@pytest.mark.parametrize("sd_prevent_unmasked_area", [False, True])
@@ -140,7 +138,7 @@ def test_runway_sd_1_5_negative_prompt(
)
@pytest.mark.parametrize("sd_device", ["cuda"])
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
@pytest.mark.parametrize("cpu_textencoder", [False])
@@ -151,7 +149,7 @@ def test_runway_sd_1_5_sd_scale(
if sd_device == "cuda" and not torch.cuda.is_available():
return
sd_steps = 50 if sd_device == "cuda" else 1
sd_steps = 50 if sd_device == "cuda" else 20
model = ModelManager(
name="sd1.5",
device=torch.device(sd_device),
@@ -177,7 +175,7 @@ def test_runway_sd_1_5_sd_scale(
)
@pytest.mark.parametrize("sd_device", ["mps"])
@pytest.mark.parametrize("sd_device", ["cuda", "mps"])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])
def test_runway_sd_sd_strength(sd_device, strategy, sampler):
@@ -214,7 +212,7 @@ def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
if sd_device == "cuda" and not torch.cuda.is_available():
return
sd_steps = 50 if sd_device == "cuda" else 1
sd_steps = 50 if sd_device == "cuda" else 20
model = ModelManager(
name="sd1.5",
device=torch.device(sd_device),
@@ -246,7 +244,7 @@ def test_local_file_path(sd_device, sampler):
if sd_device == "cuda" and not torch.cuda.is_available():
return
sd_steps = 1 if sd_device == "cpu" else 50
sd_steps = 1 if sd_device == "cpu" else 30
model = ModelManager(
name="sd1.5",
device=torch.device(sd_device),