This commit is contained in:
Qing
2023-11-16 11:08:34 +08:00
parent 8f942e27c4
commit 0cfec489b7
7 changed files with 63 additions and 28 deletions

View File

@@ -10,13 +10,13 @@ from lama_cleaner.schema import HDStrategy
current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / 'result'
save_dir.mkdir(exist_ok=True, parents=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'mps'
@pytest.mark.parametrize("disable_nsfw", [True, False])
@pytest.mark.parametrize("cpu_offload", [False, True])
def test_instruct_pix2pix(disable_nsfw, cpu_offload):
sd_steps = 50 if device == 'cuda' else 1
sd_steps = 50 if device == 'cuda' else 20
model = ModelManager(name="instruct_pix2pix",
device=torch.device(device),
hf_access_token="",
@@ -41,7 +41,7 @@ def test_instruct_pix2pix(disable_nsfw, cpu_offload):
@pytest.mark.parametrize("disable_nsfw", [False])
@pytest.mark.parametrize("cpu_offload", [False])
def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload):
sd_steps = 50 if device == 'cuda' else 1
sd_steps = 50 if device == 'cuda' else 20
model = ModelManager(name="instruct_pix2pix",
device=torch.device(device),
hf_access_token="",