remove iopath

fix: https://github.com/Sanster/IOPaint/issues/593
This commit is contained in:
Qing
2024-11-04 09:23:57 +08:00
parent 32e7dc2584
commit 91e6556610

View File

@@ -11,7 +11,6 @@ from typing import List, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from iopath.common.file_io import g_pathmgr
from .utils import ( from .utils import (
PatchEmbed, PatchEmbed,
@@ -266,8 +265,7 @@ class Hiera(nn.Module):
) )
if weights_path is not None: if weights_path is not None:
with g_pathmgr.open(weights_path, "rb") as f: chkpt = torch.load(weights_path, map_location="cpu")
chkpt = torch.load(f, map_location="cpu")
logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False)) logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: