update plugins
This commit is contained in:
@@ -210,26 +210,26 @@ class Api:
|
||||
)
|
||||
|
||||
def api_run_plugin(self, req: RunPluginRequest):
|
||||
ext = "png"
|
||||
if req.name not in self.plugins:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
image, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||
bgr_res = self.plugins[req.name].run(image, req)
|
||||
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||
bgr_np_img = self.plugins[req.name](rgb_np_img, req)
|
||||
torch_gc()
|
||||
if req.name == InteractiveSeg.name:
|
||||
return Response(
|
||||
content=numpy_to_bytes(bgr_res, "png"),
|
||||
media_type="image/png",
|
||||
content=numpy_to_bytes(bgr_np_img, ext),
|
||||
media_type=f"image/{ext}",
|
||||
)
|
||||
ext = "png"
|
||||
if req.name in [RemoveBG.name, AnimeSeg.name]:
|
||||
rgb_res = bgr_res
|
||||
if bgr_np_img.shape[2] == 4:
|
||||
rgba_np_img = bgr_np_img
|
||||
else:
|
||||
rgb_res = cv2.cvtColor(bgr_res, cv2.COLOR_BGR2RGB)
|
||||
rgb_res = concat_alpha_channel(rgb_res, alpha_channel)
|
||||
rgba_np_img = cv2.cvtColor(bgr_np_img, cv2.COLOR_BGR2RGB)
|
||||
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
|
||||
|
||||
return Response(
|
||||
content=pil_to_bytes(
|
||||
Image.fromarray(rgb_res),
|
||||
Image.fromarray(rgba_np_img),
|
||||
ext=ext,
|
||||
quality=self.config.quality,
|
||||
infos=infos,
|
||||
|
||||
@@ -7,6 +7,7 @@ from PIL import Image
|
||||
|
||||
from lama_cleaner.helper import load_model
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
|
||||
class REBNCONV(nn.Module):
|
||||
@@ -425,7 +426,7 @@ class AnimeSeg(BasePlugin):
|
||||
ANIME_SEG_MODELS["md5"],
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
return self.forward(rgb_np_img)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from loguru import logger
|
||||
import numpy as np
|
||||
|
||||
from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
@@ -8,7 +11,8 @@ class BasePlugin:
|
||||
logger.error(err_msg)
|
||||
exit(-1)
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest) -> np.array:
|
||||
# return RGBA np image or BGR np image
|
||||
...
|
||||
|
||||
def check_dep(self):
|
||||
|
||||
@@ -3,6 +3,7 @@ from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import download_model
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
|
||||
class GFPGANPlugin(BasePlugin):
|
||||
@@ -36,7 +37,7 @@ class GFPGANPlugin(BasePlugin):
|
||||
self.face_enhancer.face_helper.face_det.to(device)
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
weight = 0.5
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
logger.info(f"GFPGAN input shape: {bgr_np_img.shape}")
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import hashlib
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -7,6 +9,7 @@ from loguru import logger
|
||||
from lama_cleaner.helper import download_model
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
from lama_cleaner.plugins.segment_anything import SamPredictor, sam_model_registry
|
||||
from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
# 从小到大
|
||||
SEGMENT_ANYTHING_MODELS = {
|
||||
@@ -44,11 +47,11 @@ class InteractiveSeg(BasePlugin):
|
||||
)
|
||||
self.prev_img_md5 = None
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
clicks = json.loads(form["clicks"])
|
||||
return self.forward(rgb_np_img, clicks, form["img_md5"])
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
|
||||
return self.forward(rgb_np_img, req.clicks, img_md5)
|
||||
|
||||
def forward(self, rgb_np_img, clicks, img_md5):
|
||||
def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
|
||||
input_point = []
|
||||
input_label = []
|
||||
for click in clicks:
|
||||
|
||||
@@ -6,6 +6,7 @@ from loguru import logger
|
||||
from lama_cleaner.const import RealESRGANModel
|
||||
from lama_cleaner.helper import download_model
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
|
||||
class RealESRGANUpscaler(BasePlugin):
|
||||
@@ -76,11 +77,10 @@ class RealESRGANUpscaler(BasePlugin):
|
||||
device=device,
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
scale = float(form["upscale"])
|
||||
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {scale}")
|
||||
result = self.forward(bgr_np_img, scale)
|
||||
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")
|
||||
result = self.forward(bgr_np_img, req.scale)
|
||||
logger.info(f"RealESRGAN output shape: {result.shape}")
|
||||
return result
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import numpy as np
|
||||
from torch.hub import get_dir
|
||||
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
|
||||
class RemoveBG(BasePlugin):
|
||||
@@ -19,7 +20,7 @@ class RemoveBG(BasePlugin):
|
||||
|
||||
self.session = new_session(model_name="u2net")
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
return self.forward(bgr_np_img)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import download_model
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
from lama_cleaner.schema import RunPluginRequest
|
||||
|
||||
|
||||
class RestoreFormerPlugin(BasePlugin):
|
||||
@@ -31,7 +32,7 @@ class RestoreFormerPlugin(BasePlugin):
|
||||
bg_upsampler=upscaler.model if upscaler is not None else None,
|
||||
)
|
||||
|
||||
def __call__(self, rgb_np_img, files, form):
|
||||
def __call__(self, rgb_np_img, req: RunPluginRequest):
|
||||
weight = 0.5
|
||||
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
|
||||
logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}")
|
||||
|
||||
@@ -136,6 +136,12 @@ class InpaintRequest(BaseModel):
|
||||
extender_height: int = Field(640, description="Extend height for extender")
|
||||
extender_width: int = Field(640, description="Extend width for extender")
|
||||
|
||||
sd_scale: float = Field(
|
||||
1.0,
|
||||
description="Resize the image before doing sd inpainting, the area outside the mask will not lose quality.",
|
||||
gt=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
sd_mask_blur: int = Field(
|
||||
33,
|
||||
description="Blur the edge of mask area. The higher the number the smoother blend with the original image",
|
||||
@@ -143,6 +149,7 @@ class InpaintRequest(BaseModel):
|
||||
sd_strength: float = Field(
|
||||
1.0,
|
||||
description="Strength is a measure of how much noise is added to the base image, which influences how similar the output is to the base image. Higher value means more noise and more different from the base image",
|
||||
le=1.0,
|
||||
)
|
||||
sd_steps: int = Field(
|
||||
50,
|
||||
@@ -202,7 +209,9 @@ class InpaintRequest(BaseModel):
|
||||
|
||||
# ControlNet
|
||||
enable_controlnet: bool = Field(False, description="Enable controlnet")
|
||||
controlnet_conditioning_scale: float = Field(0.4, description="Conditioning scale")
|
||||
controlnet_conditioning_scale: float = Field(
|
||||
0.4, description="Conditioning scale", gt=0.0, le=1.0
|
||||
)
|
||||
controlnet_method: str = Field(
|
||||
"lllyasviel/control_v11p_sd15_canny", description="Controlnet method"
|
||||
)
|
||||
@@ -214,6 +223,8 @@ class InpaintRequest(BaseModel):
|
||||
fitting_degree: float = Field(
|
||||
1.0,
|
||||
description="Control the fitting degree of the generated objects to the mask shape.",
|
||||
gt=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
|
||||
@field_validator("sd_seed")
|
||||
@@ -226,7 +237,7 @@ class InpaintRequest(BaseModel):
|
||||
|
||||
class RunPluginRequest(BaseModel):
|
||||
name: str
|
||||
image: Optional[str] = Field(..., description="base64 encoded image")
|
||||
image: str = Field(..., description="base64 encoded image")
|
||||
clicks: List[List[int]] = Field(
|
||||
[], description="Clicks for interactive seg, [[x,y,0/1], [x2,y2,0/1]]"
|
||||
)
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import hashlib
|
||||
import os
|
||||
import time
|
||||
from PIL import Image
|
||||
|
||||
from lama_cleaner.helper import encode_pil_to_base64
|
||||
from lama_cleaner.plugins.anime_seg import AnimeSeg
|
||||
from lama_cleaner.schema import RunPluginRequest
|
||||
from lama_cleaner.tests.utils import check_device, current_dir, save_dir
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
@@ -22,6 +25,8 @@ img_p = current_dir / "bunny.jpeg"
|
||||
img_bytes = open(img_p, "rb").read()
|
||||
bgr_img = cv2.imread(str(img_p))
|
||||
rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
|
||||
rgb_img_base64 = encode_pil_to_base64(Image.fromarray(rgb_img), 100, {})
|
||||
bgr_img_base64 = encode_pil_to_base64(Image.fromarray(bgr_img), 100, {})
|
||||
|
||||
|
||||
def _save(img, name):
|
||||
@@ -30,15 +35,18 @@ def _save(img, name):
|
||||
|
||||
def test_remove_bg():
|
||||
model = RemoveBG()
|
||||
res = model.forward(bgr_img)
|
||||
res = cv2.cvtColor(res, cv2.COLOR_RGBA2BGRA)
|
||||
rgba_np_img = model(
|
||||
rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64)
|
||||
)
|
||||
res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA)
|
||||
_save(res, "test_remove_bg.png")
|
||||
|
||||
|
||||
def test_anime_seg():
|
||||
model = AnimeSeg()
|
||||
img = cv2.imread(str(current_dir / "anime_test.png"))
|
||||
res = model.forward(img)
|
||||
img_base64 = encode_pil_to_base64(Image.fromarray(img), 100, {})
|
||||
res = model(img, RunPluginRequest(name=AnimeSeg.name, image=img_base64))
|
||||
assert len(res.shape) == 3
|
||||
assert res.shape[-1] == 4
|
||||
_save(res, "test_anime_seg.png")
|
||||
@@ -48,10 +56,16 @@ def test_anime_seg():
|
||||
def test_upscale(device):
|
||||
check_device(device)
|
||||
model = RealESRGANUpscaler("realesr-general-x4v3", device)
|
||||
res = model.forward(bgr_img, 2)
|
||||
res = model(
|
||||
rgb_img,
|
||||
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=2),
|
||||
)
|
||||
_save(res, f"test_upscale_x2_{device}.png")
|
||||
|
||||
res = model.forward(bgr_img, 4)
|
||||
res = model(
|
||||
rgb_img,
|
||||
RunPluginRequest(name=RealESRGANUpscaler.name, image=rgb_img_base64, scale=4),
|
||||
)
|
||||
_save(res, f"test_upscale_x4_{device}.png")
|
||||
|
||||
|
||||
@@ -59,7 +73,7 @@ def test_upscale(device):
|
||||
def test_gfpgan(device):
|
||||
check_device(device)
|
||||
model = GFPGANPlugin(device)
|
||||
res = model(rgb_img, None, None)
|
||||
res = model(rgb_img, RunPluginRequest(name=GFPGANPlugin.name, image=rgb_img_base64))
|
||||
_save(res, f"test_gfpgan_{device}.png")
|
||||
|
||||
|
||||
@@ -67,20 +81,24 @@ def test_gfpgan(device):
|
||||
def test_restoreformer(device):
|
||||
check_device(device)
|
||||
model = RestoreFormerPlugin(device)
|
||||
res = model(rgb_img, None, None)
|
||||
res = model(
|
||||
rgb_img, RunPluginRequest(name=RestoreFormerPlugin.name, image=rgb_img_base64)
|
||||
)
|
||||
_save(res, f"test_restoreformer_{device}.png")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
|
||||
def test_segment_anything(device):
|
||||
check_device(device)
|
||||
img_md5 = hashlib.md5(img_bytes).hexdigest()
|
||||
model = InteractiveSeg("vit_l", device)
|
||||
new_mask = model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5)
|
||||
new_mask = model(
|
||||
rgb_img,
|
||||
RunPluginRequest(
|
||||
name=InteractiveSeg.name,
|
||||
image=rgb_img_base64,
|
||||
clicks=([[448 // 2, 394 // 2, 1]]),
|
||||
),
|
||||
)
|
||||
|
||||
save_name = f"test_segment_anything_{device}.png"
|
||||
_save(new_mask, save_name)
|
||||
|
||||
start = time.time()
|
||||
model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5)
|
||||
print(f"Time for {save_name}: {time.time() - start:.2f}s")
|
||||
|
||||
Reference in New Issue
Block a user