add model md5 check

This commit is contained in:
Qing
2023-02-26 09:19:48 +08:00
parent 64336498ba
commit ecfecac050
9 changed files with 2002 additions and 933 deletions

View File

@@ -26,17 +26,27 @@ LDM_ENCODE_MODEL_URL = os.environ.get(
"LDM_ENCODE_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt",
)
LDM_ENCODE_MODEL_MD5 = os.environ.get(
"LDM_ENCODE_MODEL_MD5", "23239fc9081956a3e70de56472b3f296"
)
LDM_DECODE_MODEL_URL = os.environ.get(
"LDM_DECODE_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt",
)
LDM_DECODE_MODEL_MD5 = os.environ.get(
"LDM_DECODE_MODEL_MD5", "fe419cd15a750d37a4733589d0d3585c"
)
LDM_DIFFUSION_MODEL_URL = os.environ.get(
"LDM_DIFFUSION_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt",
)
LDM_DIFFUSION_MODEL_MD5 = os.environ.get(
"LDM_DIFFUSION_MODEL_MD5", "b0afda12bf790c03aba2a7431f11d22d"
)
class DDPM(nn.Module):
# classic DDPM with Gaussian diffusion, in image space
@@ -234,9 +244,15 @@ class LDM(InpaintModel):
self.device = device
def init_model(self, device, **kwargs):
self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device)
self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device)
self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device)
self.diffusion_model = load_jit_model(
LDM_DIFFUSION_MODEL_URL, device, LDM_DIFFUSION_MODEL_MD5
)
self.cond_stage_model_decode = load_jit_model(
LDM_DECODE_MODEL_URL, device, LDM_DECODE_MODEL_MD5
)
self.cond_stage_model_encode = load_jit_model(
LDM_ENCODE_MODEL_URL, device, LDM_ENCODE_MODEL_MD5
)
if self.fp16 and "cuda" in str(device):
self.diffusion_model = self.diffusion_model.half()
self.cond_stage_model_decode = self.cond_stage_model_decode.half()