get samplers from backend

This commit is contained in:
Qing
2024-01-02 14:34:36 +08:00
parent a2fd5bb3ea
commit f38be37f8c
14 changed files with 141 additions and 101 deletions

View File

@@ -3,7 +3,9 @@ import cv2
import pytest
import torch
from lama_cleaner.helper import encode_pil_to_base64
from lama_cleaner.schema import LDMSampler, HDStrategy, InpaintRequest, SDSampler
from PIL import Image
current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / "result"
@@ -21,7 +23,7 @@ def check_device(device: str) -> int:
def assert_equal(
model,
config,
config: InpaintRequest,
gt_name,
fx: float = 1,
fy: float = 1,
@@ -29,6 +31,8 @@ def assert_equal(
mask_p=current_dir / "mask.png",
):
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
config.image = encode_pil_to_base64(Image.fromarray(img), 95, {})[0]
config.mask = encode_pil_to_base64(Image.fromarray(mask), 95, {})[0]
print(f"Input image shape: {img.shape}")
res = model(img, mask, config)
ok = cv2.imwrite(
@@ -72,4 +76,4 @@ def get_config(**kwargs):
hd_strategy_resize_limit=200,
)
data.update(**kwargs)
return InpaintRequest(**data)
return InpaintRequest(image="", mask="", **data)