wip: add interactive seg model
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
@@ -15,6 +16,7 @@ import torch
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.interactive_seg import InteractiveSeg, Click
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
@@ -71,6 +73,7 @@ CORS(app, expose_headers=["Content-Disposition"])
|
||||
# socketio = SocketIO(app, max_http_buffer_size=MAX_BUFFER_SIZE, async_mode='threading')
|
||||
|
||||
model: ModelManager = None
|
||||
interactive_seg_model: InteractiveSeg = None
|
||||
device = None
|
||||
input_image_path: str = None
|
||||
is_disable_model_switch: bool = False
|
||||
@@ -97,6 +100,8 @@ def process():
|
||||
|
||||
image, alpha_channel = load_img(origin_image_bytes)
|
||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
||||
|
||||
if image.shape[:2] != mask.shape[:2]:
|
||||
return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400
|
||||
|
||||
@@ -181,6 +186,33 @@ def process():
|
||||
return response
|
||||
|
||||
|
||||
@app.route("/interactive_seg", methods=["POST"])
|
||||
def interactive_seg():
|
||||
input = request.files
|
||||
origin_image_bytes = input["image"].read() # RGB
|
||||
image, _ = load_img(origin_image_bytes)
|
||||
if 'mask' in input:
|
||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
_clicks = json.loads(request.form["clicks"])
|
||||
clicks = []
|
||||
for i, click in enumerate(_clicks):
|
||||
clicks.append(Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1))
|
||||
|
||||
start = time.time()
|
||||
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
|
||||
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
|
||||
response = make_response(
|
||||
send_file(
|
||||
io.BytesIO(numpy_to_bytes(new_mask, 'png')),
|
||||
mimetype=f"image/png",
|
||||
)
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@app.route("/model")
|
||||
def current_model():
|
||||
return model.name, 200
|
||||
@@ -240,6 +272,7 @@ def set_input_photo():
|
||||
|
||||
def main(args):
|
||||
global model
|
||||
global interactive_seg_model
|
||||
global device
|
||||
global input_image_path
|
||||
global is_disable_model_switch
|
||||
@@ -263,6 +296,8 @@ def main(args):
|
||||
callback=diffuser_callback,
|
||||
)
|
||||
|
||||
interactive_seg_model = InteractiveSeg()
|
||||
|
||||
if args.gui:
|
||||
app_width, app_height = args.gui_size
|
||||
from flaskwebgui import FlaskUI
|
||||
|
||||
Reference in New Issue
Block a user