add plugin dep check
This commit is contained in:
@@ -4,6 +4,7 @@ import cv2
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import download_model
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
|
||||
|
||||
class RealESRGANModelName(str, Enum):
|
||||
@@ -15,7 +16,7 @@ class RealESRGANModelName(str, Enum):
|
||||
RealESRGANModelNameList = [e.value for e in RealESRGANModelName]
|
||||
|
||||
|
||||
class RealESRGANUpscaler:
|
||||
class RealESRGANUpscaler(BasePlugin):
|
||||
name = "RealESRGAN"
|
||||
|
||||
def __init__(self, name, device):
|
||||
@@ -84,7 +85,7 @@ class RealESRGANUpscaler:
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
scale = float(form['upscale'])
|
||||
scale = float(form["upscale"])
|
||||
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {scale}")
|
||||
result = self.forward(bgr_np_img, scale)
|
||||
logger.info(f"RealESRGAN output shape: {result.shape}")
|
||||
@@ -94,3 +95,9 @@ class RealESRGANUpscaler:
|
||||
# 输出是 BGR
|
||||
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
|
||||
return upsampled
|
||||
|
||||
def check_dep(self):
|
||||
try:
|
||||
import realesrgan
|
||||
except ImportError:
|
||||
return "RealESRGAN is not installed, please install it first. pip install realesrgan"
|
||||
|
||||
Reference in New Issue
Block a user