add model md5 check
This commit is contained in:
@@ -11,6 +11,15 @@ import torch
|
||||
from lama_cleaner.const import MPS_SUPPORT_MODELS
|
||||
from loguru import logger
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
import hashlib
|
||||
|
||||
|
||||
def md5sum(filename):
|
||||
md5 = hashlib.md5()
|
||||
with open(filename, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(128 * md5.block_size), b""):
|
||||
md5.update(chunk)
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
def switch_mps_device(model_name, device):
|
||||
@@ -33,12 +42,22 @@ def get_cache_path_by_url(url):
|
||||
return cached_file
|
||||
|
||||
|
||||
def download_model(url):
|
||||
def download_model(url, model_md5: str = None):
|
||||
cached_file = get_cache_path_by_url(url)
|
||||
if not os.path.exists(cached_file):
|
||||
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
hash_prefix = None
|
||||
download_url_to_file(url, cached_file, hash_prefix, progress=True)
|
||||
if model_md5:
|
||||
_md5 = md5sum(cached_file)
|
||||
if model_md5 == _md5:
|
||||
logger.info(f"Download model success, md5: {_md5}")
|
||||
else:
|
||||
logger.error(
|
||||
f"Download model failed, md5: {_md5}, expected: {model_md5}. Please delete model at {cached_file} and restart lama-cleaner"
|
||||
)
|
||||
exit(-1)
|
||||
|
||||
return cached_file
|
||||
|
||||
|
||||
@@ -48,42 +67,49 @@ def ceil_modulo(x, mod):
|
||||
return (x // mod + 1) * mod
|
||||
|
||||
|
||||
def \
|
||||
load_jit_model(url_or_path, device):
|
||||
def handle_error(model_path, model_md5, e):
|
||||
_md5 = md5sum(model_path)
|
||||
if _md5 != model_md5:
|
||||
logger.error(
|
||||
f"Model md5: {_md5}, expected: {model_md5}, please delete {model_path} and restart lama-cleaner."
|
||||
f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to load model {model_path},"
|
||||
f"please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
|
||||
)
|
||||
exit(-1)
|
||||
|
||||
|
||||
def load_jit_model(url_or_path, device, model_md5: str):
|
||||
if os.path.exists(url_or_path):
|
||||
model_path = url_or_path
|
||||
else:
|
||||
model_path = download_model(url_or_path)
|
||||
model_path = download_model(url_or_path, model_md5)
|
||||
|
||||
logger.info(f"Loading model from: {model_path}")
|
||||
try:
|
||||
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load {model_path}, please delete model and restart lama-cleaner.\n"
|
||||
f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n"
|
||||
f"If all above operations doesn't work, please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}"
|
||||
)
|
||||
exit(-1)
|
||||
handle_error(model_path, model_md5, e)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_model(model: torch.nn.Module, url_or_path, device):
|
||||
def load_model(model: torch.nn.Module, url_or_path, device, model_md5):
|
||||
if os.path.exists(url_or_path):
|
||||
model_path = url_or_path
|
||||
else:
|
||||
model_path = download_model(url_or_path)
|
||||
model_path = download_model(url_or_path, model_md5)
|
||||
|
||||
try:
|
||||
logger.info(f"Loading model from: {model_path}")
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
model.to(device)
|
||||
logger.info(f"Load model from: {model_path}")
|
||||
except:
|
||||
logger.error(
|
||||
f"Failed to load {model_path}, delete model and restart lama-cleaner"
|
||||
)
|
||||
exit(-1)
|
||||
except Exception as e:
|
||||
handle_error(model_path, model_md5, e)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user