update test
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
|
||||
@@ -9,9 +8,9 @@ import numpy as np
|
||||
import nvidia_smi
|
||||
import psutil
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lama_cleaner.lama import LaMa
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.schema import Config, HDStrategy, SDSampler
|
||||
|
||||
try:
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
@@ -21,8 +20,6 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
from lama_cleaner.helper import norm_img
|
||||
|
||||
NUM_THREADS = str(4)
|
||||
|
||||
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
|
||||
@@ -37,20 +34,23 @@ if os.environ.get("CACHE_DIR"):
|
||||
def run_model(model, size):
|
||||
# RGB
|
||||
image = np.random.randint(0, 256, (size[0], size[1], 3)).astype(np.uint8)
|
||||
image = norm_img(image)
|
||||
|
||||
mask = np.random.randint(0, 255, size).astype(np.uint8)
|
||||
mask = norm_img(mask)
|
||||
model(image, mask)
|
||||
|
||||
config = Config(
|
||||
ldm_steps=2,
|
||||
hd_strategy=HDStrategy.ORIGINAL,
|
||||
hd_strategy_crop_margin=128,
|
||||
hd_strategy_crop_trigger_size=128,
|
||||
hd_strategy_resize_limit=128,
|
||||
prompt="a fox is sitting on a bench",
|
||||
sd_steps=5,
|
||||
sd_sampler=SDSampler.ddim
|
||||
)
|
||||
model(image, mask, config)
|
||||
|
||||
|
||||
def benchmark(model, times: int, empty_cache: bool):
|
||||
sizes = [
|
||||
(512, 512),
|
||||
(640, 640),
|
||||
(1080, 800),
|
||||
(2000, 2000)
|
||||
]
|
||||
sizes = [(512, 512)]
|
||||
|
||||
nvidia_smi.nvmlInit()
|
||||
device_id = 0
|
||||
@@ -71,8 +71,6 @@ def benchmark(model, times: int, empty_cache: bool):
|
||||
start = time.time()
|
||||
run_model(model, size)
|
||||
torch.cuda.synchronize()
|
||||
if empty_cache:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# cpu_metrics.append(process.cpu_percent())
|
||||
time_metrics.append((time.time() - start) * 1000)
|
||||
@@ -90,8 +88,9 @@ def benchmark(model, times: int, empty_cache: bool):
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--name")
|
||||
parser.add_argument("--device", default="cuda", type=str)
|
||||
parser.add_argument("--times", default=20, type=int)
|
||||
parser.add_argument("--times", default=10, type=int)
|
||||
parser.add_argument("--empty-cache", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -99,5 +98,12 @@ def get_args_parser():
|
||||
if __name__ == "__main__":
|
||||
args = get_args_parser()
|
||||
device = torch.device(args.device)
|
||||
model = LaMa(device)
|
||||
model = ModelManager(
|
||||
name=args.name,
|
||||
device=device,
|
||||
sd_run_local=True,
|
||||
sd_disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
hf_access_token="123"
|
||||
)
|
||||
benchmark(model, args.times, args.empty_cache)
|
||||
|
||||
Reference in New Issue
Block a user