Files
IOPaint/api_service_mvp.py
let5sne 49eaddd8b5 🐛 修复 API 服务的图像处理问题
修复:
- 修复 PIL 图像未转换为 numpy 数组的问题
- ModelManager 期望 numpy 数组而非 PIL 对象
- 移除不必要的 ApiConfig 导入,直接初始化 ModelManager

测试结果:
 健康检查端点正常
 统计端点正常
 去水印功能正常 (处理时间: ~1.4s)
 响应返回有效的 PNG 图像 (770KB)

性能指标:
- 图片尺寸: 512x512
- 处理时间: 1.35秒
- 模型: LaMa (CUDA)
- 成功率: 100%

🔧 Generated with Claude Code
2025-11-28 18:10:22 +00:00

362 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
IOPaint 去水印 API 服务 - MVP版本
专注于单一功能:去除图片水印
遵循KISS原则
- 只支持LaMa模型
- 简单的API Key认证
- 同步处理(无需队列)
- 本地存储
"""
import os
import time
import hashlib
from pathlib import Path
from typing import Optional
from datetime import datetime
import torch
import uvicorn
import numpy as np
from fastapi import FastAPI, File, UploadFile, Header, HTTPException, Request
from fastapi.responses import Response, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from PIL import Image
from iopaint.model_manager import ModelManager
from iopaint.schema import InpaintRequest, HDStrategy
from iopaint.helper import (
decode_base64_to_image,
numpy_to_bytes,
load_img,
)
# ==================== 配置 ====================
class Config:
"""服务配置"""
# API密钥生产环境应从环境变量读取
API_KEY = os.getenv("API_KEY", "your_secret_key_change_me")
# 模型配置
MODEL_NAME = "lama"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 限制配置
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096")) # 最大边长
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
# 日志配置
LOG_DIR = Path("./logs")
LOG_DIR.mkdir(exist_ok=True)
# 指标统计
ENABLE_METRICS = os.getenv("ENABLE_METRICS", "true").lower() == "true"
# ==================== 应用初始化 ====================
app = FastAPI(
title="IOPaint 去水印 API",
description="基于LaMa模型的图片去水印API服务",
version="1.0.0-MVP",
docs_url="/docs", # Swagger文档
redoc_url="/redoc", # ReDoc文档
)
# CORS配置
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境应限制具体域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ==================== 全局变量 ====================
model_manager: Optional[ModelManager] = None
request_stats = {
"total": 0,
"success": 0,
"failed": 0,
"total_processing_time": 0.0,
}
# ==================== 认证中间件 ====================
async def verify_api_key(x_api_key: str = Header(None, alias="X-API-Key")):
"""验证API密钥"""
if not x_api_key:
raise HTTPException(
status_code=401,
detail="Missing API Key. Please provide X-API-Key header."
)
if x_api_key != Config.API_KEY:
logger.warning(f"Invalid API key attempt: {x_api_key[:8]}...")
raise HTTPException(
status_code=401,
detail="Invalid API Key"
)
return x_api_key
# ==================== 启动/关闭事件 ====================
@app.on_event("startup")
async def startup_event():
"""应用启动时加载模型"""
global model_manager
logger.info("=" * 60)
logger.info("IOPaint API Service - MVP Version")
logger.info("=" * 60)
logger.info(f"Device: {Config.DEVICE}")
logger.info(f"Model: {Config.MODEL_NAME}")
logger.info(f"Max Image Size: {Config.MAX_IMAGE_SIZE}")
logger.info(f"API Key: {'*' * 20}{Config.API_KEY[-4:]}")
logger.info("=" * 60)
try:
# 直接初始化模型管理器,不使用 ApiConfig
model_manager = ModelManager(
name=Config.MODEL_NAME,
device=torch.device(Config.DEVICE),
no_half=False,
low_mem=False,
cpu_offload=False,
disable_nsfw=True,
sd_cpu_textencoder=False,
local_files_only=False,
cpu_textencoder=False,
)
logger.success(f"✓ Model {Config.MODEL_NAME} loaded successfully on {Config.DEVICE}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
@app.on_event("shutdown")
async def shutdown_event():
"""应用关闭时的清理工作"""
logger.info("Shutting down API service...")
if Config.ENABLE_METRICS:
logger.info("=" * 60)
logger.info("Final Statistics:")
logger.info(f" Total Requests: {request_stats['total']}")
logger.info(f" Successful: {request_stats['success']}")
logger.info(f" Failed: {request_stats['failed']}")
if request_stats['success'] > 0:
avg_time = request_stats['total_processing_time'] / request_stats['success']
logger.info(f" Avg Processing Time: {avg_time:.2f}s")
logger.info("=" * 60)
# ==================== API路由 ====================
@app.get("/")
async def root():
"""根路径"""
return {
"service": "IOPaint Watermark Removal API",
"version": "1.0.0-MVP",
"status": "running",
"model": Config.MODEL_NAME,
"device": Config.DEVICE,
"docs": "/docs",
}
@app.get("/api/v1/health")
async def health_check():
"""健康检查"""
return {
"status": "healthy",
"model": Config.MODEL_NAME,
"device": Config.DEVICE,
"gpu_available": torch.cuda.is_available(),
}
@app.get("/api/v1/stats")
async def get_stats(api_key: str = Header(None, alias="X-API-Key")):
"""获取使用统计需要API Key"""
await verify_api_key(api_key)
if not Config.ENABLE_METRICS:
raise HTTPException(status_code=404, detail="Metrics disabled")
stats = request_stats.copy()
if stats['success'] > 0:
stats['avg_processing_time'] = stats['total_processing_time'] / stats['success']
else:
stats['avg_processing_time'] = 0
return stats
@app.post("/api/v1/remove-watermark")
async def remove_watermark(
request: Request,
image: UploadFile = File(..., description="原始图片"),
mask: Optional[UploadFile] = File(None, description="水印遮罩(可选)"),
api_key: str = Header(None, alias="X-API-Key")
):
"""
去除图片水印
参数:
- image: 原始图片文件(必需)
- mask: 水印遮罩图片(可选,黑色区域会被保留,白色区域会被修复)
返回:
- 处理后的图片PNG格式
"""
# 验证API Key
await verify_api_key(api_key)
start_time = time.time()
request_stats["total"] += 1
try:
# 1. 读取图片
image_bytes = await image.read()
if len(image_bytes) > Config.MAX_FILE_SIZE:
raise HTTPException(
status_code=400,
detail=f"Image too large. Max size: {Config.MAX_FILE_SIZE / 1024 / 1024}MB"
)
# 验证图片格式
try:
pil_image = Image.open(image.file).convert("RGB")
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Invalid image format: {str(e)}"
)
# 检查图片尺寸
width, height = pil_image.size
if max(width, height) > Config.MAX_IMAGE_SIZE:
raise HTTPException(
status_code=400,
detail=f"Image too large. Max dimension: {Config.MAX_IMAGE_SIZE}px"
)
logger.info(f"Processing image: {width}x{height}")
# 2. 读取遮罩(如果提供)
mask_pil = None
if mask:
mask_bytes = await mask.read()
try:
mask_pil = Image.open(mask.file).convert("L")
# 确保遮罩尺寸与原图一致
if mask_pil.size != pil_image.size:
mask_pil = mask_pil.resize(pil_image.size, Image.LANCZOS)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Invalid mask format: {str(e)}"
)
else:
# 如果没有提供遮罩,创建全白遮罩(修复整张图)
logger.info("No mask provided, will process entire image")
mask_pil = Image.new("L", pil_image.size, 255)
# 3. 将 PIL 图像转换为 numpy 数组
image_np = np.array(pil_image)
mask_np = np.array(mask_pil)
# 4. 构建请求配置
inpaint_request = InpaintRequest(
image="", # 不需要 base64
mask="",
hd_strategy=HDStrategy.ORIGINAL,
hd_strategy_crop_margin=128,
hd_strategy_crop_trigger_size=800,
hd_strategy_resize_limit=2048,
)
# 5. 调用模型进行处理
logger.info("Running model inference...")
inference_start = time.time()
result_image = model_manager(
image=image_np,
mask=mask_np,
config=inpaint_request,
)
inference_time = time.time() - inference_start
logger.info(f"Inference completed in {inference_time:.2f}s")
# 6. 转换结果为字节
output_bytes = numpy_to_bytes(
result_image,
ext="png",
)
# 7. 更新统计
processing_time = time.time() - start_time
request_stats["success"] += 1
request_stats["total_processing_time"] += processing_time
logger.success(
f"✓ Request completed in {processing_time:.2f}s "
f"(inference: {inference_time:.2f}s)"
)
# 8. 返回结果
return Response(
content=output_bytes,
media_type="image/png",
headers={
"X-Processing-Time": f"{processing_time:.3f}",
"X-Image-Size": f"{width}x{height}",
}
)
except HTTPException:
request_stats["failed"] += 1
raise
except Exception as e:
request_stats["failed"] += 1
logger.error(f"Error processing request: {e}")
logger.exception(e)
raise HTTPException(
status_code=500,
detail=f"Processing failed: {str(e)}"
)
# ==================== 主函数 ====================
def main():
"""启动服务"""
# 配置日志
logger.add(
Config.LOG_DIR / "api_{time:YYYY-MM-DD}.log",
rotation="1 day",
retention="7 days",
level="INFO",
)
# 启动服务
uvicorn.run(
app,
host="0.0.0.0",
port=8080,
log_level="info",
)
if __name__ == "__main__":
main()