diff --git a/lama_cleaner/tests/test_model.py b/lama_cleaner/tests/test_model.py index accdae3..cd70ad7 100644 --- a/lama_cleaner/tests/test_model.py +++ b/lama_cleaner/tests/test_model.py @@ -158,7 +158,7 @@ def test_fcf(strategy): @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm]) -def test_sd(strategy, sampler, capfd): +def test_sd(strategy, sampler): def callback(step: int): print(f"sd_step_{step}") @@ -184,6 +184,38 @@ def test_sd(strategy, sampler, capfd): mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png", ) - # captured = capfd.readouterr() - # for i in range(sd_steps): - # assert f'sd_step_{i}' in captured.out + +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_sd_run_local(strategy, sampler): + def callback(step: int): + print(f"sd_step_{step}") + + sd_steps = 50 + model = ModelManager( + name="sd1.4", + device=device, + hf_access_token=None, + sd_run_local=True, + sd_disable_nsfw=True, + sd_cpu_textencoder=True, + callbacks=[callback] + ) + cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps) + cfg.sd_sampler = sampler + + assert_equal( + model, + cfg, + f"sd_{strategy.capitalize()}_{sampler}_local_result.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + assert_equal( + model, + cfg, + f"sd_{strategy.capitalize()}_{sampler}_blur_mask_local_result.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png", + )