wip
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
@@ -11,10 +12,10 @@ current_dir = Path(__file__).parent.absolute().resolve()
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
|
||||
def get_data(fx=1, fy=1.0):
|
||||
img = cv2.imread(str(current_dir / "image.png"))
|
||||
def get_data(fx=1, fy=1.0, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"):
|
||||
img = cv2.imread(str(img_p))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
||||
mask = cv2.imread(str(current_dir / "mask.png"), cv2.IMREAD_GRAYSCALE)
|
||||
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
if fx != 1:
|
||||
img = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
|
||||
@@ -35,8 +36,8 @@ def get_config(strategy, **kwargs):
|
||||
return Config(**data)
|
||||
|
||||
|
||||
def assert_equal(model, config, gt_name, fx=1, fy=1):
|
||||
img, mask = get_data(fx=fx, fy=fy)
|
||||
def assert_equal(model, config, gt_name, fx=1, fy=1, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"):
|
||||
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||
res = model(img, mask, config)
|
||||
cv2.imwrite(
|
||||
str(current_dir / gt_name),
|
||||
@@ -153,3 +154,26 @@ def test_fcf(strategy):
|
||||
fx=3.8,
|
||||
fy=2
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
def test_sd(strategy, capfd):
|
||||
def callback(step: int):
|
||||
print(f"sd_step_{step}")
|
||||
|
||||
sd_steps = 2
|
||||
model = ModelManager(name="sd", device=device, hf_access_token=os.environ['HF_ACCESS_TOKEN'], callbacks=[callback])
|
||||
cfg = get_config(strategy, prompt='a cat sitting on a bench', sd_steps=sd_steps)
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sd_{strategy.capitalize()}_result.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=0.5,
|
||||
fy=0.5
|
||||
)
|
||||
|
||||
captured = capfd.readouterr()
|
||||
for i in range(sd_steps):
|
||||
assert f'sd_step_{i}' in captured.out
|
||||
|
||||
Reference in New Issue
Block a user