add sam2.1

This commit is contained in:
Qing
2024-11-01 18:45:32 +08:00
parent 2098d687d1
commit ec08ce3063
11 changed files with 565 additions and 184 deletions

View File

@@ -247,7 +247,7 @@ class MaskDecoder(nn.Module):
def _get_stability_scores(self, mask_logits):
"""
Compute stability scores of the mask logits based on the IoU between upper and
lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568.
lower thresholds.
"""
mask_logits = mask_logits.flatten(-2)
stability_delta = self.dynamic_multimask_stability_delta