update
This commit is contained in:
@@ -32,7 +32,15 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
from flask import Flask, request, send_file, cli, make_response, send_from_directory, jsonify
|
||||
from flask import (
|
||||
Flask,
|
||||
request,
|
||||
send_file,
|
||||
cli,
|
||||
make_response,
|
||||
send_from_directory,
|
||||
jsonify,
|
||||
)
|
||||
|
||||
# Disable ability for Flask to display warning about using a development server in a production environment.
|
||||
# https://gist.github.com/jerblack/735b9953ba1ab6234abb43174210d356
|
||||
@@ -43,6 +51,7 @@ from lama_cleaner.helper import (
|
||||
load_img,
|
||||
numpy_to_bytes,
|
||||
resize_max_size,
|
||||
pil_to_bytes,
|
||||
)
|
||||
|
||||
NUM_THREADS = str(multiprocessing.cpu_count())
|
||||
@@ -103,14 +112,13 @@ def make_gif():
|
||||
origin_image, _ = load_img(origin_image_bytes)
|
||||
clean_image, _ = load_img(clean_image_bytes)
|
||||
gif_bytes = make_compare_gif(
|
||||
Image.fromarray(origin_image),
|
||||
Image.fromarray(clean_image)
|
||||
Image.fromarray(origin_image), Image.fromarray(clean_image)
|
||||
)
|
||||
return send_file(
|
||||
io.BytesIO(gif_bytes),
|
||||
mimetype='image/gif',
|
||||
mimetype="image/gif",
|
||||
as_attachment=True,
|
||||
attachment_filename=filename
|
||||
attachment_filename=filename,
|
||||
)
|
||||
|
||||
|
||||
@@ -121,12 +129,12 @@ def save_image():
|
||||
origin_image_bytes = input["image"].read() # RGB
|
||||
image, _ = load_img(origin_image_bytes)
|
||||
thumb.save_to_output_directory(image, request.form["filename"])
|
||||
return 'ok', 200
|
||||
return "ok", 200
|
||||
|
||||
|
||||
@app.route("/medias/<tab>")
|
||||
def medias(tab):
|
||||
if tab == 'image':
|
||||
if tab == "image":
|
||||
response = make_response(jsonify(thumb.media_names), 200)
|
||||
else:
|
||||
response = make_response(jsonify(thumb.output_media_names), 200)
|
||||
@@ -137,18 +145,18 @@ def medias(tab):
|
||||
return response
|
||||
|
||||
|
||||
@app.route('/media/<tab>/<filename>')
|
||||
@app.route("/media/<tab>/<filename>")
|
||||
def media_file(tab, filename):
|
||||
if tab == 'image':
|
||||
if tab == "image":
|
||||
return send_from_directory(thumb.root_directory, filename)
|
||||
return send_from_directory(thumb.output_dir, filename)
|
||||
|
||||
|
||||
@app.route('/media_thumbnail/<tab>/<filename>')
|
||||
@app.route("/media_thumbnail/<tab>/<filename>")
|
||||
def media_thumbnail_file(tab, filename):
|
||||
args = request.args
|
||||
width = args.get('width')
|
||||
height = args.get('height')
|
||||
width = args.get("width")
|
||||
height = args.get("height")
|
||||
if width is None and height is None:
|
||||
width = 256
|
||||
if width:
|
||||
@@ -157,9 +165,11 @@ def media_thumbnail_file(tab, filename):
|
||||
height = int(float(height))
|
||||
|
||||
directory = thumb.root_directory
|
||||
if tab == 'output':
|
||||
if tab == "output":
|
||||
directory = thumb.output_dir
|
||||
thumb_filename, (width, height) = thumb.get_thumbnail(directory, filename, width, height)
|
||||
thumb_filename, (width, height) = thumb.get_thumbnail(
|
||||
directory, filename, width, height
|
||||
)
|
||||
thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}"
|
||||
|
||||
response = make_response(send_file(thumb_filepath))
|
||||
@@ -173,13 +183,16 @@ def process():
|
||||
input = request.files
|
||||
# RGB
|
||||
origin_image_bytes = input["image"].read()
|
||||
image, alpha_channel = load_img(origin_image_bytes)
|
||||
image, alpha_channel, exif = load_img(origin_image_bytes, return_exif=True)
|
||||
|
||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
||||
|
||||
if image.shape[:2] != mask.shape[:2]:
|
||||
return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400
|
||||
return (
|
||||
f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}",
|
||||
400,
|
||||
)
|
||||
|
||||
original_shape = image.shape
|
||||
interpolation = cv2.INTER_CUBIC
|
||||
@@ -192,7 +205,9 @@ def process():
|
||||
size_limit = int(size_limit)
|
||||
|
||||
if "paintByExampleImage" in input:
|
||||
paint_by_example_example_image, _ = load_img(input["paintByExampleImage"].read())
|
||||
paint_by_example_example_image, _ = load_img(
|
||||
input["paintByExampleImage"].read()
|
||||
)
|
||||
paint_by_example_example_image = Image.fromarray(paint_by_example_example_image)
|
||||
else:
|
||||
paint_by_example_example_image = None
|
||||
@@ -221,7 +236,7 @@ def process():
|
||||
sd_seed=form["sdSeed"],
|
||||
sd_match_histograms=form["sdMatchHistograms"],
|
||||
cv2_flag=form["cv2Flag"],
|
||||
cv2_radius=form['cv2Radius'],
|
||||
cv2_radius=form["cv2Radius"],
|
||||
paint_by_example_steps=form["paintByExampleSteps"],
|
||||
paint_by_example_guidance_scale=form["paintByExampleGuidanceScale"],
|
||||
paint_by_example_mask_blur=form["paintByExampleMaskBlur"],
|
||||
@@ -259,6 +274,7 @@ def process():
|
||||
logger.info(f"process time: {(time.time() - start) * 1000}ms")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
res_np_img = cv2.cvtColor(res_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
||||
if alpha_channel is not None:
|
||||
if alpha_channel.shape[:2] != res_np_img.shape[:2]:
|
||||
alpha_channel = cv2.resize(
|
||||
@@ -270,9 +286,15 @@ def process():
|
||||
|
||||
ext = get_image_ext(origin_image_bytes)
|
||||
|
||||
if exif is not None:
|
||||
bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext, exif=exif))
|
||||
else:
|
||||
bytes_io = io.BytesIO(pil_to_bytes(Image.fromarray(res_np_img), ext))
|
||||
|
||||
response = make_response(
|
||||
send_file(
|
||||
io.BytesIO(numpy_to_bytes(res_np_img, ext)),
|
||||
# io.BytesIO(numpy_to_bytes(res_np_img, ext)),
|
||||
bytes_io,
|
||||
mimetype=f"image/{ext}",
|
||||
)
|
||||
)
|
||||
@@ -285,7 +307,7 @@ def interactive_seg():
|
||||
input = request.files
|
||||
origin_image_bytes = input["image"].read() # RGB
|
||||
image, _ = load_img(origin_image_bytes)
|
||||
if 'mask' in input:
|
||||
if "mask" in input:
|
||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||
else:
|
||||
mask = None
|
||||
@@ -293,14 +315,16 @@ def interactive_seg():
|
||||
_clicks = json.loads(request.form["clicks"])
|
||||
clicks = []
|
||||
for i, click in enumerate(_clicks):
|
||||
clicks.append(Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1))
|
||||
clicks.append(
|
||||
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
new_mask = interactive_seg_model(image, clicks=clicks, prev_mask=mask)
|
||||
logger.info(f"interactive seg process time: {(time.time() - start) * 1000}ms")
|
||||
response = make_response(
|
||||
send_file(
|
||||
io.BytesIO(numpy_to_bytes(new_mask, 'png')),
|
||||
io.BytesIO(numpy_to_bytes(new_mask, "png")),
|
||||
mimetype=f"image/png",
|
||||
)
|
||||
)
|
||||
@@ -314,13 +338,13 @@ def current_model():
|
||||
|
||||
@app.route("/is_disable_model_switch")
|
||||
def get_is_disable_model_switch():
|
||||
res = 'true' if is_disable_model_switch else 'false'
|
||||
res = "true" if is_disable_model_switch else "false"
|
||||
return res, 200
|
||||
|
||||
|
||||
@app.route("/is_enable_file_manager")
|
||||
def get_is_enable_file_manager():
|
||||
res = 'true' if is_enable_file_manager else 'false'
|
||||
res = "true" if is_enable_file_manager else "false"
|
||||
return res, 200
|
||||
|
||||
|
||||
@@ -389,14 +413,18 @@ def main(args):
|
||||
is_disable_model_switch = args.disable_model_switch
|
||||
is_desktop = args.gui
|
||||
if is_disable_model_switch:
|
||||
logger.info(f"Start with --disable-model-switch, model switch on frontend is disable")
|
||||
logger.info(
|
||||
f"Start with --disable-model-switch, model switch on frontend is disable"
|
||||
)
|
||||
|
||||
if args.input and os.path.isdir(args.input):
|
||||
logger.info(f"Initialize file manager")
|
||||
thumb = FileManager(app)
|
||||
is_enable_file_manager = True
|
||||
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
|
||||
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(args.output_dir, 'lama_cleaner_thumbnails')
|
||||
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
|
||||
args.output_dir, "lama_cleaner_thumbnails"
|
||||
)
|
||||
thumb.output_dir = Path(args.output_dir)
|
||||
# thumb.start()
|
||||
# try:
|
||||
@@ -432,8 +460,12 @@ def main(args):
|
||||
from flaskwebgui import FlaskUI
|
||||
|
||||
ui = FlaskUI(
|
||||
app, width=app_width, height=app_height, host=args.host, port=args.port,
|
||||
close_server_on_exit=not args.no_gui_auto_close
|
||||
app,
|
||||
width=app_width,
|
||||
height=app_height,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
close_server_on_exit=not args.no_gui_auto_close,
|
||||
)
|
||||
ui.run()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user