make generate mask from RemoveBG && AnimeSeg work
This commit is contained in:
@@ -416,6 +416,8 @@ ANIME_SEG_MODELS = {
|
||||
class AnimeSeg(BasePlugin):
|
||||
# Model from: https://github.com/SkyTNT/anime-segmentation
|
||||
name = "AnimeSeg"
|
||||
support_gen_image = True
|
||||
support_gen_mask = True
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -426,10 +428,19 @@ class AnimeSeg(BasePlugin):
|
||||
ANIME_SEG_MODELS["md5"],
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
mask = self.forward(rgb_np_img)
|
||||
mask = Image.fromarray(mask, mode="L")
|
||||
h0, w0 = rgb_np_img.shape[0], rgb_np_img.shape[1]
|
||||
empty = Image.new("RGBA", (w0, h0), 0)
|
||||
img = Image.fromarray(rgb_np_img)
|
||||
cutout = Image.composite(img, empty, mask)
|
||||
return np.asarray(cutout)
|
||||
|
||||
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
return self.forward(rgb_np_img)
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def forward(self, rgb_np_img):
|
||||
s = 1024
|
||||
|
||||
@@ -448,9 +459,4 @@ class AnimeSeg(BasePlugin):
|
||||
mask = self.model(tmpImg)
|
||||
mask = mask[0, :, ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]
|
||||
mask = cv2.resize(mask.cpu().numpy().transpose((1, 2, 0)), (w0, h0))
|
||||
mask = Image.fromarray((mask * 255).astype("uint8"), mode="L")
|
||||
|
||||
empty = Image.new("RGBA", (w0, h0), 0)
|
||||
img = Image.fromarray(rgb_np_img)
|
||||
cutout = Image.composite(img, empty, mask)
|
||||
return np.asarray(cutout)
|
||||
return (mask * 255).astype("uint8")
|
||||
|
||||
Reference in New Issue
Block a user