add sam_hq
This commit is contained in:
@@ -17,6 +17,9 @@ from .modeling import (
|
||||
Sam,
|
||||
TwoWayTransformer,
|
||||
)
|
||||
from .modeling.image_encoder_hq import ImageEncoderViTHQ
|
||||
from .modeling.mask_decoder import MaskDecoderHQ
|
||||
from .modeling.sam_hq import SamHQ
|
||||
|
||||
|
||||
def build_sam_vit_h(checkpoint=None):
|
||||
@@ -29,9 +32,6 @@ def build_sam_vit_h(checkpoint=None):
|
||||
)
|
||||
|
||||
|
||||
build_sam = build_sam_vit_h
|
||||
|
||||
|
||||
def build_sam_vit_l(checkpoint=None):
|
||||
return _build_sam(
|
||||
encoder_embed_dim=1024,
|
||||
@@ -104,11 +104,44 @@ def build_sam_vit_t(checkpoint=None):
|
||||
return mobile_sam
|
||||
|
||||
|
||||
def build_sam_vit_h_hq(checkpoint=None):
|
||||
return _build_sam_hq(
|
||||
encoder_embed_dim=1280,
|
||||
encoder_depth=32,
|
||||
encoder_num_heads=16,
|
||||
encoder_global_attn_indexes=[7, 15, 23, 31],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam_vit_l_hq(checkpoint=None):
|
||||
return _build_sam_hq(
|
||||
encoder_embed_dim=1024,
|
||||
encoder_depth=24,
|
||||
encoder_num_heads=16,
|
||||
encoder_global_attn_indexes=[5, 11, 17, 23],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam_vit_b_hq(checkpoint=None):
|
||||
return _build_sam_hq(
|
||||
encoder_embed_dim=768,
|
||||
encoder_depth=12,
|
||||
encoder_num_heads=12,
|
||||
encoder_global_attn_indexes=[2, 5, 8, 11],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
sam_model_registry = {
|
||||
"default": build_sam,
|
||||
"vit_h": build_sam,
|
||||
"default": build_sam_vit_h,
|
||||
"vit_h": build_sam_vit_h,
|
||||
"vit_l": build_sam_vit_l,
|
||||
"vit_b": build_sam_vit_b,
|
||||
"sam_hq_vit_h": build_sam_vit_h_hq,
|
||||
"sam_hq_vit_l": build_sam_vit_l_hq,
|
||||
"sam_hq_vit_b": build_sam_vit_b_hq,
|
||||
"mobile_sam": build_sam_vit_t,
|
||||
}
|
||||
|
||||
@@ -166,3 +199,71 @@ def _build_sam(
|
||||
state_dict = torch.load(f)
|
||||
sam.load_state_dict(state_dict)
|
||||
return sam
|
||||
|
||||
|
||||
def _build_sam_hq(
|
||||
encoder_embed_dim,
|
||||
encoder_depth,
|
||||
encoder_num_heads,
|
||||
encoder_global_attn_indexes,
|
||||
checkpoint=None,
|
||||
):
|
||||
prompt_embed_dim = 256
|
||||
image_size = 1024
|
||||
vit_patch_size = 16
|
||||
image_embedding_size = image_size // vit_patch_size
|
||||
sam = SamHQ(
|
||||
image_encoder=ImageEncoderViTHQ(
|
||||
depth=encoder_depth,
|
||||
embed_dim=encoder_embed_dim,
|
||||
img_size=image_size,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
||||
num_heads=encoder_num_heads,
|
||||
patch_size=vit_patch_size,
|
||||
qkv_bias=True,
|
||||
use_rel_pos=True,
|
||||
global_attn_indexes=encoder_global_attn_indexes,
|
||||
window_size=14,
|
||||
out_chans=prompt_embed_dim,
|
||||
),
|
||||
prompt_encoder=PromptEncoder(
|
||||
embed_dim=prompt_embed_dim,
|
||||
image_embedding_size=(image_embedding_size, image_embedding_size),
|
||||
input_image_size=(image_size, image_size),
|
||||
mask_in_chans=16,
|
||||
),
|
||||
mask_decoder=MaskDecoderHQ(
|
||||
num_multimask_outputs=3,
|
||||
transformer=TwoWayTransformer(
|
||||
depth=2,
|
||||
embedding_dim=prompt_embed_dim,
|
||||
mlp_dim=2048,
|
||||
num_heads=8,
|
||||
),
|
||||
transformer_dim=prompt_embed_dim,
|
||||
iou_head_depth=3,
|
||||
iou_head_hidden_dim=256,
|
||||
vit_dim=encoder_embed_dim,
|
||||
),
|
||||
pixel_mean=[123.675, 116.28, 103.53],
|
||||
pixel_std=[58.395, 57.12, 57.375],
|
||||
)
|
||||
sam.eval()
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
state_dict = torch.load(f, map_location=device)
|
||||
info = sam.load_state_dict(state_dict, strict=False)
|
||||
print(info)
|
||||
for n, p in sam.named_parameters():
|
||||
if (
|
||||
"hf_token" not in n
|
||||
and "hf_mlp" not in n
|
||||
and "compress_vit_feat" not in n
|
||||
and "embedding_encoder" not in n
|
||||
and "embedding_maskfeature" not in n
|
||||
):
|
||||
p.requires_grad = False
|
||||
|
||||
return sam
|
||||
|
||||
Reference in New Issue
Block a user