add sam2.1

This commit is contained in:
Qing
2024-11-01 18:45:32 +08:00
parent 2098d687d1
commit ec08ce3063
11 changed files with 565 additions and 184 deletions

View File

@@ -129,7 +129,7 @@ def start(
quality: int = Option(95, help=QUALITY_HELP), quality: int = Option(95, help=QUALITY_HELP),
enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP), enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
interactive_seg_model: InteractiveSegModel = Option( interactive_seg_model: InteractiveSegModel = Option(
InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP InteractiveSegModel.sam2_1_tiny, help=INTERACTIVE_SEG_MODEL_HELP
), ),
interactive_seg_device: Device = Option(Device.cpu), interactive_seg_device: Device = Option(Device.cpu),
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP), enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),

View File

@@ -59,6 +59,22 @@ SEGMENT_ANYTHING_MODELS = {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt", "url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
"md5": "08083462423be3260cd6a5eef94dc01c", "md5": "08083462423be3260cd6a5eef94dc01c",
}, },
"sam2_1_tiny": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
"md5": "6aa6761c9da74fbaa74b4c790a0a2007",
},
"sam2_1_small": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
"md5": "51713b3d1994696d27f35f9c6de6f5ef",
},
"sam2_1_base": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
"md5": "ec7bd7d23d280d5e3cfa45984c02eda5",
},
"sam2_1_large": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
"md5": "2b30654b6112c42a115563c638d238d9",
},
} }

View File

