big update
This commit is contained in:
94
main.py
94
main.py
@@ -2,19 +2,21 @@
|
||||
|
||||
import argparse
|
||||
import io
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import imghdr
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from lama_cleaner.lama import LaMa
|
||||
from lama_cleaner.ldm import LDM
|
||||
from loguru import logger
|
||||
|
||||
from flaskwebgui import FlaskUI
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
try:
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
@@ -29,7 +31,6 @@ from flask_cors import CORS
|
||||
|
||||
from lama_cleaner.helper import (
|
||||
load_img,
|
||||
norm_img,
|
||||
numpy_to_bytes,
|
||||
resize_max_size,
|
||||
)
|
||||
@@ -46,11 +47,19 @@ if os.environ.get("CACHE_DIR"):
|
||||
|
||||
BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
logger_opt = logger.opt(depth=6, exception=record.exc_info)
|
||||
logger_opt.log(record.levelno, record.getMessage())
|
||||
|
||||
|
||||
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
||||
app.config["JSON_AS_ASCII"] = False
|
||||
CORS(app)
|
||||
app.logger.addHandler(InterceptHandler())
|
||||
CORS(app, expose_headers=["Content-Disposition"])
|
||||
|
||||
model = None
|
||||
model: ModelManager = None
|
||||
device = None
|
||||
input_image_path: str = None
|
||||
|
||||
@@ -72,24 +81,31 @@ def process():
|
||||
original_shape = image.shape
|
||||
interpolation = cv2.INTER_CUBIC
|
||||
|
||||
size_limit: Union[int, str] = request.form.get("sizeLimit", "1080")
|
||||
form = request.form
|
||||
size_limit: Union[int, str] = form.get("sizeLimit", "1080")
|
||||
if size_limit == "Original":
|
||||
size_limit = max(image.shape)
|
||||
else:
|
||||
size_limit = int(size_limit)
|
||||
|
||||
print(f"Origin image shape: {original_shape}")
|
||||
config = Config(
|
||||
ldm_steps=form['ldmSteps'],
|
||||
hd_strategy=form['hdStrategy'],
|
||||
hd_strategy_crop_margin=form['hdStrategyCropMargin'],
|
||||
hd_strategy_crop_trigger_size=form['hdStrategyCropTrigerSize'],
|
||||
hd_strategy_resize_limit=form['hdStrategyResizeLimit'],
|
||||
)
|
||||
|
||||
logger.info(f"Origin image shape: {original_shape}")
|
||||
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
||||
print(f"Resized image shape: {image.shape}")
|
||||
image = norm_img(image)
|
||||
logger.info(f"Resized image shape: {image.shape}")
|
||||
|
||||
mask, _ = load_img(input["mask"].read(), gray=True)
|
||||
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
||||
mask = norm_img(mask)
|
||||
|
||||
start = time.time()
|
||||
res_np_img = model(image, mask)
|
||||
print(f"process time: {(time.time() - start) * 1000}ms")
|
||||
res_np_img = model(image, mask, config)
|
||||
logger.info(f"process time: {(time.time() - start) * 1000}ms")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -109,6 +125,19 @@ def process():
|
||||
)
|
||||
|
||||
|
||||
@app.route("/switch_model", methods=["POST"])
|
||||
def switch_model():
|
||||
new_name = request.form.get("name")
|
||||
if new_name == model.name:
|
||||
return "Same model", 200
|
||||
|
||||
try:
|
||||
model.switch(new_name)
|
||||
except NotImplementedError:
|
||||
return f"{new_name} not implemented", 403
|
||||
return f"ok, switch to {new_name}", 200
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
||||
@@ -120,7 +149,9 @@ def set_input_photo():
|
||||
with open(input_image_path, "rb") as f:
|
||||
image_in_bytes = f.read()
|
||||
return send_file(
|
||||
io.BytesIO(image_in_bytes),
|
||||
input_image_path,
|
||||
as_attachment=True,
|
||||
download_name=Path(input_image_path).name,
|
||||
mimetype=f"image/{get_image_ext(image_in_bytes)}",
|
||||
)
|
||||
else:
|
||||
@@ -135,29 +166,6 @@ def get_args_parser():
|
||||
parser.add_argument("--host", default="127.0.0.1")
|
||||
parser.add_argument("--port", default=8080, type=int)
|
||||
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
|
||||
parser.add_argument(
|
||||
"--crop-trigger-size",
|
||||
default=[2042, 2042],
|
||||
nargs=2,
|
||||
type=int,
|
||||
help="If image size large then crop-trigger-size, "
|
||||
"crop each area from original image to do inference."
|
||||
"Mainly for performance and memory reasons"
|
||||
"Only for lama",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop-margin",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Margin around bounding box of painted stroke when crop mode triggered",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ldm-steps",
|
||||
default=50,
|
||||
type=int,
|
||||
help="Steps for DDIM sampling process."
|
||||
"The larger the value, the better the result, but it will be more time-consuming",
|
||||
)
|
||||
parser.add_argument("--device", default="cuda", type=str)
|
||||
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
||||
parser.add_argument(
|
||||
@@ -188,19 +196,11 @@ def main():
|
||||
device = torch.device(args.device)
|
||||
input_image_path = args.input
|
||||
|
||||
if args.model == "lama":
|
||||
model = LaMa(
|
||||
crop_trigger_size=args.crop_trigger_size,
|
||||
crop_margin=args.crop_margin,
|
||||
device=device,
|
||||
)
|
||||
elif args.model == "ldm":
|
||||
model = LDM(device, steps=args.ldm_steps)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported model: {args.model}")
|
||||
model = ModelManager(name=args.model, device=device)
|
||||
|
||||
if args.gui:
|
||||
app_width, app_height = args.gui_size
|
||||
from flaskwebgui import FlaskUI
|
||||
ui = FlaskUI(app, width=app_width, height=app_height)
|
||||
ui.run()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user