This commit is contained in:
Qing
2022-09-15 22:21:27 +08:00
parent 3ac6ee7f44
commit 32854d40da
52 changed files with 2258 additions and 205 deletions

View File

@@ -1,3 +1,4 @@
import os
from pathlib import Path
import cv2
@@ -11,10 +12,10 @@ current_dir = Path(__file__).parent.absolute().resolve()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_data(fx=1, fy=1.0):
img = cv2.imread(str(current_dir / "image.png"))
def get_data(fx=1, fy=1.0, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"):
img = cv2.imread(str(img_p))
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
mask = cv2.imread(str(current_dir / "mask.png"), cv2.IMREAD_GRAYSCALE)
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
if fx != 1:
img = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
@@ -35,8 +36,8 @@ def get_config(strategy, **kwargs):
return Config(**data)
def assert_equal(model, config, gt_name, fx=1, fy=1):
img, mask = get_data(fx=fx, fy=fy)
def assert_equal(model, config, gt_name, fx=1, fy=1, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"):
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
res = model(img, mask, config)
cv2.imwrite(
str(current_dir / gt_name),
@@ -153,3 +154,26 @@ def test_fcf(strategy):
fx=3.8,
fy=2
)
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_sd(strategy, capfd):
def callback(step: int):
print(f"sd_step_{step}")
sd_steps = 2
model = ModelManager(name="sd", device=device, hf_access_token=os.environ['HF_ACCESS_TOKEN'], callbacks=[callback])
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
assert_equal(
model,
cfg,
f"sd_{strategy.capitalize()}_result.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fx=0.5,
fy=0.5
)
captured = capfd.readouterr()
for i in range(sd_steps):
assert f'sd_step_{i}' in captured.out