@@ -17,9 +17,6 @@ from .modeling.position_encoding import PositionEmbeddingSine
from .modeling.sam.transformer import RoPEAttention from .modeling.sam.transformer import RoPEAttention
from .modeling.sam2_base import SAM2Base from .modeling.sam2_base import SAM2Base
CURRENT_DIR = Path(__file__).parent
CONFIG_DIR = CURRENT_DIR / "sam2_configs"
common_kwargs = dict( common_kwargs = dict(
num_maskmem=7, num_maskmem=7,
image_size=1024, image_size=1024,
@@ -44,6 +41,33 @@ common_kwargs = dict(
compile_image_encoder=False, compile_image_encoder=False,
) )
common_kwargs_for_2_1 = dict(
num_maskmem=7,
image_size=1024,
sigmoid_scale_for_mem_enc=20.0,
sigmoid_bias_for_mem_enc=-10.0,
use_mask_input_as_output_without_sam=True,
directly_add_no_mem_embed=True,
no_obj_embed_spatial=True,
use_high_res_features_in_sam=True,
multimask_output_in_sam=True,
iou_prediction_use_sigmoid=True,
use_obj_ptrs_in_encoder=True,
add_tpos_enc_to_obj_ptrs=True,
proj_tpos_enc_in_obj_ptrs=True,
use_signed_tpos_enc_to_obj_ptrs=True,
only_obj_ptrs_in_the_past_for_eval=True,
pred_obj_scores=True,
pred_obj_scores_mlp=True,
fixed_no_obj_ptr=True,
multimask_output_for_tracking=True,
use_multimask_token_for_obj_ptr=True,
multimask_min_pt_num=0,
multimask_max_pt_num=1,
use_mlp_for_obj_ptr_proj=True,
compile_image_encoder=False,
)
def build_memory_attention(): def build_memory_attention():
return MemoryAttention( return MemoryAttention(
@@ -103,10 +127,8 @@ def build_memory_encoder():
) )
def build_sam2_tiny(): def build_image_encoder_tiny():
return SAM2Base( return ImageEncoder(
**common_kwargs,
image_encoder=ImageEncoder(
scalp=1, scalp=1,
trunk=Hiera( trunk=Hiera(
embed_dim=96, embed_dim=96,
@@ -128,16 +150,11 @@ def build_sam2_tiny():
fpn_top_down_levels=[2, 3], fpn_top_down_levels=[2, 3],
fpn_interp_model="nearest", fpn_interp_model="nearest",
), ),
),
memory_attention=build_memory_attention(),
memory_encoder=build_memory_encoder(),
) )
def build_sam2_small(): def build_image_encoder_small():
return SAM2Base( return ImageEncoder(
**common_kwargs,
image_encoder=ImageEncoder(
scalp=1, scalp=1,
trunk=Hiera( trunk=Hiera(
embed_dim=96, embed_dim=96,
@@ -159,16 +176,11 @@ def build_sam2_small():
fpn_top_down_levels=[2, 3], fpn_top_down_levels=[2, 3],
fpn_interp_model="nearest", fpn_interp_model="nearest",
), ),
),
memory_attention=build_memory_attention(),
memory_encoder=build_memory_encoder(),
) )
def build_sam2_base(): def build_image_encoder_base():
return SAM2Base( return ImageEncoder(
**common_kwargs,
image_encoder=ImageEncoder(
scalp=1, scalp=1,
trunk=Hiera( trunk=Hiera(
embed_dim=112, embed_dim=112,
@@ -190,16 +202,11 @@ def build_sam2_base():
fpn_top_down_levels=[2, 3], fpn_top_down_levels=[2, 3],
fpn_interp_model="nearest", fpn_interp_model="nearest",
), ),
),
memory_attention=build_memory_attention(),
memory_encoder=build_memory_encoder(),
) )
def build_sam2_large(): def build_image_encoder_large():
return SAM2Base( return ImageEncoder(
**common_kwargs,
image_encoder=ImageEncoder(
scalp=1, scalp=1,
trunk=Hiera( trunk=Hiera(
embed_dim=144, embed_dim=144,
@@ -221,7 +228,76 @@ def build_sam2_large():
fpn_top_down_levels=[2, 3], fpn_top_down_levels=[2, 3],
fpn_interp_model="nearest", fpn_interp_model="nearest",
), ),
), )
def build_sam2_tiny():
return SAM2Base(
**common_kwargs,
image_encoder=build_image_encoder_tiny(),
memory_attention=build_memory_attention(),
memory_encoder=build_memory_encoder(),
)
def build_sam2_small():
return SAM2Base(
**common_kwargs,
image_encoder=build_image_encoder_small(),
memory_attention=build_memory_attention(),
memory_encoder=build_memory_encoder(),
)
def build_sam2_base():
return SAM2Base(
**common_kwargs,
image_encoder=build_image_encoder_base(),
memory_attention=build_memory_attention(),
memory_encoder=build_memory_encoder(),
)
def build_sam2_large():
return SAM2Base(
**common_kwargs,
image_encoder=build_image_encoder_large(),
memory_attention=build_memory_attention(),
memory_encoder=build_memory_encoder(),
)
def build_sam2_1_tiny():
return SAM2Base(
**common_kwargs_for_2_1,
image_encoder=build_image_encoder_tiny(),
memory_attention=build_memory_attention(),
memory_encoder=build_memory_encoder(),
)
def build_sam2_1_small():
return SAM2Base(
**common_kwargs_for_2_1,
image_encoder=build_image_encoder_small(),
memory_attention=build_memory_attention(),
memory_encoder=build_memory_encoder(),
)
def build_sam2_1_base():
return SAM2Base(
**common_kwargs_for_2_1,
image_encoder=build_image_encoder_base(),
memory_attention=build_memory_attention(),
memory_encoder=build_memory_encoder(),
)
def build_sam2_1_large():
return SAM2Base(
**common_kwargs_for_2_1,
image_encoder=build_image_encoder_large(),
memory_attention=build_memory_attention(), memory_attention=build_memory_attention(),
memory_encoder=build_memory_encoder(), memory_encoder=build_memory_encoder(),
) )
@@ -232,6 +308,10 @@ sam2_model_registry = {
"sam2_small": build_sam2_small, "sam2_small": build_sam2_small,
"sam2_base": build_sam2_base, "sam2_base": build_sam2_base,
"sam2_large": build_sam2_large, "sam2_large": build_sam2_large,
"sam2_1_tiny": build_sam2_1_tiny,
"sam2_1_small": build_sam2_1_small,
"sam2_1_base": build_sam2_1_base,
"sam2_1_large": build_sam2_1_large,
} }

View File

