add model md5 check

This commit is contained in:
Qing
2023-02-26 09:19:48 +08:00
parent 64336498ba
commit ecfecac050
9 changed files with 2002 additions and 933 deletions

View File

@@ -11,6 +11,15 @@ import torch
from lama_cleaner.const import MPS_SUPPORT_MODELS
from loguru import logger
from torch.hub import download_url_to_file, get_dir
import hashlib
def md5sum(filename):
md5 = hashlib.md5()
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
md5.update(chunk)
return md5.hexdigest()
def switch_mps_device(model_name, device):
@@ -33,12 +42,22 @@ def get_cache_path_by_url(url):
return cached_file
def download_model(url):
def download_model(url, model_md5: str = None):
cached_file = get_cache_path_by_url(url)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
download_url_to_file(url, cached_file, hash_prefix, progress=True)
if model_md5:
_md5 = md5sum(cached_file)
if model_md5 == _md5:
logger.info(f"Download model success, md5: {_md5}")
else:
logger.error(
f"Download model failed, md5: {_md5}, expected: {model_md5}. Please delete model at {cached_file} and restart lama-cleaner"
)
exit(-1)
return cached_file
@@ -48,42 +67,49 @@ def ceil_modulo(x, mod):
return (x // mod + 1) * mod
def \
load_jit_model(url_or_path, device):
def handle_error(model_path, model_md5, e):
_md5 = md5sum(model_path)
if _md5 != model_md5:
logger.error(
f"Model md5: {_md5}, expected: {model_md5}, please delete {model_path} and restart lama-cleaner."
f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
)
else:
logger.error(
f"Failed to load model {model_path},"
f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
)
exit(-1)
def load_jit_model(url_or_path, device, model_md5: str):
if os.path.exists(url_or_path):
model_path = url_or_path
else:
model_path = download_model(url_or_path)
model_path = download_model(url_or_path, model_md5)
logger.info(f"Loading model from: {model_path}")
try:
model = torch.jit.load(model_path, map_location="cpu").to(device)
except Exception as e:
logger.error(
f"Failed to load {model_path}, please delete model and restart lama-cleaner.\n"
f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
f"If all above operations doesn't work, please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
)
exit(-1)
handle_error(model_path, model_md5, e)
model.eval()
return model
def load_model(model: torch.nn.Module, url_or_path, device):
def load_model(model: torch.nn.Module, url_or_path, device, model_md5):
if os.path.exists(url_or_path):
model_path = url_or_path
else:
model_path = download_model(url_or_path)
model_path = download_model(url_or_path, model_md5)
try:
logger.info(f"Loading model from: {model_path}")
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model.to(device)
logger.info(f"Load model from: {model_path}")
except:
logger.error(
f"Failed to load {model_path}, delete model and restart lama-cleaner"
)
exit(-1)
except Exception as e:
handle_error(model_path, model_md5, e)
model.eval()
return model