make generate mask from RemoveBG && AnimeSeg work

This commit is contained in:
Qing
2024-01-02 22:32:40 +08:00
parent 6253016019
commit aca85543ca
22 changed files with 244 additions and 100 deletions

View File

@@ -28,11 +28,13 @@ from lama_cleaner.helper import (
pil_to_bytes,
numpy_to_bytes,
concat_alpha_channel,
gen_frontend_mask,
)
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_info import ModelInfo
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.plugins import build_plugins, InteractiveSeg, RemoveBG, AnimeSeg
from lama_cleaner.plugins.base_plugin import BasePlugin
from lama_cleaner.schema import (
GenInfoResponse,
ApiConfig,
@@ -41,6 +43,7 @@ from lama_cleaner.schema import (
InpaintRequest,
RunPluginRequest,
SDSampler,
PluginInfo,
)
from lama_cleaner.file_manager import FileManager
@@ -145,7 +148,8 @@ class Api:
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
self.add_api_route("/api/v1/run_plugin", self.api_run_plugin, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
# fmt: on
@@ -173,7 +177,14 @@ class Api:
def api_server_config(self) -> ServerConfigResponse:
return ServerConfigResponse(
plugins=list(self.plugins.keys()),
plugins=[
PluginInfo(
name=it.name,
support_gen_image=it.support_gen_image,
support_gen_mask=it.support_gen_mask,
)
for it in self.plugins.values()
],
enableFileManager=self.file_manager is not None,
enableAutoSaving=self.config.output_dir is not None,
enableControlnet=self.model_manager.enable_controlnet,
@@ -237,22 +248,22 @@ class Api:
headers={"X-Seed": str(req.sd_seed)},
)
def api_run_plugin(self, req: RunPluginRequest):
def api_run_plugin_gen_image(self, req: RunPluginRequest):
ext = "png"
if req.name not in self.plugins:
raise HTTPException(status_code=404, detail="Plugin not found")
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
bgr_np_img = self.plugins[req.name](rgb_np_img, req)
torch_gc()
if req.name == InteractiveSeg.name:
return Response(
content=numpy_to_bytes(bgr_np_img, ext),
media_type=f"image/{ext}",
raise HTTPException(status_code=422, detail="Plugin not found")
if not self.plugins[req.name].support_gen_image:
raise HTTPException(
status_code=422, detail="Plugin does not support output image"
)
if bgr_np_img.shape[2] == 4:
rgba_np_img = bgr_np_img
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
torch_gc()
if bgr_or_rgba_np_img.shape[2] == 4:
rgba_np_img = bgr_or_rgba_np_img
else:
rgba_np_img = cv2.cvtColor(bgr_np_img, cv2.COLOR_BGR2RGB)
rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB)
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
return Response(
@@ -265,6 +276,22 @@ class Api:
media_type=f"image/{ext}",
)
def api_run_plugin_gen_mask(self, req: RunPluginRequest):
if req.name not in self.plugins:
raise HTTPException(status_code=422, detail="Plugin not found")
if not self.plugins[req.name].support_gen_mask:
raise HTTPException(
status_code=422, detail="Plugin does not support output image"
)
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
torch_gc()
res_mask = gen_frontend_mask(bgr_or_gray_mask)
return Response(
content=numpy_to_bytes(res_mask, "png"),
media_type="image/png",
)
def api_samplers(self) -> List[str]:
return [member.value for member in SDSampler.__members__.values()]
@@ -290,7 +317,7 @@ class Api:
)
return None
def _build_plugins(self) -> Dict:
def _build_plugins(self) -> Dict[str, BasePlugin]:
return build_plugins(
self.config.enable_interactive_seg,
self.config.interactive_seg_model,