@@ -4,14 +4,16 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging
from functools import partial from functools import partial
from typing import List, Tuple, Union from typing import List, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from iopath.common.file_io import g_pathmgr
from ..backbones.utils import ( from .utils import (
PatchEmbed, PatchEmbed,
window_partition, window_partition,
window_unpartition, window_unpartition,
@@ -46,11 +48,7 @@ class MultiScaleAttention(nn.Module):
self.dim = dim self.dim = dim
self.dim_out = dim_out self.dim_out = dim_out
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim_out // num_heads
self.scale = head_dim**-0.5
self.q_pool = q_pool self.q_pool = q_pool
self.qkv = nn.Linear(dim, dim_out * 3) self.qkv = nn.Linear(dim, dim_out * 3)
self.proj = nn.Linear(dim_out, dim_out) self.proj = nn.Linear(dim_out, dim_out)
@@ -197,6 +195,7 @@ class Hiera(nn.Module):
16, 16,
20, 20,
), ),
weights_path=None,
return_interm_layers=True, # return feats from every stage return_interm_layers=True, # return feats from every stage
): ):
super().__init__() super().__init__()
@@ -266,6 +265,11 @@ class Hiera(nn.Module):
else [self.blocks[-1].dim_out] else [self.blocks[-1].dim_out]
) )
if weights_path is not None:
with g_pathmgr.open(weights_path, "rb") as f:
chkpt = torch.load(f, map_location="cpu")
logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
h, w = hw h, w = hw
window_embed = self.pos_embed_window window_embed = self.pos_embed_window
@@ -293,3 +297,21 @@ class Hiera(nn.Module):
outputs.append(feats) outputs.append(feats)
return outputs return outputs
def get_layer_id(self, layer_name):
# https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
num_layers = self.get_num_layers()
if layer_name.find("rel_pos") != -1:
return num_layers + 1
elif layer_name.find("pos_embed") != -1:
return 0
elif layer_name.find("patch_embed") != -1:
return 0
elif layer_name.find("blocks") != -1:
return int(layer_name.split("blocks")[1].split(".")[1]) + 1
else:
return num_layers + 1
def get_num_layers(self) -> int:
return len(self.blocks)

View File

