switch to FastAPI

This commit is contained in:
Qing
2024-01-01 16:05:34 +08:00
parent c4abda3942
commit 79a41454f6
8 changed files with 54 additions and 256 deletions

View File

@@ -175,6 +175,7 @@ class Api:
def api_inpaint(self, req: InpaintRequest):
image, alpha_channel, infos = decode_base64_to_image(req.image)
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
if image.shape[:2] != mask.shape[:2]:
raise HTTPException(

View File

@@ -7,13 +7,7 @@ from PIL import Image, ImageOps, PngImagePlugin
from fastapi import FastAPI, UploadFile, HTTPException
from starlette.responses import FileResponse
from ..schema import (
MediasResponse,
MediasRequest,
MediaFileRequest,
MediaTab,
MediaThumbnailFileRequest,
)
from ..schema import MediasResponse, MediaTab
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
@@ -34,9 +28,9 @@ class FileManager:
# fmt: off
self.app.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["POST"], response_model=List[MediasResponse])
self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["POST"], response_model=None)
self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["POST"], response_model=None)
self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["GET"], response_model=List[MediasResponse])
self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["GET"])
self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["GET"])
# fmt: on
def api_save_image(self, file: UploadFile):
@@ -45,18 +39,21 @@ class FileManager:
with open(self.output_dir / filename, "wb") as fw:
fw.write(origin_image_bytes)
def api_medias(self, req: MediasRequest) -> List[MediasResponse]:
img_dir = self._get_dir(req.tab)
def api_medias(self, tab: MediaTab) -> List[MediasResponse]:
img_dir = self._get_dir(tab)
return self._media_names(img_dir)
def api_media_file(self, req: MediaFileRequest) -> FileResponse:
file_path = self._get_file(req.tab, req.filename)
return FileResponse(file_path)
def api_media_file(self, tab: MediaTab, filename: str) -> FileResponse:
file_path = self._get_file(tab, filename)
return FileResponse(file_path, media_type="image/png")
def api_media_thumbnail_file(self, req: MediaThumbnailFileRequest) -> FileResponse:
img_dir = self._get_dir(req.tab)
# tab=${tab}?filename=${filename.name}?width=${width}&height=${height}
def api_media_thumbnail_file(
self, tab: MediaTab, filename: str, width: int, height: int
) -> FileResponse:
img_dir = self._get_dir(tab)
thumb_filename, (width, height) = self.get_thumbnail(
img_dir, req.filename, width=req.width, height=req.height
img_dir, filename, width=width, height=height
)
thumbnail_filepath = self.thumbnail_directory / thumb_filename
return FileResponse(
@@ -65,6 +62,7 @@ class FileManager:
"X-Width": str(width),
"X-Height": str(height),
},
media_type="image/jpeg",
)
def _get_dir(self, tab: MediaTab) -> Path:

View File

@@ -236,10 +236,6 @@ class RunPluginRequest(BaseModel):
MediaTab = Literal["input", "output"]
class MediasRequest(BaseModel):
tab: MediaTab
class MediasResponse(BaseModel):
name: str
height: int
@@ -248,18 +244,6 @@ class MediasResponse(BaseModel):
mtime: float
class MediaFileRequest(BaseModel):
tab: MediaTab
filename: str
class MediaThumbnailFileRequest(BaseModel):
tab: MediaTab
filename: str
width: int = 0
height: int = 0
class GenInfoResponse(BaseModel):
prompt: str = ""
negative_prompt: str = ""

View File

@@ -1,186 +0,0 @@
#!/usr/bin/env python3
import multiprocessing
import os
import cv2
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
NUM_THREADS = str(multiprocessing.cpu_count())
cv2.setNumThreads(NUM_THREADS)
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
import hashlib
import traceback
from dataclasses import dataclass
import io
import random
import time
from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image
from loguru import logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from lama_cleaner.const import *
from lama_cleaner.file_manager import FileManager
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.plugins import (
InteractiveSeg,
RemoveBG,
AnimeSeg,
build_plugins,
)
from lama_cleaner.schema import InpaintRequest
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
pil_to_bytes,
is_mac,
get_image_ext, concat_alpha_channel,
)
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
global_config = GlobalConfig()
def diffuser_callback(i, t, latents):
socketio.emit("diffusion_progress", {"step": i})
def start(
host: str,
port: int,
model: str,
no_half: bool,
cpu_offload: bool,
disable_nsfw_checker,
cpu_textencoder: bool,
device: Device,
gui: bool,
disable_model_switch: bool,
input: Path,
output_dir: Path,
quality: int,
enable_interactive_seg: bool,
interactive_seg_model: InteractiveSegModel,
interactive_seg_device: Device,
enable_remove_bg: bool,
enable_anime_seg: bool,
enable_realesrgan: bool,
realesrgan_device: Device,
realesrgan_model: RealESRGANModel,
enable_gfpgan: bool,
gfpgan_device: Device,
enable_restoreformer: bool,
restoreformer_device: Device,
):
if input:
if not input.exists():
logger.error(f"invalid --input: {input} not exists")
exit()
if input.is_dir():
logger.info(f"Initialize file manager")
file_manager = FileManager(app)
app.config["THUMBNAIL_MEDIA_ROOT"] = input
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
output_dir, "lama_cleaner_thumbnails"
)
file_manager.output_dir = output_dir
global_config.file_manager = file_manager
else:
global_config.input_image_path = input
global_config.image_quality = quality
global_config.disable_model_switch = disable_model_switch
global_config.is_desktop = gui
build_plugins(
global_config,
enable_interactive_seg,
interactive_seg_model,
interactive_seg_device,
enable_remove_bg,
enable_anime_seg,
enable_realesrgan,
realesrgan_device,
realesrgan_model,
enable_gfpgan,
gfpgan_device,
enable_restoreformer,
restoreformer_device,
no_half,
)
if output_dir:
output_dir = output_dir.expanduser().absolute()
logger.info(f"Image will auto save to output dir: {output_dir}")
if not output_dir.exists():
logger.info(f"Create output dir: {output_dir}")
output_dir.mkdir(parents=True)
global_config.output_dir = output_dir
global_config.model_manager = ModelManager(
name=model,
device=torch.device(device),
no_half=no_half,
disable_nsfw=disable_nsfw_checker,
sd_cpu_textencoder=cpu_textencoder,
cpu_offload=cpu_offload,
callback=diffuser_callback,
)
if gui:
from flaskwebgui import FlaskUI
ui = FlaskUI(
app,
socketio=socketio,
width=1200,
height=800,
host=host,
port=port,
close_server_on_exit=True,
idle_interval=60,
)
ui.run()
else:
socketio.run(
app,
host=host,
port=port,
allow_unsafe_werkzeug=True,
)