add sam2.1
This commit is contained in:
@@ -129,7 +129,7 @@ def start(
|
||||
quality: int = Option(95, help=QUALITY_HELP),
|
||||
enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
|
||||
interactive_seg_model: InteractiveSegModel = Option(
|
||||
InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP
|
||||
InteractiveSegModel.sam2_1_tiny, help=INTERACTIVE_SEG_MODEL_HELP
|
||||
),
|
||||
interactive_seg_device: Device = Option(Device.cpu),
|
||||
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
|
||||
|
||||
@@ -59,6 +59,22 @@ SEGMENT_ANYTHING_MODELS = {
|
||||
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
|
||||
"md5": "08083462423be3260cd6a5eef94dc01c",
|
||||
},
|
||||
"sam2_1_tiny": {
|
||||
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
|
||||
"md5": "6aa6761c9da74fbaa74b4c790a0a2007",
|
||||
},
|
||||
"sam2_1_small": {
|
||||
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
|
||||
"md5": "51713b3d1994696d27f35f9c6de6f5ef",
|
||||
},
|
||||
"sam2_1_base": {
|
||||
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
|
||||
"md5": "ec7bd7d23d280d5e3cfa45984c02eda5",
|
||||
},
|
||||
"sam2_1_large": {
|
||||
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
|
||||
"md5": "2b30654b6112c42a115563c638d238d9",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -17,9 +17,6 @@ from .modeling.position_encoding import PositionEmbeddingSine
|
||||
from .modeling.sam.transformer import RoPEAttention
|
||||
from .modeling.sam2_base import SAM2Base
|
||||
|
||||
CURRENT_DIR = Path(__file__).parent
|
||||
CONFIG_DIR = CURRENT_DIR / "sam2_configs"
|
||||
|
||||
common_kwargs = dict(
|
||||
num_maskmem=7,
|
||||
image_size=1024,
|
||||
@@ -44,6 +41,33 @@ common_kwargs = dict(
|
||||
compile_image_encoder=False,
|
||||
)
|
||||
|
||||
common_kwargs_for_2_1 = dict(
|
||||
num_maskmem=7,
|
||||
image_size=1024,
|
||||
sigmoid_scale_for_mem_enc=20.0,
|
||||
sigmoid_bias_for_mem_enc=-10.0,
|
||||
use_mask_input_as_output_without_sam=True,
|
||||
directly_add_no_mem_embed=True,
|
||||
no_obj_embed_spatial=True,
|
||||
use_high_res_features_in_sam=True,
|
||||
multimask_output_in_sam=True,
|
||||
iou_prediction_use_sigmoid=True,
|
||||
use_obj_ptrs_in_encoder=True,
|
||||
add_tpos_enc_to_obj_ptrs=True,
|
||||
proj_tpos_enc_in_obj_ptrs=True,
|
||||
use_signed_tpos_enc_to_obj_ptrs=True,
|
||||
only_obj_ptrs_in_the_past_for_eval=True,
|
||||
pred_obj_scores=True,
|
||||
pred_obj_scores_mlp=True,
|
||||
fixed_no_obj_ptr=True,
|
||||
multimask_output_for_tracking=True,
|
||||
use_multimask_token_for_obj_ptr=True,
|
||||
multimask_min_pt_num=0,
|
||||
multimask_max_pt_num=1,
|
||||
use_mlp_for_obj_ptr_proj=True,
|
||||
compile_image_encoder=False,
|
||||
)
|
||||
|
||||
|
||||
def build_memory_attention():
|
||||
return MemoryAttention(
|
||||
@@ -103,32 +127,114 @@ def build_memory_encoder():
|
||||
)
|
||||
|
||||
|
||||
def build_image_encoder_tiny():
|
||||
return ImageEncoder(
|
||||
scalp=1,
|
||||
trunk=Hiera(
|
||||
embed_dim=96,
|
||||
num_heads=1,
|
||||
stages=(1, 2, 7, 2),
|
||||
global_att_blocks=(5, 7, 9),
|
||||
window_pos_embed_bkg_spatial_size=(7, 7),
|
||||
window_spec=(8, 4, 14, 7),
|
||||
),
|
||||
neck=FpnNeck(
|
||||
position_encoding=PositionEmbeddingSine(
|
||||
num_pos_feats=256,
|
||||
normalize=True,
|
||||
scale=None,
|
||||
temperature=10000,
|
||||
),
|
||||
d_model=256,
|
||||
backbone_channel_list=[768, 384, 192, 96],
|
||||
fpn_top_down_levels=[2, 3],
|
||||
fpn_interp_model="nearest",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def build_image_encoder_small():
|
||||
return ImageEncoder(
|
||||
scalp=1,
|
||||
trunk=Hiera(
|
||||
embed_dim=96,
|
||||
num_heads=1,
|
||||
stages=(1, 2, 11, 2),
|
||||
global_att_blocks=(7, 10, 13),
|
||||
window_pos_embed_bkg_spatial_size=(7, 7),
|
||||
window_spec=(8, 4, 14, 7),
|
||||
),
|
||||
neck=FpnNeck(
|
||||
position_encoding=PositionEmbeddingSine(
|
||||
num_pos_feats=256,
|
||||
normalize=True,
|
||||
scale=None,
|
||||
temperature=10000,
|
||||
),
|
||||
d_model=256,
|
||||
backbone_channel_list=[768, 384, 192, 96],
|
||||
fpn_top_down_levels=[2, 3],
|
||||
fpn_interp_model="nearest",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def build_image_encoder_base():
|
||||
return ImageEncoder(
|
||||
scalp=1,
|
||||
trunk=Hiera(
|
||||
embed_dim=112,
|
||||
num_heads=2,
|
||||
stages=(2, 3, 16, 3),
|
||||
global_att_blocks=(12, 16, 20),
|
||||
window_pos_embed_bkg_spatial_size=(14, 14),
|
||||
window_spec=(8, 4, 14, 7),
|
||||
),
|
||||
neck=FpnNeck(
|
||||
position_encoding=PositionEmbeddingSine(
|
||||
num_pos_feats=256,
|
||||
normalize=True,
|
||||
scale=None,
|
||||
temperature=10000,
|
||||
),
|
||||
d_model=256,
|
||||
backbone_channel_list=[896, 448, 224, 112],
|
||||
fpn_top_down_levels=[2, 3],
|
||||
fpn_interp_model="nearest",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def build_image_encoder_large():
|
||||
return ImageEncoder(
|
||||
scalp=1,
|
||||
trunk=Hiera(
|
||||
embed_dim=144,
|
||||
num_heads=2,
|
||||
stages=(2, 6, 36, 4),
|
||||
global_att_blocks=(23, 33, 43),
|
||||
window_pos_embed_bkg_spatial_size=(7, 7),
|
||||
window_spec=(8, 4, 16, 8),
|
||||
),
|
||||
neck=FpnNeck(
|
||||
position_encoding=PositionEmbeddingSine(
|
||||
num_pos_feats=256,
|
||||
normalize=True,
|
||||
scale=None,
|
||||
temperature=10000,
|
||||
),
|
||||
d_model=256,
|
||||
backbone_channel_list=[1152, 576, 288, 144],
|
||||
fpn_top_down_levels=[2, 3],
|
||||
fpn_interp_model="nearest",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def build_sam2_tiny():
|
||||
return SAM2Base(
|
||||
**common_kwargs,
|
||||
image_encoder=ImageEncoder(
|
||||
scalp=1,
|
||||
trunk=Hiera(
|
||||
embed_dim=96,
|
||||
num_heads=1,
|
||||
stages=(1, 2, 7, 2),
|
||||
global_att_blocks=(5, 7, 9),
|
||||
window_pos_embed_bkg_spatial_size=(7, 7),
|
||||
window_spec=(8, 4, 14, 7),
|
||||
),
|
||||
neck=FpnNeck(
|
||||
position_encoding=PositionEmbeddingSine(
|
||||
num_pos_feats=256,
|
||||
normalize=True,
|
||||
scale=None,
|
||||
temperature=10000,
|
||||
),
|
||||
d_model=256,
|
||||
backbone_channel_list=[768, 384, 192, 96],
|
||||
fpn_top_down_levels=[2, 3],
|
||||
fpn_interp_model="nearest",
|
||||
),
|
||||
),
|
||||
image_encoder=build_image_encoder_tiny(),
|
||||
memory_attention=build_memory_attention(),
|
||||
memory_encoder=build_memory_encoder(),
|
||||
)
|
||||
@@ -137,29 +243,7 @@ def build_sam2_tiny():
|
||||
def build_sam2_small():
|
||||
return SAM2Base(
|
||||
**common_kwargs,
|
||||
image_encoder=ImageEncoder(
|
||||
scalp=1,
|
||||
trunk=Hiera(
|
||||
embed_dim=96,
|
||||
num_heads=1,
|
||||
stages=(1, 2, 11, 2),
|
||||
global_att_blocks=(7, 10, 13),
|
||||
window_pos_embed_bkg_spatial_size=(7, 7),
|
||||
window_spec=(8, 4, 14, 7),
|
||||
),
|
||||
neck=FpnNeck(
|
||||
position_encoding=PositionEmbeddingSine(
|
||||
num_pos_feats=256,
|
||||
normalize=True,
|
||||
scale=None,
|
||||
temperature=10000,
|
||||
),
|
||||
d_model=256,
|
||||
backbone_channel_list=[768, 384, 192, 96],
|
||||
fpn_top_down_levels=[2, 3],
|
||||
fpn_interp_model="nearest",
|
||||
),
|
||||
),
|
||||
image_encoder=build_image_encoder_small(),
|
||||
memory_attention=build_memory_attention(),
|
||||
memory_encoder=build_memory_encoder(),
|
||||
)
|
||||
@@ -168,29 +252,7 @@ def build_sam2_small():
|
||||
def build_sam2_base():
|
||||
return SAM2Base(
|
||||
**common_kwargs,
|
||||
image_encoder=ImageEncoder(
|
||||
scalp=1,
|
||||
trunk=Hiera(
|
||||
embed_dim=112,
|
||||
num_heads=2,
|
||||
stages=(2, 3, 16, 3),
|
||||
global_att_blocks=(12, 16, 20),
|
||||
window_pos_embed_bkg_spatial_size=(14, 14),
|
||||
window_spec=(8, 4, 14, 7),
|
||||
),
|
||||
neck=FpnNeck(
|
||||
position_encoding=PositionEmbeddingSine(
|
||||
num_pos_feats=256,
|
||||
normalize=True,
|
||||
scale=None,
|
||||
temperature=10000,
|
||||
),
|
||||
d_model=256,
|
||||
backbone_channel_list=[896, 448, 224, 112],
|
||||
fpn_top_down_levels=[2, 3],
|
||||
fpn_interp_model="nearest",
|
||||
),
|
||||
),
|
||||
image_encoder=build_image_encoder_base(),
|
||||
memory_attention=build_memory_attention(),
|
||||
memory_encoder=build_memory_encoder(),
|
||||
)
|
||||
@@ -199,29 +261,43 @@ def build_sam2_base():
|
||||
def build_sam2_large():
|
||||
return SAM2Base(
|
||||
**common_kwargs,
|
||||
image_encoder=ImageEncoder(
|
||||
scalp=1,
|
||||
trunk=Hiera(
|
||||
embed_dim=144,
|
||||
num_heads=2,
|
||||
stages=(2, 6, 36, 4),
|
||||
global_att_blocks=(23, 33, 43),
|
||||
window_pos_embed_bkg_spatial_size=(7, 7),
|
||||
window_spec=(8, 4, 16, 8),
|
||||
),
|
||||
neck=FpnNeck(
|
||||
position_encoding=PositionEmbeddingSine(
|
||||
num_pos_feats=256,
|
||||
normalize=True,
|
||||
scale=None,
|
||||
temperature=10000,
|
||||
),
|
||||
d_model=256,
|
||||
backbone_channel_list=[1152, 576, 288, 144],
|
||||
fpn_top_down_levels=[2, 3],
|
||||
fpn_interp_model="nearest",
|
||||
),
|
||||
),
|
||||
image_encoder=build_image_encoder_large(),
|
||||
memory_attention=build_memory_attention(),
|
||||
memory_encoder=build_memory_encoder(),
|
||||
)
|
||||
|
||||
|
||||
def build_sam2_1_tiny():
|
||||
return SAM2Base(
|
||||
**common_kwargs_for_2_1,
|
||||
image_encoder=build_image_encoder_tiny(),
|
||||
memory_attention=build_memory_attention(),
|
||||
memory_encoder=build_memory_encoder(),
|
||||
)
|
||||
|
||||
|
||||
def build_sam2_1_small():
|
||||
return SAM2Base(
|
||||
**common_kwargs_for_2_1,
|
||||
image_encoder=build_image_encoder_small(),
|
||||
memory_attention=build_memory_attention(),
|
||||
memory_encoder=build_memory_encoder(),
|
||||
)
|
||||
|
||||
|
||||
def build_sam2_1_base():
|
||||
return SAM2Base(
|
||||
**common_kwargs_for_2_1,
|
||||
image_encoder=build_image_encoder_base(),
|
||||
memory_attention=build_memory_attention(),
|
||||
memory_encoder=build_memory_encoder(),
|
||||
)
|
||||
|
||||
|
||||
def build_sam2_1_large():
|
||||
return SAM2Base(
|
||||
**common_kwargs_for_2_1,
|
||||
image_encoder=build_image_encoder_large(),
|
||||
memory_attention=build_memory_attention(),
|
||||
memory_encoder=build_memory_encoder(),
|
||||
)
|
||||
@@ -232,6 +308,10 @@ sam2_model_registry = {
|
||||
"sam2_small": build_sam2_small,
|
||||
"sam2_base": build_sam2_base,
|
||||
"sam2_large": build_sam2_large,
|
||||
"sam2_1_tiny": build_sam2_1_tiny,
|
||||
"sam2_1_small": build_sam2_1_small,
|
||||
"sam2_1_base": build_sam2_1_base,
|
||||
"sam2_1_large": build_sam2_1_large,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -4,14 +4,16 @@
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
|
||||
from ..backbones.utils import (
|
||||
from .utils import (
|
||||
PatchEmbed,
|
||||
window_partition,
|
||||
window_unpartition,
|
||||
@@ -46,11 +48,7 @@ class MultiScaleAttention(nn.Module):
|
||||
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim_out // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_pool = q_pool
|
||||
self.qkv = nn.Linear(dim, dim_out * 3)
|
||||
self.proj = nn.Linear(dim_out, dim_out)
|
||||
@@ -197,6 +195,7 @@ class Hiera(nn.Module):
|
||||
16,
|
||||
20,
|
||||
),
|
||||
weights_path=None,
|
||||
return_interm_layers=True, # return feats from every stage
|
||||
):
|
||||
super().__init__()
|
||||
@@ -266,6 +265,11 @@ class Hiera(nn.Module):
|
||||
else [self.blocks[-1].dim_out]
|
||||
)
|
||||
|
||||
if weights_path is not None:
|
||||
with g_pathmgr.open(weights_path, "rb") as f:
|
||||
chkpt = torch.load(f, map_location="cpu")
|
||||
logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
|
||||
|
||||
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
||||
h, w = hw
|
||||
window_embed = self.pos_embed_window
|
||||
@@ -293,3 +297,21 @@ class Hiera(nn.Module):
|
||||
outputs.append(feats)
|
||||
|
||||
return outputs
|
||||
|
||||
def get_layer_id(self, layer_name):
|
||||
# https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
|
||||
num_layers = self.get_num_layers()
|
||||
|
||||
if layer_name.find("rel_pos") != -1:
|
||||
return num_layers + 1
|
||||
elif layer_name.find("pos_embed") != -1:
|
||||
return 0
|
||||
elif layer_name.find("patch_embed") != -1:
|
||||
return 0
|
||||
elif layer_name.find("blocks") != -1:
|
||||
return int(layer_name.split("blocks")[1].split(".")[1]) + 1
|
||||
else:
|
||||
return num_layers + 1
|
||||
|
||||
def get_num_layers(self) -> int:
|
||||
return len(self.blocks)
|
||||
|
||||
@@ -71,6 +71,7 @@ class FpnNeck(nn.Module):
|
||||
self.position_encoding = position_encoding
|
||||
self.convs = nn.ModuleList()
|
||||
self.backbone_channel_list = backbone_channel_list
|
||||
self.d_model = d_model
|
||||
for dim in backbone_channel_list:
|
||||
current = nn.Sequential()
|
||||
current.add_module(
|
||||
@@ -99,7 +100,6 @@ class FpnNeck(nn.Module):
|
||||
self.fpn_top_down_levels = list(fpn_top_down_levels)
|
||||
|
||||
def forward(self, xs: List[torch.Tensor]):
|
||||
|
||||
out = [None] * len(self.convs)
|
||||
pos = [None] * len(self.convs)
|
||||
assert len(xs) == len(self.convs)
|
||||
|
||||
@@ -16,7 +16,7 @@ from torch import nn
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
used by the Attention Is All You Need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -211,6 +211,11 @@ def apply_rotary_enc(
|
||||
# repeat freqs along seq_len dim to match k seq_len
|
||||
if repeat_freqs_k:
|
||||
r = xk_.shape[-2] // xq_.shape[-2]
|
||||
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
||||
if freqs_cis.is_cuda:
|
||||
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
||||
else:
|
||||
# torch.repeat on complex numbers may not be supported on non-CUDA devices
|
||||
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
|
||||
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
||||
|
||||
@@ -247,7 +247,7 @@ class MaskDecoder(nn.Module):
|
||||
def _get_stability_scores(self, mask_logits):
|
||||
"""
|
||||
Compute stability scores of the mask logits based on the IoU between upper and
|
||||
lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568.
|
||||
lower thresholds.
|
||||
"""
|
||||
mask_logits = mask_logits.flatten(-2)
|
||||
stability_delta = self.dynamic_multimask_stability_delta
|
||||
|
||||
@@ -60,9 +60,6 @@ class SAM2Base(torch.nn.Module):
|
||||
# For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
|
||||
# (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
|
||||
memory_temporal_stride_for_eval=1,
|
||||
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
||||
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
||||
add_all_frames_to_correct_as_cond=False,
|
||||
# whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
|
||||
non_overlap_masks_for_mem_enc=False,
|
||||
# whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
||||
@@ -74,6 +71,9 @@ class SAM2Base(torch.nn.Module):
|
||||
# whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
|
||||
# with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
|
||||
proj_tpos_enc_in_obj_ptrs=False,
|
||||
# whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers
|
||||
# (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
|
||||
use_signed_tpos_enc_to_obj_ptrs=False,
|
||||
# whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
|
||||
# (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
|
||||
only_obj_ptrs_in_the_past_for_eval=False,
|
||||
@@ -89,6 +89,8 @@ class SAM2Base(torch.nn.Module):
|
||||
# hope to make recovery easier if there is a mistake and mitigate accumulation of errors
|
||||
soft_no_obj_ptr: bool = False,
|
||||
use_mlp_for_obj_ptr_proj: bool = False,
|
||||
# add no obj embedding to spatial frames
|
||||
no_obj_embed_spatial: bool = False,
|
||||
# extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
|
||||
sam_mask_decoder_extra_args=None,
|
||||
compile_image_encoder: bool = False,
|
||||
@@ -111,12 +113,13 @@ class SAM2Base(torch.nn.Module):
|
||||
if proj_tpos_enc_in_obj_ptrs:
|
||||
assert add_tpos_enc_to_obj_ptrs # these options need to be used together
|
||||
self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
|
||||
self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
|
||||
self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
|
||||
|
||||
# Part 2: memory attention to condition current frame's visual features
|
||||
# with memories (and obj ptrs) from past frames
|
||||
self.memory_attention = memory_attention
|
||||
self.hidden_dim = memory_attention.d_model
|
||||
self.hidden_dim = image_encoder.neck.d_model
|
||||
|
||||
# Part 3: memory encoder for the previous frame's outputs
|
||||
self.memory_encoder = memory_encoder
|
||||
@@ -171,9 +174,12 @@ class SAM2Base(torch.nn.Module):
|
||||
self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
|
||||
trunc_normal_(self.no_obj_ptr, std=0.02)
|
||||
self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
|
||||
self.no_obj_embed_spatial = None
|
||||
if no_obj_embed_spatial:
|
||||
self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
|
||||
trunc_normal_(self.no_obj_embed_spatial, std=0.02)
|
||||
|
||||
self._build_sam_heads()
|
||||
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
||||
self.max_cond_frames_in_attn = max_cond_frames_in_attn
|
||||
|
||||
# Model compilation
|
||||
@@ -195,8 +201,8 @@ class SAM2Base(torch.nn.Module):
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Please use the corresponding methods in SAM2VideoPredictor for inference."
|
||||
"See notebooks/video_predictor_example.ipynb for an example."
|
||||
"Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning"
|
||||
"See notebooks/video_predictor_example.ipynb for an inference example."
|
||||
)
|
||||
|
||||
def _build_sam_heads(self):
|
||||
@@ -389,8 +395,6 @@ class SAM2Base(torch.nn.Module):
|
||||
if self.pred_obj_scores:
|
||||
# Allow *soft* no obj ptr, unlike for masks
|
||||
if self.soft_no_obj_ptr:
|
||||
# Only hard possible with gt
|
||||
assert not self.teacher_force_obj_scores_for_mem
|
||||
lambda_is_obj_appearing = object_score_logits.sigmoid()
|
||||
else:
|
||||
lambda_is_obj_appearing = is_obj_appearing.float()
|
||||
@@ -514,6 +518,7 @@ class SAM2Base(torch.nn.Module):
|
||||
return pix_feat
|
||||
|
||||
num_obj_ptr_tokens = 0
|
||||
tpos_sign_mul = -1 if track_in_reverse else 1
|
||||
# Step 1: condition the visual features of the current frame on previous memories
|
||||
if not is_init_cond_frame:
|
||||
# Retrieve the memories encoded with the maskmem backbone
|
||||
@@ -529,9 +534,9 @@ class SAM2Base(torch.nn.Module):
|
||||
t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
|
||||
# Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
|
||||
# the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
|
||||
# We also allow taking the memory frame non-consecutively (with r>1), in which case
|
||||
# we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
|
||||
r = self.memory_temporal_stride_for_eval
|
||||
# We also allow taking the memory frame non-consecutively (with stride>1), in which case
|
||||
# we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
|
||||
stride = 1 if self.training else self.memory_temporal_stride_for_eval
|
||||
for t_pos in range(1, self.num_maskmem):
|
||||
t_rel = self.num_maskmem - t_pos # how many frames before current frame
|
||||
if t_rel == 1:
|
||||
@@ -547,15 +552,15 @@ class SAM2Base(torch.nn.Module):
|
||||
if not track_in_reverse:
|
||||
# first find the nearest frame among every r-th frames before this frame
|
||||
# for r=1, this would be (frame_idx - 2)
|
||||
prev_frame_idx = ((frame_idx - 2) // r) * r
|
||||
prev_frame_idx = ((frame_idx - 2) // stride) * stride
|
||||
# then seek further among every r-th frames
|
||||
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
|
||||
prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride
|
||||
else:
|
||||
# first find the nearest frame among every r-th frames after this frame
|
||||
# for r=1, this would be (frame_idx + 2)
|
||||
prev_frame_idx = -(-(frame_idx + 2) // r) * r
|
||||
prev_frame_idx = -(-(frame_idx + 2) // stride) * stride
|
||||
# then seek further among every r-th frames
|
||||
prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
|
||||
prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride
|
||||
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
|
||||
if out is None:
|
||||
# If an unselected conditioning frame is among the last (self.num_maskmem - 1)
|
||||
@@ -568,10 +573,10 @@ class SAM2Base(torch.nn.Module):
|
||||
continue # skip padding frames
|
||||
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
||||
# so we load it back to GPU (it's a no-op if it's already on GPU).
|
||||
feats = prev["maskmem_features"].cuda(non_blocking=True)
|
||||
feats = prev["maskmem_features"].to(device, non_blocking=True)
|
||||
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
||||
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
|
||||
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
||||
# Temporal positional encoding
|
||||
maskmem_enc = (
|
||||
@@ -594,7 +599,14 @@ class SAM2Base(torch.nn.Module):
|
||||
ptr_cond_outputs = selected_cond_outputs
|
||||
pos_and_ptrs = [
|
||||
# Temporal pos encoding contains how far away each pointer is from current frame
|
||||
(abs(frame_idx - t), out["obj_ptr"])
|
||||
(
|
||||
(
|
||||
(frame_idx - t) * tpos_sign_mul
|
||||
if self.use_signed_tpos_enc_to_obj_ptrs
|
||||
else abs(frame_idx - t)
|
||||
),
|
||||
out["obj_ptr"],
|
||||
)
|
||||
for t, out in ptr_cond_outputs.items()
|
||||
]
|
||||
# Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
|
||||
@@ -643,7 +655,7 @@ class SAM2Base(torch.nn.Module):
|
||||
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
||||
return pix_feat_with_mem
|
||||
|
||||
# Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder)
|
||||
# Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
|
||||
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
|
||||
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
|
||||
|
||||
@@ -667,6 +679,7 @@ class SAM2Base(torch.nn.Module):
|
||||
current_vision_feats,
|
||||
feat_sizes,
|
||||
pred_masks_high_res,
|
||||
object_score_logits,
|
||||
is_mask_from_pts,
|
||||
):
|
||||
"""Encode the current image and its prediction into a memory feature."""
|
||||
@@ -701,9 +714,104 @@ class SAM2Base(torch.nn.Module):
|
||||
)
|
||||
maskmem_features = maskmem_out["vision_features"]
|
||||
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
||||
# add a no-object embedding to the spatial memory to indicate that the frame
|
||||
# is predicted to be occluded (i.e. no object is appearing in the frame)
|
||||
if self.no_obj_embed_spatial is not None:
|
||||
is_obj_appearing = (object_score_logits > 0).float()
|
||||
maskmem_features += (
|
||||
1 - is_obj_appearing[..., None, None]
|
||||
) * self.no_obj_embed_spatial[..., None, None].expand(
|
||||
*maskmem_features.shape
|
||||
)
|
||||
|
||||
return maskmem_features, maskmem_pos_enc
|
||||
|
||||
def _track_step(
|
||||
self,
|
||||
frame_idx,
|
||||
is_init_cond_frame,
|
||||
current_vision_feats,
|
||||
current_vision_pos_embeds,
|
||||
feat_sizes,
|
||||
point_inputs,
|
||||
mask_inputs,
|
||||
output_dict,
|
||||
num_frames,
|
||||
track_in_reverse,
|
||||
prev_sam_mask_logits,
|
||||
):
|
||||
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
|
||||
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
||||
if len(current_vision_feats) > 1:
|
||||
high_res_features = [
|
||||
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
|
||||
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
|
||||
]
|
||||
else:
|
||||
high_res_features = None
|
||||
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
|
||||
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
|
||||
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
||||
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
||||
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
||||
sam_outputs = self._use_mask_as_output(
|
||||
pix_feat, high_res_features, mask_inputs
|
||||
)
|
||||
else:
|
||||
# fused the visual feature with previous memory features in the memory bank
|
||||
pix_feat = self._prepare_memory_conditioned_features(
|
||||
frame_idx=frame_idx,
|
||||
is_init_cond_frame=is_init_cond_frame,
|
||||
current_vision_feats=current_vision_feats[-1:],
|
||||
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
|
||||
feat_sizes=feat_sizes[-1:],
|
||||
output_dict=output_dict,
|
||||
num_frames=num_frames,
|
||||
track_in_reverse=track_in_reverse,
|
||||
)
|
||||
# apply SAM-style segmentation head
|
||||
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
|
||||
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
|
||||
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
|
||||
if prev_sam_mask_logits is not None:
|
||||
assert point_inputs is not None and mask_inputs is None
|
||||
mask_inputs = prev_sam_mask_logits
|
||||
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
|
||||
sam_outputs = self._forward_sam_heads(
|
||||
backbone_features=pix_feat,
|
||||
point_inputs=point_inputs,
|
||||
mask_inputs=mask_inputs,
|
||||
high_res_features=high_res_features,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
|
||||
return current_out, sam_outputs, high_res_features, pix_feat
|
||||
|
||||
def _encode_memory_in_output(
|
||||
self,
|
||||
current_vision_feats,
|
||||
feat_sizes,
|
||||
point_inputs,
|
||||
run_mem_encoder,
|
||||
high_res_masks,
|
||||
object_score_logits,
|
||||
current_out,
|
||||
):
|
||||
if run_mem_encoder and self.num_maskmem > 0:
|
||||
high_res_masks_for_mem_enc = high_res_masks
|
||||
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
||||
current_vision_feats=current_vision_feats,
|
||||
feat_sizes=feat_sizes,
|
||||
pred_masks_high_res=high_res_masks_for_mem_enc,
|
||||
object_score_logits=object_score_logits,
|
||||
is_mask_from_pts=(point_inputs is not None),
|
||||
)
|
||||
current_out["maskmem_features"] = maskmem_features
|
||||
current_out["maskmem_pos_enc"] = maskmem_pos_enc
|
||||
else:
|
||||
current_out["maskmem_features"] = None
|
||||
current_out["maskmem_pos_enc"] = None
|
||||
|
||||
def track_step(
|
||||
self,
|
||||
frame_idx,
|
||||
@@ -725,50 +833,20 @@ class SAM2Base(torch.nn.Module):
|
||||
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
|
||||
prev_sam_mask_logits=None,
|
||||
):
|
||||
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
|
||||
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
||||
if len(current_vision_feats) > 1:
|
||||
high_res_features = [
|
||||
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
|
||||
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
|
||||
]
|
||||
else:
|
||||
high_res_features = None
|
||||
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
|
||||
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
|
||||
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
||||
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
||||
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
||||
sam_outputs = self._use_mask_as_output(
|
||||
pix_feat, high_res_features, mask_inputs
|
||||
)
|
||||
else:
|
||||
# fused the visual feature with previous memory features in the memory bank
|
||||
pix_feat_with_mem = self._prepare_memory_conditioned_features(
|
||||
frame_idx=frame_idx,
|
||||
is_init_cond_frame=is_init_cond_frame,
|
||||
current_vision_feats=current_vision_feats[-1:],
|
||||
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
|
||||
feat_sizes=feat_sizes[-1:],
|
||||
output_dict=output_dict,
|
||||
num_frames=num_frames,
|
||||
track_in_reverse=track_in_reverse,
|
||||
)
|
||||
# apply SAM-style segmentation head
|
||||
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
|
||||
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
|
||||
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
|
||||
if prev_sam_mask_logits is not None:
|
||||
assert point_inputs is not None and mask_inputs is None
|
||||
mask_inputs = prev_sam_mask_logits
|
||||
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
|
||||
sam_outputs = self._forward_sam_heads(
|
||||
backbone_features=pix_feat_with_mem,
|
||||
point_inputs=point_inputs,
|
||||
mask_inputs=mask_inputs,
|
||||
high_res_features=high_res_features,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
current_out, sam_outputs, _, _ = self._track_step(
|
||||
frame_idx,
|
||||
is_init_cond_frame,
|
||||
current_vision_feats,
|
||||
current_vision_pos_embeds,
|
||||
feat_sizes,
|
||||
point_inputs,
|
||||
mask_inputs,
|
||||
output_dict,
|
||||
num_frames,
|
||||
track_in_reverse,
|
||||
prev_sam_mask_logits,
|
||||
)
|
||||
|
||||
(
|
||||
_,
|
||||
_,
|
||||
@@ -776,28 +854,28 @@ class SAM2Base(torch.nn.Module):
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
obj_ptr,
|
||||
_,
|
||||
object_score_logits,
|
||||
) = sam_outputs
|
||||
|
||||
current_out["pred_masks"] = low_res_masks
|
||||
current_out["pred_masks_high_res"] = high_res_masks
|
||||
current_out["obj_ptr"] = obj_ptr
|
||||
if not self.training:
|
||||
# Only add this in inference (to avoid unused param in activation checkpointing;
|
||||
# it's mainly used in the demo to encode spatial memories w/ consolidated masks)
|
||||
current_out["object_score_logits"] = object_score_logits
|
||||
|
||||
# Finally run the memory encoder on the predicted mask to encode
|
||||
# it into a new memory feature (that can be used in future frames)
|
||||
if run_mem_encoder and self.num_maskmem > 0:
|
||||
high_res_masks_for_mem_enc = high_res_masks
|
||||
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
||||
current_vision_feats=current_vision_feats,
|
||||
feat_sizes=feat_sizes,
|
||||
pred_masks_high_res=high_res_masks_for_mem_enc,
|
||||
is_mask_from_pts=(point_inputs is not None),
|
||||
)
|
||||
current_out["maskmem_features"] = maskmem_features
|
||||
current_out["maskmem_pos_enc"] = maskmem_pos_enc
|
||||
else:
|
||||
current_out["maskmem_features"] = None
|
||||
current_out["maskmem_pos_enc"] = None
|
||||
self._encode_memory_in_output(
|
||||
current_vision_feats,
|
||||
feat_sizes,
|
||||
point_inputs,
|
||||
run_mem_encoder,
|
||||
high_res_masks,
|
||||
object_score_logits,
|
||||
current_out,
|
||||
)
|
||||
|
||||
return current_out
|
||||
|
||||
|
||||
@@ -6,11 +6,15 @@
|
||||
|
||||
|
||||
import copy
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils.misc import mask_to_box
|
||||
|
||||
|
||||
def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
|
||||
"""
|
||||
@@ -147,3 +151,173 @@ class LayerNorm2d(nn.Module):
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
return x
|
||||
|
||||
|
||||
def sample_box_points(
|
||||
masks: torch.Tensor,
|
||||
noise: float = 0.1, # SAM default
|
||||
noise_bound: int = 20, # SAM default
|
||||
top_left_label: int = 2,
|
||||
bottom_right_label: int = 3,
|
||||
) -> Tuple[np.array, np.array]:
|
||||
"""
|
||||
Sample a noised version of the top left and bottom right corners of a given `bbox`
|
||||
|
||||
Inputs:
|
||||
- masks: [B, 1, H,W] boxes, dtype=torch.Tensor
|
||||
- noise: noise as a fraction of box width and height, dtype=float
|
||||
- noise_bound: maximum amount of noise (in pure pixesl), dtype=int
|
||||
|
||||
Returns:
|
||||
- box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float
|
||||
- box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32
|
||||
"""
|
||||
device = masks.device
|
||||
box_coords = mask_to_box(masks)
|
||||
B, _, H, W = masks.shape
|
||||
box_labels = torch.tensor(
|
||||
[top_left_label, bottom_right_label], dtype=torch.int, device=device
|
||||
).repeat(B)
|
||||
if noise > 0.0:
|
||||
if not isinstance(noise_bound, torch.Tensor):
|
||||
noise_bound = torch.tensor(noise_bound, device=device)
|
||||
bbox_w = box_coords[..., 2] - box_coords[..., 0]
|
||||
bbox_h = box_coords[..., 3] - box_coords[..., 1]
|
||||
max_dx = torch.min(bbox_w * noise, noise_bound)
|
||||
max_dy = torch.min(bbox_h * noise, noise_bound)
|
||||
box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1
|
||||
box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
|
||||
|
||||
box_coords = box_coords + box_noise
|
||||
img_bounds = (
|
||||
torch.tensor([W, H, W, H], device=device) - 1
|
||||
) # uncentered pixel coords
|
||||
box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
|
||||
|
||||
box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
|
||||
box_labels = box_labels.reshape(-1, 2)
|
||||
return box_coords, box_labels
|
||||
|
||||
|
||||
def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
|
||||
"""
|
||||
Sample `num_pt` random points (along with their labels) independently from the error regions.
|
||||
|
||||
Inputs:
|
||||
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
|
||||
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
|
||||
- num_pt: int, number of points to sample independently for each of the B error maps
|
||||
|
||||
Outputs:
|
||||
- points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
|
||||
- labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means
|
||||
negative clicks
|
||||
"""
|
||||
if pred_masks is None: # if pred_masks is not provided, treat it as empty
|
||||
pred_masks = torch.zeros_like(gt_masks)
|
||||
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
|
||||
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
|
||||
assert num_pt >= 0
|
||||
|
||||
B, _, H_im, W_im = gt_masks.shape
|
||||
device = gt_masks.device
|
||||
|
||||
# false positive region, a new point sampled in this region should have
|
||||
# negative label to correct the FP error
|
||||
fp_masks = ~gt_masks & pred_masks
|
||||
# false negative region, a new point sampled in this region should have
|
||||
# positive label to correct the FN error
|
||||
fn_masks = gt_masks & ~pred_masks
|
||||
# whether the prediction completely match the ground-truth on each mask
|
||||
all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2)
|
||||
all_correct = all_correct[..., None, None]
|
||||
|
||||
# channel 0 is FP map, while channel 1 is FN map
|
||||
pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device)
|
||||
# sample a negative new click from FP region or a positive new click
|
||||
# from FN region, depend on where the maximum falls,
|
||||
# and in case the predictions are all correct (no FP or FN), we just
|
||||
# sample a negative click from the background region
|
||||
pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks)
|
||||
pts_noise[..., 1] *= fn_masks
|
||||
pts_idx = pts_noise.flatten(2).argmax(dim=2)
|
||||
labels = (pts_idx % 2).to(torch.int32)
|
||||
pts_idx = pts_idx // 2
|
||||
pts_x = pts_idx % W_im
|
||||
pts_y = pts_idx // W_im
|
||||
points = torch.stack([pts_x, pts_y], dim=2).to(torch.float)
|
||||
return points, labels
|
||||
|
||||
|
||||
def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
|
||||
"""
|
||||
Sample 1 random point (along with its label) from the center of each error region,
|
||||
that is, the point with the largest distance to the boundary of each error region.
|
||||
This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
|
||||
|
||||
Inputs:
|
||||
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
|
||||
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
|
||||
- padding: if True, pad with boundary of 1 px for distance transform
|
||||
|
||||
Outputs:
|
||||
- points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
|
||||
- labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
|
||||
"""
|
||||
import cv2
|
||||
|
||||
if pred_masks is None:
|
||||
pred_masks = torch.zeros_like(gt_masks)
|
||||
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
|
||||
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
|
||||
|
||||
B, _, _, W_im = gt_masks.shape
|
||||
device = gt_masks.device
|
||||
|
||||
# false positive region, a new point sampled in this region should have
|
||||
# negative label to correct the FP error
|
||||
fp_masks = ~gt_masks & pred_masks
|
||||
# false negative region, a new point sampled in this region should have
|
||||
# positive label to correct the FN error
|
||||
fn_masks = gt_masks & ~pred_masks
|
||||
|
||||
fp_masks = fp_masks.cpu().numpy()
|
||||
fn_masks = fn_masks.cpu().numpy()
|
||||
points = torch.zeros(B, 1, 2, dtype=torch.float)
|
||||
labels = torch.ones(B, 1, dtype=torch.int32)
|
||||
for b in range(B):
|
||||
fn_mask = fn_masks[b, 0]
|
||||
fp_mask = fp_masks[b, 0]
|
||||
if padding:
|
||||
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
|
||||
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
|
||||
# compute the distance of each point in FN/FP region to its boundary
|
||||
fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
|
||||
fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
|
||||
if padding:
|
||||
fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
|
||||
fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
|
||||
|
||||
# take the point in FN/FP region with the largest distance to its boundary
|
||||
fn_mask_dt_flat = fn_mask_dt.reshape(-1)
|
||||
fp_mask_dt_flat = fp_mask_dt.reshape(-1)
|
||||
fn_argmax = np.argmax(fn_mask_dt_flat)
|
||||
fp_argmax = np.argmax(fp_mask_dt_flat)
|
||||
is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
|
||||
pt_idx = fn_argmax if is_positive else fp_argmax
|
||||
points[b, 0, 0] = pt_idx % W_im # x
|
||||
points[b, 0, 1] = pt_idx // W_im # y
|
||||
labels[b, 0] = int(is_positive)
|
||||
|
||||
points = points.to(device)
|
||||
labels = labels.to(device)
|
||||
return points, labels
|
||||
|
||||
|
||||
def get_next_point(gt_masks, pred_masks, method):
|
||||
if method == "uniform":
|
||||
return sample_random_points_from_errors(gt_masks, pred_masks)
|
||||
elif method == "center":
|
||||
return sample_one_point_from_error_center(gt_masks, pred_masks)
|
||||
else:
|
||||
raise ValueError(f"unknown sampling method {method}")
|
||||
|
||||
@@ -175,11 +175,17 @@ class InteractiveSegModel(Choices):
|
||||
sam_hq_vit_l = "sam_hq_vit_l"
|
||||
sam_hq_vit_h = "sam_hq_vit_h"
|
||||
mobile_sam = "mobile_sam"
|
||||
|
||||
sam2_tiny = "sam2_tiny"
|
||||
sam2_small = "sam2_small"
|
||||
sam2_base = "sam2_base"
|
||||
sam2_large = "sam2_large"
|
||||
|
||||
sam2_1_tiny = "sam2_1_tiny"
|
||||
sam2_1_small = "sam2_1_small"
|
||||
sam2_1_base = "sam2_1_base"
|
||||
sam2_1_large = "sam2_1_large"
|
||||
|
||||
|
||||
class PluginInfo(BaseModel):
|
||||
name: str
|
||||
|
||||
@@ -48,7 +48,7 @@ default_configs = dict(
|
||||
output_dir=None,
|
||||
quality=95,
|
||||
enable_interactive_seg=False,
|
||||
interactive_seg_model=InteractiveSegModel.vit_b,
|
||||
interactive_seg_model=InteractiveSegModel.sam2_1_tiny,
|
||||
interactive_seg_device=Device.cpu,
|
||||
enable_remove_bg=False,
|
||||
remove_bg_model=RemoveBGModel.briaai_rmbg_1_4,
|
||||
|
||||
Reference in New Issue
Block a user