update plugins
This commit is contained in:
@@ -7,6 +7,7 @@ from PIL import Image
|
||||
|
||||
from lama_cleaner.helper import load_model
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
|
||||
class REBNCONV(nn.Module):
|
||||
@@ -425,7 +426,7 @@ class AnimeSeg(BasePlugin):
|
||||
ANIME_SEG_MODELS["md5"],
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
return self.forward(rgb_np_img)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from loguru import logger
|
||||
import numpy as np
|
||||
|
||||
from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
@@ -8,7 +11,8 @@ class BasePlugin:
|
||||
logger.error(err_msg)
|
||||
exit(-1)
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest) -> np.array:
|
||||
# return RGBA np image or BGR np image
|
||||
...
|
||||
|
||||
def check_dep(self):
|
||||
|
||||
@@ -3,6 +3,7 @@ from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import download_model
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
|
||||
class GFPGANPlugin(BasePlugin):
|
||||
@@ -36,7 +37,7 @@ class GFPGANPlugin(BasePlugin):
|
||||
self.face_enhancer.face_helper.face_det.to(device)
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
weight = 0.5
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
logger.info(f"GFPGAN input shape: {bgr_np_img.shape}")
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import hashlib
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -7,6 +9,7 @@ from loguru import logger
|
||||
from lama_cleaner.helper import download_model
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
from lama_cleaner.plugins.segment_anything import SamPredictor, sam_model_registry
|
||||
from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
# 从小到大
|
||||
SEGMENT_ANYTHING_MODELS = {
|
||||
@@ -44,11 +47,11 @@ class InteractiveSeg(BasePlugin):
|
||||
)
|
||||
self.prev_img_md5 = None
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
clicks = json.loads(form["clicks"])
|
||||
return self.forward(rgb_np_img, clicks, form["img_md5"])
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
|
||||
return self.forward(rgb_np_img, req.clicks, img_md5)
|
||||
|
||||
def forward(self, rgb_np_img, clicks, img_md5):
|
||||
def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
|
||||
input_point = []
|
||||
input_label = []
|
||||
for click in clicks:
|
||||
|
||||
@@ -6,6 +6,7 @@ from loguru import logger
|
||||
from lama_cleaner.const import RealESRGANModel
|
||||
from lama_cleaner.helper import download_model
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
|
||||
class RealESRGANUpscaler(BasePlugin):
|
||||
@@ -76,11 +77,10 @@ class RealESRGANUpscaler(BasePlugin):
|
||||
device=device,
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
scale = float(form["upscale"])
|
||||
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {scale}")
|
||||
result = self.forward(bgr_np_img, scale)
|
||||
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")
|
||||
result = self.forward(bgr_np_img, req.scale)
|
||||
logger.info(f"RealESRGAN output shape: {result.shape}")
|
||||
return result
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import numpy as np
|
||||
from torch.hub import get_dir
|
||||
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
|
||||
class RemoveBG(BasePlugin):
|
||||
@@ -19,7 +20,7 @@ class RemoveBG(BasePlugin):
|
||||
|
||||
self.session = new_session(model_name="u2net")
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
return self.forward(bgr_np_img)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import download_model
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
|
||||
class RestoreFormerPlugin(BasePlugin):
|
||||
@@ -31,7 +32,7 @@ class RestoreFormerPlugin(BasePlugin):
|
||||
bg_upsampler=upscaler.model if upscaler is not None else None,
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
weight = 0.5
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}")
|
||||
|
||||
Reference in New Issue
Block a user