wip: add interactive seg model

This commit is contained in:
Qing
2022-11-27 21:25:27 +08:00
parent af87cca643
commit 023306ae40
20 changed files with 820 additions and 46 deletions

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python3
import io
import json
import logging
import multiprocessing
import os
@@ -15,6 +16,7 @@ import torch
import numpy as np
from loguru import logger
from lama_cleaner.interactive_seg import InteractiveSeg, Click
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config
@@ -71,6 +73,7 @@ CORS(app, expose_headers=["Content-Disposition"])
# socketio = SocketIO(app, max_http_buffer_size=MAX_BUFFER_SIZE, async_mode='threading')
model: ModelManager = None
interactive_seg_model: InteractiveSeg = None
device = None
input_image_path: str = None
is_disable_model_switch: bool = False
@@ -97,6 +100,8 @@ def process():
image, alpha_channel = load_img(origin_image_bytes)
mask, _ = load_img(input["mask"].read(), gray=True)
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
if image.shape[:2] != mask.shape[:2]:
return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400
@@ -181,6 +186,33 @@ def process():
return response
@app.route("/interactive_seg", methods=["POST"])
def interactive_seg():
input = request.files
origin_image_bytes = input["image"].read() # RGB
image, _ = load_img(origin_image_bytes)
if 'mask' in input:
mask, _ = load_img(input["mask"].read(), gray=True)
else:
mask = None
_clicks = json.loads(request.form["clicks"])
clicks = []
for i, click in enumerate(_clicks):
clicks.append(Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1))
start = time.time()
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
response = make_response(
send_file(
io.BytesIO(numpy_to_bytes(new_mask, 'png')),
mimetype=f"image/png",
)
)
return response
@app.route("/model")
def current_model():
return model.name, 200
@@ -240,6 +272,7 @@ def set_input_photo():
def main(args):
global model
global interactive_seg_model
global device
global input_image_path
global is_disable_model_switch
@@ -263,6 +296,8 @@ def main(args):
callback=diffuser_callback,
)
interactive_seg_model = InteractiveSeg()
if args.gui:
app_width, app_height = args.gui_size
from flaskwebgui import FlaskUI