Files
IOPaint/lama_cleaner/server.py
2023-12-30 23:36:44 +08:00

187 lines
4.8 KiB
Python

#!/usr/bin/env python3
import multiprocessing
import os
import cv2
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
NUM_THREADS = str(multiprocessing.cpu_count())
cv2.setNumThreads(NUM_THREADS)
# fix libomp problem on windows https://github.com/Sanster/lama-cleaner/issues/56
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
import hashlib
import traceback
from dataclasses import dataclass
import io
import random
import time
from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image
from loguru import logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from lama_cleaner.const import *
from lama_cleaner.file_manager import FileManager
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.plugins import (
InteractiveSeg,
RemoveBG,
AnimeSeg,
build_plugins,
)
from lama_cleaner.schema import InpaintRequest
from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
pil_to_bytes,
is_mac,
get_image_ext, concat_alpha_channel,
)
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "app/build")
global_config = GlobalConfig()
def diffuser_callback(i, t, latents):
socketio.emit("diffusion_progress", {"step": i})
def start(
host: str,
port: int,
model: str,
no_half: bool,
cpu_offload: bool,
disable_nsfw_checker,
cpu_textencoder: bool,
device: Device,
gui: bool,
disable_model_switch: bool,
input: Path,
output_dir: Path,
quality: int,
enable_interactive_seg: bool,
interactive_seg_model: InteractiveSegModel,
interactive_seg_device: Device,
enable_remove_bg: bool,
enable_anime_seg: bool,
enable_realesrgan: bool,
realesrgan_device: Device,
realesrgan_model: RealESRGANModel,
enable_gfpgan: bool,
gfpgan_device: Device,
enable_restoreformer: bool,
restoreformer_device: Device,
):
if input:
if not input.exists():
logger.error(f"invalid --input: {input} not exists")
exit()
if input.is_dir():
logger.info(f"Initialize file manager")
file_manager = FileManager(app)
app.config["THUMBNAIL_MEDIA_ROOT"] = input
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
output_dir, "lama_cleaner_thumbnails"
)
file_manager.output_dir = output_dir
global_config.file_manager = file_manager
else:
global_config.input_image_path = input
global_config.image_quality = quality
global_config.disable_model_switch = disable_model_switch
global_config.is_desktop = gui
build_plugins(
global_config,
enable_interactive_seg,
interactive_seg_model,
interactive_seg_device,
enable_remove_bg,
enable_anime_seg,
enable_realesrgan,
realesrgan_device,
realesrgan_model,
enable_gfpgan,
gfpgan_device,
enable_restoreformer,
restoreformer_device,
no_half,
)
if output_dir:
output_dir = output_dir.expanduser().absolute()
logger.info(f"Image will auto save to output dir: {output_dir}")
if not output_dir.exists():
logger.info(f"Create output dir: {output_dir}")
output_dir.mkdir(parents=True)
global_config.output_dir = output_dir
global_config.model_manager = ModelManager(
name=model,
device=torch.device(device),
no_half=no_half,
disable_nsfw=disable_nsfw_checker,
sd_cpu_textencoder=cpu_textencoder,
cpu_offload=cpu_offload,
callback=diffuser_callback,
)
if gui:
from flaskwebgui import FlaskUI
ui = FlaskUI(
app,
socketio=socketio,
width=1200,
height=800,
host=host,
port=port,
close_server_on_exit=True,
idle_interval=60,
)
ui.run()
else:
socketio.run(
app,
host=host,
port=port,
allow_unsafe_werkzeug=True,
)