update test
This commit is contained in:
@@ -14,7 +14,7 @@ save_dir.mkdir(exist_ok=True, parents=True)
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
|
||||
def get_data(fx=1, fy=1.0, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"):
|
||||
def get_data(fx: float = 1, fy: float = 1.0, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"):
|
||||
img = cv2.imread(str(img_p))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
||||
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
||||
@@ -36,7 +36,10 @@ def get_config(strategy, **kwargs):
|
||||
return Config(**data)
|
||||
|
||||
|
||||
def assert_equal(model, config, gt_name, fx=1, fy=1, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"):
|
||||
def assert_equal(model, config, gt_name,
|
||||
fx: float = 1, fy: float = 1,
|
||||
img_p=current_dir / "image.png",
|
||||
mask_p=current_dir / "mask.png"):
|
||||
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||
print(f"Input image shape: {img.shape}")
|
||||
res = model(img, mask, config)
|
||||
@@ -157,91 +160,32 @@ def test_fcf(strategy):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_device", ['cpu'])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm, SDSampler.k_lms])
|
||||
def test_sd(strategy, sampler):
|
||||
def callback(i, t, latents):
|
||||
print(f"sd_step_{i}")
|
||||
|
||||
sd_steps = 50
|
||||
model = ModelManager(name="sd1.4",
|
||||
device=device,
|
||||
hf_access_token=os.environ['HF_ACCESS_TOKEN'],
|
||||
sd_run_local=False,
|
||||
sd_disable_nsfw=False,
|
||||
sd_cpu_textencoder=False,
|
||||
callback=callback)
|
||||
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sd_{strategy.capitalize()}_{sampler}_result.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sd_{strategy.capitalize()}_{sampler}_blur_mask_result.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm, SDSampler.k_lms])
|
||||
@pytest.mark.parametrize("disable_nsfw", [True, False])
|
||||
@pytest.mark.parametrize("cpu_textencoder", [True, False])
|
||||
def test_sd_run_local(strategy, sampler, disable_nsfw, cpu_textencoder):
|
||||
@pytest.mark.parametrize("disable_nsfw", [True, False])
|
||||
def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw):
|
||||
def callback(i, t, latents):
|
||||
print(f"sd_step_{i}")
|
||||
|
||||
sd_steps = 50
|
||||
model = ModelManager(
|
||||
name="sd1.4",
|
||||
device=device,
|
||||
# hf_access_token=os.environ.get('HF_ACCESS_TOKEN', None),
|
||||
hf_access_token=None,
|
||||
sd_run_local=True,
|
||||
sd_disable_nsfw=disable_nsfw,
|
||||
sd_cpu_textencoder=cpu_textencoder,
|
||||
)
|
||||
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sd_{strategy.capitalize()}_{sampler}_local_disablensfw_{disable_nsfw}_cputextencoder_{cpu_textencoder}_result.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim, SDSampler.pndm, SDSampler.k_lms])
|
||||
def test_runway_sd_1_5(strategy, sampler):
|
||||
def callback(i, t, latents):
|
||||
print(f"sd_step_{i}")
|
||||
|
||||
sd_steps = 20
|
||||
sd_steps = 1
|
||||
model = ModelManager(name="sd1.5",
|
||||
device=device,
|
||||
hf_access_token=None,
|
||||
device=sd_device,
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
sd_disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
sd_disable_nsfw=disable_nsfw,
|
||||
sd_cpu_textencoder=cpu_textencoder,
|
||||
callback=callback)
|
||||
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
|
||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
name = f"{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}"
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"runway_sd_{strategy.capitalize()}_{sampler}_result.png",
|
||||
f"runway_sd_{strategy.capitalize()}_{name}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=1.3
|
||||
@@ -250,7 +194,7 @@ def test_runway_sd_1_5(strategy, sampler):
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"runway_sd_{strategy.capitalize()}_{sampler}_blur_mask_result.png",
|
||||
f"runway_sd_{strategy.capitalize()}_{name}_blur_mask.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
|
||||
fy=1.3
|
||||
|
||||
Reference in New Issue
Block a user