add ldm model
This commit is contained in:
63
main.py
63
main.py
@@ -4,14 +4,14 @@ import argparse
|
||||
import io
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
from distutils.util import strtobool
|
||||
from typing import Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lama_cleaner.lama import LaMa
|
||||
from lama_cleaner.ldm import LDM
|
||||
|
||||
try:
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
@@ -24,13 +24,10 @@ from flask import Flask, request, send_file
|
||||
from flask_cors import CORS
|
||||
|
||||
from lama_cleaner.helper import (
|
||||
download_model,
|
||||
load_img,
|
||||
norm_img,
|
||||
numpy_to_bytes,
|
||||
pad_img_to_modulo,
|
||||
resize_max_size,
|
||||
)
|
||||
resize_max_size, )
|
||||
|
||||
NUM_THREADS = str(multiprocessing.cpu_count())
|
||||
|
||||
@@ -55,6 +52,7 @@ device = None
|
||||
@app.route("/inpaint", methods=["POST"])
|
||||
def process():
|
||||
input = request.files
|
||||
# RGB
|
||||
image = load_img(input["image"].read())
|
||||
original_shape = image.shape
|
||||
interpolation = cv2.INTER_CUBIC
|
||||
@@ -74,14 +72,7 @@ def process():
|
||||
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
||||
mask = norm_img(mask)
|
||||
|
||||
res_np_img = run(image, mask)
|
||||
|
||||
# resize to original size
|
||||
# res_np_img = cv2.resize(
|
||||
# res_np_img,
|
||||
# dsize=(original_shape[1], original_shape[0]),
|
||||
# interpolation=interpolation,
|
||||
# )
|
||||
res_np_img = model(image, mask)
|
||||
|
||||
return send_file(
|
||||
io.BytesIO(numpy_to_bytes(res_np_img)),
|
||||
@@ -96,35 +87,12 @@ def index():
|
||||
return send_file(os.path.join(BUILD_DIR, "index.html"))
|
||||
|
||||
|
||||
def run(image, mask):
|
||||
"""
|
||||
image: [C, H, W]
|
||||
mask: [1, H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
origin_height, origin_width = image.shape[1:]
|
||||
image = pad_img_to_modulo(image, mod=8)
|
||||
mask = pad_img_to_modulo(mask, mod=8)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
start = time.time()
|
||||
with torch.no_grad():
|
||||
inpainted_image = model(image, mask)
|
||||
|
||||
print(f"process time: {(time.time() - start)*1000}ms")
|
||||
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
cur_res = cur_res[0:origin_height, 0:origin_width, :]
|
||||
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
||||
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
|
||||
return cur_res
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", default=8080, type=int)
|
||||
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
|
||||
parser.add_argument("--ldm-steps", default=50, type=int, help="Steps for DDIM sampling process."
|
||||
"The larger the value, the better the result, but it will be more time-consuming")
|
||||
parser.add_argument("--device", default="cuda", type=str)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
return parser.parse_args()
|
||||
@@ -136,16 +104,13 @@ def main():
|
||||
args = get_args_parser()
|
||||
device = torch.device(args.device)
|
||||
|
||||
if os.environ.get("LAMA_MODEL"):
|
||||
model_path = os.environ.get("LAMA_MODEL")
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"lama torchscript model not found: {model_path}")
|
||||
if args.model == "lama":
|
||||
model = LaMa(device)
|
||||
elif args.model == "ldm":
|
||||
model = LDM(device, steps=args.ldm_steps)
|
||||
else:
|
||||
model_path = download_model()
|
||||
raise NotImplementedError(f"Not supported model: {args.model}")
|
||||
|
||||
model = torch.jit.load(model_path, map_location="cpu")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
app.run(host="0.0.0.0", port=args.port, debug=args.debug)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user