This commit is contained in:
Qing
2024-08-12 10:10:24 +08:00
parent 9afdbd1c0a
commit 2f833029aa
23 changed files with 3801 additions and 3 deletions

View File

@@ -9,6 +9,8 @@ from iopaint.helper import download_model
from iopaint.plugins.base_plugin import BasePlugin
from iopaint.plugins.segment_anything import SamPredictor, sam_model_registry
from iopaint.plugins.segment_anything.predictor_hq import SamHQPredictor
from iopaint.plugins.segment_anything2.build_sam import build_sam2
from iopaint.plugins.segment_anything2.sam2_image_predictor import SAM2ImagePredictor
from iopaint.schema import RunPluginRequest
# 从小到大
@@ -41,6 +43,22 @@ SEGMENT_ANYTHING_MODELS = {
"url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth",
"md5": "3560f6b6a5a6edacd814a1325c39640a",
},
"sam2_tiny": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
"md5": "99eacccce4ada0b35153d4fd7af05297",
},
"sam2_small": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
"md5": "7f320dbeb497330a2472da5a16c7324d",
},
"sam2_base": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
"md5": "09dc5a3d7719f64aaea1d37341ef26f2",
},
"sam2_large": {
"url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
"md5": "08083462423be3260cd6a5eef94dc01c",
},
}
@@ -64,6 +82,11 @@ class InteractiveSeg(BasePlugin):
self.predictor = SamHQPredictor(
sam_model_registry[model_name](checkpoint=model_path).to(self.device)
)
elif model_name.startswith("sam2"):
sam2_model = build_sam2(
model_name, ckpt_path=model_path, device=self.device
)
self.predictor = SAM2ImagePredictor(sam2_model)
else:
self.predictor = SamPredictor(
sam_model_registry[model_name](checkpoint=model_path).to(self.device)
@@ -98,7 +121,7 @@ class InteractiveSeg(BasePlugin):
self.prev_img_md5 = img_md5
self.predictor.set_image(rgb_np_img)
masks, scores, _ = self.predictor.predict(
masks, _, _ = self.predictor.predict(
point_coords=np.array(input_point),
point_labels=np.array(input_label),
multimask_output=False,