make generate mask from RemoveBG && AnimeSeg work
This commit is contained in:
@@ -28,11 +28,13 @@ from lama_cleaner.helper import (
|
||||
pil_to_bytes,
|
||||
numpy_to_bytes,
|
||||
concat_alpha_channel,
|
||||
gen_frontend_mask,
|
||||
)
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.model_info import ModelInfo
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.plugins import build_plugins, InteractiveSeg, RemoveBG, AnimeSeg
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
from lama_cleaner.schema import (
|
||||
GenInfoResponse,
|
||||
ApiConfig,
|
||||
@@ -41,6 +43,7 @@ from lama_cleaner.schema import (
|
||||
InpaintRequest,
|
||||
RunPluginRequest,
|
||||
SDSampler,
|
||||
PluginInfo,
|
||||
)
|
||||
from lama_cleaner.file_manager import FileManager
|
||||
|
||||
@@ -145,7 +148,8 @@ class Api:
|
||||
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
||||
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
||||
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
||||
self.add_api_route("/api/v1/run_plugin", self.api_run_plugin, methods=["POST"])
|
||||
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
||||
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
||||
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
||||
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
||||
# fmt: on
|
||||
@@ -173,7 +177,14 @@ class Api:
|
||||
|
||||
def api_server_config(self) -> ServerConfigResponse:
|
||||
return ServerConfigResponse(
|
||||
plugins=list(self.plugins.keys()),
|
||||
plugins=[
|
||||
PluginInfo(
|
||||
name=it.name,
|
||||
support_gen_image=it.support_gen_image,
|
||||
support_gen_mask=it.support_gen_mask,
|
||||
)
|
||||
for it in self.plugins.values()
|
||||
],
|
||||
enableFileManager=self.file_manager is not None,
|
||||
enableAutoSaving=self.config.output_dir is not None,
|
||||
enableControlnet=self.model_manager.enable_controlnet,
|
||||
@@ -237,22 +248,22 @@ class Api:
|
||||
headers={"X-Seed": str(req.sd_seed)},
|
||||
)
|
||||
|
||||
def api_run_plugin(self, req: RunPluginRequest):
|
||||
def api_run_plugin_gen_image(self, req: RunPluginRequest):
|
||||
ext = "png"
|
||||
if req.name not in self.plugins:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||
bgr_np_img = self.plugins[req.name](rgb_np_img, req)
|
||||
torch_gc()
|
||||
if req.name == InteractiveSeg.name:
|
||||
return Response(
|
||||
content=numpy_to_bytes(bgr_np_img, ext),
|
||||
media_type=f"image/{ext}",
|
||||
raise HTTPException(status_code=422, detail="Plugin not found")
|
||||
if not self.plugins[req.name].support_gen_image:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Plugin does not support output image"
|
||||
)
|
||||
if bgr_np_img.shape[2] == 4:
|
||||
rgba_np_img = bgr_np_img
|
||||
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||
bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
|
||||
torch_gc()
|
||||
|
||||
if bgr_or_rgba_np_img.shape[2] == 4:
|
||||
rgba_np_img = bgr_or_rgba_np_img
|
||||
else:
|
||||
rgba_np_img = cv2.cvtColor(bgr_np_img, cv2.COLOR_BGR2RGB)
|
||||
rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB)
|
||||
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
|
||||
|
||||
return Response(
|
||||
@@ -265,6 +276,22 @@ class Api:
|
||||
media_type=f"image/{ext}",
|
||||
)
|
||||
|
||||
def api_run_plugin_gen_mask(self, req: RunPluginRequest):
|
||||
if req.name not in self.plugins:
|
||||
raise HTTPException(status_code=422, detail="Plugin not found")
|
||||
if not self.plugins[req.name].support_gen_mask:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Plugin does not support output image"
|
||||
)
|
||||
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
|
||||
torch_gc()
|
||||
res_mask = gen_frontend_mask(bgr_or_gray_mask)
|
||||
return Response(
|
||||
content=numpy_to_bytes(res_mask, "png"),
|
||||
media_type="image/png",
|
||||
)
|
||||
|
||||
def api_samplers(self) -> List[str]:
|
||||
return [member.value for member in SDSampler.__members__.values()]
|
||||
|
||||
@@ -290,7 +317,7 @@ class Api:
|
||||
)
|
||||
return None
|
||||
|
||||
def _build_plugins(self) -> Dict:
|
||||
def _build_plugins(self) -> Dict[str, BasePlugin]:
|
||||
return build_plugins(
|
||||
self.config.enable_interactive_seg,
|
||||
self.config.interactive_seg_model,
|
||||
|
||||
@@ -350,3 +350,23 @@ def concat_alpha_channel(rgb_np_img, alpha_channel) -> np.ndarray:
|
||||
(rgb_np_img, alpha_channel[:, :, np.newaxis]), axis=-1
|
||||
)
|
||||
return rgb_np_img
|
||||
|
||||
|
||||
def gen_frontend_mask(bgr_or_gray_mask):
|
||||
if len(bgr_or_gray_mask.shape) == 3 and bgr_or_gray_mask.shape[2] != 1:
|
||||
bgr_or_gray_mask = cv2.cvtColor(bgr_or_gray_mask, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# fronted brush color "ffcc00bb"
|
||||
# TODO: how to set kernel size?
|
||||
kernel_size = 9
|
||||
bgr_or_gray_mask = cv2.dilate(
|
||||
bgr_or_gray_mask,
|
||||
np.ones((kernel_size, kernel_size), np.uint8),
|
||||
iterations=1,
|
||||
)
|
||||
res_mask = np.zeros(
|
||||
(bgr_or_gray_mask.shape[0], bgr_or_gray_mask.shape[1], 4), dtype=np.uint8
|
||||
)
|
||||
res_mask[bgr_or_gray_mask > 128] = [255, 203, 0, int(255 * 0.73)]
|
||||
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
|
||||
return res_mask
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from pydantic import computed_field, BaseModel
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from typing import Dict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .interactive_seg import InteractiveSeg
|
||||
from .remove_bg import RemoveBG
|
||||
from .realesrgan import RealESRGANUpscaler
|
||||
from .gfpgan_plugin import GFPGANPlugin
|
||||
from .restoreformer import RestoreFormerPlugin
|
||||
from .anime_seg import AnimeSeg
|
||||
from .gfpgan_plugin import GFPGANPlugin
|
||||
from .interactive_seg import InteractiveSeg
|
||||
from .realesrgan import RealESRGANUpscaler
|
||||
from .remove_bg import RemoveBG
|
||||
from .restoreformer import RestoreFormerPlugin
|
||||
from ..const import InteractiveSegModel, Device, RealESRGANModel
|
||||
|
||||
|
||||
@@ -23,7 +25,7 @@ def build_plugins(
|
||||
enable_restoreformer: bool,
|
||||
restoreformer_device: Device,
|
||||
no_half: bool,
|
||||
):
|
||||
) -> Dict:
|
||||
plugins = {}
|
||||
if enable_interactive_seg:
|
||||
logger.info(f"Initialize {InteractiveSeg.name} plugin")
|
||||
|
||||
@@ -416,6 +416,8 @@ ANIME_SEG_MODELS = {
|
||||
class AnimeSeg(BasePlugin):
|
||||
# Model from: https://github.com/SkyTNT/anime-segmentation
|
||||
name = "AnimeSeg"
|
||||
support_gen_image = True
|
||||
support_gen_mask = True
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -426,10 +428,19 @@ class AnimeSeg(BasePlugin):
|
||||
ANIME_SEG_MODELS["md5"],
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
mask = self.forward(rgb_np_img)
|
||||
mask = Image.fromarray(mask, mode="L")
|
||||
h0, w0 = rgb_np_img.shape[0], rgb_np_img.shape[1]
|
||||
empty = Image.new("RGBA", (w0, h0), 0)
|
||||
img = Image.fromarray(rgb_np_img)
|
||||
cutout = Image.composite(img, empty, mask)
|
||||
return np.asarray(cutout)
|
||||
|
||||
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
return self.forward(rgb_np_img)
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def forward(self, rgb_np_img):
|
||||
s = 1024
|
||||
|
||||
@@ -448,9 +459,4 @@ class AnimeSeg(BasePlugin):
|
||||
mask = self.model(tmpImg)
|
||||
mask = mask[0, :, ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]
|
||||
mask = cv2.resize(mask.cpu().numpy().transpose((1, 2, 0)), (w0, h0))
|
||||
mask = Image.fromarray((mask * 255).astype("uint8"), mode="L")
|
||||
|
||||
empty = Image.new("RGBA", (w0, h0), 0)
|
||||
img = Image.fromarray(rgb_np_img)
|
||||
cutout = Image.composite(img, empty, mask)
|
||||
return np.asarray(cutout)
|
||||
return (mask * 255).astype("uint8")
|
||||
|
||||
@@ -5,15 +5,23 @@ from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
name: str
|
||||
support_gen_image: bool = False
|
||||
support_gen_mask: bool = False
|
||||
|
||||
def __init__(self):
|
||||
err_msg = self.check_dep()
|
||||
if err_msg:
|
||||
logger.error(err_msg)
|
||||
exit(-1)
|
||||
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest) -> np.array:
|
||||
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
# return RGBA np image or BGR np image
|
||||
...
|
||||
|
||||
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
# return GRAY or BGR np image, 255 means foreground, 0 means background
|
||||
...
|
||||
|
||||
def check_dep(self):
|
||||
...
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import download_model
|
||||
@@ -8,6 +9,7 @@ from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
class GFPGANPlugin(BasePlugin):
|
||||
name = "GFPGAN"
|
||||
support_gen_image = True
|
||||
|
||||
def __init__(self, device, upscaler=None):
|
||||
super().__init__()
|
||||
@@ -37,7 +39,7 @@ class GFPGANPlugin(BasePlugin):
|
||||
self.face_enhancer.face_helper.face_det.to(device)
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
weight = 0.5
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
logger.info(f"GFPGAN input shape: {bgr_np_img.shape}")
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import download_model
|
||||
@@ -34,6 +35,7 @@ SEGMENT_ANYTHING_MODELS = {
|
||||
|
||||
class InteractiveSeg(BasePlugin):
|
||||
name = "InteractiveSeg"
|
||||
support_gen_mask = True
|
||||
|
||||
def __init__(self, model_name, device):
|
||||
super().__init__()
|
||||
@@ -47,10 +49,11 @@ class InteractiveSeg(BasePlugin):
|
||||
)
|
||||
self.prev_img_md5 = None
|
||||
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
|
||||
return self.forward(rgb_np_img, req.clicks, img_md5)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
|
||||
input_point = []
|
||||
input_label = []
|
||||
@@ -70,13 +73,4 @@ class InteractiveSeg(BasePlugin):
|
||||
multimask_output=False,
|
||||
)
|
||||
mask = masks[0].astype(np.uint8) * 255
|
||||
# TODO: how to set kernel size?
|
||||
kernel_size = 9
|
||||
mask = cv2.dilate(
|
||||
mask, np.ones((kernel_size, kernel_size), np.uint8), iterations=1
|
||||
)
|
||||
# fronted brush color "ffcc00bb"
|
||||
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
|
||||
res_mask[mask == 255] = [255, 203, 0, int(255 * 0.73)]
|
||||
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
|
||||
return res_mask
|
||||
return mask
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from enum import Enum
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.const import RealESRGANModel
|
||||
@@ -11,6 +13,7 @@ from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
class RealESRGANUpscaler(BasePlugin):
|
||||
name = "RealESRGAN"
|
||||
support_gen_image = True
|
||||
|
||||
def __init__(self, name, device, no_half=False):
|
||||
super().__init__()
|
||||
@@ -77,13 +80,14 @@ class RealESRGANUpscaler(BasePlugin):
|
||||
device=device,
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")
|
||||
result = self.forward(bgr_np_img, req.scale)
|
||||
logger.info(f"RealESRGAN output shape: {result.shape}")
|
||||
return result
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, bgr_np_img, scale: float):
|
||||
# 输出是 BGR
|
||||
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
|
||||
|
||||
@@ -9,6 +9,8 @@ from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
class RemoveBG(BasePlugin):
|
||||
name = "RemoveBG"
|
||||
support_gen_mask = True
|
||||
support_gen_image = True
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -20,17 +22,24 @@ class RemoveBG(BasePlugin):
|
||||
|
||||
self.session = new_session(model_name="u2net")
|
||||
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
return self.forward(bgr_np_img)
|
||||
|
||||
def forward(self, bgr_np_img) -> np.ndarray:
|
||||
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
from rembg import remove
|
||||
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# return BGRA image
|
||||
output = remove(bgr_np_img, session=self.session)
|
||||
return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
|
||||
|
||||
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
from rembg import remove
|
||||
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# return BGR image, 255 means foreground, 0 means background
|
||||
output = remove(bgr_np_img, session=self.session, only_mask=True)
|
||||
return output
|
||||
|
||||
def check_dep(self):
|
||||
try:
|
||||
import rembg
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import download_model
|
||||
@@ -8,6 +9,7 @@ from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
class RestoreFormerPlugin(BasePlugin):
|
||||
name = "RestoreFormer"
|
||||
support_gen_image = True
|
||||
|
||||
def __init__(self, device, upscaler=None):
|
||||
super().__init__()
|
||||
@@ -32,7 +34,7 @@ class RestoreFormerPlugin(BasePlugin):
|
||||
bg_upsampler=upscaler.model if upscaler is not None else None,
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
weight = 0.5
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}")
|
||||
|
||||
@@ -3,12 +3,17 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Literal, List
|
||||
|
||||
from PIL.Image import Image
|
||||
from pydantic import BaseModel, Field, validator, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from lama_cleaner.const import Device, InteractiveSegModel, RealESRGANModel
|
||||
|
||||
|
||||
class PluginInfo(BaseModel):
|
||||
name: str
|
||||
support_gen_image: bool = False
|
||||
support_gen_mask: bool = False
|
||||
|
||||
|
||||
class CV2Flag(str, Enum):
|
||||
INPAINT_NS = "INPAINT_NS"
|
||||
INPAINT_TELEA = "INPAINT_TELEA"
|
||||
@@ -272,7 +277,7 @@ class GenInfoResponse(BaseModel):
|
||||
|
||||
|
||||
class ServerConfigResponse(BaseModel):
|
||||
plugins: List[str]
|
||||
plugins: List[PluginInfo]
|
||||
enableFileManager: bool
|
||||
enableAutoSaving: bool
|
||||
enableControlnet: bool
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import time
|
||||
from PIL import Image
|
||||
|
||||
from lama_cleaner.helper import encode_pil_to_base64
|
||||
from lama_cleaner.helper import encode_pil_to_base64, gen_frontend_mask
|
||||
from lama_cleaner.plugins.anime_seg import AnimeSeg
|
||||
from lama_cleaner.schema import RunPluginRequest
|
||||
from lama_cleaner.tests.utils import check_device, current_dir, save_dir
|
||||
@@ -35,34 +35,48 @@ def _save(img, name):
|
||||
|
||||
def test_remove_bg():
|
||||
model = RemoveBG()
|
||||
rgba_np_img = model(
|
||||
rgba_np_img = model.gen_image(
|
||||
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
|
||||
)
|
||||
res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA)
|
||||
_save(res, "test_remove_bg.png")
|
||||
|
||||
bgr_np_img = model.gen_mask(
|
||||
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
|
||||
)
|
||||
|
||||
res_mask = gen_frontend_mask(bgr_np_img)
|
||||
_save(res_mask, "test_remove_bg_frontend_mask.png")
|
||||
|
||||
assert len(bgr_np_img.shape) == 2
|
||||
_save(bgr_np_img, "test_remove_bg_mask.jpeg")
|
||||
|
||||
|
||||
def test_anime_seg():
|
||||
model = AnimeSeg()
|
||||
img = cv2.imread(str(current_dir / "anime_test.png"))
|
||||
img_base64 = encode_pil_to_base64(Image.fromarray(img), 100, {})
|
||||
res = model(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64))
|
||||
res = model.gen_image(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64))
|
||||
assert len(res.shape) == 3
|
||||
assert res.shape[-1] == 4
|
||||
_save(res, "test_anime_seg.png")
|
||||
|
||||
res = model.gen_mask(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64))
|
||||
assert len(res.shape) == 2
|
||||
_save(res, "test_anime_seg_mask.png")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
|
||||
def test_upscale(device):
|
||||
check_device(device)
|
||||
model = RealESRGANUpscaler("realesr-general-x4v3", device)
|
||||
res = model(
|
||||
res = model.gen_image(
|
||||
rgb_img,
|
||||
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=2),
|
||||
)
|
||||
_save(res, f"test_upscale_x2_{device}.png")
|
||||
|
||||
res = model(
|
||||
res = model.gen_image(
|
||||
rgb_img,
|
||||
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=4),
|
||||
)
|
||||
@@ -73,7 +87,9 @@ def test_upscale(device):
|
||||
def test_gfpgan(device):
|
||||
check_device(device)
|
||||
model = GFPGANPlugin(device)
|
||||
res = model(rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64))
|
||||
res = model.gen_image(
|
||||
rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64)
|
||||
)
|
||||
_save(res, f"test_gfpgan_{device}.png")
|
||||
|
||||
|
||||
@@ -81,7 +97,7 @@ def test_gfpgan(device):
|
||||
def test_restoreformer(device):
|
||||
check_device(device)
|
||||
model = RestoreFormerPlugin(device)
|
||||
res = model(
|
||||
res = model.gen_image(
|
||||
rgb_img, RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64)
|
||||
)
|
||||
_save(res, f"test_restoreformer_{device}.png")
|
||||
@@ -91,7 +107,7 @@ def test_restoreformer(device):
|
||||
def test_segment_anything(device):
|
||||
check_device(device)
|
||||
model = InteractiveSeg("vit_l", device)
|
||||
new_mask = model(
|
||||
new_mask = model.gen_mask(
|
||||
rgb_img,
|
||||
RunPluginRequest(
|
||||
name=InteractiveSeg.name,
|
||||
|
||||
Reference in New Issue
Block a user