From b1ec157467a18a085cf152419ae5fc0b49ed7446 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 14 Feb 2023 09:08:56 +0800 Subject: [PATCH] better error handle --- lama_cleaner/helper.py | 16 ++++++++++------ lama_cleaner/model/lama.py | 22 ++++++---------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index 8097c0d..7dfaafd 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -14,9 +14,11 @@ from torch.hub import download_url_to_file, get_dir def switch_mps_device(model_name, device): - if model_name not in MPS_SUPPORT_MODELS and (device == "mps" or device == torch.device('mps')): + if model_name not in MPS_SUPPORT_MODELS and ( + device == "mps" or device == torch.device("mps") + ): logger.info(f"{model_name} not support mps, switch to cpu") - return torch.device('cpu') + return torch.device("cpu") return device @@ -51,12 +53,14 @@ def load_jit_model(url_or_path, device): model_path = url_or_path else: model_path = download_model(url_or_path) - logger.info(f"Load model from: {model_path}") + logger.info(f"Loading model from: {model_path}") try: - model = torch.jit.load(model_path).to(device) - except: + model = torch.jit.load(model_path, map_location="cpu").to(device) + except Exception as e: logger.error( - f"Failed to load {model_path}, delete model and restart lama-cleaner" + f"Failed to load {model_path}, please delete model and restart lama-cleaner.\n" + f"If you still have errors, please try download model manually first https://lama-cleaner-docs.vercel.app/install/download_model_manually.\n" + f"If all above operations doesn't work, please submit an issue at https://github.com/Sanster/lama-cleaner/issues and include a screenshot of the error:\n{e}" ) exit(-1) model.eval() diff --git a/lama_cleaner/model/lama.py b/lama_cleaner/model/lama.py index adaec35..68d6010 100644 --- a/lama_cleaner/model/lama.py +++ b/lama_cleaner/model/lama.py @@ -3,9 +3,12 @@ import os import cv2 import numpy as np import torch -from loguru import logger -from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url +from lama_cleaner.helper import ( + norm_img, + get_cache_path_by_url, + load_jit_model, +) from lama_cleaner.model.base import InpaintModel from lama_cleaner.schema import Config @@ -20,20 +23,7 @@ class LaMa(InpaintModel): pad_mod = 8 def init_model(self, device, **kwargs): - if os.environ.get("LAMA_MODEL"): - model_path = os.environ.get("LAMA_MODEL") - if not os.path.exists(model_path): - raise FileNotFoundError( - f"lama torchscript model not found: {model_path}" - ) - else: - model_path = download_model(LAMA_MODEL_URL) - logger.info(f"Load LaMa model from: {model_path}") - model = torch.jit.load(model_path, map_location="cpu") - model = model.to(device) - model.eval() - self.model = model - self.model_path = model_path + self.model = load_jit_model(LAMA_MODEL_URL, device).eval() @staticmethod def is_downloaded() -> bool: