@@ -11,7 +11,6 @@ from typing import List, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from iopath.common.file_io import g_pathmgr
|
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
PatchEmbed,
|
PatchEmbed,
|
||||||
@@ -266,8 +265,7 @@ class Hiera(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if weights_path is not None:
|
if weights_path is not None:
|
||||||
with g_pathmgr.open(weights_path, "rb") as f:
|
chkpt = torch.load(weights_path, map_location="cpu")
|
||||||
chkpt = torch.load(f, map_location="cpu")
|
|
||||||
logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
|
logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
|
||||||
|
|
||||||
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
||||||
|
|||||||
Reference in New Issue
Block a user