big update

This commit is contained in:
Sanster
2022-04-16 00:11:51 +08:00
parent 2b031603ed
commit 205286a414
40 changed files with 539 additions and 376 deletions

94
main.py
View File

@@ -2,19 +2,21 @@
import argparse
import io
import logging
import multiprocessing
import os
import time
import imghdr
from pathlib import Path
from typing import Union
import cv2
import torch
import numpy as np
from lama_cleaner.lama import LaMa
from lama_cleaner.ldm import LDM
from loguru import logger
from flaskwebgui import FlaskUI
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config
try:
torch._C._jit_override_can_fuse_on_cpu(False)
@@ -29,7 +31,6 @@ from flask_cors import CORS
from lama_cleaner.helper import (
load_img,
norm_img,
numpy_to_bytes,
resize_max_size,
)
@@ -46,11 +47,19 @@ if os.environ.get("CACHE_DIR"):
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
class InterceptHandler(logging.Handler):
def emit(self, record):
logger_opt = logger.opt(depth=6, exception=record.exc_info)
logger_opt.log(record.levelno, record.getMessage())
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False
CORS(app)
app.logger.addHandler(InterceptHandler())
CORS(app, expose_headers=["Content-Disposition"])
model = None
model: ModelManager = None
device = None
input_image_path: str = None
@@ -72,24 +81,31 @@ def process():
original_shape = image.shape
interpolation = cv2.INTER_CUBIC
size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
form = request.form
size_limit: Union[int, str] = form.get("sizeLimit", "1080")
if size_limit == "Original":
size_limit = max(image.shape)
else:
size_limit = int(size_limit)
print(f"Origin image shape: {original_shape}")
config = Config(
ldm_steps=form['ldmSteps'],
hd_strategy=form['hdStrategy'],
hd_strategy_crop_margin=form['hdStrategyCropMargin'],
hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'],
hd_strategy_resize_limit=form['hdStrategyResizeLimit'],
)
logger.info(f"Origin image shape: {original_shape}")
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
print(f"Resized image shape: {image.shape}")
image = norm_img(image)
logger.info(f"Resized image shape: {image.shape}")
mask, _ = load_img(input["mask"].read(), gray=True)
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
mask = norm_img(mask)
start = time.time()
res_np_img = model(image, mask)
print(f"process time: {(time.time() - start) * 1000}ms")
res_np_img = model(image, mask, config)
logger.info(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache()
@@ -109,6 +125,19 @@ def process():
)
@app.route("/switch_model", methods=["POST"])
def switch_model():
new_name = request.form.get("name")
if new_name == model.name:
return "Same model", 200
try:
model.switch(new_name)
except NotImplementedError:
return f"{new_name} not implemented", 403
return f"ok, switch to {new_name}", 200
@app.route("/")
def index():
return send_file(os.path.join(BUILD_DIR, "index.html"))
@@ -120,7 +149,9 @@ def set_input_photo():
with open(input_image_path, "rb") as f:
image_in_bytes = f.read()
return send_file(
io.BytesIO(image_in_bytes),
input_image_path,
as_attachment=True,
download_name=Path(input_image_path).name,
mimetype=f"image/{get_image_ext(image_in_bytes)}",
)
else:
@@ -135,29 +166,6 @@ def get_args_parser():
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
parser.add_argument(
"--crop-trigger-size",
default=[2042, 2042],
nargs=2,
type=int,
help="If image size large then crop-trigger-size, "
"crop each area from original image to do inference."
"Mainly for performance and memory reasons"
"Only for lama",
)
parser.add_argument(
"--crop-margin",
type=int,
default=256,
help="Margin around bounding box of painted stroke when crop mode triggered",
)
parser.add_argument(
"--ldm-steps",
default=50,
type=int,
help="Steps for DDIM sampling process."
"The larger the value, the better the result, but it will be more time-consuming",
)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
parser.add_argument(
@@ -188,19 +196,11 @@ def main():
device = torch.device(args.device)
input_image_path = args.input
if args.model == "lama":
model = LaMa(
crop_trigger_size=args.crop_trigger_size,
crop_margin=args.crop_margin,
device=device,
)
elif args.model == "ldm":
model = LDM(device, steps=args.ldm_steps)
else:
raise NotImplementedError(f"Not supported model: {args.model}")
model = ModelManager(name=args.model, device=device)
if args.gui:
app_width, app_height = args.gui_size
from flaskwebgui import FlaskUI
ui = FlaskUI(app, width=app_width, height=app_height)
ui.run()
else: