add plugin dep check

This commit is contained in:
Qing
2023-03-26 12:37:58 +08:00
parent 1433d21b9f
commit d938f2da3c
7 changed files with 48 additions and 14 deletions

View File

@@ -1,3 +1,4 @@
from .interactive_seg import InteractiveSeg, Click
from .remove_bg import RemoveBG
from .realesrgan import RealESRGANUpscaler
from .gif import MakeGIF

View File

@@ -0,0 +1,15 @@
from loguru import logger
class BasePlugin:
def __init__(self):
err_msg = self.check_dep()
if err_msg:
logger.error(err_msg)
exit(-1)
def __call__(self, rgb_np_img, files, form):
...
def check_dep(self):
...

View File

@@ -4,6 +4,7 @@ import math
from PIL import Image, ImageDraw
from lama_cleaner.helper import load_img
from lama_cleaner.plugins.base_plugin import BasePlugin
def keep_ratio_resize(img, size, resample=Image.BILINEAR):
@@ -117,7 +118,7 @@ def make_compare_gif(
[(right, 0), (right, height)], width=splitter_width, fill=splitter_color
)
images.append(new_frame)
for _ in range(30):
images.append(clean_img)
@@ -135,7 +136,7 @@ def make_compare_gif(
return img_byte_arr.getvalue()
class MakeGIF:
class MakeGIF(BasePlugin):
name = "MakeGIF"
def __call__(self, rgb_np_img, files, form):

View File

@@ -1,5 +1,4 @@
import json
import json
import os
from typing import Tuple, List
@@ -14,6 +13,7 @@ from lama_cleaner.helper import (
load_jit_model,
load_img,
)
from lama_cleaner.plugins.base_plugin import BasePlugin
class Click(BaseModel):
@@ -195,10 +195,11 @@ INTERACTIVE_SEG_MODEL_MD5 = os.environ.get(
)
class InteractiveSeg:
class InteractiveSeg(BasePlugin):
name = "InteractiveSeg"
def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3):
super().__init__()
device = torch.device("cpu")
model = load_jit_model(
INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5

View File

@@ -4,6 +4,7 @@ import cv2
from loguru import logger
from lama_cleaner.helper import download_model
from lama_cleaner.plugins.base_plugin import BasePlugin
class RealESRGANModelName(str, Enum):
@@ -15,7 +16,7 @@ class RealESRGANModelName(str, Enum):
RealESRGANModelNameList = [e.value for e in RealESRGANModelName]
class RealESRGANUpscaler:
class RealESRGANUpscaler(BasePlugin):
name = "RealESRGAN"
def __init__(self, name, device):
@@ -84,7 +85,7 @@ class RealESRGANUpscaler:
def __call__(self, rgb_np_img, files, form):
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
scale = float(form['upscale'])
scale = float(form["upscale"])
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {scale}")
result = self.forward(bgr_np_img, scale)
logger.info(f"RealESRGAN output shape: {result.shape}")
@@ -94,3 +95,9 @@ class RealESRGANUpscaler:
# 输出是 BGR
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
return upsampled
def check_dep(self):
try:
import realesrgan
except ImportError:
return "RealESRGAN is not installed, please install it first. pip install realesrgan"

View File

@@ -1,11 +1,14 @@
import cv2
import numpy as np
from lama_cleaner.plugins.base_plugin import BasePlugin
class RemoveBG:
class RemoveBG(BasePlugin):
name = "RemoveBG"
def __init__(self):
super().__init__()
from rembg import new_session
self.session = new_session(model_name="u2net")
@@ -20,3 +23,11 @@ class RemoveBG:
# return BGRA image
output = remove(bgr_np_img, session=self.session)
return output
def check_dep(self):
try:
import rembg
except ImportError:
return (
"RemoveBG is not installed, please install it first. pip install rembg"
)