This commit is contained in:
Qing
2023-02-06 22:00:47 +08:00
parent 24bff09534
commit 3f6bc8fada
9 changed files with 307 additions and 91 deletions

View File

@@ -32,7 +32,15 @@ try:
except:
pass
from flask import Flask, request, send_file, cli, make_response, send_from_directory, jsonify
from flask import (
Flask,
request,
send_file,
cli,
make_response,
send_from_directory,
jsonify,
)
# Disable ability for Flask to display warning about using a development server in a production environment.
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
@@ -43,6 +51,7 @@ from lama_cleaner.helper import (
load_img,
numpy_to_bytes,
resize_max_size,
pil_to_bytes,
)
NUM_THREADS = str(multiprocessing.cpu_count())
@@ -103,14 +112,13 @@ def make_gif():
origin_image, _ = load_img(origin_image_bytes)
clean_image, _ = load_img(clean_image_bytes)
gif_bytes = make_compare_gif(
Image.fromarray(origin_image),
Image.fromarray(clean_image)
Image.fromarray(origin_image), Image.fromarray(clean_image)
)
return send_file(
io.BytesIO(gif_bytes),
mimetype='image/gif',
mimetype="image/gif",
as_attachment=True,
attachment_filename=filename
attachment_filename=filename,
)
@@ -121,12 +129,12 @@ def save_image():
origin_image_bytes = input["image"].read() # RGB
image, _ = load_img(origin_image_bytes)
thumb.save_to_output_directory(image, request.form["filename"])
return 'ok', 200
return "ok", 200
@app.route("/medias/<tab>")
def medias(tab):
if tab == 'image':
if tab == "image":
response = make_response(jsonify(thumb.media_names), 200)
else:
response = make_response(jsonify(thumb.output_media_names), 200)
@@ -137,18 +145,18 @@ def medias(tab):
return response
@app.route('/media/<tab>/<filename>')
@app.route("/media/<tab>/<filename>")
def media_file(tab, filename):
if tab == 'image':
if tab == "image":
return send_from_directory(thumb.root_directory, filename)
return send_from_directory(thumb.output_dir, filename)
@app.route('/media_thumbnail/<tab>/<filename>')
@app.route("/media_thumbnail/<tab>/<filename>")
def media_thumbnail_file(tab, filename):
args = request.args
width = args.get('width')
height = args.get('height')
width = args.get("width")
height = args.get("height")
if width is None and height is None:
width = 256
if width:
@@ -157,9 +165,11 @@ def media_thumbnail_file(tab, filename):
height = int(float(height))
directory = thumb.root_directory
if tab == 'output':
if tab == "output":
directory = thumb.output_dir
thumb_filename, (width, height) = thumb.get_thumbnail(directory, filename, width, height)
thumb_filename, (width, height) = thumb.get_thumbnail(
directory, filename, width, height
)
thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}"
response = make_response(send_file(thumb_filepath))
@@ -173,13 +183,16 @@ def process():
input = request.files
# RGB
origin_image_bytes = input["image"].read()
image, alpha_channel = load_img(origin_image_bytes)
image, alpha_channel, exif = load_img(origin_image_bytes, return_exif=True)
mask, _ = load_img(input["mask"].read(), gray=True)
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
if image.shape[:2] != mask.shape[:2]:
return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400
return (
f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}",
400,
)
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
@@ -192,7 +205,9 @@ def process():
size_limit = int(size_limit)
if "paintByExampleImage" in input:
paint_by_example_example_image, _ = load_img(input["paintByExampleImage"].read())
paint_by_example_example_image, _ = load_img(
input["paintByExampleImage"].read()
)
paint_by_example_example_image = Image.fromarray(paint_by_example_example_image)
else:
paint_by_example_example_image = None
@@ -221,7 +236,7 @@ def process():
sd_seed=form["sdSeed"],
sd_match_histograms=form["sdMatchHistograms"],
cv2_flag=form["cv2Flag"],
cv2_radius=form['cv2Radius'],
cv2_radius=form["cv2Radius"],
paint_by_example_steps=form["paintByExampleSteps"],
paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"],
paint_by_example_mask_blur=form["paintByExampleMaskBlur"],
@@ -259,6 +274,7 @@ def process():
logger.info(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache()
res_np_img = cv2.cvtColor(res_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
if alpha_channel is not None:
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
alpha_channel = cv2.resize(
@@ -270,9 +286,15 @@ def process():
ext = get_image_ext(origin_image_bytes)
if exif is not None:
bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext, exif=exif))
else:
bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext))
response = make_response(
send_file(
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
# io.BytesIO(numpy_to_bytes(res_np_img, ext)),
bytes_io,
mimetype=f"image/{ext}",
)
)
@@ -285,7 +307,7 @@ def interactive_seg():
input = request.files
origin_image_bytes = input["image"].read() # RGB
image, _ = load_img(origin_image_bytes)
if 'mask' in input:
if "mask" in input:
mask, _ = load_img(input["mask"].read(), gray=True)
else:
mask = None
@@ -293,14 +315,16 @@ def interactive_seg():
_clicks = json.loads(request.form["clicks"])
clicks = []
for i, click in enumerate(_clicks):
clicks.append(Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1))
clicks.append(
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)
)
start = time.time()
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
response = make_response(
send_file(
io.BytesIO(numpy_to_bytes(new_mask, 'png')),
io.BytesIO(numpy_to_bytes(new_mask, "png")),
mimetype=f"image/png",
)
)
@@ -314,13 +338,13 @@ def current_model():
@app.route("/is_disable_model_switch")
def get_is_disable_model_switch():
res = 'true' if is_disable_model_switch else 'false'
res = "true" if is_disable_model_switch else "false"
return res, 200
@app.route("/is_enable_file_manager")
def get_is_enable_file_manager():
res = 'true' if is_enable_file_manager else 'false'
res = "true" if is_enable_file_manager else "false"
return res, 200
@@ -389,14 +413,18 @@ def main(args):
is_disable_model_switch = args.disable_model_switch
is_desktop = args.gui
if is_disable_model_switch:
logger.info(f"Start with --disable-model-switch, model switch on frontend is disable")
logger.info(
f"Start with --disable-model-switch, model switch on frontend is disable"
)
if args.input and os.path.isdir(args.input):
logger.info(f"Initialize file manager")
thumb = FileManager(app)
is_enable_file_manager = True
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(args.output_dir, 'lama_cleaner_thumbnails')
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
args.output_dir, "lama_cleaner_thumbnails"
)
thumb.output_dir = Path(args.output_dir)
# thumb.start()
# try:
@@ -432,8 +460,12 @@ def main(args):
from flaskwebgui import FlaskUI
ui = FlaskUI(
app, width=app_width, height=app_height, host=args.host, port=args.port,
close_server_on_exit=not args.no_gui_auto_close
app,
width=app_width,
height=app_height,
host=args.host,
port=args.port,
close_server_on_exit=not args.no_gui_auto_close,
)
ui.run()
else: