add MAT model

This commit is contained in:
Qing
2022-08-22 23:24:02 +08:00
parent a5e840765e
commit 6d2b24ed6b
8 changed files with 2132 additions and 9 deletions

View File

@@ -11,13 +11,19 @@ from lama_cleaner.schema import Config, HDStrategy, LDMSampler
current_dir = Path(__file__).parent.absolute().resolve()
def get_data(fx=1):
def get_data(fx=1, fy=1.0):
img = cv2.imread(str(current_dir / "image.png"))
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
mask = cv2.imread(str(current_dir / "mask.png"), cv2.IMREAD_GRAYSCALE)
# img = cv2.imread("/Users/qing/code/github/MAT/test_sets/Places/images/test1.jpg")
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# mask = cv2.imread("/Users/qing/code/github/MAT/test_sets/Places/masks/mask1.png", cv2.IMREAD_GRAYSCALE)
# mask = 255 - mask
if fx != 1:
img = cv2.resize(img, None, fx=fx, fy=1)
mask = cv2.resize(mask, None, fx=fx, fy=1)
img = cv2.resize(img, None, fx=fx, fy=fy)
mask = cv2.resize(mask, None, fx=fx, fy=fy)
return img, mask
@@ -34,8 +40,8 @@ def get_config(strategy, **kwargs):
return Config(**data)
def assert_equal(model, config, gt_name, fx=1):
img, mask = get_data(fx=fx)
def assert_equal(model, config, gt_name, fx=1, fy=1):
img, mask = get_data(fx=fx, fy=fy)
res = model(img, mask, config)
cv2.imwrite(
str(current_dir / gt_name),
@@ -111,6 +117,20 @@ def test_zits(strategy, zits_wireframe):
assert_equal(
model,
cfg,
f"zits_{strategy[0].upper() + strategy[1:]}_wireframe_{zits_wireframe}_fx_{fx}_result.png",
f"zits_{strategy.capitalize()}_wireframe_{zits_wireframe}_fx_{fx}_result.png",
fx=fx,
)
@pytest.mark.parametrize(
"strategy", [HDStrategy.ORIGINAL]
)
def test_mat(strategy):
model = ModelManager(name="mat", device="cpu")
cfg = get_config(strategy)
assert_equal(
model,
cfg,
f"mat_{strategy.capitalize()}_result.png",
)