make generate mask from RemoveBG && AnimeSeg work
This commit is contained in:
@@ -28,11 +28,13 @@ from lama_cleaner.helper import (
|
||||
pil_to_bytes,
|
||||
numpy_to_bytes,
|
||||
concat_alpha_channel,
|
||||
gen_frontend_mask,
|
||||
)
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.model_info import ModelInfo
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.plugins import build_plugins, InteractiveSeg, RemoveBG, AnimeSeg
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
from lama_cleaner.schema import (
|
||||
GenInfoResponse,
|
||||
ApiConfig,
|
||||
@@ -41,6 +43,7 @@ from lama_cleaner.schema import (
|
||||
InpaintRequest,
|
||||
RunPluginRequest,
|
||||
SDSampler,
|
||||
PluginInfo,
|
||||
)
|
||||
from lama_cleaner.file_manager import FileManager
|
||||
|
||||
@@ -145,7 +148,8 @@ class Api:
|
||||
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
||||
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
||||
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
||||
self.add_api_route("/api/v1/run_plugin", self.api_run_plugin, methods=["POST"])
|
||||
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
||||
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
||||
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
||||
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
||||
# fmt: on
|
||||
@@ -173,7 +177,14 @@ class Api:
|
||||
|
||||
def api_server_config(self) -> ServerConfigResponse:
|
||||
return ServerConfigResponse(
|
||||
plugins=list(self.plugins.keys()),
|
||||
plugins=[
|
||||
PluginInfo(
|
||||
name=it.name,
|
||||
support_gen_image=it.support_gen_image,
|
||||
support_gen_mask=it.support_gen_mask,
|
||||
)
|
||||
for it in self.plugins.values()
|
||||
],
|
||||
enableFileManager=self.file_manager is not None,
|
||||
enableAutoSaving=self.config.output_dir is not None,
|
||||
enableControlnet=self.model_manager.enable_controlnet,
|
||||
@@ -237,22 +248,22 @@ class Api:
|
||||
headers={"X-Seed": str(req.sd_seed)},
|
||||
)
|
||||
|
||||
def api_run_plugin(self, req: RunPluginRequest):
|
||||
def api_run_plugin_gen_image(self, req: RunPluginRequest):
|
||||
ext = "png"
|
||||
if req.name not in self.plugins:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||
bgr_np_img = self.plugins[req.name](rgb_np_img, req)
|
||||
torch_gc()
|
||||
if req.name == InteractiveSeg.name:
|
||||
return Response(
|
||||
content=numpy_to_bytes(bgr_np_img, ext),
|
||||
media_type=f"image/{ext}",
|
||||
raise HTTPException(status_code=422, detail="Plugin not found")
|
||||
if not self.plugins[req.name].support_gen_image:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Plugin does not support output image"
|
||||
)
|
||||
if bgr_np_img.shape[2] == 4:
|
||||
rgba_np_img = bgr_np_img
|
||||
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||
bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
|
||||
torch_gc()
|
||||
|
||||
if bgr_or_rgba_np_img.shape[2] == 4:
|
||||
rgba_np_img = bgr_or_rgba_np_img
|
||||
else:
|
||||
rgba_np_img = cv2.cvtColor(bgr_np_img, cv2.COLOR_BGR2RGB)
|
||||
rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB)
|
||||
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel)
|
||||
|
||||
return Response(
|
||||
@@ -265,6 +276,22 @@ class Api:
|
||||
media_type=f"image/{ext}",
|
||||
)
|
||||
|
||||
def api_run_plugin_gen_mask(self, req: RunPluginRequest):
|
||||
if req.name not in self.plugins:
|
||||
raise HTTPException(status_code=422, detail="Plugin not found")
|
||||
if not self.plugins[req.name].support_gen_mask:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Plugin does not support output image"
|
||||
)
|
||||
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
|
||||
torch_gc()
|
||||
res_mask = gen_frontend_mask(bgr_or_gray_mask)
|
||||
return Response(
|
||||
content=numpy_to_bytes(res_mask, "png"),
|
||||
media_type="image/png",
|
||||
)
|
||||
|
||||
def api_samplers(self) -> List[str]:
|
||||
return [member.value for member in SDSampler.__members__.values()]
|
||||
|
||||
@@ -290,7 +317,7 @@ class Api:
|
||||
)
|
||||
return None
|
||||
|
||||
def _build_plugins(self) -> Dict:
|
||||
def _build_plugins(self) -> Dict[str, BasePlugin]:
|
||||
return build_plugins(
|
||||
self.config.enable_interactive_seg,
|
||||
self.config.interactive_seg_model,
|
||||
|
||||
Reference in New Issue
Block a user