make model switch work with toast

This commit is contained in:
Sanster
2022-04-17 23:31:12 +08:00
parent 205286a414
commit f7e1e073dc
18 changed files with 447 additions and 28 deletions

View File

@@ -10,7 +10,7 @@ from lama_cleaner.schema import Config
torch.manual_seed(42)
import torch.nn as nn
from tqdm import tqdm
from lama_cleaner.helper import download_model, norm_img
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
from lama_cleaner.model.utils import make_beta_schedule, make_ddim_timesteps, make_ddim_sampling_parameters, noise_like, \
timestep_embedding
@@ -266,6 +266,7 @@ class DDIMSampler(object):
def load_jit_model(url, device):
model_path = download_model(url)
logger.info(f"Load LDM model from: {model_path}")
model = torch.jit.load(model_path).to(device)
model.eval()
return model
@@ -286,6 +287,15 @@ class LDM(InpaintModel):
model = LatentDiffusion(self.diffusion_model, device)
self.sampler = DDIMSampler(model)
@staticmethod
def is_downloaded() -> bool:
model_paths = [
get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL),
get_cache_path_by_url(LDM_DECODE_MODEL_URL),
get_cache_path_by_url(LDM_ENCODE_MODEL_URL),
]
return all([os.path.exists(it) for it in model_paths])
def forward(self, image, mask, config: Config):
"""
image: [H, W, C] RGB