@@ -71,6 +71,7 @@ class FpnNeck(nn.Module):
self.position_encoding = position_encoding self.position_encoding = position_encoding
self.convs = nn.ModuleList() self.convs = nn.ModuleList()
self.backbone_channel_list = backbone_channel_list self.backbone_channel_list = backbone_channel_list
self.d_model = d_model
for dim in backbone_channel_list: for dim in backbone_channel_list:
current = nn.Sequential() current = nn.Sequential()
current.add_module( current.add_module(
@@ -99,7 +100,6 @@ class FpnNeck(nn.Module):
self.fpn_top_down_levels = list(fpn_top_down_levels) self.fpn_top_down_levels = list(fpn_top_down_levels)
def forward(self, xs: List[torch.Tensor]): def forward(self, xs: List[torch.Tensor]):
out = [None] * len(self.convs) out = [None] * len(self.convs)
pos = [None] * len(self.convs) pos = [None] * len(self.convs)
assert len(xs) == len(self.convs) assert len(xs) == len(self.convs)

View File

@@ -16,7 +16,7 @@ from torch import nn
class PositionEmbeddingSine(nn.Module): class PositionEmbeddingSine(nn.Module):
""" """
This is a more standard version of the position embedding, very similar to the one This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images. used by the Attention Is All You Need paper, generalized to work on images.
""" """
def __init__( def __init__(
@@ -211,6 +211,11 @@ def apply_rotary_enc(
# repeat freqs along seq_len dim to match k seq_len # repeat freqs along seq_len dim to match k seq_len
if repeat_freqs_k: if repeat_freqs_k:
r = xk_.shape[-2] // xq_.shape[-2] r = xk_.shape[-2] // xq_.shape[-2]
if freqs_cis.is_cuda:
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
else:
# torch.repeat on complex numbers may not be supported on non-CUDA devices
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)

View File

@@ -247,7 +247,7 @@ class MaskDecoder(nn.Module):
def _get_stability_scores(self, mask_logits): def _get_stability_scores(self, mask_logits):
""" """
Compute stability scores of the mask logits based on the IoU between upper and Compute stability scores of the mask logits based on the IoU between upper and
lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568. lower thresholds.
""" """
mask_logits = mask_logits.flatten(-2) mask_logits = mask_logits.flatten(-2)
stability_delta = self.dynamic_multimask_stability_delta stability_delta = self.dynamic_multimask_stability_delta

View File

@@ -60,9 +60,6 @@ class SAM2Base(torch.nn.Module):
# For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
# (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
memory_temporal_stride_for_eval=1, memory_temporal_stride_for_eval=1,
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
add_all_frames_to_correct_as_cond=False,
# whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
non_overlap_masks_for_mem_enc=False, non_overlap_masks_for_mem_enc=False,
# whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
@@ -74,6 +71,9 @@ class SAM2Base(torch.nn.Module):
# whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
# with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
proj_tpos_enc_in_obj_ptrs=False, proj_tpos_enc_in_obj_ptrs=False,
# whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers
# (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
use_signed_tpos_enc_to_obj_ptrs=False,
# whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
# (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
only_obj_ptrs_in_the_past_for_eval=False, only_obj_ptrs_in_the_past_for_eval=False,
@@ -89,6 +89,8 @@ class SAM2Base(torch.nn.Module):
# hope to make recovery easier if there is a mistake and mitigate accumulation of errors # hope to make recovery easier if there is a mistake and mitigate accumulation of errors
soft_no_obj_ptr: bool = False, soft_no_obj_ptr: bool = False,
use_mlp_for_obj_ptr_proj: bool = False, use_mlp_for_obj_ptr_proj: bool = False,
# add no obj embedding to spatial frames
no_obj_embed_spatial: bool = False,
# extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
sam_mask_decoder_extra_args=None, sam_mask_decoder_extra_args=None,
compile_image_encoder: bool = False, compile_image_encoder: bool = False,
@@ -111,12 +113,13 @@ class SAM2Base(torch.nn.Module):
if proj_tpos_enc_in_obj_ptrs: if proj_tpos_enc_in_obj_ptrs:
assert add_tpos_enc_to_obj_ptrs # these options need to be used together assert add_tpos_enc_to_obj_ptrs # these options need to be used together
self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
# Part 2: memory attention to condition current frame's visual features # Part 2: memory attention to condition current frame's visual features
# with memories (and obj ptrs) from past frames # with memories (and obj ptrs) from past frames
self.memory_attention = memory_attention self.memory_attention = memory_attention
self.hidden_dim = memory_attention.d_model self.hidden_dim = image_encoder.neck.d_model
# Part 3: memory encoder for the previous frame's outputs # Part 3: memory encoder for the previous frame's outputs
self.memory_encoder = memory_encoder self.memory_encoder = memory_encoder
@@ -171,9 +174,12 @@ class SAM2Base(torch.nn.Module):
self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
trunc_normal_(self.no_obj_ptr, std=0.02) trunc_normal_(self.no_obj_ptr, std=0.02)
self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
self.no_obj_embed_spatial = None
if no_obj_embed_spatial:
self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
trunc_normal_(self.no_obj_embed_spatial, std=0.02)
self._build_sam_heads() self._build_sam_heads()
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
self.max_cond_frames_in_attn = max_cond_frames_in_attn self.max_cond_frames_in_attn = max_cond_frames_in_attn
# Model compilation # Model compilation
@@ -195,8 +201,8 @@ class SAM2Base(torch.nn.Module):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
raise NotImplementedError( raise NotImplementedError(
"Please use the corresponding methods in SAM2VideoPredictor for inference." "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning"
"See notebooks/video_predictor_example.ipynb for an example." "See notebooks/video_predictor_example.ipynb for an inference example."
) )
def _build_sam_heads(self): def _build_sam_heads(self):
@@ -389,8 +395,6 @@ class SAM2Base(torch.nn.Module):
if self.pred_obj_scores: if self.pred_obj_scores:
# Allow *soft* no obj ptr, unlike for masks # Allow *soft* no obj ptr, unlike for masks
if self.soft_no_obj_ptr: if self.soft_no_obj_ptr:
# Only hard possible with gt
assert not self.teacher_force_obj_scores_for_mem
lambda_is_obj_appearing = object_score_logits.sigmoid() lambda_is_obj_appearing = object_score_logits.sigmoid()
else: else:
lambda_is_obj_appearing = is_obj_appearing.float() lambda_is_obj_appearing = is_obj_appearing.float()
@@ -514,6 +518,7 @@ class SAM2Base(torch.nn.Module):
return pix_feat return pix_feat
num_obj_ptr_tokens = 0 num_obj_ptr_tokens = 0
tpos_sign_mul = -1 if track_in_reverse else 1
# Step 1: condition the visual features of the current frame on previous memories # Step 1: condition the visual features of the current frame on previous memories
if not is_init_cond_frame: if not is_init_cond_frame:
# Retrieve the memories encoded with the maskmem backbone # Retrieve the memories encoded with the maskmem backbone
@@ -529,9 +534,9 @@ class SAM2Base(torch.nn.Module):
t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
# Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
# the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
# We also allow taking the memory frame non-consecutively (with r>1), in which case # We also allow taking the memory frame non-consecutively (with stride>1), in which case
# we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
r = self.memory_temporal_stride_for_eval stride = 1 if self.training else self.memory_temporal_stride_for_eval
for t_pos in range(1, self.num_maskmem): for t_pos in range(1, self.num_maskmem):
t_rel = self.num_maskmem - t_pos # how many frames before current frame t_rel = self.num_maskmem - t_pos # how many frames before current frame
if t_rel == 1: if t_rel == 1:
@@ -547,15 +552,15 @@ class SAM2Base(torch.nn.Module):
if not track_in_reverse: if not track_in_reverse:
# first find the nearest frame among every r-th frames before this frame # first find the nearest frame among every r-th frames before this frame
# for r=1, this would be (frame_idx - 2) # for r=1, this would be (frame_idx - 2)
prev_frame_idx = ((frame_idx - 2) // r) * r prev_frame_idx = ((frame_idx - 2) // stride) * stride
# then seek further among every r-th frames # then seek further among every r-th frames
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride
else: else:
# first find the nearest frame among every r-th frames after this frame # first find the nearest frame among every r-th frames after this frame
# for r=1, this would be (frame_idx + 2) # for r=1, this would be (frame_idx + 2)
prev_frame_idx = -(-(frame_idx + 2) // r) * r prev_frame_idx = -(-(frame_idx + 2) // stride) * stride
# then seek further among every r-th frames # then seek further among every r-th frames
prev_frame_idx = prev_frame_idx + (t_rel - 2) * r prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
if out is None: if out is None:
# If an unselected conditioning frame is among the last (self.num_maskmem - 1) # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
@@ -568,10 +573,10 @@ class SAM2Base(torch.nn.Module):
continue # skip padding frames continue # skip padding frames
# "maskmem_features" might have been offloaded to CPU in demo use cases, # "maskmem_features" might have been offloaded to CPU in demo use cases,
# so we load it back to GPU (it's a no-op if it's already on GPU). # so we load it back to GPU (it's a no-op if it's already on GPU).
feats = prev["maskmem_features"].cuda(non_blocking=True) feats = prev["maskmem_features"].to(device, non_blocking=True)
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
# Spatial positional encoding (it might have been offloaded to CPU in eval) # Spatial positional encoding (it might have been offloaded to CPU in eval)
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda() maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
# Temporal positional encoding # Temporal positional encoding
maskmem_enc = ( maskmem_enc = (
@@ -594,7 +599,14 @@ class SAM2Base(torch.nn.Module):
ptr_cond_outputs = selected_cond_outputs ptr_cond_outputs = selected_cond_outputs
pos_and_ptrs = [ pos_and_ptrs = [
# Temporal pos encoding contains how far away each pointer is from current frame # Temporal pos encoding contains how far away each pointer is from current frame
(abs(frame_idx - t), out["obj_ptr"]) (
(
(frame_idx - t) * tpos_sign_mul
if self.use_signed_tpos_enc_to_obj_ptrs
else abs(frame_idx - t)
),
out["obj_ptr"],
)
for t, out in ptr_cond_outputs.items() for t, out in ptr_cond_outputs.items()
] ]
# Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
@@ -643,7 +655,7 @@ class SAM2Base(torch.nn.Module):
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
return pix_feat_with_mem return pix_feat_with_mem
# Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder) # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
@@ -667,6 +679,7 @@ class SAM2Base(torch.nn.Module):
current_vision_feats, current_vision_feats,
feat_sizes, feat_sizes,
pred_masks_high_res, pred_masks_high_res,
object_score_logits,
is_mask_from_pts, is_mask_from_pts,
): ):
"""Encode the current image and its prediction into a memory feature.""" """Encode the current image and its prediction into a memory feature."""
@@ -701,9 +714,104 @@ class SAM2Base(torch.nn.Module):
) )
maskmem_features = maskmem_out["vision_features"] maskmem_features = maskmem_out["vision_features"]
maskmem_pos_enc = maskmem_out["vision_pos_enc"] maskmem_pos_enc = maskmem_out["vision_pos_enc"]
# add a no-object embedding to the spatial memory to indicate that the frame
# is predicted to be occluded (i.e. no object is appearing in the frame)
if self.no_obj_embed_spatial is not None:
is_obj_appearing = (object_score_logits > 0).float()
maskmem_features += (
1 - is_obj_appearing[..., None, None]
) * self.no_obj_embed_spatial[..., None, None].expand(
*maskmem_features.shape
)
return maskmem_features, maskmem_pos_enc return maskmem_features, maskmem_pos_enc
def _track_step(
self,
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
point_inputs,
mask_inputs,
output_dict,
num_frames,
track_in_reverse,
prev_sam_mask_logits,
):
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
if len(current_vision_feats) > 1:
high_res_features = [
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
]
else:
high_res_features = None
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
sam_outputs = self._use_mask_as_output(
pix_feat, high_res_features, mask_inputs
)
else:
# fused the visual feature with previous memory features in the memory bank
pix_feat = self._prepare_memory_conditioned_features(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats[-1:],
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
feat_sizes=feat_sizes[-1:],
output_dict=output_dict,
num_frames=num_frames,
track_in_reverse=track_in_reverse,
)
# apply SAM-style segmentation head
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
if prev_sam_mask_logits is not None:
assert point_inputs is not None and mask_inputs is None
mask_inputs = prev_sam_mask_logits
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
)
return current_out, sam_outputs, high_res_features, pix_feat
def _encode_memory_in_output(
self,
current_vision_feats,
feat_sizes,
point_inputs,
run_mem_encoder,
high_res_masks,
object_score_logits,
current_out,
):
if run_mem_encoder and self.num_maskmem > 0:
high_res_masks_for_mem_enc = high_res_masks
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
current_vision_feats=current_vision_feats,
feat_sizes=feat_sizes,
pred_masks_high_res=high_res_masks_for_mem_enc,
object_score_logits=object_score_logits,
is_mask_from_pts=(point_inputs is not None),
)
current_out["maskmem_features"] = maskmem_features
current_out["maskmem_pos_enc"] = maskmem_pos_enc
else:
current_out["maskmem_features"] = None
current_out["maskmem_pos_enc"] = None
def track_step( def track_step(
self, self,
frame_idx, frame_idx,
@@ -725,50 +833,20 @@ class SAM2Base(torch.nn.Module):
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo). # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
prev_sam_mask_logits=None, prev_sam_mask_logits=None,
): ):
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} current_out, sam_outputs, _, _ = self._track_step(
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW frame_idx,
if len(current_vision_feats) > 1: is_init_cond_frame,
high_res_features = [ current_vision_feats,
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) current_vision_pos_embeds,
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) feat_sizes,
] point_inputs,
else: mask_inputs,
high_res_features = None output_dict,
if mask_inputs is not None and self.use_mask_input_as_output_without_sam: num_frames,
# When use_mask_input_as_output_without_sam=True, we directly output the mask input track_in_reverse,
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder. prev_sam_mask_logits,
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
sam_outputs = self._use_mask_as_output(
pix_feat, high_res_features, mask_inputs
)
else:
# fused the visual feature with previous memory features in the memory bank
pix_feat_with_mem = self._prepare_memory_conditioned_features(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats[-1:],
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
feat_sizes=feat_sizes[-1:],
output_dict=output_dict,
num_frames=num_frames,
track_in_reverse=track_in_reverse,
)
# apply SAM-style segmentation head
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
if prev_sam_mask_logits is not None:
assert point_inputs is not None and mask_inputs is None
mask_inputs = prev_sam_mask_logits
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat_with_mem,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
) )
( (
_, _,
_, _,
@@ -776,28 +854,28 @@ class SAM2Base(torch.nn.Module):
low_res_masks, low_res_masks,
high_res_masks, high_res_masks,
obj_ptr, obj_ptr,
_, object_score_logits,
) = sam_outputs ) = sam_outputs
current_out["pred_masks"] = low_res_masks current_out["pred_masks"] = low_res_masks
current_out["pred_masks_high_res"] = high_res_masks current_out["pred_masks_high_res"] = high_res_masks
current_out["obj_ptr"] = obj_ptr current_out["obj_ptr"] = obj_ptr
if not self.training:
# Only add this in inference (to avoid unused param in activation checkpointing;
# it's mainly used in the demo to encode spatial memories w/ consolidated masks)
current_out["object_score_logits"] = object_score_logits
# Finally run the memory encoder on the predicted mask to encode # Finally run the memory encoder on the predicted mask to encode
# it into a new memory feature (that can be used in future frames) # it into a new memory feature (that can be used in future frames)
if run_mem_encoder and self.num_maskmem > 0: self._encode_memory_in_output(
high_res_masks_for_mem_enc = high_res_masks current_vision_feats,
maskmem_features, maskmem_pos_enc = self._encode_new_memory( feat_sizes,
current_vision_feats=current_vision_feats, point_inputs,
feat_sizes=feat_sizes, run_mem_encoder,
pred_masks_high_res=high_res_masks_for_mem_enc, high_res_masks,
is_mask_from_pts=(point_inputs is not None), object_score_logits,
current_out,
) )
current_out["maskmem_features"] = maskmem_features
current_out["maskmem_pos_enc"] = maskmem_pos_enc
else:
current_out["maskmem_features"] = None
current_out["maskmem_pos_enc"] = None
return current_out return current_out

