make generate mask from RemoveBG && AnimeSeg work
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import download_model
|
||||
@@ -34,6 +35,7 @@ SEGMENT_ANYTHING_MODELS = {
|
||||
|
||||
class InteractiveSeg(BasePlugin):
|
||||
name = "InteractiveSeg"
|
||||
support_gen_mask = True
|
||||
|
||||
def __init__(self, model_name, device):
|
||||
super().__init__()
|
||||
@@ -47,10 +49,11 @@ class InteractiveSeg(BasePlugin):
|
||||
)
|
||||
self.prev_img_md5 = None
|
||||
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
|
||||
return self.forward(rgb_np_img, req.clicks, img_md5)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
|
||||
input_point = []
|
||||
input_label = []
|
||||
@@ -70,13 +73,4 @@ class InteractiveSeg(BasePlugin):
|
||||
multimask_output=False,
|
||||
)
|
||||
mask = masks[0].astype(np.uint8) * 255
|
||||
# TODO: how to set kernel size?
|
||||
kernel_size = 9
|
||||
mask = cv2.dilate(
|
||||
mask, np.ones((kernel_size, kernel_size), np.uint8), iterations=1
|
||||
)
|
||||
# fronted brush color "ffcc00bb"
|
||||
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
|
||||
res_mask[mask == 255] = [255, 203, 0, int(255 * 0.73)]
|
||||
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
|
||||
return res_mask
|
||||
return mask
|
||||
|
||||
Reference in New Issue
Block a user