make generate mask from RemoveBG && AnimeSeg work

This commit is contained in:
Qing
2024-01-02 22:32:40 +08:00
parent 6253016019
commit aca85543ca
22 changed files with 244 additions and 100 deletions

View File

@@ -4,6 +4,7 @@ from typing import List
import cv2
import numpy as np
import torch
from loguru import logger
from lama_cleaner.helper import download_model
@@ -34,6 +35,7 @@ SEGMENT_ANYTHING_MODELS = {
class InteractiveSeg(BasePlugin):
name = "InteractiveSeg"
support_gen_mask = True
def __init__(self, model_name, device):
super().__init__()
@@ -47,10 +49,11 @@ class InteractiveSeg(BasePlugin):
)
self.prev_img_md5 = None
def __call__(self, rgb_np_img, req: RunPluginRequest):
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
return self.forward(rgb_np_img, req.clicks, img_md5)
@torch.inference_mode()
def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
input_point = []
input_label = []
@@ -70,13 +73,4 @@ class InteractiveSeg(BasePlugin):
multimask_output=False,
)
mask = masks[0].astype(np.uint8) * 255
# TODO: how to set kernel size?
kernel_size = 9
mask = cv2.dilate(
mask, np.ones((kernel_size, kernel_size), np.uint8), iterations=1
)
# fronted brush color "ffcc00bb"
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
res_mask[mask == 255] = [255, 203, 0, int(255 * 0.73)]
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
return res_mask
return mask