add sam2
This commit is contained in:
@@ -9,6 +9,8 @@ from iopaint.helper import download_model
|
||||
from iopaint.plugins.base_plugin import BasePlugin
|
||||
from iopaint.plugins.segment_anything import SamPredictor, sam_model_registry
|
||||
from iopaint.plugins.segment_anything.predictor_hq import SamHQPredictor
|
||||
from iopaint.plugins.segment_anything2.build_sam import build_sam2
|
||||
from iopaint.plugins.segment_anything2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from iopaint.schema import RunPluginRequest
|
||||
|
||||
# 从小到大
|
||||
@@ -41,6 +43,22 @@ SEGMENT_ANYTHING_MODELS = {
|
||||
"url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth",
|
||||
"md5": "3560f6b6a5a6edacd814a1325c39640a",
|
||||
},
|
||||
"sam2_tiny": {
|
||||
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
|
||||
"md5": "99eacccce4ada0b35153d4fd7af05297",
|
||||
},
|
||||
"sam2_small": {
|
||||
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
|
||||
"md5": "7f320dbeb497330a2472da5a16c7324d",
|
||||
},
|
||||
"sam2_base": {
|
||||
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
|
||||
"md5": "09dc5a3d7719f64aaea1d37341ef26f2",
|
||||
},
|
||||
"sam2_large": {
|
||||
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
|
||||
"md5": "08083462423be3260cd6a5eef94dc01c",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -64,6 +82,11 @@ class InteractiveSeg(BasePlugin):
|
||||
self.predictor = SamHQPredictor(
|
||||
sam_model_registry[model_name](checkpoint=model_path).to(self.device)
|
||||
)
|
||||
elif model_name.startswith("sam2"):
|
||||
sam2_model = build_sam2(
|
||||
model_name, ckpt_path=model_path, device=self.device
|
||||
)
|
||||
self.predictor = SAM2ImagePredictor(sam2_model)
|
||||
else:
|
||||
self.predictor = SamPredictor(
|
||||
sam_model_registry[model_name](checkpoint=model_path).to(self.device)
|
||||
@@ -98,7 +121,7 @@ class InteractiveSeg(BasePlugin):
|
||||
self.prev_img_md5 = img_md5
|
||||
self.predictor.set_image(rgb_np_img)
|
||||
|
||||
masks, scores, _ = self.predictor.predict(
|
||||
masks, _, _ = self.predictor.predict(
|
||||
point_coords=np.array(input_point),
|
||||
point_labels=np.array(input_label),
|
||||
multimask_output=False,
|
||||
|
||||
Reference in New Issue
Block a user