resize image using backend;add resize radio button

frontend resize image will reduce image quality
This commit is contained in:
Qing
2021-11-27 20:37:37 +08:00
committed by Sanster
parent 1c2e7fa559
commit 1e2c8fd348
9 changed files with 163 additions and 144 deletions

44
main.py
View File

@@ -4,6 +4,8 @@ import io
import os
import time
import argparse
from distutils.util import strtobool
from typing import Union
import cv2
import numpy as np
import torch
@@ -13,6 +15,8 @@ from flask_cors import CORS
from lama_cleaner.helper import (
download_model,
load_img,
norm_img,
resize_max_size,
numpy_to_bytes,
pad_img_to_modulo,
)
@@ -43,13 +47,38 @@ device = None
def process():
input = request.files
image = load_img(input["image"].read())
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
if size_limit == "Original":
size_limit = max(image.shape)
else:
size_limit = int(size_limit)
print(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
print(f"Resized image shape: {image.shape}")
image = norm_img(image)
mask = load_img(input["mask"].read(), gray=True)
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
mask = norm_img(mask)
res_np_img = run(image, mask)
# resize to original size
res_np_img = cv2.resize(
res_np_img,
dsize=(original_shape[1], original_shape[0]),
interpolation=interpolation,
)
return send_file(
io.BytesIO(numpy_to_bytes(res_np_img)),
mimetype="image/png",
mimetype="image/jpeg",
as_attachment=True,
attachment_filename="result.png",
attachment_filename="result.jpeg",
)
@@ -61,6 +90,8 @@ def index():
def run(image, mask):
"""
image: [C, H, W]
mask: [1, H, W]
return: BGR IMAGE
"""
origin_height, origin_width = image.shape[1:]
image = pad_img_to_modulo(image, mod=8)
@@ -73,13 +104,11 @@ def run(image, mask):
start = time.time()
inpainted_image = model(image, mask)
print(
f"inpainted image shape: {inpainted_image.shape} process time: {(time.time() - start)*1000}ms"
)
print(f"process time: {(time.time() - start)*1000}ms")
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = cur_res[0:origin_height, 0:origin_width, :]
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
return cur_res
@@ -87,6 +116,7 @@ def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--debug", action="store_true")
return parser.parse_args()
@@ -98,7 +128,7 @@ def main():
model_path = download_model()
model = torch.jit.load(model_path, map_location="cpu")
model = model.to(device)
app.run(host="0.0.0.0", port=args.port, debug=False)
app.run(host="0.0.0.0", port=args.port, debug=args.debug)
if __name__ == "__main__":