add sam_hq
This commit is contained in:
@@ -8,6 +8,7 @@ from loguru import logger
|
||||
from iopaint.helper import download_model
|
||||
from iopaint.plugins.base_plugin import BasePlugin
|
||||
from iopaint.plugins.segment_anything import SamPredictor, sam_model_registry
|
||||
from iopaint.plugins.segment_anything.predictor_hq import SamHQPredictor
|
||||
from iopaint.schema import RunPluginRequest
|
||||
|
||||
# 从小到大
|
||||
@@ -28,6 +29,18 @@ SEGMENT_ANYTHING_MODELS = {
|
||||
"url": "https://github.com/Sanster/models/releases/download/MobileSAM/mobile_sam.pt",
|
||||
"md5": "f3c0d8cda613564d499310dab6c812cd",
|
||||
},
|
||||
"sam_hq_vit_b": {
|
||||
"url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth",
|
||||
"md5": "c6b8953247bcfdc8bb8ef91e36a6cacc",
|
||||
},
|
||||
"sam_hq_vit_l": {
|
||||
"url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth",
|
||||
"md5": "08947267966e4264fb39523eccc33f86",
|
||||
},
|
||||
"sam_hq_vit_h": {
|
||||
"url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth",
|
||||
"md5": "3560f6b6a5a6edacd814a1325c39640a",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -47,9 +60,14 @@ class InteractiveSeg(BasePlugin):
|
||||
SEGMENT_ANYTHING_MODELS[model_name]["md5"],
|
||||
)
|
||||
logger.info(f"SegmentAnything model path: {model_path}")
|
||||
self.predictor = SamPredictor(
|
||||
sam_model_registry[model_name](checkpoint=model_path).to(self.device)
|
||||
)
|
||||
if "sam_hq" in model_name:
|
||||
self.predictor = SamHQPredictor(
|
||||
sam_model_registry[model_name](checkpoint=model_path).to(self.device)
|
||||
)
|
||||
else:
|
||||
self.predictor = SamPredictor(
|
||||
sam_model_registry[model_name](checkpoint=model_path).to(self.device)
|
||||
)
|
||||
self.prev_img_md5 = None
|
||||
|
||||
def switch_model(self, new_model_name):
|
||||
|
||||
Reference in New Issue
Block a user