add sam_hq

This commit is contained in:
Qing
2024-02-20 09:03:11 +08:00
parent 6447e821cb
commit b358e6cbce
10 changed files with 1281 additions and 19 deletions

View File

@@ -8,6 +8,7 @@ from loguru import logger
from iopaint.helper import download_model
from iopaint.plugins.base_plugin import BasePlugin
from iopaint.plugins.segment_anything import SamPredictor, sam_model_registry
from iopaint.plugins.segment_anything.predictor_hq import SamHQPredictor
from iopaint.schema import RunPluginRequest
# 从小到大
@@ -28,6 +29,18 @@ SEGMENT_ANYTHING_MODELS = {
"url": "https://github.com/Sanster/models/releases/download/MobileSAM/mobile_sam.pt",
"md5": "f3c0d8cda613564d499310dab6c812cd",
},
"sam_hq_vit_b": {
"url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth",
"md5": "c6b8953247bcfdc8bb8ef91e36a6cacc",
},
"sam_hq_vit_l": {
"url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth",
"md5": "08947267966e4264fb39523eccc33f86",
},
"sam_hq_vit_h": {
"url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth",
"md5": "3560f6b6a5a6edacd814a1325c39640a",
},
}
@@ -47,9 +60,14 @@ class InteractiveSeg(BasePlugin):
SEGMENT_ANYTHING_MODELS[model_name]["md5"],
)
logger.info(f"SegmentAnything model path: {model_path}")
self.predictor = SamPredictor(
sam_model_registry[model_name](checkpoint=model_path).to(self.device)
)
if "sam_hq" in model_name:
self.predictor = SamHQPredictor(
sam_model_registry[model_name](checkpoint=model_path).to(self.device)
)
else:
self.predictor = SamPredictor(
sam_model_registry[model_name](checkpoint=model_path).to(self.device)
)
self.prev_img_md5 = None
def switch_model(self, new_model_name):