diff --git a/api_service_mvp.py b/api_service_mvp.py index 61fac8d..3b757b3 100644 --- a/api_service_mvp.py +++ b/api_service_mvp.py @@ -18,6 +18,7 @@ from datetime import datetime import torch import uvicorn +import numpy as np from fastapi import FastAPI, File, UploadFile, Header, HTTPException, Request from fastapi.responses import Response, JSONResponse from fastapi.middleware.cors import CORSMiddleware @@ -25,7 +26,7 @@ from loguru import logger from PIL import Image from iopaint.model_manager import ModelManager -from iopaint.schema import ApiConfig, InpaintRequest, HDStrategy +from iopaint.schema import InpaintRequest, HDStrategy from iopaint.helper import ( decode_base64_to_image, numpy_to_bytes, @@ -119,30 +120,17 @@ async def startup_event(): logger.info("=" * 60) try: - # 初始化模型管理器 - api_config = ApiConfig( - host="0.0.0.0", - port=8080, - model=Config.MODEL_NAME, - device=Config.DEVICE, - gui=False, - no_gui_auto_close=True, - cpu_offload=False, - disable_nsfw_checker=True, - cpu_textencoder=False, - local_files_only=False, - ) - + # 直接初始化模型管理器,不使用 ApiConfig model_manager = ModelManager( - name=api_config.model, - device=torch.device(api_config.device), + name=Config.MODEL_NAME, + device=torch.device(Config.DEVICE), no_half=False, low_mem=False, cpu_offload=False, - disable_nsfw=api_config.disable_nsfw_checker, - sd_cpu_textencoder=api_config.cpu_textencoder, - local_files_only=api_config.local_files_only, - cpu_textencoder=api_config.cpu_textencoder, + disable_nsfw=True, + sd_cpu_textencoder=False, + local_files_only=False, + cpu_textencoder=False, ) logger.success(f"✓ Model {Config.MODEL_NAME} loaded successfully on {Config.DEVICE}") @@ -282,9 +270,13 @@ async def remove_watermark( logger.info("No mask provided, will process entire image") mask_pil = Image.new("L", pil_image.size, 255) - # 3. 构建请求配置 + # 3. 将 PIL 图像转换为 numpy 数组 + image_np = np.array(pil_image) + mask_np = np.array(mask_pil) + + # 4. 构建请求配置 inpaint_request = InpaintRequest( - image="", # 我们直接传PIL对象,不需要base64 + image="", # 不需要 base64 mask="", hd_strategy=HDStrategy.ORIGINAL, hd_strategy_crop_margin=128, @@ -292,26 +284,26 @@ async def remove_watermark( hd_strategy_resize_limit=2048, ) - # 4. 调用模型进行处理 + # 5. 调用模型进行处理 logger.info("Running model inference...") inference_start = time.time() result_image = model_manager( - image=pil_image, - mask=mask_pil, + image=image_np, + mask=mask_np, config=inpaint_request, ) inference_time = time.time() - inference_start logger.info(f"Inference completed in {inference_time:.2f}s") - # 5. 转换结果为字节 + # 6. 转换结果为字节 output_bytes = numpy_to_bytes( result_image, ext="png", ) - # 6. 更新统计 + # 7. 更新统计 processing_time = time.time() - start_time request_stats["success"] += 1 request_stats["total_processing_time"] += processing_time @@ -321,7 +313,7 @@ async def remove_watermark( f"(inference: {inference_time:.2f}s)" ) - # 7. 返回结果 + # 8. 返回结果 return Response( content=output_bytes, media_type="image/png",