make plugin work

This commit is contained in:
Qing
2023-03-25 09:53:22 +08:00
parent 996a264797
commit 6e54f77ed6
16 changed files with 528 additions and 284 deletions

View File

@@ -17,10 +17,10 @@ import numpy as np
from loguru import logger
from lama_cleaner.const import SD15_MODELS
from lama_cleaner.make_gif import make_compare_gif
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.plugins import InteractiveSeg, RemoveBG, RealESRGANUpscaler
from lama_cleaner.plugins.gif import MakeGIF
from lama_cleaner.schema import Config
from lama_cleaner.file_manager import FileManager
@@ -318,11 +318,10 @@ def process():
return response
@app.route("/run_plugin/", methods=["POST"])
@app.route("/run_plugin", methods=["POST"])
def run_plugin():
form = request.form
files = request.files
name = form["name"]
if name not in plugins:
return "Plugin not found", 500
@@ -335,18 +334,33 @@ def run_plugin():
logger.info(f"{name} process time: {(time.time() - start) * 1000}ms")
torch_gc()
response = make_response(
send_file(
io.BytesIO(numpy_to_bytes(res, "png")),
mimetype=f"image/png",
if name == MakeGIF.name:
filename = form["filename"]
return send_file(
io.BytesIO(res),
mimetype="image/gif",
as_attachment=True,
attachment_filename=filename,
)
else:
response = make_response(
send_file(
io.BytesIO(numpy_to_bytes(res, "png")),
mimetype=f"image/png",
)
)
)
return response
@app.route("/plugins/", methods=["GET"])
def get_plugins():
return list(plugins.keys()), 200
@app.route("/server_config", methods=["GET"])
def get_server_config():
return {
"isControlNet": is_controlnet,
"isDisableModelSwitchState": is_disable_model_switch,
"isEnableAutoSaving": is_enable_file_manager,
"enableFileManager": is_enable_auto_saving,
"plugins": list(plugins.keys()),
}, 200
@app.route("/model")
@@ -354,30 +368,6 @@ def current_model():
return model.name, 200
@app.route("/is_controlnet")
def get_is_controlnet():
res = "true" if is_controlnet else "false"
return res, 200
@app.route("/is_disable_model_switch")
def get_is_disable_model_switch():
res = "true" if is_disable_model_switch else "false"
return res, 200
@app.route("/is_enable_file_manager")
def get_is_enable_file_manager():
res = "true" if is_enable_file_manager else "false"
return res, 200
@app.route("/is_enable_auto_saving")
def get_is_enable_auto_saving():
res = "true" if is_enable_auto_saving else "false"
return res, 200
@app.route("/model_downloaded/<name>")
def model_downloaded(name):
return str(model.is_downloaded(name)), 200
@@ -435,6 +425,9 @@ def build_plugins(args):
if args.enable_realesrgan:
logger.info(f"Initialize {RealESRGANUpscaler.name} plugin")
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(args.realesrgan_device)
if args.enable_gif:
logger.info(f"Initialize GIF plugin")
plugins[MakeGIF.name] = MakeGIF()
def main(args):