View File

@@ -6,11 +6,15 @@
import copy import copy
from typing import Tuple
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..utils.misc import mask_to_box
def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
""" """
@@ -147,3 +151,173 @@ class LayerNorm2d(nn.Module):
x = (x - u) / torch.sqrt(s + self.eps) x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None] x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x return x
def sample_box_points(
masks: torch.Tensor,
noise: float = 0.1, # SAM default
noise_bound: int = 20, # SAM default
top_left_label: int = 2,
bottom_right_label: int = 3,
) -> Tuple[np.array, np.array]:
"""
Sample a noised version of the top left and bottom right corners of a given `bbox`
Inputs:
- masks: [B, 1, H,W] boxes, dtype=torch.Tensor
- noise: noise as a fraction of box width and height, dtype=float
- noise_bound: maximum amount of noise (in pure pixesl), dtype=int
Returns:
- box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float
- box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32
"""
device = masks.device
box_coords = mask_to_box(masks)
B, _, H, W = masks.shape
box_labels = torch.tensor(
[top_left_label, bottom_right_label], dtype=torch.int, device=device
).repeat(B)
if noise > 0.0:
if not isinstance(noise_bound, torch.Tensor):
noise_bound = torch.tensor(noise_bound, device=device)
bbox_w = box_coords[..., 2] - box_coords[..., 0]
bbox_h = box_coords[..., 3] - box_coords[..., 1]
max_dx = torch.min(bbox_w * noise, noise_bound)
max_dy = torch.min(bbox_h * noise, noise_bound)
box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1
box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
box_coords = box_coords + box_noise
img_bounds = (
torch.tensor([W, H, W, H], device=device) - 1
) # uncentered pixel coords
box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
box_labels = box_labels.reshape(-1, 2)
return box_coords, box_labels
def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
"""
Sample `num_pt` random points (along with their labels) independently from the error regions.
Inputs:
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
- num_pt: int, number of points to sample independently for each of the B error maps
Outputs:
- points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
- labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means
negative clicks
"""
if pred_masks is None: # if pred_masks is not provided, treat it as empty
pred_masks = torch.zeros_like(gt_masks)
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
assert num_pt >= 0
B, _, H_im, W_im = gt_masks.shape
device = gt_masks.device
# false positive region, a new point sampled in this region should have
# negative label to correct the FP error
fp_masks = ~gt_masks & pred_masks
# false negative region, a new point sampled in this region should have
# positive label to correct the FN error
fn_masks = gt_masks & ~pred_masks
# whether the prediction completely match the ground-truth on each mask
all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2)
all_correct = all_correct[..., None, None]
# channel 0 is FP map, while channel 1 is FN map
pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device)
# sample a negative new click from FP region or a positive new click
# from FN region, depend on where the maximum falls,
# and in case the predictions are all correct (no FP or FN), we just
# sample a negative click from the background region
pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks)
pts_noise[..., 1] *= fn_masks
pts_idx = pts_noise.flatten(2).argmax(dim=2)
labels = (pts_idx % 2).to(torch.int32)
pts_idx = pts_idx // 2
pts_x = pts_idx % W_im
pts_y = pts_idx // W_im
points = torch.stack([pts_x, pts_y], dim=2).to(torch.float)
return points, labels
def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
"""
Sample 1 random point (along with its label) from the center of each error region,
that is, the point with the largest distance to the boundary of each error region.
This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
Inputs:
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
- padding: if True, pad with boundary of 1 px for distance transform
Outputs:
- points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
- labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
"""
import cv2
if pred_masks is None:
pred_masks = torch.zeros_like(gt_masks)
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
B, _, _, W_im = gt_masks.shape
device = gt_masks.device
# false positive region, a new point sampled in this region should have
# negative label to correct the FP error
fp_masks = ~gt_masks & pred_masks
# false negative region, a new point sampled in this region should have
# positive label to correct the FN error
fn_masks = gt_masks & ~pred_masks
fp_masks = fp_masks.cpu().numpy()
fn_masks = fn_masks.cpu().numpy()
points = torch.zeros(B, 1, 2, dtype=torch.float)
labels = torch.ones(B, 1, dtype=torch.int32)
for b in range(B):
fn_mask = fn_masks[b, 0]
fp_mask = fp_masks[b, 0]
if padding:
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
# compute the distance of each point in FN/FP region to its boundary
fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
if padding:
fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
# take the point in FN/FP region with the largest distance to its boundary
fn_mask_dt_flat = fn_mask_dt.reshape(-1)
fp_mask_dt_flat = fp_mask_dt.reshape(-1)
fn_argmax = np.argmax(fn_mask_dt_flat)
fp_argmax = np.argmax(fp_mask_dt_flat)
is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
pt_idx = fn_argmax if is_positive else fp_argmax
points[b, 0, 0] = pt_idx % W_im # x
points[b, 0, 1] = pt_idx // W_im # y
labels[b, 0] = int(is_positive)
points = points.to(device)
labels = labels.to(device)
return points, labels
def get_next_point(gt_masks, pred_masks, method):
if method == "uniform":
return sample_random_points_from_errors(gt_masks, pred_masks)
elif method == "center":
return sample_one_point_from_error_center(gt_masks, pred_masks)
else:
raise ValueError(f"unknown sampling method {method}")

