make generate mask from RemoveBG && AnimeSeg work
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from enum import Enum
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.const import RealESRGANModel
|
||||
@@ -11,6 +13,7 @@ from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
class RealESRGANUpscaler(BasePlugin):
|
||||
name = "RealESRGAN"
|
||||
support_gen_image = True
|
||||
|
||||
def __init__(self, name, device, no_half=False):
|
||||
super().__init__()
|
||||
@@ -77,13 +80,14 @@ class RealESRGANUpscaler(BasePlugin):
|
||||
device=device,
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")
|
||||
result = self.forward(bgr_np_img, req.scale)
|
||||
logger.info(f"RealESRGAN output shape: {result.shape}")
|
||||
return result
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, bgr_np_img, scale: float):
|
||||
# 输出是 BGR
|
||||
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
|
||||
|
||||
Reference in New Issue
Block a user