From ec08ce306393afe729e0ee5984d64334fc563a29 Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 1 Nov 2024 18:45:32 +0800 Subject: [PATCH] add sam2.1 --- iopaint/cli.py | 2 +- iopaint/plugins/interactive_seg.py | 16 ++ .../plugins/segment_anything2/build_sam.py | 270 ++++++++++++------ .../modeling/backbones/hieradet.py | 32 ++- .../modeling/backbones/image_encoder.py | 2 +- .../modeling/position_encoding.py | 9 +- .../modeling/sam/mask_decoder.py | 2 +- .../segment_anything2/modeling/sam2_base.py | 234 ++++++++++----- .../segment_anything2/modeling/sam2_utils.py | 174 +++++++++++ iopaint/schema.py | 6 + iopaint/web_config.py | 2 +- 11 files changed, 565 insertions(+), 184 deletions(-) diff --git a/iopaint/cli.py b/iopaint/cli.py index 9ba686b..ae0b0c4 100644 --- a/iopaint/cli.py +++ b/iopaint/cli.py @@ -129,7 +129,7 @@ def start( quality: int = Option(95, help=QUALITY_HELP), enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP), interactive_seg_model: InteractiveSegModel = Option( - InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP + InteractiveSegModel.sam2_1_tiny, help=INTERACTIVE_SEG_MODEL_HELP ), interactive_seg_device: Device = Option(Device.cpu), enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP), diff --git a/iopaint/plugins/interactive_seg.py b/iopaint/plugins/interactive_seg.py index 27859fa..fd87224 100644 --- a/iopaint/plugins/interactive_seg.py +++ b/iopaint/plugins/interactive_seg.py @@ -59,6 +59,22 @@ SEGMENT_ANYTHING_MODELS = { "url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt", "md5": "08083462423be3260cd6a5eef94dc01c", }, + "sam2_1_tiny": { + "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt", + "md5": "6aa6761c9da74fbaa74b4c790a0a2007", + }, + "sam2_1_small": { + "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt", + "md5": "51713b3d1994696d27f35f9c6de6f5ef", + }, + "sam2_1_base": { + "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt", + "md5": "ec7bd7d23d280d5e3cfa45984c02eda5", + }, + "sam2_1_large": { + "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", + "md5": "2b30654b6112c42a115563c638d238d9", + }, } diff --git a/iopaint/plugins/segment_anything2/build_sam.py b/iopaint/plugins/segment_anything2/build_sam.py index 5100f70..924e0d5 100644 --- a/iopaint/plugins/segment_anything2/build_sam.py +++ b/iopaint/plugins/segment_anything2/build_sam.py @@ -17,9 +17,6 @@ from .modeling.position_encoding import PositionEmbeddingSine from .modeling.sam.transformer import RoPEAttention from .modeling.sam2_base import SAM2Base -CURRENT_DIR = Path(__file__).parent -CONFIG_DIR = CURRENT_DIR / "sam2_configs" - common_kwargs = dict( num_maskmem=7, image_size=1024, @@ -44,6 +41,33 @@ common_kwargs = dict( compile_image_encoder=False, ) +common_kwargs_for_2_1 = dict( + num_maskmem=7, + image_size=1024, + sigmoid_scale_for_mem_enc=20.0, + sigmoid_bias_for_mem_enc=-10.0, + use_mask_input_as_output_without_sam=True, + directly_add_no_mem_embed=True, + no_obj_embed_spatial=True, + use_high_res_features_in_sam=True, + multimask_output_in_sam=True, + iou_prediction_use_sigmoid=True, + use_obj_ptrs_in_encoder=True, + add_tpos_enc_to_obj_ptrs=True, + proj_tpos_enc_in_obj_ptrs=True, + use_signed_tpos_enc_to_obj_ptrs=True, + only_obj_ptrs_in_the_past_for_eval=True, + pred_obj_scores=True, + pred_obj_scores_mlp=True, + fixed_no_obj_ptr=True, + multimask_output_for_tracking=True, + use_multimask_token_for_obj_ptr=True, + multimask_min_pt_num=0, + multimask_max_pt_num=1, + use_mlp_for_obj_ptr_proj=True, + compile_image_encoder=False, +) + def build_memory_attention(): return MemoryAttention( @@ -103,32 +127,114 @@ def build_memory_encoder(): ) +def build_image_encoder_tiny(): + return ImageEncoder( + scalp=1, + trunk=Hiera( + embed_dim=96, + num_heads=1, + stages=(1, 2, 7, 2), + global_att_blocks=(5, 7, 9), + window_pos_embed_bkg_spatial_size=(7, 7), + window_spec=(8, 4, 14, 7), + ), + neck=FpnNeck( + position_encoding=PositionEmbeddingSine( + num_pos_feats=256, + normalize=True, + scale=None, + temperature=10000, + ), + d_model=256, + backbone_channel_list=[768, 384, 192, 96], + fpn_top_down_levels=[2, 3], + fpn_interp_model="nearest", + ), + ) + + +def build_image_encoder_small(): + return ImageEncoder( + scalp=1, + trunk=Hiera( + embed_dim=96, + num_heads=1, + stages=(1, 2, 11, 2), + global_att_blocks=(7, 10, 13), + window_pos_embed_bkg_spatial_size=(7, 7), + window_spec=(8, 4, 14, 7), + ), + neck=FpnNeck( + position_encoding=PositionEmbeddingSine( + num_pos_feats=256, + normalize=True, + scale=None, + temperature=10000, + ), + d_model=256, + backbone_channel_list=[768, 384, 192, 96], + fpn_top_down_levels=[2, 3], + fpn_interp_model="nearest", + ), + ) + + +def build_image_encoder_base(): + return ImageEncoder( + scalp=1, + trunk=Hiera( + embed_dim=112, + num_heads=2, + stages=(2, 3, 16, 3), + global_att_blocks=(12, 16, 20), + window_pos_embed_bkg_spatial_size=(14, 14), + window_spec=(8, 4, 14, 7), + ), + neck=FpnNeck( + position_encoding=PositionEmbeddingSine( + num_pos_feats=256, + normalize=True, + scale=None, + temperature=10000, + ), + d_model=256, + backbone_channel_list=[896, 448, 224, 112], + fpn_top_down_levels=[2, 3], + fpn_interp_model="nearest", + ), + ) + + +def build_image_encoder_large(): + return ImageEncoder( + scalp=1, + trunk=Hiera( + embed_dim=144, + num_heads=2, + stages=(2, 6, 36, 4), + global_att_blocks=(23, 33, 43), + window_pos_embed_bkg_spatial_size=(7, 7), + window_spec=(8, 4, 16, 8), + ), + neck=FpnNeck( + position_encoding=PositionEmbeddingSine( + num_pos_feats=256, + normalize=True, + scale=None, + temperature=10000, + ), + d_model=256, + backbone_channel_list=[1152, 576, 288, 144], + fpn_top_down_levels=[2, 3], + fpn_interp_model="nearest", + ), + ) + + def build_sam2_tiny(): return SAM2Base( **common_kwargs, - image_encoder=ImageEncoder( - scalp=1, - trunk=Hiera( - embed_dim=96, - num_heads=1, - stages=(1, 2, 7, 2), - global_att_blocks=(5, 7, 9), - window_pos_embed_bkg_spatial_size=(7, 7), - window_spec=(8, 4, 14, 7), - ), - neck=FpnNeck( - position_encoding=PositionEmbeddingSine( - num_pos_feats=256, - normalize=True, - scale=None, - temperature=10000, - ), - d_model=256, - backbone_channel_list=[768, 384, 192, 96], - fpn_top_down_levels=[2, 3], - fpn_interp_model="nearest", - ), - ), + image_encoder=build_image_encoder_tiny(), memory_attention=build_memory_attention(), memory_encoder=build_memory_encoder(), ) @@ -137,29 +243,7 @@ def build_sam2_tiny(): def build_sam2_small(): return SAM2Base( **common_kwargs, - image_encoder=ImageEncoder( - scalp=1, - trunk=Hiera( - embed_dim=96, - num_heads=1, - stages=(1, 2, 11, 2), - global_att_blocks=(7, 10, 13), - window_pos_embed_bkg_spatial_size=(7, 7), - window_spec=(8, 4, 14, 7), - ), - neck=FpnNeck( - position_encoding=PositionEmbeddingSine( - num_pos_feats=256, - normalize=True, - scale=None, - temperature=10000, - ), - d_model=256, - backbone_channel_list=[768, 384, 192, 96], - fpn_top_down_levels=[2, 3], - fpn_interp_model="nearest", - ), - ), + image_encoder=build_image_encoder_small(), memory_attention=build_memory_attention(), memory_encoder=build_memory_encoder(), ) @@ -168,29 +252,7 @@ def build_sam2_small(): def build_sam2_base(): return SAM2Base( **common_kwargs, - image_encoder=ImageEncoder( - scalp=1, - trunk=Hiera( - embed_dim=112, - num_heads=2, - stages=(2, 3, 16, 3), - global_att_blocks=(12, 16, 20), - window_pos_embed_bkg_spatial_size=(14, 14), - window_spec=(8, 4, 14, 7), - ), - neck=FpnNeck( - position_encoding=PositionEmbeddingSine( - num_pos_feats=256, - normalize=True, - scale=None, - temperature=10000, - ), - d_model=256, - backbone_channel_list=[896, 448, 224, 112], - fpn_top_down_levels=[2, 3], - fpn_interp_model="nearest", - ), - ), + image_encoder=build_image_encoder_base(), memory_attention=build_memory_attention(), memory_encoder=build_memory_encoder(), ) @@ -199,29 +261,43 @@ def build_sam2_base(): def build_sam2_large(): return SAM2Base( **common_kwargs, - image_encoder=ImageEncoder( - scalp=1, - trunk=Hiera( - embed_dim=144, - num_heads=2, - stages=(2, 6, 36, 4), - global_att_blocks=(23, 33, 43), - window_pos_embed_bkg_spatial_size=(7, 7), - window_spec=(8, 4, 16, 8), - ), - neck=FpnNeck( - position_encoding=PositionEmbeddingSine( - num_pos_feats=256, - normalize=True, - scale=None, - temperature=10000, - ), - d_model=256, - backbone_channel_list=[1152, 576, 288, 144], - fpn_top_down_levels=[2, 3], - fpn_interp_model="nearest", - ), - ), + image_encoder=build_image_encoder_large(), + memory_attention=build_memory_attention(), + memory_encoder=build_memory_encoder(), + ) + + +def build_sam2_1_tiny(): + return SAM2Base( + **common_kwargs_for_2_1, + image_encoder=build_image_encoder_tiny(), + memory_attention=build_memory_attention(), + memory_encoder=build_memory_encoder(), + ) + + +def build_sam2_1_small(): + return SAM2Base( + **common_kwargs_for_2_1, + image_encoder=build_image_encoder_small(), + memory_attention=build_memory_attention(), + memory_encoder=build_memory_encoder(), + ) + + +def build_sam2_1_base(): + return SAM2Base( + **common_kwargs_for_2_1, + image_encoder=build_image_encoder_base(), + memory_attention=build_memory_attention(), + memory_encoder=build_memory_encoder(), + ) + + +def build_sam2_1_large(): + return SAM2Base( + **common_kwargs_for_2_1, + image_encoder=build_image_encoder_large(), memory_attention=build_memory_attention(), memory_encoder=build_memory_encoder(), ) @@ -232,6 +308,10 @@ sam2_model_registry = { "sam2_small": build_sam2_small, "sam2_base": build_sam2_base, "sam2_large": build_sam2_large, + "sam2_1_tiny": build_sam2_1_tiny, + "sam2_1_small": build_sam2_1_small, + "sam2_1_base": build_sam2_1_base, + "sam2_1_large": build_sam2_1_large, } diff --git a/iopaint/plugins/segment_anything2/modeling/backbones/hieradet.py b/iopaint/plugins/segment_anything2/modeling/backbones/hieradet.py index 9375b6a..d9ab6c4 100644 --- a/iopaint/plugins/segment_anything2/modeling/backbones/hieradet.py +++ b/iopaint/plugins/segment_anything2/modeling/backbones/hieradet.py @@ -4,14 +4,16 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import logging from functools import partial from typing import List, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F +from iopath.common.file_io import g_pathmgr -from ..backbones.utils import ( +from .utils import ( PatchEmbed, window_partition, window_unpartition, @@ -46,11 +48,7 @@ class MultiScaleAttention(nn.Module): self.dim = dim self.dim_out = dim_out - self.num_heads = num_heads - head_dim = dim_out // num_heads - self.scale = head_dim**-0.5 - self.q_pool = q_pool self.qkv = nn.Linear(dim, dim_out * 3) self.proj = nn.Linear(dim_out, dim_out) @@ -197,6 +195,7 @@ class Hiera(nn.Module): 16, 20, ), + weights_path=None, return_interm_layers=True, # return feats from every stage ): super().__init__() @@ -266,6 +265,11 @@ class Hiera(nn.Module): else [self.blocks[-1].dim_out] ) + if weights_path is not None: + with g_pathmgr.open(weights_path, "rb") as f: + chkpt = torch.load(f, map_location="cpu") + logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False)) + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: h, w = hw window_embed = self.pos_embed_window @@ -293,3 +297,21 @@ class Hiera(nn.Module): outputs.append(feats) return outputs + + def get_layer_id(self, layer_name): + # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + num_layers = self.get_num_layers() + + if layer_name.find("rel_pos") != -1: + return num_layers + 1 + elif layer_name.find("pos_embed") != -1: + return 0 + elif layer_name.find("patch_embed") != -1: + return 0 + elif layer_name.find("blocks") != -1: + return int(layer_name.split("blocks")[1].split(".")[1]) + 1 + else: + return num_layers + 1 + + def get_num_layers(self) -> int: + return len(self.blocks) diff --git a/iopaint/plugins/segment_anything2/modeling/backbones/image_encoder.py b/iopaint/plugins/segment_anything2/modeling/backbones/image_encoder.py index 5f92baf..c3ffefe 100644 --- a/iopaint/plugins/segment_anything2/modeling/backbones/image_encoder.py +++ b/iopaint/plugins/segment_anything2/modeling/backbones/image_encoder.py @@ -71,6 +71,7 @@ class FpnNeck(nn.Module): self.position_encoding = position_encoding self.convs = nn.ModuleList() self.backbone_channel_list = backbone_channel_list + self.d_model = d_model for dim in backbone_channel_list: current = nn.Sequential() current.add_module( @@ -99,7 +100,6 @@ class FpnNeck(nn.Module): self.fpn_top_down_levels = list(fpn_top_down_levels) def forward(self, xs: List[torch.Tensor]): - out = [None] * len(self.convs) pos = [None] * len(self.convs) assert len(xs) == len(self.convs) diff --git a/iopaint/plugins/segment_anything2/modeling/position_encoding.py b/iopaint/plugins/segment_anything2/modeling/position_encoding.py index f4b57ae..52ac226 100644 --- a/iopaint/plugins/segment_anything2/modeling/position_encoding.py +++ b/iopaint/plugins/segment_anything2/modeling/position_encoding.py @@ -16,7 +16,7 @@ from torch import nn class PositionEmbeddingSine(nn.Module): """ This is a more standard version of the position embedding, very similar to the one - used by the Attention is all you need paper, generalized to work on images. + used by the Attention Is All You Need paper, generalized to work on images. """ def __init__( @@ -211,6 +211,11 @@ def apply_rotary_enc( # repeat freqs along seq_len dim to match k seq_len if repeat_freqs_k: r = xk_.shape[-2] // xq_.shape[-2] - freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + if freqs_cis.is_cuda: + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + else: + # torch.repeat on complex numbers may not be supported on non-CUDA devices + # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten + freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/iopaint/plugins/segment_anything2/modeling/sam/mask_decoder.py b/iopaint/plugins/segment_anything2/modeling/sam/mask_decoder.py index fb8bb05..a491a3d 100644 --- a/iopaint/plugins/segment_anything2/modeling/sam/mask_decoder.py +++ b/iopaint/plugins/segment_anything2/modeling/sam/mask_decoder.py @@ -247,7 +247,7 @@ class MaskDecoder(nn.Module): def _get_stability_scores(self, mask_logits): """ Compute stability scores of the mask logits based on the IoU between upper and - lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568. + lower thresholds. """ mask_logits = mask_logits.flatten(-2) stability_delta = self.dynamic_multimask_stability_delta diff --git a/iopaint/plugins/segment_anything2/modeling/sam2_base.py b/iopaint/plugins/segment_anything2/modeling/sam2_base.py index 7896060..5a19b6a 100644 --- a/iopaint/plugins/segment_anything2/modeling/sam2_base.py +++ b/iopaint/plugins/segment_anything2/modeling/sam2_base.py @@ -60,9 +60,6 @@ class SAM2Base(torch.nn.Module): # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. memory_temporal_stride_for_eval=1, - # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click - # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames - add_all_frames_to_correct_as_cond=False, # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) non_overlap_masks_for_mem_enc=False, # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder @@ -74,6 +71,9 @@ class SAM2Base(torch.nn.Module): # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) proj_tpos_enc_in_obj_ptrs=False, + # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers + # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + use_signed_tpos_enc_to_obj_ptrs=False, # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) only_obj_ptrs_in_the_past_for_eval=False, @@ -89,6 +89,8 @@ class SAM2Base(torch.nn.Module): # hope to make recovery easier if there is a mistake and mitigate accumulation of errors soft_no_obj_ptr: bool = False, use_mlp_for_obj_ptr_proj: bool = False, + # add no obj embedding to spatial frames + no_obj_embed_spatial: bool = False, # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. sam_mask_decoder_extra_args=None, compile_image_encoder: bool = False, @@ -111,12 +113,13 @@ class SAM2Base(torch.nn.Module): if proj_tpos_enc_in_obj_ptrs: assert add_tpos_enc_to_obj_ptrs # these options need to be used together self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval # Part 2: memory attention to condition current frame's visual features # with memories (and obj ptrs) from past frames self.memory_attention = memory_attention - self.hidden_dim = memory_attention.d_model + self.hidden_dim = image_encoder.neck.d_model # Part 3: memory encoder for the previous frame's outputs self.memory_encoder = memory_encoder @@ -171,9 +174,12 @@ class SAM2Base(torch.nn.Module): self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) trunc_normal_(self.no_obj_ptr, std=0.02) self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + self.no_obj_embed_spatial = None + if no_obj_embed_spatial: + self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) + trunc_normal_(self.no_obj_embed_spatial, std=0.02) self._build_sam_heads() - self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond self.max_cond_frames_in_attn = max_cond_frames_in_attn # Model compilation @@ -195,8 +201,8 @@ class SAM2Base(torch.nn.Module): def forward(self, *args, **kwargs): raise NotImplementedError( - "Please use the corresponding methods in SAM2VideoPredictor for inference." - "See notebooks/video_predictor_example.ipynb for an example." + "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning" + "See notebooks/video_predictor_example.ipynb for an inference example." ) def _build_sam_heads(self): @@ -389,8 +395,6 @@ class SAM2Base(torch.nn.Module): if self.pred_obj_scores: # Allow *soft* no obj ptr, unlike for masks if self.soft_no_obj_ptr: - # Only hard possible with gt - assert not self.teacher_force_obj_scores_for_mem lambda_is_obj_appearing = object_score_logits.sigmoid() else: lambda_is_obj_appearing = is_obj_appearing.float() @@ -514,6 +518,7 @@ class SAM2Base(torch.nn.Module): return pix_feat num_obj_ptr_tokens = 0 + tpos_sign_mul = -1 if track_in_reverse else 1 # Step 1: condition the visual features of the current frame on previous memories if not is_init_cond_frame: # Retrieve the memories encoded with the maskmem backbone @@ -529,9 +534,9 @@ class SAM2Base(torch.nn.Module): t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 - # We also allow taking the memory frame non-consecutively (with r>1), in which case - # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. - r = self.memory_temporal_stride_for_eval + # We also allow taking the memory frame non-consecutively (with stride>1), in which case + # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame. + stride = 1 if self.training else self.memory_temporal_stride_for_eval for t_pos in range(1, self.num_maskmem): t_rel = self.num_maskmem - t_pos # how many frames before current frame if t_rel == 1: @@ -547,15 +552,15 @@ class SAM2Base(torch.nn.Module): if not track_in_reverse: # first find the nearest frame among every r-th frames before this frame # for r=1, this would be (frame_idx - 2) - prev_frame_idx = ((frame_idx - 2) // r) * r + prev_frame_idx = ((frame_idx - 2) // stride) * stride # then seek further among every r-th frames - prev_frame_idx = prev_frame_idx - (t_rel - 2) * r + prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride else: # first find the nearest frame among every r-th frames after this frame # for r=1, this would be (frame_idx + 2) - prev_frame_idx = -(-(frame_idx + 2) // r) * r + prev_frame_idx = -(-(frame_idx + 2) // stride) * stride # then seek further among every r-th frames - prev_frame_idx = prev_frame_idx + (t_rel - 2) * r + prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) if out is None: # If an unselected conditioning frame is among the last (self.num_maskmem - 1) @@ -568,10 +573,10 @@ class SAM2Base(torch.nn.Module): continue # skip padding frames # "maskmem_features" might have been offloaded to CPU in demo use cases, # so we load it back to GPU (it's a no-op if it's already on GPU). - feats = prev["maskmem_features"].cuda(non_blocking=True) + feats = prev["maskmem_features"].to(device, non_blocking=True) to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) # Spatial positional encoding (it might have been offloaded to CPU in eval) - maskmem_enc = prev["maskmem_pos_enc"][-1].cuda() + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) # Temporal positional encoding maskmem_enc = ( @@ -594,7 +599,14 @@ class SAM2Base(torch.nn.Module): ptr_cond_outputs = selected_cond_outputs pos_and_ptrs = [ # Temporal pos encoding contains how far away each pointer is from current frame - (abs(frame_idx - t), out["obj_ptr"]) + ( + ( + (frame_idx - t) * tpos_sign_mul + if self.use_signed_tpos_enc_to_obj_ptrs + else abs(frame_idx - t) + ), + out["obj_ptr"], + ) for t, out in ptr_cond_outputs.items() ] # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame @@ -643,7 +655,7 @@ class SAM2Base(torch.nn.Module): pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) return pix_feat_with_mem - # Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder) + # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder) to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] @@ -667,6 +679,7 @@ class SAM2Base(torch.nn.Module): current_vision_feats, feat_sizes, pred_masks_high_res, + object_score_logits, is_mask_from_pts, ): """Encode the current image and its prediction into a memory feature.""" @@ -701,9 +714,104 @@ class SAM2Base(torch.nn.Module): ) maskmem_features = maskmem_out["vision_features"] maskmem_pos_enc = maskmem_out["vision_pos_enc"] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += ( + 1 - is_obj_appearing[..., None, None] + ) * self.no_obj_embed_spatial[..., None, None].expand( + *maskmem_features.shape + ) return maskmem_features, maskmem_pos_enc + def _track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + + return current_out, sam_outputs, high_res_features, pix_feat + + def _encode_memory_in_output( + self, + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ): + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + def track_step( self, frame_idx, @@ -725,50 +833,20 @@ class SAM2Base(torch.nn.Module): # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). prev_sam_mask_logits=None, ): - current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} - # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW - if len(current_vision_feats) > 1: - high_res_features = [ - x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) - for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) - ] - else: - high_res_features = None - if mask_inputs is not None and self.use_mask_input_as_output_without_sam: - # When use_mask_input_as_output_without_sam=True, we directly output the mask input - # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. - pix_feat = current_vision_feats[-1].permute(1, 2, 0) - pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) - sam_outputs = self._use_mask_as_output( - pix_feat, high_res_features, mask_inputs - ) - else: - # fused the visual feature with previous memory features in the memory bank - pix_feat_with_mem = self._prepare_memory_conditioned_features( - frame_idx=frame_idx, - is_init_cond_frame=is_init_cond_frame, - current_vision_feats=current_vision_feats[-1:], - current_vision_pos_embeds=current_vision_pos_embeds[-1:], - feat_sizes=feat_sizes[-1:], - output_dict=output_dict, - num_frames=num_frames, - track_in_reverse=track_in_reverse, - ) - # apply SAM-style segmentation head - # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, - # e.g. in demo where such logits come from earlier interaction instead of correction sampling - # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) - if prev_sam_mask_logits is not None: - assert point_inputs is not None and mask_inputs is None - mask_inputs = prev_sam_mask_logits - multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) - sam_outputs = self._forward_sam_heads( - backbone_features=pix_feat_with_mem, - point_inputs=point_inputs, - mask_inputs=mask_inputs, - high_res_features=high_res_features, - multimask_output=multimask_output, - ) + current_out, sam_outputs, _, _ = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + ( _, _, @@ -776,28 +854,28 @@ class SAM2Base(torch.nn.Module): low_res_masks, high_res_masks, obj_ptr, - _, + object_score_logits, ) = sam_outputs current_out["pred_masks"] = low_res_masks current_out["pred_masks_high_res"] = high_res_masks current_out["obj_ptr"] = obj_ptr + if not self.training: + # Only add this in inference (to avoid unused param in activation checkpointing; + # it's mainly used in the demo to encode spatial memories w/ consolidated masks) + current_out["object_score_logits"] = object_score_logits # Finally run the memory encoder on the predicted mask to encode # it into a new memory feature (that can be used in future frames) - if run_mem_encoder and self.num_maskmem > 0: - high_res_masks_for_mem_enc = high_res_masks - maskmem_features, maskmem_pos_enc = self._encode_new_memory( - current_vision_feats=current_vision_feats, - feat_sizes=feat_sizes, - pred_masks_high_res=high_res_masks_for_mem_enc, - is_mask_from_pts=(point_inputs is not None), - ) - current_out["maskmem_features"] = maskmem_features - current_out["maskmem_pos_enc"] = maskmem_pos_enc - else: - current_out["maskmem_features"] = None - current_out["maskmem_pos_enc"] = None + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) return current_out diff --git a/iopaint/plugins/segment_anything2/modeling/sam2_utils.py b/iopaint/plugins/segment_anything2/modeling/sam2_utils.py index 6d97059..ad00a76 100644 --- a/iopaint/plugins/segment_anything2/modeling/sam2_utils.py +++ b/iopaint/plugins/segment_anything2/modeling/sam2_utils.py @@ -6,11 +6,15 @@ import copy +from typing import Tuple +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from ..utils.misc import mask_to_box + def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): """ @@ -147,3 +151,173 @@ class LayerNorm2d(nn.Module): x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x + + +def sample_box_points( + masks: torch.Tensor, + noise: float = 0.1, # SAM default + noise_bound: int = 20, # SAM default + top_left_label: int = 2, + bottom_right_label: int = 3, +) -> Tuple[np.array, np.array]: + """ + Sample a noised version of the top left and bottom right corners of a given `bbox` + + Inputs: + - masks: [B, 1, H,W] boxes, dtype=torch.Tensor + - noise: noise as a fraction of box width and height, dtype=float + - noise_bound: maximum amount of noise (in pure pixesl), dtype=int + + Returns: + - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float + - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32 + """ + device = masks.device + box_coords = mask_to_box(masks) + B, _, H, W = masks.shape + box_labels = torch.tensor( + [top_left_label, bottom_right_label], dtype=torch.int, device=device + ).repeat(B) + if noise > 0.0: + if not isinstance(noise_bound, torch.Tensor): + noise_bound = torch.tensor(noise_bound, device=device) + bbox_w = box_coords[..., 2] - box_coords[..., 0] + bbox_h = box_coords[..., 3] - box_coords[..., 1] + max_dx = torch.min(bbox_w * noise, noise_bound) + max_dy = torch.min(bbox_h * noise, noise_bound) + box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1 + box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1) + + box_coords = box_coords + box_noise + img_bounds = ( + torch.tensor([W, H, W, H], device=device) - 1 + ) # uncentered pixel coords + box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping + + box_coords = box_coords.reshape(-1, 2, 2) # always 2 points + box_labels = box_labels.reshape(-1, 2) + return box_coords, box_labels + + +def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1): + """ + Sample `num_pt` random points (along with their labels) independently from the error regions. + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - num_pt: int, number of points to sample independently for each of the B error maps + + Outputs: + - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means + negative clicks + """ + if pred_masks is None: # if pred_masks is not provided, treat it as empty + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + assert num_pt >= 0 + + B, _, H_im, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + # whether the prediction completely match the ground-truth on each mask + all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2) + all_correct = all_correct[..., None, None] + + # channel 0 is FP map, while channel 1 is FN map + pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device) + # sample a negative new click from FP region or a positive new click + # from FN region, depend on where the maximum falls, + # and in case the predictions are all correct (no FP or FN), we just + # sample a negative click from the background region + pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks) + pts_noise[..., 1] *= fn_masks + pts_idx = pts_noise.flatten(2).argmax(dim=2) + labels = (pts_idx % 2).to(torch.int32) + pts_idx = pts_idx // 2 + pts_x = pts_idx % W_im + pts_y = pts_idx // W_im + points = torch.stack([pts_x, pts_y], dim=2).to(torch.float) + return points, labels + + +def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True): + """ + Sample 1 random point (along with its label) from the center of each error region, + that is, the point with the largest distance to the boundary of each error region. + This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py + + Inputs: + - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool + - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None + - padding: if True, pad with boundary of 1 px for distance transform + + Outputs: + - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point + - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks + """ + import cv2 + + if pred_masks is None: + pred_masks = torch.zeros_like(gt_masks) + assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 + assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape + + B, _, _, W_im = gt_masks.shape + device = gt_masks.device + + # false positive region, a new point sampled in this region should have + # negative label to correct the FP error + fp_masks = ~gt_masks & pred_masks + # false negative region, a new point sampled in this region should have + # positive label to correct the FN error + fn_masks = gt_masks & ~pred_masks + + fp_masks = fp_masks.cpu().numpy() + fn_masks = fn_masks.cpu().numpy() + points = torch.zeros(B, 1, 2, dtype=torch.float) + labels = torch.ones(B, 1, dtype=torch.int32) + for b in range(B): + fn_mask = fn_masks[b, 0] + fp_mask = fp_masks[b, 0] + if padding: + fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant") + fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant") + # compute the distance of each point in FN/FP region to its boundary + fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) + fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) + if padding: + fn_mask_dt = fn_mask_dt[1:-1, 1:-1] + fp_mask_dt = fp_mask_dt[1:-1, 1:-1] + + # take the point in FN/FP region with the largest distance to its boundary + fn_mask_dt_flat = fn_mask_dt.reshape(-1) + fp_mask_dt_flat = fp_mask_dt.reshape(-1) + fn_argmax = np.argmax(fn_mask_dt_flat) + fp_argmax = np.argmax(fp_mask_dt_flat) + is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax] + pt_idx = fn_argmax if is_positive else fp_argmax + points[b, 0, 0] = pt_idx % W_im # x + points[b, 0, 1] = pt_idx // W_im # y + labels[b, 0] = int(is_positive) + + points = points.to(device) + labels = labels.to(device) + return points, labels + + +def get_next_point(gt_masks, pred_masks, method): + if method == "uniform": + return sample_random_points_from_errors(gt_masks, pred_masks) + elif method == "center": + return sample_one_point_from_error_center(gt_masks, pred_masks) + else: + raise ValueError(f"unknown sampling method {method}") diff --git a/iopaint/schema.py b/iopaint/schema.py index 0b01a93..5bbb896 100644 --- a/iopaint/schema.py +++ b/iopaint/schema.py @@ -175,11 +175,17 @@ class InteractiveSegModel(Choices): sam_hq_vit_l = "sam_hq_vit_l" sam_hq_vit_h = "sam_hq_vit_h" mobile_sam = "mobile_sam" + sam2_tiny = "sam2_tiny" sam2_small = "sam2_small" sam2_base = "sam2_base" sam2_large = "sam2_large" + sam2_1_tiny = "sam2_1_tiny" + sam2_1_small = "sam2_1_small" + sam2_1_base = "sam2_1_base" + sam2_1_large = "sam2_1_large" + class PluginInfo(BaseModel): name: str diff --git a/iopaint/web_config.py b/iopaint/web_config.py index 7501fd4..3b3f8d1 100644 --- a/iopaint/web_config.py +++ b/iopaint/web_config.py @@ -48,7 +48,7 @@ default_configs = dict( output_dir=None, quality=95, enable_interactive_seg=False, - interactive_seg_model=InteractiveSegModel.vit_b, + interactive_seg_model=InteractiveSegModel.sam2_1_tiny, interactive_seg_device=Device.cpu, enable_remove_bg=False, remove_bg_model=RemoveBGModel.briaai_rmbg_1_4,