switch to FastAPI
This commit is contained in:
@@ -175,6 +175,7 @@ class Api:
|
||||
def api_inpaint(self, req: InpaintRequest):
|
||||
image, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
||||
|
||||
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
||||
if image.shape[:2] != mask.shape[:2]:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -7,13 +7,7 @@ from PIL import Image, ImageOps, PngImagePlugin
|
||||
from fastapi import FastAPI, UploadFile, HTTPException
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
from ..schema import (
|
||||
MediasResponse,
|
||||
MediasRequest,
|
||||
MediaFileRequest,
|
||||
MediaTab,
|
||||
MediaThumbnailFileRequest,
|
||||
)
|
||||
from ..schema import MediasResponse, MediaTab
|
||||
|
||||
LARGE_ENOUGH_NUMBER = 100
|
||||
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
|
||||
@@ -34,9 +28,9 @@ class FileManager:
|
||||
|
||||
# fmt: off
|
||||
self.app.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"])
|
||||
self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["POST"], response_model=List[MediasResponse])
|
||||
self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["POST"], response_model=None)
|
||||
self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["POST"], response_model=None)
|
||||
self.app.add_api_route("/api/v1/medias", self.api_medias, methods=["GET"], response_model=List[MediasResponse])
|
||||
self.app.add_api_route("/api/v1/media_file", self.api_media_file, methods=["GET"])
|
||||
self.app.add_api_route("/api/v1/media_thumbnail_file", self.api_media_thumbnail_file, methods=["GET"])
|
||||
# fmt: on
|
||||
|
||||
def api_save_image(self, file: UploadFile):
|
||||
@@ -45,18 +39,21 @@ class FileManager:
|
||||
with open(self.output_dir / filename, "wb") as fw:
|
||||
fw.write(origin_image_bytes)
|
||||
|
||||
def api_medias(self, req: MediasRequest) -> List[MediasResponse]:
|
||||
img_dir = self._get_dir(req.tab)
|
||||
def api_medias(self, tab: MediaTab) -> List[MediasResponse]:
|
||||
img_dir = self._get_dir(tab)
|
||||
return self._media_names(img_dir)
|
||||
|
||||
def api_media_file(self, req: MediaFileRequest) -> FileResponse:
|
||||
file_path = self._get_file(req.tab, req.filename)
|
||||
return FileResponse(file_path)
|
||||
def api_media_file(self, tab: MediaTab, filename: str) -> FileResponse:
|
||||
file_path = self._get_file(tab, filename)
|
||||
return FileResponse(file_path, media_type="image/png")
|
||||
|
||||
def api_media_thumbnail_file(self, req: MediaThumbnailFileRequest) -> FileResponse:
|
||||
img_dir = self._get_dir(req.tab)
|
||||
# tab=${tab}?filename=${filename.name}?width=${width}&height=${height}
|
||||
def api_media_thumbnail_file(
|
||||
self, tab: MediaTab, filename: str, width: int, height: int
|
||||
) -> FileResponse:
|
||||
img_dir = self._get_dir(tab)
|
||||
thumb_filename, (width, height) = self.get_thumbnail(
|
||||
img_dir, req.filename, width=req.width, height=req.height
|
||||
img_dir, filename, width=width, height=height
|
||||
)
|
||||
thumbnail_filepath = self.thumbnail_directory / thumb_filename
|
||||
return FileResponse(
|
||||
@@ -65,6 +62,7 @@ class FileManager:
|
||||
"X-Width": str(width),
|
||||
"X-Height": str(height),
|
||||
},
|
||||
media_type="image/jpeg",
|
||||
)
|
||||
|
||||
def _get_dir(self, tab: MediaTab) -> Path:
|
||||
|
||||
@@ -236,10 +236,6 @@ class RunPluginRequest(BaseModel):
|
||||
MediaTab = Literal["input", "output"]
|
||||
|
||||
|
||||
class MediasRequest(BaseModel):
|
||||
tab: MediaTab
|
||||
|
||||
|
||||
class MediasResponse(BaseModel):
|
||||
name: str
|
||||
height: int
|
||||
@@ -248,18 +244,6 @@ class MediasResponse(BaseModel):
|
||||
mtime: float
|
||||
|
||||
|
||||
class MediaFileRequest(BaseModel):
|
||||
tab: MediaTab
|
||||
filename: str
|
||||
|
||||
|
||||
class MediaThumbnailFileRequest(BaseModel):
|
||||
tab: MediaTab
|
||||
filename: str
|
||||
width: int = 0
|
||||
height: int = 0
|
||||
|
||||
|
||||
class GenInfoResponse(BaseModel):
|
||||
prompt: str = ""
|
||||
negative_prompt: str = ""
|
||||
|
||||
@@ -1,186 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
import cv2
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
NUM_THREADS = str(multiprocessing.cpu_count())
|
||||
cv2.setNumThreads(NUM_THREADS)
|
||||
|
||||
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
||||
|
||||
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
|
||||
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
|
||||
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
|
||||
|
||||
import hashlib
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
|
||||
import io
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from loguru import logger
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from lama_cleaner.const import *
|
||||
from lama_cleaner.file_manager import FileManager
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.plugins import (
|
||||
InteractiveSeg,
|
||||
RemoveBG,
|
||||
AnimeSeg,
|
||||
build_plugins,
|
||||
)
|
||||
from lama_cleaner.schema import InpaintRequest
|
||||
from lama_cleaner.helper import (
|
||||
load_img,
|
||||
numpy_to_bytes,
|
||||
resize_max_size,
|
||||
pil_to_bytes,
|
||||
is_mac,
|
||||
get_image_ext, concat_alpha_channel,
|
||||
)
|
||||
|
||||
try:
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(False)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
|
||||
|
||||
global_config = GlobalConfig()
|
||||
|
||||
def diffuser_callback(i, t, latents):
|
||||
socketio.emit("diffusion_progress", {"step": i})
|
||||
|
||||
|
||||
def start(
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
no_half: bool,
|
||||
cpu_offload: bool,
|
||||
disable_nsfw_checker,
|
||||
cpu_textencoder: bool,
|
||||
device: Device,
|
||||
gui: bool,
|
||||
disable_model_switch: bool,
|
||||
input: Path,
|
||||
output_dir: Path,
|
||||
quality: int,
|
||||
enable_interactive_seg: bool,
|
||||
interactive_seg_model: InteractiveSegModel,
|
||||
interactive_seg_device: Device,
|
||||
enable_remove_bg: bool,
|
||||
enable_anime_seg: bool,
|
||||
enable_realesrgan: bool,
|
||||
realesrgan_device: Device,
|
||||
realesrgan_model: RealESRGANModel,
|
||||
enable_gfpgan: bool,
|
||||
gfpgan_device: Device,
|
||||
enable_restoreformer: bool,
|
||||
restoreformer_device: Device,
|
||||
):
|
||||
if input:
|
||||
if not input.exists():
|
||||
logger.error(f"invalid --input: {input} not exists")
|
||||
exit()
|
||||
if input.is_dir():
|
||||
logger.info(f"Initialize file manager")
|
||||
file_manager = FileManager(app)
|
||||
app.config["THUMBNAIL_MEDIA_ROOT"] = input
|
||||
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
|
||||
output_dir, "lama_cleaner_thumbnails"
|
||||
)
|
||||
file_manager.output_dir = output_dir
|
||||
global_config.file_manager = file_manager
|
||||
else:
|
||||
global_config.input_image_path = input
|
||||
|
||||
global_config.image_quality = quality
|
||||
global_config.disable_model_switch = disable_model_switch
|
||||
global_config.is_desktop = gui
|
||||
build_plugins(
|
||||
global_config,
|
||||
enable_interactive_seg,
|
||||
interactive_seg_model,
|
||||
interactive_seg_device,
|
||||
enable_remove_bg,
|
||||
enable_anime_seg,
|
||||
enable_realesrgan,
|
||||
realesrgan_device,
|
||||
realesrgan_model,
|
||||
enable_gfpgan,
|
||||
gfpgan_device,
|
||||
enable_restoreformer,
|
||||
restoreformer_device,
|
||||
no_half,
|
||||
)
|
||||
if output_dir:
|
||||
output_dir = output_dir.expanduser().absolute()
|
||||
logger.info(f"Image will auto save to output dir: {output_dir}")
|
||||
if not output_dir.exists():
|
||||
logger.info(f"Create output dir: {output_dir}")
|
||||
output_dir.mkdir(parents=True)
|
||||
global_config.output_dir = output_dir
|
||||
|
||||
global_config.model_manager = ModelManager(
|
||||
name=model,
|
||||
device=torch.device(device),
|
||||
no_half=no_half,
|
||||
disable_nsfw=disable_nsfw_checker,
|
||||
sd_cpu_textencoder=cpu_textencoder,
|
||||
cpu_offload=cpu_offload,
|
||||
callback=diffuser_callback,
|
||||
)
|
||||
|
||||
if gui:
|
||||
from flaskwebgui import FlaskUI
|
||||
|
||||
ui = FlaskUI(
|
||||
app,
|
||||
socketio=socketio,
|
||||
width=1200,
|
||||
height=800,
|
||||
host=host,
|
||||
port=port,
|
||||
close_server_on_exit=True,
|
||||
idle_interval=60,
|
||||
)
|
||||
ui.run()
|
||||
else:
|
||||
socketio.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
allow_unsafe_werkzeug=True,
|
||||
)
|
||||
Reference in New Issue
Block a user