Files
IOPaint/iopaint/batch_processing.py
let5sne 1b87a98261 🎨 完整的 IOPaint 项目更新
## 主要更新
-  更新所有依赖到最新稳定版本
- 📝 添加详细的项目文档和模型推荐
- 🔧 配置 VSCode Cloud Studio 预览功能
- 🐛 修复 PyTorch API 弃用警告

## 依赖更新
- diffusers: 0.27.2 → 0.35.2
- gradio: 4.21.0 → 5.46.0
- peft: 0.7.1 → 0.18.0
- Pillow: 9.5.0 → 11.3.0
- fastapi: 0.108.0 → 0.116.2

## 新增文件
- CLAUDE.md - 项目架构和开发指南
- UPGRADE_NOTES.md - 详细的升级说明
- .vscode/preview.yml - 预览配置
- .vscode/LAUNCH_GUIDE.md - 启动指南
- .gitignore - 更新的忽略规则

## 代码修复
- 修复 iopaint/model/ldm.py 中的 torch.cuda.amp.autocast() 弃用警告

## 文档更新
- README.md - 添加模型推荐和使用指南
- 完整的项目源码(iopaint/)
- Web 前端源码(web_app/)

🤖 Generated with Claude Code
2025-11-28 17:10:24 +00:00

129 lines
4.1 KiB
Python

import json
from pathlib import Path
from typing import Dict, Optional
import cv2
import numpy as np
from PIL import Image
from loguru import logger
from rich.console import Console
from rich.progress import (
Progress,
SpinnerColumn,
TimeElapsedColumn,
MofNCompleteColumn,
TextColumn,
BarColumn,
TaskProgressColumn,
)
from iopaint.helper import pil_to_bytes
from iopaint.model.utils import torch_gc
from iopaint.model_manager import ModelManager
from iopaint.schema import InpaintRequest
def glob_images(path: Path) -> Dict[str, Path]:
# png/jpg/jpeg
if path.is_file():
return {path.stem: path}
elif path.is_dir():
res = {}
for it in path.glob("*.*"):
if it.suffix.lower() in [".png", ".jpg", ".jpeg"]:
res[it.stem] = it
return res
def batch_inpaint(
model: str,
device,
image: Path,
mask: Path,
output: Path,
config: Optional[Path] = None,
concat: bool = False,
):
if image.is_dir() and output.is_file():
logger.error(
"invalid --output: when image is a directory, output should be a directory"
)
exit(-1)
output.mkdir(parents=True, exist_ok=True)
image_paths = glob_images(image)
mask_paths = glob_images(mask)
if len(image_paths) == 0:
logger.error("invalid --image: empty image folder")
exit(-1)
if len(mask_paths) == 0:
logger.error("invalid --mask: empty mask folder")
exit(-1)
if config is None:
inpaint_request = InpaintRequest()
logger.info(f"Using default config: {inpaint_request}")
else:
with open(config, "r", encoding="utf-8") as f:
inpaint_request = InpaintRequest(**json.load(f))
logger.info(f"Using config: {inpaint_request}")
model_manager = ModelManager(name=model, device=device)
first_mask = list(mask_paths.values())[0]
console = Console()
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
console=console,
transient=False,
) as progress:
task = progress.add_task("Batch processing...", total=len(image_paths))
for stem, image_p in image_paths.items():
if stem not in mask_paths and mask.is_dir():
progress.log(f"mask for {image_p} not found")
progress.update(task, advance=1)
continue
mask_p = mask_paths.get(stem, first_mask)
infos = Image.open(image_p).info
img = np.array(Image.open(image_p).convert("RGB"))
mask_img = np.array(Image.open(mask_p).convert("L"))
if mask_img.shape[:2] != img.shape[:2]:
progress.log(
f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}"
)
mask_img = cv2.resize(
mask_img,
(img.shape[1], img.shape[0]),
interpolation=cv2.INTER_NEAREST,
)
mask_img[mask_img >= 127] = 255
mask_img[mask_img < 127] = 0
# bgr
inpaint_result = model_manager(img, mask_img, inpaint_request)
inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB)
if concat:
mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB)
inpaint_result = cv2.hconcat([img, mask_img, inpaint_result])
img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos)
save_p = output / f"{stem}.png"
with open(save_p, "wb") as fw:
fw.write(img_bytes)
progress.update(task, advance=1)
torch_gc()
# pid = psutil.Process().pid
# memory_info = psutil.Process(pid).memory_info()
# memory_in_mb = memory_info.rss / (1024 * 1024)
# print(f"原图大小:{img.shape},当前进程的内存占用:{memory_in_mb}MB")