add model md5 check

This commit is contained in:
Qing
2023-02-26 09:19:48 +08:00
parent 64336498ba
commit ecfecac050
9 changed files with 2002 additions and 933 deletions

View File

@@ -17,21 +17,33 @@ ZITS_INPAINT_MODEL_URL = os.environ.get(
"ZITS_INPAINT_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_zits/zits-inpaint-0717.pt",
)
ZITS_INPAINT_MODEL_MD5 = os.environ.get(
"ZITS_INPAINT_MODEL_MD5", "9978cc7157dc29699e42308d675b2154"
)
ZITS_EDGE_LINE_MODEL_URL = os.environ.get(
"ZITS_EDGE_LINE_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_zits/zits-edge-line-0717.pt",
)
ZITS_EDGE_LINE_MODEL_MD5 = os.environ.get(
"ZITS_EDGE_LINE_MODEL_MD5", "55e31af21ba96bbf0c80603c76ea8c5f"
)
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL = os.environ.get(
"ZITS_STRUCTURE_UPSAMPLE_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_zits/zits-structure-upsample-0717.pt",
)
ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5 = os.environ.get(
"ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5", "3d88a07211bd41b2ec8cc0d999f29927"
)
ZITS_WIRE_FRAME_MODEL_URL = os.environ.get(
"ZITS_WIRE_FRAME_MODEL_URL",
"https://github.com/Sanster/models/releases/download/add_zits/zits-wireframe-0717.pt",
)
ZITS_WIRE_FRAME_MODEL_MD5 = os.environ.get(
"ZITS_WIRE_FRAME_MODEL_MD5", "a9727c63a8b48b65c905d351b21ce46b"
)
def resize(img, height, width, center_crop=False):
@@ -219,12 +231,12 @@ class ZITS(InpaintModel):
self.sample_edge_line_iterations = 1
def init_model(self, device, **kwargs):
self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device)
self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device)
self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device, ZITS_WIRE_FRAME_MODEL_MD5)
self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device, ZITS_EDGE_LINE_MODEL_MD5)
self.structure_upsample = load_jit_model(
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device
ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device, ZITS_STRUCTURE_UPSAMPLE_MODEL_MD5
)
self.inpaint = load_jit_model(ZITS_INPAINT_MODEL_URL, device)
self.inpaint = load_jit_model(ZITS_INPAINT_MODEL_URL, device, ZITS_INPAINT_MODEL_MD5)
@staticmethod
def is_downloaded() -> bool: