add adjust mask feature
This commit is contained in:
@@ -29,6 +29,7 @@ from lama_cleaner.helper import (
|
||||
numpy_to_bytes,
|
||||
concat_alpha_channel,
|
||||
gen_frontend_mask,
|
||||
adjust_mask,
|
||||
)
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.model_info import ModelInfo
|
||||
@@ -44,6 +45,7 @@ from lama_cleaner.schema import (
|
||||
RunPluginRequest,
|
||||
SDSampler,
|
||||
PluginInfo,
|
||||
AdjustMaskRequest,
|
||||
)
|
||||
|
||||
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
|
||||
@@ -150,6 +152,7 @@ class Api:
|
||||
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
||||
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
||||
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
||||
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
|
||||
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
||||
# fmt: on
|
||||
|
||||
@@ -294,6 +297,13 @@ class Api:
|
||||
def api_samplers(self) -> List[str]:
|
||||
return [member.value for member in SDSampler.__members__.values()]
|
||||
|
||||
def api_adjust_mask(self, req: AdjustMaskRequest):
|
||||
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
||||
cv2.imwrite("tmp_adjust_mask_input.png", mask)
|
||||
mask = adjust_mask(mask, req.kernel_size, req.operate)
|
||||
cv2.imwrite("tmp_adjust_mask.png", mask)
|
||||
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
|
||||
|
||||
def launch(self):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(
|
||||
|
||||
Reference in New Issue
Block a user