sd make change sampler work

This commit is contained in:
Qing
2022-09-22 12:38:32 +08:00
parent 047474ab84
commit e1fb0030d1
3 changed files with 46 additions and 19 deletions

View File

@@ -6,7 +6,7 @@ import pytest
import torch
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config, HDStrategy, LDMSampler
from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler
current_dir = Path(__file__).parent.absolute().resolve()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -155,25 +155,27 @@ def test_fcf(strategy):
fy=2
)
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_sd(strategy, capfd):
@pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm])
def test_sd(strategy, sampler, capfd):
def callback(step: int):
print(f"sd_step_{step}")
sd_steps = 2
model = ModelManager(name="sd", device=device, hf_access_token=os.environ['HF_ACCESS_TOKEN'], callbacks=[callback])
sd_steps = 50
model = ModelManager(name="sd1.4", device=device, hf_access_token=os.environ['HF_ACCESS_TOKEN'],
callbacks=[callback])
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
cfg.sd_sampler = sampler
assert_equal(
model,
cfg,
f"sd_{strategy.capitalize()}_result.png",
f"sd_{strategy.capitalize()}_{sampler}_result.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=0.5,
fy=0.5
)
captured = capfd.readouterr()
for i in range(sd_steps):
assert f'sd_step_{i}' in captured.out
# captured = capfd.readouterr()
# for i in range(sd_steps):
# assert f'sd_step_{i}' in captured.out