add switch interactiveSegModel
This commit is contained in:
@@ -37,16 +37,31 @@ class InteractiveSeg(BasePlugin):
|
||||
|
||||
def __init__(self, model_name, device):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self._init_session(model_name)
|
||||
|
||||
def _init_session(self, model_name: str):
|
||||
model_path = download_model(
|
||||
SEGMENT_ANYTHING_MODELS[model_name]["url"],
|
||||
SEGMENT_ANYTHING_MODELS[model_name]["md5"],
|
||||
)
|
||||
logger.info(f"SegmentAnything model path: {model_path}")
|
||||
self.predictor = SamPredictor(
|
||||
sam_model_registry[model_name](checkpoint=model_path).to(device)
|
||||
sam_model_registry[model_name](checkpoint=model_path).to(self.device)
|
||||
)
|
||||
self.prev_img_md5 = None
|
||||
|
||||
def switch_model(self, new_model_name):
|
||||
if self.model_name == new_model_name:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Switching InteractiveSeg model from {self.model_name} to {new_model_name}"
|
||||
)
|
||||
self._init_session(new_model_name)
|
||||
self.model_name = new_model_name
|
||||
|
||||
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
|
||||
return self.forward(rgb_np_img, req.clicks, img_md5)
|
||||
|
||||
Reference in New Issue
Block a user