View File

@@ -175,11 +175,17 @@ class InteractiveSegModel(Choices):
sam_hq_vit_l = "sam_hq_vit_l" sam_hq_vit_l = "sam_hq_vit_l"
sam_hq_vit_h = "sam_hq_vit_h" sam_hq_vit_h = "sam_hq_vit_h"
mobile_sam = "mobile_sam" mobile_sam = "mobile_sam"
sam2_tiny = "sam2_tiny" sam2_tiny = "sam2_tiny"
sam2_small = "sam2_small" sam2_small = "sam2_small"
sam2_base = "sam2_base" sam2_base = "sam2_base"
sam2_large = "sam2_large" sam2_large = "sam2_large"
sam2_1_tiny = "sam2_1_tiny"
sam2_1_small = "sam2_1_small"
sam2_1_base = "sam2_1_base"
sam2_1_large = "sam2_1_large"
class PluginInfo(BaseModel): class PluginInfo(BaseModel):
name: str name: str

View File

@@ -48,7 +48,7 @@ default_configs = dict(
output_dir=None, output_dir=None,
quality=95, quality=95,
enable_interactive_seg=False, enable_interactive_seg=False,
interactive_seg_model=InteractiveSegModel.vit_b, interactive_seg_model=InteractiveSegModel.sam2_1_tiny,
interactive_seg_device=Device.cpu, interactive_seg_device=Device.cpu,
enable_remove_bg=False, enable_remove_bg=False,
remove_bg_model=RemoveBGModel.briaai_rmbg_1_4, remove_bg_model=RemoveBGModel.briaai_rmbg_1_4,