add plugins

This commit is contained in:
Qing
2023-03-22 12:57:18 +08:00
parent b48d964c2c
commit 5a38d28ad1
11 changed files with 283 additions and 91 deletions

View File

@@ -1,7 +1,6 @@
#!/usr/bin/env python3
import io
import json
import logging
import multiprocessing
import os
@@ -16,12 +15,12 @@ import cv2
import torch
import numpy as np
from loguru import logger
from watchdog.events import FileSystemEventHandler
from lama_cleaner.const import SD15_MODELS
from lama_cleaner.interactive_seg import InteractiveSeg, Click
from lama_cleaner.make_gif import make_compare_gif
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler
from lama_cleaner.schema import Config
from lama_cleaner.file_manager import FileManager
@@ -85,7 +84,6 @@ CORS(app, expose_headers=["Content-Disposition"])
model: ModelManager = None
thumb: FileManager = None
output_dir: str = None
interactive_seg_model: InteractiveSeg = None
device = None
input_image_path: str = None
is_disable_model_switch: bool = False
@@ -94,6 +92,7 @@ is_enable_file_manager: bool = False
is_enable_auto_saving: bool = False
is_desktop: bool = False
image_quality: int = 95
plugins = {}
def get_image_ext(img_bytes):
@@ -319,35 +318,37 @@ def process():
return response
@app.route("/interactive_seg", methods=["POST"])
def interactive_seg():
input = request.files
origin_image_bytes = input["image"].read() # RGB
image, _ = load_img(origin_image_bytes)
if "mask" in input:
mask, _ = load_img(input["mask"].read(), gray=True)
else:
mask = None
@app.route("/run_plugin/", methods=["POST"])
def run_plugin():
form = request.form
files = request.files
_clicks = json.loads(request.form["clicks"])
clicks = []
for i, click in enumerate(_clicks):
clicks.append(
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)
)
name = form["name"]
if name not in plugins:
return "Plugin not found", 500
origin_image_bytes = files["image"].read() # RGB
rgb_np_img, _ = load_img(origin_image_bytes)
start = time.time()
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
res = plugins[name](rgb_np_img, files, form)
logger.info(f"{name} process time: {(time.time() - start) * 1000}ms")
torch_gc()
response = make_response(
send_file(
io.BytesIO(numpy_to_bytes(new_mask, "png")),
io.BytesIO(numpy_to_bytes(res, "png")),
mimetype=f"image/png",
)
)
return response
@app.route("/plugins/", methods=["GET"])
def get_plugins():
return list(plugins.keys()), 200
@app.route("/model")
def current_model():
return model.name, 200
@@ -423,14 +424,21 @@ def set_input_photo():
return "No Input Image"
class FSHandler(FileSystemEventHandler):
def on_modified(self, event):
print("File modified: %s" % event.src_path)
def build_plugins(args):
global plugins
if args.enable_interactive_seg:
logger.info(f"Initialize {InteractiveSeg.name} plugin")
plugins[InteractiveSeg.name] = InteractiveSeg()
if args.enable_remove_bg:
logger.info(f"Initialize {RemoveBG.name} plugin")
plugins[RemoveBG.name] = RemoveBG()
if args.enable_realesrgan:
logger.info(f"Initialize {RealESRGANUpscaler.name} plugin")
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(args.realesrgan_device)
def main(args):
global model
global interactive_seg_model
global device
global input_image_path
global is_disable_model_switch
@@ -442,6 +450,8 @@ def main(args):
global is_controlnet
global image_quality
build_plugins(args)
image_quality = args.quality
if args.sd_controlnet and args.model in SD15_MODELS:
@@ -496,8 +506,6 @@ def main(args):
callback=diffuser_callback,
)
interactive_seg_model = InteractiveSeg()
if args.gui:
app_width, app_height = args.gui_size
from flaskwebgui import FlaskUI