rename to iopaint
This commit is contained in:
76
iopaint/plugins/interactive_seg.py
Normal file
76
iopaint/plugins/interactive_seg.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import hashlib
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from iopaint.helper import download_model
|
||||
from iopaint.plugins.base_plugin import BasePlugin
|
||||
from iopaint.plugins.segment_anything import SamPredictor, sam_model_registry
|
||||
from iopaint.schema import RunPluginRequest
|
||||
|
||||
# 从小到大
|
||||
SEGMENT_ANYTHING_MODELS = {
|
||||
"vit_b": {
|
||||
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
|
||||
"md5": "01ec64d29a2fca3f0661936605ae66f8",
|
||||
},
|
||||
"vit_l": {
|
||||
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
||||
"md5": "0b3195507c641ddb6910d2bb5adee89c",
|
||||
},
|
||||
"vit_h": {
|
||||
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
||||
"md5": "4b8939a88964f0f4ff5f5b2642c598a6",
|
||||
},
|
||||
"mobile_sam": {
|
||||
"url": "https://github.com/Sanster/models/releases/download/MobileSAM/mobile_sam.pt",
|
||||
"md5": "f3c0d8cda613564d499310dab6c812cd",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class InteractiveSeg(BasePlugin):
|
||||
name = "InteractiveSeg"
|
||||
support_gen_mask = True
|
||||
|
||||
def __init__(self, model_name, device):
|
||||
super().__init__()
|
||||
model_path = download_model(
|
||||
SEGMENT_ANYTHING_MODELS[model_name]["url"],
|
||||
SEGMENT_ANYTHING_MODELS[model_name]["md5"],
|
||||
)
|
||||
logger.info(f"SegmentAnything model path: {model_path}")
|
||||
self.predictor = SamPredictor(
|
||||
sam_model_registry[model_name](checkpoint=model_path).to(device)
|
||||
)
|
||||
self.prev_img_md5 = None
|
||||
|
||||
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
|
||||
return self.forward(rgb_np_img, req.clicks, img_md5)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
|
||||
input_point = []
|
||||
input_label = []
|
||||
for click in clicks:
|
||||
x = click[0]
|
||||
y = click[1]
|
||||
input_point.append([x, y])
|
||||
input_label.append(click[2])
|
||||
|
||||
if img_md5 and img_md5 != self.prev_img_md5:
|
||||
self.prev_img_md5 = img_md5
|
||||
self.predictor.set_image(rgb_np_img)
|
||||
|
||||
masks, scores, _ = self.predictor.predict(
|
||||
point_coords=np.array(input_point),
|
||||
point_labels=np.array(input_label),
|
||||
multimask_output=False,
|
||||
)
|
||||
mask = masks[0].astype(np.uint8) * 255
|
||||
return mask
|
||||
Reference in New Issue
Block a user