add adjust mask feature

This commit is contained in:
Qing
2024-01-05 14:57:30 +08:00
parent 2996544e75
commit e889e527ab
18 changed files with 507 additions and 76 deletions

View File

@@ -29,6 +29,7 @@ from lama_cleaner.helper import (
numpy_to_bytes,
concat_alpha_channel,
gen_frontend_mask,
adjust_mask,
)
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_info import ModelInfo
@@ -44,6 +45,7 @@ from lama_cleaner.schema import (
RunPluginRequest,
SDSampler,
PluginInfo,
AdjustMaskRequest,
)
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
@@ -150,6 +152,7 @@ class Api:
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
# fmt: on
@@ -294,6 +297,13 @@ class Api:
def api_samplers(self) -> List[str]:
return [member.value for member in SDSampler.__members__.values()]
def api_adjust_mask(self, req: AdjustMaskRequest):
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
cv2.imwrite("tmp_adjust_mask_input.png", mask)
mask = adjust_mask(mask, req.kernel_size, req.operate)
cv2.imwrite("tmp_adjust_mask.png", mask)
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
def launch(self):
self.app.include_router(self.router)
uvicorn.run(