🎨 完整的 IOPaint 项目更新
## 主要更新 - ✨ 更新所有依赖到最新稳定版本 - 📝 添加详细的项目文档和模型推荐 - 🔧 配置 VSCode Cloud Studio 预览功能 - 🐛 修复 PyTorch API 弃用警告 ## 依赖更新 - diffusers: 0.27.2 → 0.35.2 - gradio: 4.21.0 → 5.46.0 - peft: 0.7.1 → 0.18.0 - Pillow: 9.5.0 → 11.3.0 - fastapi: 0.108.0 → 0.116.2 ## 新增文件 - CLAUDE.md - 项目架构和开发指南 - UPGRADE_NOTES.md - 详细的升级说明 - .vscode/preview.yml - 预览配置 - .vscode/LAUNCH_GUIDE.md - 启动指南 - .gitignore - 更新的忽略规则 ## 代码修复 - 修复 iopaint/model/ldm.py 中的 torch.cuda.amp.autocast() 弃用警告 ## 文档更新 - README.md - 添加模型推荐和使用指南 - 完整的项目源码(iopaint/) - Web 前端源码(web_app/) 🤖 Generated with Claude Code
2
iopaint/tests/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*_result.png
|
||||
result/
|
||||
0
iopaint/tests/__init__.py
Normal file
BIN
iopaint/tests/anime_test.png
Normal file
|
After Width: | Height: | Size: 480 KiB |
BIN
iopaint/tests/anytext_mask.jpg
Normal file
|
After Width: | Height: | Size: 6.7 KiB |
BIN
iopaint/tests/anytext_ref.jpg
Normal file
|
After Width: | Height: | Size: 104 KiB |
BIN
iopaint/tests/bunny.jpeg
Normal file
|
After Width: | Height: | Size: 51 KiB |
BIN
iopaint/tests/cat.png
Normal file
|
After Width: | Height: | Size: 481 KiB |
BIN
iopaint/tests/icc_profile_test.jpg
Normal file
|
After Width: | Height: | Size: 215 KiB |
BIN
iopaint/tests/icc_profile_test.png
Normal file
|
After Width: | Height: | Size: 305 KiB |
BIN
iopaint/tests/image.png
Normal file
|
After Width: | Height: | Size: 129 KiB |
BIN
iopaint/tests/mask.png
Normal file
|
After Width: | Height: | Size: 7.7 KiB |
BIN
iopaint/tests/overture-creations-5sI6fQgYIuo.png
Normal file
|
After Width: | Height: | Size: 395 KiB |
BIN
iopaint/tests/overture-creations-5sI6fQgYIuo_mask.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
iopaint/tests/overture-creations-5sI6fQgYIuo_mask_blur.png
Normal file
|
After Width: | Height: | Size: 38 KiB |
BIN
iopaint/tests/png_parameter_test.png
Normal file
|
After Width: | Height: | Size: 69 KiB |
17
iopaint/tests/test_adjust_mask.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import cv2
|
||||
from iopaint.helper import adjust_mask
|
||||
from iopaint.tests.utils import current_dir, save_dir
|
||||
|
||||
mask_p = current_dir / "overture-creations-5sI6fQgYIuo_mask.png"
|
||||
|
||||
|
||||
def test_adjust_mask():
|
||||
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
||||
res_mask = adjust_mask(mask, 0, "expand")
|
||||
cv2.imwrite(str(save_dir / "adjust_mask_original.png"), res_mask)
|
||||
res_mask = adjust_mask(mask, 40, "expand")
|
||||
cv2.imwrite(str(save_dir / "adjust_mask_expand.png"), res_mask)
|
||||
res_mask = adjust_mask(mask, 20, "shrink")
|
||||
cv2.imwrite(str(save_dir / "adjust_mask_shrink.png"), res_mask)
|
||||
res_mask = adjust_mask(mask, 20, "reverse")
|
||||
cv2.imwrite(str(save_dir / "adjust_mask_reverse.png"), res_mask)
|
||||
45
iopaint/tests/test_anytext.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import os
|
||||
|
||||
from iopaint.tests.utils import check_device, get_config, assert_equal
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.schema import HDStrategy
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
save_dir = current_dir / "result"
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||
def test_anytext(device):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name="Sanster/AnyText",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
|
||||
cfg = get_config(
|
||||
strategy=HDStrategy.ORIGINAL,
|
||||
prompt='Characters written in chalk on the blackboard that says "DADDY", best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks',
|
||||
negative_prompt="low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture",
|
||||
sd_steps=sd_steps,
|
||||
sd_guidance_scale=9.0,
|
||||
sd_seed=66273235,
|
||||
sd_match_histograms=True
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"anytext.png",
|
||||
img_p=current_dir / "anytext_ref.jpg",
|
||||
mask_p=current_dir / "anytext_mask.jpg",
|
||||
)
|
||||
110
iopaint/tests/test_brushnet.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import os
|
||||
|
||||
from iopaint.const import SD_BRUSHNET_CHOICES
|
||||
from iopaint.tests.utils import check_device, get_config, assert_equal
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.schema import HDStrategy, SDSampler, PowerPaintTask
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
save_dir = current_dir / "result"
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m_karras])
|
||||
def test_runway_brushnet(device, sampler):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-v1-5",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=HDStrategy.ORIGINAL,
|
||||
prompt="face of a fox, sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
sd_guidance_scale=7.5,
|
||||
enable_brushnet=True,
|
||||
brushnet_method=SD_BRUSHNET_CHOICES[0],
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"brushnet_random_mask_{device}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m])
|
||||
def test_runway_powerpaint_v2(device, sampler):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-v1-5",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
|
||||
tasks = {
|
||||
PowerPaintTask.text_guided: {
|
||||
"prompt": "face of a fox, sitting on a bench",
|
||||
"scale": 7.5,
|
||||
},
|
||||
PowerPaintTask.context_aware: {
|
||||
"prompt": "face of a fox, sitting on a bench",
|
||||
"scale": 7.5,
|
||||
},
|
||||
PowerPaintTask.shape_guided: {
|
||||
"prompt": "face of a fox, sitting on a bench",
|
||||
"scale": 7.5,
|
||||
},
|
||||
PowerPaintTask.object_remove: {
|
||||
"prompt": "",
|
||||
"scale": 12,
|
||||
},
|
||||
PowerPaintTask.outpainting: {
|
||||
"prompt": "",
|
||||
"scale": 7.5,
|
||||
},
|
||||
}
|
||||
|
||||
for task, data in tasks.items():
|
||||
cfg = get_config(
|
||||
strategy=HDStrategy.ORIGINAL,
|
||||
prompt=data["prompt"],
|
||||
negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature",
|
||||
sd_steps=sd_steps,
|
||||
sd_guidance_scale=data["scale"],
|
||||
enable_powerpaint_v2=True,
|
||||
powerpaint_task=task,
|
||||
sd_sampler=sampler,
|
||||
sd_mask_blur=11,
|
||||
sd_seed=42,
|
||||
# sd_keep_unmasked_area=False
|
||||
)
|
||||
if task == PowerPaintTask.outpainting:
|
||||
cfg.use_extender = True
|
||||
cfg.extender_x = -128
|
||||
cfg.extender_y = -128
|
||||
cfg.extender_width = 768
|
||||
cfg.extender_height = 768
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"powerpaint_v2_{device}_{task}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
118
iopaint/tests/test_controlnet.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import os
|
||||
|
||||
from iopaint.const import SD_CONTROLNET_CHOICES
|
||||
from iopaint.tests.utils import current_dir, check_device, get_config, assert_equal
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.schema import HDStrategy, SDSampler
|
||||
|
||||
|
||||
model_name = "runwayml/stable-diffusion-inpainting"
|
||||
|
||||
|
||||
def convert_controlnet_method_name(name):
|
||||
return name.replace("/", "--")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||
@pytest.mark.parametrize("controlnet_method", [SD_CONTROLNET_CHOICES[0]])
|
||||
def test_runway_sd_1_5(device, controlnet_method):
|
||||
sd_steps = check_device(device)
|
||||
|
||||
model = ModelManager(
|
||||
name=model_name,
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=device == "cuda",
|
||||
enable_controlnet=True,
|
||||
controlnet_method=controlnet_method,
|
||||
)
|
||||
|
||||
cfg = get_config(
|
||||
prompt="a fox sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
enable_controlnet=True,
|
||||
controlnet_conditioning_scale=0.5,
|
||||
controlnet_method=controlnet_method,
|
||||
)
|
||||
name = f"device_{device}"
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sd_controlnet_{convert_controlnet_method_name(controlnet_method)}_{name}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||
def test_controlnet_switch(device):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name=model_name,
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=True,
|
||||
enable_controlnet=True,
|
||||
controlnet_method="lllyasviel/control_v11p_sd15_canny",
|
||||
)
|
||||
cfg = get_config(
|
||||
prompt="a fox sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
enable_controlnet=True,
|
||||
controlnet_method="lllyasviel/control_v11f1p_sd15_depth",
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"controlnet_switch_canny_to_depth_device_{device}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=1.2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||
@pytest.mark.parametrize(
|
||||
"local_file", ["sd-v1-5-inpainting.ckpt", "v1-5-pruned-emaonly.safetensors"]
|
||||
)
|
||||
def test_local_file_path(device, local_file):
|
||||
sd_steps = check_device(device)
|
||||
|
||||
controlnet_kwargs = dict(
|
||||
enable_controlnet=True,
|
||||
controlnet_method=SD_CONTROLNET_CHOICES[0],
|
||||
)
|
||||
|
||||
model = ModelManager(
|
||||
name=local_file,
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=True,
|
||||
**controlnet_kwargs,
|
||||
)
|
||||
cfg = get_config(
|
||||
prompt="a fox sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
**controlnet_kwargs,
|
||||
)
|
||||
|
||||
name = f"device_{device}"
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"{convert_controlnet_method_name(controlnet_kwargs['controlnet_method'])}_local_model_{name}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
40
iopaint/tests/test_instruct_pix2pix.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.schema import HDStrategy
|
||||
from iopaint.tests.utils import get_config, check_device, assert_equal, current_dir
|
||||
|
||||
model_name = "timbrooks/instruct-pix2pix"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||
@pytest.mark.parametrize("disable_nsfw", [True, False])
|
||||
@pytest.mark.parametrize("cpu_offload", [False, True])
|
||||
def test_instruct_pix2pix(device, disable_nsfw, cpu_offload):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name=model_name,
|
||||
device=torch.device(device),
|
||||
disable_nsfw=disable_nsfw,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=cpu_offload,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=HDStrategy.ORIGINAL,
|
||||
prompt="What if it were snowing?",
|
||||
sd_steps=sd_steps
|
||||
)
|
||||
|
||||
name = f"device_{device}_disnsfw_{disable_nsfw}_cpu_offload_{cpu_offload}"
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"instruct_pix2pix_{name}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=1.3,
|
||||
)
|
||||
19
iopaint/tests/test_load_img.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from iopaint.helper import load_img
|
||||
from iopaint.tests.utils import current_dir
|
||||
|
||||
png_img_p = current_dir / "image.png"
|
||||
jpg_img_p = current_dir / "bunny.jpeg"
|
||||
|
||||
|
||||
def test_load_png_image():
|
||||
with open(png_img_p, "rb") as f:
|
||||
np_img, alpha_channel = load_img(f.read())
|
||||
assert np_img.shape == (256, 256, 3)
|
||||
assert alpha_channel.shape == (256, 256)
|
||||
|
||||
|
||||
def test_load_jpg_image():
|
||||
with open(jpg_img_p, "rb") as f:
|
||||
np_img, alpha_channel = load_img(f.read())
|
||||
assert np_img.shape == (394, 448, 3)
|
||||
assert alpha_channel is None
|
||||
102
iopaint/tests/test_low_mem.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import os
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from iopaint.tests.utils import check_device, get_config, assert_equal, current_dir
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.schema import HDStrategy, SDSampler
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||
def test_runway_sd_1_5_low_mem(device):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-inpainting",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
low_mem=True,
|
||||
)
|
||||
|
||||
all_samplers = [member.value for member in SDSampler.__members__.values()]
|
||||
print(all_samplers)
|
||||
cfg = get_config(
|
||||
strategy=HDStrategy.ORIGINAL,
|
||||
prompt="a fox sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
sd_sampler=SDSampler.ddim,
|
||||
)
|
||||
|
||||
name = f"device_{device}"
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"runway_sd_{name}_low_mem.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.lcm])
|
||||
def test_runway_sd_lcm_lora_low_mem(device, sampler):
|
||||
check_device(device)
|
||||
|
||||
sd_steps = 5
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-inpainting",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
low_mem=True,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=HDStrategy.ORIGINAL,
|
||||
prompt="face of a fox, sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
sd_guidance_scale=2,
|
||||
sd_lcm_lora=True,
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"runway_sd_1_5_lcm_lora_device_{device}_low_mem.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||
def test_runway_norm_sd_model(device, strategy, sampler):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-v1-5",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
low_mem=True,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=strategy, prompt="face of a fox, sitting on a bench", sd_steps=sd_steps
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"runway_{device}_norm_sd_model_device_{device}_low_mem.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
36
iopaint/tests/test_match_histograms.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.schema import SDSampler, HDStrategy
|
||||
from iopaint.tests.utils import check_device, get_config, assert_equal, current_dir
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||
def test_sd_match_histograms(device, sampler):
|
||||
sd_steps = check_device(device)
|
||||
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-inpainting",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=HDStrategy.ORIGINAL,
|
||||
prompt="face of a fox, sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
sd_guidance_scale=7.5,
|
||||
sd_lcm_lora=False,
|
||||
sd_match_histograms=True,
|
||||
sd_sampler=sampler
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"runway_sd_1_5_device_{device}_match_histograms.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
160
iopaint/tests/test_model.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.schema import HDStrategy, LDMSampler
|
||||
from iopaint.tests.utils import assert_equal, get_config, current_dir, check_device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||
@pytest.mark.parametrize(
|
||||
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
|
||||
)
|
||||
def test_lama(device, strategy):
|
||||
check_device(device)
|
||||
model = ModelManager(name="lama", device=device)
|
||||
assert_equal(
|
||||
model,
|
||||
get_config(strategy=strategy),
|
||||
f"lama_{strategy[0].upper() + strategy[1:]}_result.png",
|
||||
)
|
||||
|
||||
fx = 1.3
|
||||
assert_equal(
|
||||
model,
|
||||
get_config(strategy=strategy),
|
||||
f"lama_{strategy[0].upper() + strategy[1:]}_fx_{fx}_result.png",
|
||||
fx=1.3,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize(
|
||||
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
|
||||
)
|
||||
@pytest.mark.parametrize("ldm_sampler", [LDMSampler.ddim, LDMSampler.plms])
|
||||
def test_ldm(device, strategy, ldm_sampler):
|
||||
check_device(device)
|
||||
model = ModelManager(name="ldm", device=device)
|
||||
cfg = get_config(strategy=strategy, ldm_sampler=ldm_sampler)
|
||||
assert_equal(
|
||||
model, cfg, f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_result.png"
|
||||
)
|
||||
|
||||
fx = 1.3
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_fx_{fx}_result.png",
|
||||
fx=fx,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize(
|
||||
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
|
||||
)
|
||||
@pytest.mark.parametrize("zits_wireframe", [False, True])
|
||||
def test_zits(device, strategy, zits_wireframe):
|
||||
check_device(device)
|
||||
model = ModelManager(name="zits", device=device)
|
||||
cfg = get_config(strategy=strategy, zits_wireframe=zits_wireframe)
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"zits_{strategy[0].upper() + strategy[1:]}_wireframe_{zits_wireframe}_result.png",
|
||||
)
|
||||
|
||||
fx = 1.3
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"zits_{strategy.capitalize()}_wireframe_{zits_wireframe}_fx_{fx}_result.png",
|
||||
fx=fx,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("no_half", [True, False])
|
||||
def test_mat(device, strategy, no_half):
|
||||
check_device(device)
|
||||
model = ModelManager(name="mat", device=device, no_half=no_half)
|
||||
cfg = get_config(strategy=strategy)
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"mat_{strategy.capitalize()}_result.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
def test_fcf(device, strategy):
|
||||
check_device(device)
|
||||
model = ModelManager(name="fcf", device=device)
|
||||
cfg = get_config(strategy=strategy)
|
||||
|
||||
assert_equal(model, cfg, f"fcf_{strategy.capitalize()}_result.png", fx=2, fy=2)
|
||||
assert_equal(model, cfg, f"fcf_{strategy.capitalize()}_result.png", fx=3.8, fy=2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
|
||||
)
|
||||
@pytest.mark.parametrize("cv2_flag", ["INPAINT_NS", "INPAINT_TELEA"])
|
||||
@pytest.mark.parametrize("cv2_radius", [3, 15])
|
||||
def test_cv2(strategy, cv2_flag, cv2_radius):
|
||||
model = ModelManager(
|
||||
name="cv2",
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
cfg = get_config(strategy=strategy, cv2_flag=cv2_flag, cv2_radius=cv2_radius)
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"cv2_{strategy.capitalize()}_{cv2_flag}_{cv2_radius}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize(
|
||||
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
|
||||
)
|
||||
def test_manga(device, strategy):
|
||||
check_device(device)
|
||||
model = ModelManager(
|
||||
name="manga",
|
||||
device=torch.device(device),
|
||||
)
|
||||
cfg = get_config(strategy=strategy)
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"manga_{strategy.capitalize()}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
def test_mi_gan(device, strategy):
|
||||
check_device(device)
|
||||
model = ModelManager(
|
||||
name="migan",
|
||||
device=torch.device(device),
|
||||
)
|
||||
cfg = get_config(strategy=strategy)
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"migan_device_{device}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=1.5,
|
||||
fy=1.7
|
||||
)
|
||||
16
iopaint/tests/test_model_md5.py
Normal file
@@ -0,0 +1,16 @@
|
||||
def test_load_model():
|
||||
from iopaint.plugins import InteractiveSeg
|
||||
from iopaint.model_manager import ModelManager
|
||||
|
||||
interactive_seg_model = InteractiveSeg("vit_l", "cpu")
|
||||
|
||||
models = ["lama", "ldm", "zits", "mat", "fcf", "manga", "migan"]
|
||||
for m in models:
|
||||
ModelManager(
|
||||
name=m,
|
||||
device="cpu",
|
||||
no_half=False,
|
||||
disable_nsfw=False,
|
||||
sd_cpu_textencoder=True,
|
||||
cpu_offload=True,
|
||||
)
|
||||
70
iopaint/tests/test_model_switch.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import os
|
||||
|
||||
from iopaint.schema import InpaintRequest
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
import torch
|
||||
|
||||
from iopaint.model_manager import ModelManager
|
||||
|
||||
|
||||
def test_model_switch():
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-inpainting",
|
||||
enable_controlnet=True,
|
||||
controlnet_method="lllyasviel/control_v11p_sd15_canny",
|
||||
device=torch.device("mps"),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
cpu_offload=False,
|
||||
)
|
||||
|
||||
model.switch("lama")
|
||||
|
||||
|
||||
def test_controlnet_switch_onoff(caplog):
|
||||
name = "runwayml/stable-diffusion-inpainting"
|
||||
model = ModelManager(
|
||||
name=name,
|
||||
enable_controlnet=True,
|
||||
controlnet_method="lllyasviel/control_v11p_sd15_canny",
|
||||
device=torch.device("mps"),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
cpu_offload=False,
|
||||
)
|
||||
|
||||
model.switch_controlnet_method(
|
||||
InpaintRequest(
|
||||
name=name,
|
||||
enable_controlnet=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert "Disable controlnet" in caplog.text
|
||||
|
||||
|
||||
def test_switch_controlnet_method(caplog):
|
||||
name = "runwayml/stable-diffusion-inpainting"
|
||||
old_method = "lllyasviel/control_v11p_sd15_canny"
|
||||
new_method = "lllyasviel/control_v11p_sd15_openpose"
|
||||
model = ModelManager(
|
||||
name=name,
|
||||
enable_controlnet=True,
|
||||
controlnet_method=old_method,
|
||||
device=torch.device("mps"),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
cpu_offload=False,
|
||||
)
|
||||
|
||||
model.switch_controlnet_method(
|
||||
InpaintRequest(
|
||||
name=name,
|
||||
enable_controlnet=True,
|
||||
controlnet_method=new_method,
|
||||
)
|
||||
)
|
||||
|
||||
assert f"Switch Controlnet method from {old_method} to {new_method}" in caplog.text
|
||||
137
iopaint/tests/test_outpainting.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import os
|
||||
|
||||
from iopaint.tests.utils import current_dir, check_device
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.schema import SDSampler
|
||||
from iopaint.tests.test_model import get_config, assert_equal
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name", ["runwayml/stable-diffusion-inpainting"])
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||
@pytest.mark.parametrize(
|
||||
"rect",
|
||||
[
|
||||
[0, -100, 512, 512 - 128 + 100],
|
||||
[0, 128, 512, 512 - 128 + 100],
|
||||
[128, 0, 512 - 128 + 100, 512],
|
||||
[-100, 0, 512 - 128 + 100, 512],
|
||||
[0, 0, 512, 512 + 200],
|
||||
[256, 0, 512 + 200, 512],
|
||||
[-100, -100, 512 + 200, 512 + 200],
|
||||
],
|
||||
)
|
||||
def test_outpainting(name, device, rect):
|
||||
sd_steps = check_device(device)
|
||||
|
||||
model = ModelManager(
|
||||
name=name,
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
cfg = get_config(
|
||||
prompt="a dog sitting on a bench in the park",
|
||||
sd_steps=sd_steps,
|
||||
use_extender=True,
|
||||
extender_x=rect[0],
|
||||
extender_y=rect[1],
|
||||
extender_width=rect[2],
|
||||
extender_height=rect[3],
|
||||
sd_guidance_scale=8.0,
|
||||
sd_sampler=SDSampler.dpm_plus_plus_2m,
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"{name.replace('/', '--')}_outpainting_{'_'.join(map(str, rect))}_device_{device}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name", ["kandinsky-community/kandinsky-2-2-decoder-inpaint"])
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||
@pytest.mark.parametrize(
|
||||
"rect",
|
||||
[
|
||||
[-128, -128, 768, 768],
|
||||
],
|
||||
)
|
||||
def test_kandinsky_outpainting(name, device, rect):
|
||||
sd_steps = check_device(device)
|
||||
|
||||
model = ModelManager(
|
||||
name=name,
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
cfg = get_config(
|
||||
prompt="a cat",
|
||||
negative_prompt="lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature",
|
||||
sd_steps=sd_steps,
|
||||
use_extender=True,
|
||||
extender_x=rect[0],
|
||||
extender_y=rect[1],
|
||||
extender_width=rect[2],
|
||||
extender_height=rect[3],
|
||||
sd_guidance_scale=7,
|
||||
sd_sampler=SDSampler.dpm_plus_plus_2m,
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"{name.replace('/', '--')}_outpainting_{'_'.join(map(str, rect))}_device_{device}.png",
|
||||
img_p=current_dir / "cat.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=1,
|
||||
fy=1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name", ["Sanster/PowerPaint-V1-stable-diffusion-inpainting"])
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||
@pytest.mark.parametrize(
|
||||
"rect",
|
||||
[
|
||||
[-100, -100, 512 + 200, 512 + 200],
|
||||
],
|
||||
)
|
||||
def test_powerpaint_outpainting(name, device, rect):
|
||||
sd_steps = check_device(device)
|
||||
|
||||
model = ModelManager(
|
||||
name=name,
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
low_mem=True,
|
||||
)
|
||||
cfg = get_config(
|
||||
prompt="a dog sitting on a bench in the park",
|
||||
sd_steps=sd_steps,
|
||||
use_extender=True,
|
||||
extender_x=rect[0],
|
||||
extender_y=rect[1],
|
||||
extender_width=rect[2],
|
||||
extender_height=rect[3],
|
||||
sd_guidance_scale=8.0,
|
||||
sd_sampler=SDSampler.dpm_plus_plus_2m,
|
||||
powerpaint_task="outpainting",
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"{name.replace('/', '--')}_outpainting_{'_'.join(map(str, rect))}_device_{device}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
58
iopaint/tests/test_paint_by_example.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import cv2
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from iopaint.helper import encode_pil_to_base64
|
||||
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.schema import HDStrategy
|
||||
from iopaint.tests.utils import (
|
||||
current_dir,
|
||||
get_config,
|
||||
get_data,
|
||||
save_dir,
|
||||
check_device,
|
||||
)
|
||||
|
||||
model_name = "Fantasy-Studio/Paint-by-Example"
|
||||
|
||||
|
||||
def assert_equal(
|
||||
model,
|
||||
config,
|
||||
save_name: str,
|
||||
fx: float = 1,
|
||||
fy: float = 1,
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
example_p=current_dir / "bunny.jpeg",
|
||||
):
|
||||
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||
|
||||
example_image = cv2.imread(str(example_p))
|
||||
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGRA2RGB)
|
||||
example_image = cv2.resize(
|
||||
example_image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA
|
||||
)
|
||||
|
||||
print(f"Input image shape: {img.shape}, example_image: {example_image.shape}")
|
||||
config.paint_by_example_example_image = encode_pil_to_base64(
|
||||
Image.fromarray(example_image), 100, {}
|
||||
).decode("utf-8")
|
||||
res = model(img, mask, config)
|
||||
cv2.imwrite(str(save_dir / save_name), res)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||
def test_paint_by_example(device):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(name=model_name, device=device, disable_nsfw=True)
|
||||
cfg = get_config(strategy=HDStrategy.ORIGINAL, sd_steps=sd_steps)
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"paint_by_example_device_{device}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fy=0.9,
|
||||
fx=1.3,
|
||||
)
|
||||
130
iopaint/tests/test_plugins.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
from iopaint.helper import encode_pil_to_base64, gen_frontend_mask
|
||||
from iopaint.plugins.anime_seg import AnimeSeg
|
||||
from iopaint.schema import Device, RunPluginRequest, RemoveBGModel, InteractiveSegModel
|
||||
from iopaint.tests.utils import check_device, current_dir, save_dir
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
import cv2
|
||||
import pytest
|
||||
|
||||
from iopaint.plugins import (
|
||||
RemoveBG,
|
||||
RealESRGANUpscaler,
|
||||
GFPGANPlugin,
|
||||
RestoreFormerPlugin,
|
||||
InteractiveSeg,
|
||||
)
|
||||
|
||||
img_p = current_dir / "bunny.jpeg"
|
||||
img_bytes = open(img_p, "rb").read()
|
||||
bgr_img = cv2.imread(str(img_p))
|
||||
rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
|
||||
rgb_img_base64 = encode_pil_to_base64(Image.fromarray(rgb_img), 100, {})
|
||||
bgr_img_base64 = encode_pil_to_base64(Image.fromarray(bgr_img), 100, {})
|
||||
|
||||
person_p = current_dir / "image.png"
|
||||
person_bgr_img = cv2.imread(str(person_p))
|
||||
person_rgb_img = cv2.cvtColor(person_bgr_img, cv2.COLOR_BGR2RGB)
|
||||
person_rgb_img = cv2.resize(person_rgb_img, (512, 512))
|
||||
|
||||
|
||||
def _save(img, name):
|
||||
name = name.replace("/", "_")
|
||||
cv2.imwrite(str(save_dir / name), img)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", RemoveBGModel.values())
|
||||
@pytest.mark.parametrize("device", Device.values())
|
||||
def test_remove_bg(model_name, device):
|
||||
check_device(device)
|
||||
print(f"Testing {model_name} on {device}")
|
||||
model = RemoveBG(model_name, device)
|
||||
rgba_np_img = model.gen_image(
|
||||
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
|
||||
)
|
||||
res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA)
|
||||
_save(res, f"test_remove_bg_{model_name}_{device}.png")
|
||||
|
||||
bgr_np_img = model.gen_mask(
|
||||
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
|
||||
)
|
||||
|
||||
res_mask = gen_frontend_mask(bgr_np_img)
|
||||
_save(res_mask, f"test_remove_bg_frontend_mask_{model_name}_{device}.png")
|
||||
|
||||
assert len(bgr_np_img.shape) == 2
|
||||
_save(bgr_np_img, f"test_remove_bg_mask_{model_name}_{device}.jpeg")
|
||||
|
||||
|
||||
def test_anime_seg():
|
||||
model = AnimeSeg()
|
||||
img = cv2.imread(str(current_dir / "anime_test.png"))
|
||||
img_base64 = encode_pil_to_base64(Image.fromarray(img), 100, {})
|
||||
res = model.gen_image(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64))
|
||||
assert len(res.shape) == 3
|
||||
assert res.shape[-1] == 4
|
||||
_save(res, "test_anime_seg.png")
|
||||
|
||||
res = model.gen_mask(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64))
|
||||
assert len(res.shape) == 2
|
||||
_save(res, "test_anime_seg_mask.png")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
|
||||
def test_upscale(device):
|
||||
check_device(device)
|
||||
model = RealESRGANUpscaler("realesr-general-x4v3", device)
|
||||
res = model.gen_image(
|
||||
rgb_img,
|
||||
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=2),
|
||||
)
|
||||
_save(res, f"test_upscale_x2_{device}.png")
|
||||
|
||||
res = model.gen_image(
|
||||
rgb_img,
|
||||
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=4),
|
||||
)
|
||||
_save(res, f"test_upscale_x4_{device}.png")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
|
||||
def test_gfpgan(device):
|
||||
check_device(device)
|
||||
model = GFPGANPlugin(device)
|
||||
res = model.gen_image(
|
||||
person_rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64)
|
||||
)
|
||||
_save(res, f"test_gfpgan_{device}.png")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
|
||||
def test_restoreformer(device):
|
||||
check_device(device)
|
||||
model = RestoreFormerPlugin(device)
|
||||
res = model.gen_image(
|
||||
person_rgb_img,
|
||||
RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64),
|
||||
)
|
||||
_save(res, f"test_restoreformer_{device}.png")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name", InteractiveSegModel.values())
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
|
||||
def test_segment_anything(name, device):
|
||||
check_device(device)
|
||||
model = InteractiveSeg(name, device)
|
||||
new_mask = model.gen_mask(
|
||||
rgb_img,
|
||||
RunPluginRequest(
|
||||
name=InteractiveSeg.name,
|
||||
image=rgb_img_base64,
|
||||
clicks=([[448 // 2, 394 // 2, 1]]),
|
||||
),
|
||||
)
|
||||
|
||||
save_name = f"test_segment_anything_{name}_{device}.png"
|
||||
_save(new_mask, save_name)
|
||||
59
iopaint/tests/test_save_exif.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import io
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from iopaint.helper import pil_to_bytes, load_img
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
|
||||
|
||||
def print_exif(exif):
|
||||
for k, v in exif.items():
|
||||
print(f"{k}: {v}")
|
||||
|
||||
|
||||
def extra_info(img_p: Path):
|
||||
ext = img_p.suffix.strip(".")
|
||||
img_bytes = img_p.read_bytes()
|
||||
np_img, _, infos = load_img(img_bytes, False, True)
|
||||
res_pil_bytes = pil_to_bytes(Image.fromarray(np_img), ext=ext, infos=infos)
|
||||
res_img = Image.open(io.BytesIO(res_pil_bytes))
|
||||
return infos, res_img.info, res_pil_bytes
|
||||
|
||||
|
||||
def assert_keys(keys: List[str], infos, res_infos):
|
||||
for k in keys:
|
||||
assert k in infos
|
||||
assert k in res_infos
|
||||
assert infos[k] == res_infos[k]
|
||||
|
||||
|
||||
def run_test(file_path, keys):
|
||||
infos, res_infos, res_pil_bytes = extra_info(file_path)
|
||||
assert_keys(keys, infos, res_infos)
|
||||
with tempfile.NamedTemporaryFile("wb", suffix=file_path.suffix) as temp_file:
|
||||
temp_file.write(res_pil_bytes)
|
||||
temp_file.flush()
|
||||
infos, res_infos, res_pil_bytes = extra_info(Path(temp_file.name))
|
||||
assert_keys(keys, infos, res_infos)
|
||||
|
||||
|
||||
def test_png_icc_profile_png():
|
||||
run_test(current_dir / "icc_profile_test.png", ["icc_profile", "exif"])
|
||||
|
||||
|
||||
def test_png_icc_profile_jpeg():
|
||||
run_test(current_dir / "icc_profile_test.jpg", ["icc_profile", "exif"])
|
||||
|
||||
|
||||
def test_jpeg():
|
||||
jpg_img_p = current_dir / "bunny.jpeg"
|
||||
run_test(jpg_img_p, ["dpi", "exif"])
|
||||
|
||||
|
||||
def test_png_parameter():
|
||||
jpg_img_p = current_dir / "png_parameter_test.png"
|
||||
run_test(jpg_img_p, ["parameters"])
|
||||
71
iopaint/tests/test_save_quality.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
import io
|
||||
from PIL import Image
|
||||
from iopaint.helper import pil_to_bytes
|
||||
|
||||
TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def test_jpeg_quality():
|
||||
# Test JPEG quality settings
|
||||
img_path = os.path.join(TESTS_DIR, "bunny.jpeg")
|
||||
pil_img = Image.open(img_path)
|
||||
|
||||
# Test different quality settings
|
||||
high_quality = pil_to_bytes(pil_img, "jpg", quality=95)
|
||||
low_quality = pil_to_bytes(pil_img, "jpg", quality=50)
|
||||
|
||||
# Print file sizes in KB
|
||||
print(f"High quality JPEG size: {len(high_quality) / 1024:.2f} KB")
|
||||
print(f"Low quality JPEG size: {len(low_quality) / 1024:.2f} KB")
|
||||
|
||||
# Higher quality should result in larger file size
|
||||
assert len(high_quality) > len(low_quality)
|
||||
|
||||
# Verify the output can be opened as an image
|
||||
Image.open(io.BytesIO(high_quality))
|
||||
Image.open(io.BytesIO(low_quality))
|
||||
|
||||
|
||||
def test_png_parameters():
|
||||
# Test PNG with parameters
|
||||
img_path = os.path.join(TESTS_DIR, "cat.png")
|
||||
pil_img = Image.open(img_path)
|
||||
|
||||
# Test PNG with parameters
|
||||
params = {"parameters": "test_param=value"}
|
||||
png_with_params = pil_to_bytes(pil_img, "png", infos=params)
|
||||
|
||||
# Test PNG without parameters
|
||||
png_without_params = pil_to_bytes(pil_img, "png")
|
||||
|
||||
# Print file sizes in KB
|
||||
print(f"PNG with parameters size: {len(png_with_params) / 1024:.2f} KB")
|
||||
print(f"PNG without parameters size: {len(png_without_params) / 1024:.2f} KB")
|
||||
|
||||
# Verify both outputs can be opened as images
|
||||
Image.open(io.BytesIO(png_with_params))
|
||||
Image.open(io.BytesIO(png_without_params))
|
||||
|
||||
|
||||
def test_format_conversion():
|
||||
# Test format conversion
|
||||
jpeg_path = os.path.join(TESTS_DIR, "bunny.jpeg")
|
||||
png_path = os.path.join(TESTS_DIR, "cat.png")
|
||||
|
||||
# Convert JPEG to PNG
|
||||
jpeg_img = Image.open(jpeg_path)
|
||||
jpeg_to_png = pil_to_bytes(jpeg_img, "png")
|
||||
converted_png = Image.open(io.BytesIO(jpeg_to_png))
|
||||
print(f"JPEG to PNG size: {len(jpeg_to_png) / 1024:.2f} KB")
|
||||
assert converted_png.format.lower() == "png"
|
||||
|
||||
# Convert PNG to JPEG
|
||||
png_img = Image.open(png_path)
|
||||
# Convert RGBA to RGB if necessary
|
||||
if png_img.mode == "RGBA":
|
||||
png_img = png_img.convert("RGB")
|
||||
png_to_jpeg = pil_to_bytes(png_img, "jpg")
|
||||
print(f"PNG to JPEG size: {len(png_to_jpeg) / 1024:.2f} KB")
|
||||
converted_jpeg = Image.open(io.BytesIO(png_to_jpeg))
|
||||
assert converted_jpeg.format.lower() == "jpeg"
|
||||
240
iopaint/tests/test_sd_model.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import os
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from iopaint.tests.utils import check_device, get_config, assert_equal
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.schema import HDStrategy, SDSampler
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
save_dir = current_dir / "result"
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||
def test_runway_sd_1_5_all_samplers(device):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-inpainting",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
|
||||
all_samplers = [member.value for member in SDSampler.__members__.values()]
|
||||
print(all_samplers)
|
||||
for sampler in all_samplers:
|
||||
print(f"Testing sampler {sampler}")
|
||||
if (
|
||||
sampler
|
||||
in [SDSampler.dpm2_karras, SDSampler.dpm2_a_karras, SDSampler.lms_karras]
|
||||
and device == "mps"
|
||||
):
|
||||
# diffusers 0.25.0 still has bug on these sampler on mps, wait main branch released to fix it
|
||||
logger.warning(
|
||||
"skip dpm2_karras on mps, diffusers does not support it on mps. TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead."
|
||||
)
|
||||
continue
|
||||
cfg = get_config(
|
||||
strategy=HDStrategy.ORIGINAL,
|
||||
prompt="a fox sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
sd_sampler=sampler,
|
||||
)
|
||||
|
||||
name = f"device_{device}_{sampler}"
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"runway_sd_{name}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.lcm])
|
||||
def test_runway_sd_lcm_lora(device, sampler):
|
||||
check_device(device)
|
||||
|
||||
sd_steps = 5
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-inpainting",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=HDStrategy.ORIGINAL,
|
||||
prompt="face of a fox, sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
sd_guidance_scale=2,
|
||||
sd_lcm_lora=True,
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"runway_sd_1_5_lcm_lora_device_{device}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||
def test_runway_sd_sd_strength(device, strategy, sampler):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-inpainting",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=strategy,
|
||||
prompt="a fox sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
sd_strength=0.8,
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"runway_sd_strength_0.8_device_{device}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||
def test_runway_sd_cpu_textencoder(device, strategy, sampler):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-inpainting",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=strategy,
|
||||
prompt="a fox sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
sd_sampler=sampler,
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"runway_sd_device_{device}_cpu_textencoder.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||
def test_runway_norm_sd_model(device, strategy, sampler):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-v1-5",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=strategy, prompt="face of a fox, sitting on a bench", sd_steps=sd_steps
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"runway_{device}_norm_sd_model_device_{device}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda"])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m])
|
||||
def test_runway_sd_1_5_cpu_offload(device, strategy, sampler):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-inpainting",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=True,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
name = f"device_{device}_{sampler}"
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"runway_sd_{strategy.capitalize()}_{name}_cpu_offload.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||
@pytest.mark.parametrize(
|
||||
"name",
|
||||
[
|
||||
"sd-v1-5-inpainting.safetensors",
|
||||
"v1-5-pruned-emaonly.safetensors",
|
||||
"sd_xl_base_1.0.safetensors",
|
||||
"sd_xl_base_1.0_inpainting_0.1.safetensors",
|
||||
],
|
||||
)
|
||||
def test_local_file_path(device, sampler, name):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name=name,
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=False,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=HDStrategy.ORIGINAL,
|
||||
prompt="a fox sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
name = f"device_{device}_{sampler}_{name}"
|
||||
|
||||
is_sdxl = "sd_xl" in name
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sd_local_model_{name}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=1.5 if is_sdxl else 1,
|
||||
fy=1.5 if is_sdxl else 1,
|
||||
)
|
||||
118
iopaint/tests/test_sdxl.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import os
|
||||
|
||||
from iopaint.tests.utils import check_device, current_dir
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.schema import HDStrategy, SDSampler
|
||||
from iopaint.tests.test_model import get_config, assert_equal
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||
def test_sdxl(device, strategy, sampler):
|
||||
sd_steps = check_device(device)
|
||||
|
||||
model = ModelManager(
|
||||
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=strategy,
|
||||
prompt="face of a fox, sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
sd_strength=1.0,
|
||||
sd_guidance_scale=7.0,
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sdxl_device_{device}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=2,
|
||||
fy=2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||
def test_sdxl_cpu_text_encoder(device, strategy, sampler):
|
||||
sd_steps = check_device(device)
|
||||
|
||||
model = ModelManager(
|
||||
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=strategy,
|
||||
prompt="face of a fox, sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
sd_strength=1.0,
|
||||
sd_guidance_scale=7.0,
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sdxl_device_{device}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=2,
|
||||
fy=2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||
@pytest.mark.parametrize(
|
||||
"rect",
|
||||
[
|
||||
[-128, -128, 1024, 1024],
|
||||
],
|
||||
)
|
||||
def test_sdxl_outpainting(device, rect):
|
||||
sd_steps = check_device(device)
|
||||
|
||||
model = ModelManager(
|
||||
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
|
||||
cfg = get_config(
|
||||
strategy=HDStrategy.ORIGINAL,
|
||||
prompt="a dog sitting on a bench in the park",
|
||||
sd_steps=sd_steps,
|
||||
use_extender=True,
|
||||
extender_x=rect[0],
|
||||
extender_y=rect[1],
|
||||
extender_width=rect[2],
|
||||
extender_height=rect[3],
|
||||
sd_strength=1.0,
|
||||
sd_guidance_scale=8.0,
|
||||
sd_sampler=SDSampler.ddim,
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"sdxl_outpainting_dog_ddim_{'_'.join(map(str, rect))}_device_{device}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=1.5,
|
||||
fy=1.5,
|
||||
)
|
||||
77
iopaint/tests/utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from pathlib import Path
|
||||
import cv2
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from iopaint.schema import LDMSampler, HDStrategy, InpaintRequest, SDSampler
|
||||
import numpy as np
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
save_dir = current_dir / "result"
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
def check_device(device: str) -> int:
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available, skip test on cuda")
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
pytest.skip("mps is not available, skip test on mps")
|
||||
steps = 2 if device == "cpu" else 20
|
||||
return steps
|
||||
|
||||
|
||||
def assert_equal(
|
||||
model,
|
||||
config: InpaintRequest,
|
||||
gt_name,
|
||||
fx: float = 1,
|
||||
fy: float = 1,
|
||||
img_p=current_dir / "image.png",
|
||||
mask_p=current_dir / "mask.png",
|
||||
):
|
||||
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||
print(f"Input image shape: {img.shape}")
|
||||
|
||||
res = model(img, mask, config)
|
||||
ok = cv2.imwrite(
|
||||
str(save_dir / gt_name),
|
||||
res,
|
||||
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
||||
)
|
||||
assert ok, save_dir / gt_name
|
||||
|
||||
"""
|
||||
Note that JPEG is lossy compression, so even if it is the highest quality 100,
|
||||
when the saved images is reloaded, a difference occurs with the original pixel value.
|
||||
If you want to save the original images as it is, save it as PNG or BMP.
|
||||
"""
|
||||
# gt = cv2.imread(str(current_dir / gt_name), cv2.IMREAD_UNCHANGED)
|
||||
# assert np.array_equal(res, gt)
|
||||
|
||||
|
||||
def get_data(
|
||||
fx: float = 1,
|
||||
fy: float = 1.0,
|
||||
img_p=current_dir / "image.png",
|
||||
mask_p=current_dir / "mask.png",
|
||||
):
|
||||
img = cv2.imread(str(img_p))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
||||
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
||||
img = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
|
||||
mask = cv2.resize(mask, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)
|
||||
return img, mask
|
||||
|
||||
|
||||
def get_config(**kwargs):
|
||||
data = dict(
|
||||
sd_sampler=kwargs.get("sd_sampler", SDSampler.uni_pc),
|
||||
ldm_steps=1,
|
||||
ldm_sampler=LDMSampler.plms,
|
||||
hd_strategy=kwargs.get("strategy", HDStrategy.ORIGINAL),
|
||||
hd_strategy_crop_margin=32,
|
||||
hd_strategy_crop_trigger_size=200,
|
||||
hd_strategy_resize_limit=200,
|
||||
)
|
||||
data.update(**kwargs)
|
||||
return InpaintRequest(image="", mask="", **data)
|
||||