make model switch work with toast
This commit is contained in:
@@ -10,7 +10,7 @@ from lama_cleaner.schema import Config
|
||||
torch.manual_seed(42)
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from lama_cleaner.helper import download_model, norm_img
|
||||
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
||||
from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
|
||||
timestep_embedding
|
||||
|
||||
@@ -266,6 +266,7 @@ class DDIMSampler(object):
|
||||
|
||||
def load_jit_model(url, device):
|
||||
model_path = download_model(url)
|
||||
logger.info(f"Load LDM model from: {model_path}")
|
||||
model = torch.jit.load(model_path).to(device)
|
||||
model.eval()
|
||||
return model
|
||||
@@ -286,6 +287,15 @@ class LDM(InpaintModel):
|
||||
model = LatentDiffusion(self.diffusion_model, device)
|
||||
self.sampler = DDIMSampler(model)
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
model_paths = [
|
||||
get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
|
||||
get_cache_path_by_url(LDM_DECODE_MODEL_URL),
|
||||
get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
|
||||
]
|
||||
return all([os.path.exists(it) for it in model_paths])
|
||||
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""
|
||||
image: [H, W, C] RGB
|
||||
|
||||
Reference in New Issue
Block a user