add switch interactiveSegModel

This commit is contained in:
Qing
2024-02-10 12:34:56 +08:00
parent 9aa5a7e0ba
commit ec2db92ad9
5 changed files with 103 additions and 42 deletions

View File

@@ -37,16 +37,31 @@ class InteractiveSeg(BasePlugin):
def __init__(self, model_name, device):
super().__init__()
self.model_name = model_name
self.device = device
self._init_session(model_name)
def _init_session(self, model_name: str):
model_path = download_model(
SEGMENT_ANYTHING_MODELS[model_name]["url"],
SEGMENT_ANYTHING_MODELS[model_name]["md5"],
)
logger.info(f"SegmentAnything model path: {model_path}")
self.predictor = SamPredictor(
sam_model_registry[model_name](checkpoint=model_path).to(device)
sam_model_registry[model_name](checkpoint=model_path).to(self.device)
)
self.prev_img_md5 = None
def switch_model(self, new_model_name):
if self.model_name == new_model_name:
return
logger.info(
f"Switching InteractiveSeg model from {self.model_name} to {new_model_name}"
)
self._init_session(new_model_name)
self.model_name = new_model_name
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
return self.forward(rgb_np_img, req.clicks, img_md5)