make generate mask from RemoveBG && AnimeSeg work

This commit is contained in:
Qing
2024-01-02 22:32:40 +08:00
parent 6253016019
commit aca85543ca
22 changed files with 244 additions and 100 deletions

View File

@@ -1,11 +1,13 @@
from typing import Dict
from loguru import logger
from .interactive_seg import InteractiveSeg
from .remove_bg import RemoveBG
from .realesrgan import RealESRGANUpscaler
from .gfpgan_plugin import GFPGANPlugin
from .restoreformer import RestoreFormerPlugin
from .anime_seg import AnimeSeg
from .gfpgan_plugin import GFPGANPlugin
from .interactive_seg import InteractiveSeg
from .realesrgan import RealESRGANUpscaler
from .remove_bg import RemoveBG
from .restoreformer import RestoreFormerPlugin
from ..const import InteractiveSegModel, Device, RealESRGANModel
@@ -23,7 +25,7 @@ def build_plugins(
enable_restoreformer: bool,
restoreformer_device: Device,
no_half: bool,
):
) -> Dict:
plugins = {}
if enable_interactive_seg:
logger.info(f"Initialize {InteractiveSeg.name} plugin")

View File

@@ -416,6 +416,8 @@ ANIME_SEG_MODELS = {
class AnimeSeg(BasePlugin):
# Model from: https://github.com/SkyTNT/anime-segmentation
name = "AnimeSeg"
support_gen_image = True
support_gen_mask = True
def __init__(self):
super().__init__()
@@ -426,10 +428,19 @@ class AnimeSeg(BasePlugin):
ANIME_SEG_MODELS["md5"],
)
def __call__(self, rgb_np_img, req: RunPluginRequest):
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
mask = self.forward(rgb_np_img)
mask = Image.fromarray(mask, mode="L")
h0, w0 = rgb_np_img.shape[0], rgb_np_img.shape[1]
empty = Image.new("RGBA", (w0, h0), 0)
img = Image.fromarray(rgb_np_img)
cutout = Image.composite(img, empty, mask)
return np.asarray(cutout)
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
return self.forward(rgb_np_img)
@torch.no_grad()
@torch.inference_mode()
def forward(self, rgb_np_img):
s = 1024
@@ -448,9 +459,4 @@ class AnimeSeg(BasePlugin):
mask = self.model(tmpImg)
mask = mask[0, :, ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]
mask = cv2.resize(mask.cpu().numpy().transpose((1, 2, 0)), (w0, h0))
mask = Image.fromarray((mask * 255).astype("uint8"), mode="L")
empty = Image.new("RGBA", (w0, h0), 0)
img = Image.fromarray(rgb_np_img)
cutout = Image.composite(img, empty, mask)
return np.asarray(cutout)
return (mask * 255).astype("uint8")

View File

@@ -5,15 +5,23 @@ from lama_cleaner.schema import RunPluginRequest
class BasePlugin:
name: str
support_gen_image: bool = False
support_gen_mask: bool = False
def __init__(self):
err_msg = self.check_dep()
if err_msg:
logger.error(err_msg)
exit(-1)
def __call__(self, rgb_np_img, req: RunPluginRequest) -> np.array:
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
# return RGBA np image or BGR np image
...
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
# return GRAY or BGR np image, 255 means foreground, 0 means background
...
def check_dep(self):
...

View File

@@ -1,4 +1,5 @@
import cv2
import numpy as np
from loguru import logger
from lama_cleaner.helper import download_model
@@ -8,6 +9,7 @@ from lama_cleaner.schema import RunPluginRequest
class GFPGANPlugin(BasePlugin):
name = "GFPGAN"
support_gen_image = True
def __init__(self, device, upscaler=None):
super().__init__()
@@ -37,7 +39,7 @@ class GFPGANPlugin(BasePlugin):
self.face_enhancer.face_helper.face_det.to(device)
)
def __call__(self, rgb_np_img, req: RunPluginRequest):
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
weight = 0.5
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
logger.info(f"GFPGAN input shape: {bgr_np_img.shape}")

View File

@@ -4,6 +4,7 @@ from typing import List
import cv2
import numpy as np
import torch
from loguru import logger
from lama_cleaner.helper import download_model
@@ -34,6 +35,7 @@ SEGMENT_ANYTHING_MODELS = {
class InteractiveSeg(BasePlugin):
name = "InteractiveSeg"
support_gen_mask = True
def __init__(self, model_name, device):
super().__init__()
@@ -47,10 +49,11 @@ class InteractiveSeg(BasePlugin):
)
self.prev_img_md5 = None
def __call__(self, rgb_np_img, req: RunPluginRequest):
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
return self.forward(rgb_np_img, req.clicks, img_md5)
@torch.inference_mode()
def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
input_point = []
input_label = []
@@ -70,13 +73,4 @@ class InteractiveSeg(BasePlugin):
multimask_output=False,
)
mask = masks[0].astype(np.uint8) * 255
# TODO: how to set kernel size?
kernel_size = 9
mask = cv2.dilate(
mask, np.ones((kernel_size, kernel_size), np.uint8), iterations=1
)
# fronted brush color "ffcc00bb"
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
res_mask[mask == 255] = [255, 203, 0, int(255 * 0.73)]
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
return res_mask
return mask

View File

@@ -1,6 +1,8 @@
from enum import Enum
import cv2
import numpy as np
import torch
from loguru import logger
from lama_cleaner.const import RealESRGANModel
@@ -11,6 +13,7 @@ from lama_cleaner.schema import RunPluginRequest
class RealESRGANUpscaler(BasePlugin):
name = "RealESRGAN"
support_gen_image = True
def __init__(self, name, device, no_half=False):
super().__init__()
@@ -77,13 +80,14 @@ class RealESRGANUpscaler(BasePlugin):
device=device,
)
def __call__(self, rgb_np_img, req: RunPluginRequest):
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")
result = self.forward(bgr_np_img, req.scale)
logger.info(f"RealESRGAN output shape: {result.shape}")
return result
@torch.inference_mode()
def forward(self, bgr_np_img, scale: float):
# 输出是 BGR
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]

View File

@@ -9,6 +9,8 @@ from lama_cleaner.schema import RunPluginRequest
class RemoveBG(BasePlugin):
name = "RemoveBG"
support_gen_mask = True
support_gen_image = True
def __init__(self):
super().__init__()
@@ -20,17 +22,24 @@ class RemoveBG(BasePlugin):
self.session = new_session(model_name="u2net")
def __call__(self, rgb_np_img, req: RunPluginRequest):
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
return self.forward(bgr_np_img)
def forward(self, bgr_np_img) -> np.ndarray:
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
from rembg import remove
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
# return BGRA image
output = remove(bgr_np_img, session=self.session)
return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
from rembg import remove
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
# return BGR image, 255 means foreground, 0 means background
output = remove(bgr_np_img, session=self.session, only_mask=True)
return output
def check_dep(self):
try:
import rembg

View File

@@ -1,4 +1,5 @@
import cv2
import numpy as np
from loguru import logger
from lama_cleaner.helper import download_model
@@ -8,6 +9,7 @@ from lama_cleaner.schema import RunPluginRequest
class RestoreFormerPlugin(BasePlugin):
name = "RestoreFormer"
support_gen_image = True
def __init__(self, device, upscaler=None):
super().__init__()
@@ -32,7 +34,7 @@ class RestoreFormerPlugin(BasePlugin):
bg_upsampler=upscaler.model if upscaler is not None else None,
)
def __call__(self, rgb_np_img, req: RunPluginRequest):
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
weight = 0.5
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}")