better error handle
This commit is contained in:
@@ -3,9 +3,12 @@ import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
||||
from lama_cleaner.helper import (
|
||||
norm_img,
|
||||
get_cache_path_by_url,
|
||||
load_jit_model,
|
||||
)
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
@@ -20,20 +23,7 @@ class LaMa(InpaintModel):
|
||||
pad_mod = 8
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
if os.environ.get("LAMA_MODEL"):
|
||||
model_path = os.environ.get("LAMA_MODEL")
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(
|
||||
f"lama torchscript model not found: {model_path}"
|
||||
)
|
||||
else:
|
||||
model_path = download_model(LAMA_MODEL_URL)
|
||||
logger.info(f"Load LaMa model from: {model_path}")
|
||||
model = torch.jit.load(model_path, map_location="cpu")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
self.model = model
|
||||
self.model_path = model_path
|
||||
self.model = load_jit_model(LAMA_MODEL_URL, device).eval()
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
|
||||
Reference in New Issue
Block a user