add plugins
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
@@ -16,12 +15,12 @@ import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
|
||||
from lama_cleaner.const import SD15_MODELS
|
||||
from lama_cleaner.interactive_seg import InteractiveSeg, Click
|
||||
from lama_cleaner.make_gif import make_compare_gif
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler
|
||||
from lama_cleaner.schema import Config
|
||||
from lama_cleaner.file_manager import FileManager
|
||||
|
||||
@@ -85,7 +84,6 @@ CORS(app, expose_headers=["Content-Disposition"])
|
||||
model: ModelManager = None
|
||||
thumb: FileManager = None
|
||||
output_dir: str = None
|
||||
interactive_seg_model: InteractiveSeg = None
|
||||
device = None
|
||||
input_image_path: str = None
|
||||
is_disable_model_switch: bool = False
|
||||
@@ -94,6 +92,7 @@ is_enable_file_manager: bool = False
|
||||
is_enable_auto_saving: bool = False
|
||||
is_desktop: bool = False
|
||||
image_quality: int = 95
|
||||
plugins = {}
|
||||
|
||||
|
||||
def get_image_ext(img_bytes):
|
||||
@@ -319,35 +318,37 @@ def process():
|
||||
return response
|
||||
|
||||
|
||||
@app.route("/interactive_seg", methods=["POST"])
|
||||
def interactive_seg():
|
||||
input = request.files
|
||||
origin_image_bytes = input["image"].read() # RGB
|
||||
image, _ = load_img(origin_image_bytes)
|
||||
if "mask" in input:
|
||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||
else:
|
||||
mask = None
|
||||
@app.route("/run_plugin/", methods=["POST"])
|
||||
def run_plugin():
|
||||
form = request.form
|
||||
files = request.files
|
||||
|
||||
_clicks = json.loads(request.form["clicks"])
|
||||
clicks = []
|
||||
for i, click in enumerate(_clicks):
|
||||
clicks.append(
|
||||
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)
|
||||
)
|
||||
name = form["name"]
|
||||
if name not in plugins:
|
||||
return "Plugin not found", 500
|
||||
|
||||
origin_image_bytes = files["image"].read() # RGB
|
||||
rgb_np_img, _ = load_img(origin_image_bytes)
|
||||
|
||||
start = time.time()
|
||||
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
|
||||
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
|
||||
res = plugins[name](rgb_np_img, files, form)
|
||||
logger.info(f"{name} process time: {(time.time() - start) * 1000}ms")
|
||||
torch_gc()
|
||||
|
||||
response = make_response(
|
||||
send_file(
|
||||
io.BytesIO(numpy_to_bytes(new_mask, "png")),
|
||||
io.BytesIO(numpy_to_bytes(res, "png")),
|
||||
mimetype=f"image/png",
|
||||
)
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@app.route("/plugins/", methods=["GET"])
|
||||
def get_plugins():
|
||||
return list(plugins.keys()), 200
|
||||
|
||||
|
||||
@app.route("/model")
|
||||
def current_model():
|
||||
return model.name, 200
|
||||
@@ -423,14 +424,21 @@ def set_input_photo():
|
||||
return "No Input Image"
|
||||
|
||||
|
||||
class FSHandler(FileSystemEventHandler):
|
||||
def on_modified(self, event):
|
||||
print("File modified: %s" % event.src_path)
|
||||
def build_plugins(args):
|
||||
global plugins
|
||||
if args.enable_interactive_seg:
|
||||
logger.info(f"Initialize {InteractiveSeg.name} plugin")
|
||||
plugins[InteractiveSeg.name] = InteractiveSeg()
|
||||
if args.enable_remove_bg:
|
||||
logger.info(f"Initialize {RemoveBG.name} plugin")
|
||||
plugins[RemoveBG.name] = RemoveBG()
|
||||
if args.enable_realesrgan:
|
||||
logger.info(f"Initialize {RealESRGANUpscaler.name} plugin")
|
||||
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(args.realesrgan_device)
|
||||
|
||||
|
||||
def main(args):
|
||||
global model
|
||||
global interactive_seg_model
|
||||
global device
|
||||
global input_image_path
|
||||
global is_disable_model_switch
|
||||
@@ -442,6 +450,8 @@ def main(args):
|
||||
global is_controlnet
|
||||
global image_quality
|
||||
|
||||
build_plugins(args)
|
||||
|
||||
image_quality = args.quality
|
||||
|
||||
if args.sd_controlnet and args.model in SD15_MODELS:
|
||||
@@ -496,8 +506,6 @@ def main(args):
|
||||
callback=diffuser_callback,
|
||||
)
|
||||
|
||||
interactive_seg_model = InteractiveSeg()
|
||||
|
||||
if args.gui:
|
||||
app_width, app_height = args.gui_size
|
||||
from flaskwebgui import FlaskUI
|
||||
|
||||
Reference in New Issue
Block a user