add adjust mask feature
This commit is contained in:
@@ -29,6 +29,7 @@ from lama_cleaner.helper import (
|
||||
numpy_to_bytes,
|
||||
concat_alpha_channel,
|
||||
gen_frontend_mask,
|
||||
adjust_mask,
|
||||
)
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.model_info import ModelInfo
|
||||
@@ -44,6 +45,7 @@ from lama_cleaner.schema import (
|
||||
RunPluginRequest,
|
||||
SDSampler,
|
||||
PluginInfo,
|
||||
AdjustMaskRequest,
|
||||
)
|
||||
|
||||
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
|
||||
@@ -150,6 +152,7 @@ class Api:
|
||||
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
||||
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
||||
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
||||
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
|
||||
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
||||
# fmt: on
|
||||
|
||||
@@ -294,6 +297,13 @@ class Api:
|
||||
def api_samplers(self) -> List[str]:
|
||||
return [member.value for member in SDSampler.__members__.values()]
|
||||
|
||||
def api_adjust_mask(self, req: AdjustMaskRequest):
|
||||
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
||||
cv2.imwrite("tmp_adjust_mask_input.png", mask)
|
||||
mask = adjust_mask(mask, req.kernel_size, req.operate)
|
||||
cv2.imwrite("tmp_adjust_mask.png", mask)
|
||||
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
|
||||
|
||||
def launch(self):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(
|
||||
|
||||
@@ -358,16 +358,19 @@ def adjust_mask(mask: np.ndarray, kernel_size: int, operate):
|
||||
mask[mask >= 127] = 255
|
||||
mask[mask < 127] = 0
|
||||
# fronted brush color "ffcc00bb"
|
||||
kernel = cv2.getStructuringElement(
|
||||
cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
|
||||
)
|
||||
if operate == "expand":
|
||||
mask = cv2.dilate(
|
||||
mask,
|
||||
np.ones((kernel_size, kernel_size), np.uint8),
|
||||
kernel,
|
||||
iterations=1,
|
||||
)
|
||||
else:
|
||||
mask = cv2.erode(
|
||||
mask,
|
||||
np.ones((kernel_size, kernel_size), np.uint8),
|
||||
kernel,
|
||||
iterations=1,
|
||||
)
|
||||
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
|
||||
|
||||
@@ -110,8 +110,8 @@ class ApiConfig(BaseModel):
|
||||
|
||||
|
||||
class InpaintRequest(BaseModel):
|
||||
image: Optional[str] = Field(..., description="base64 encoded image")
|
||||
mask: Optional[str] = Field(..., description="base64 encoded mask")
|
||||
image: Optional[str] = Field(None, description="base64 encoded image")
|
||||
mask: Optional[str] = Field(None, description="base64 encoded mask")
|
||||
|
||||
ldm_steps: int = Field(20, description="Steps for ldm model.")
|
||||
ldm_sampler: str = Field(LDMSampler.plms, discription="Sampler for ldm model.")
|
||||
@@ -289,3 +289,12 @@ class ServerConfigResponse(BaseModel):
|
||||
|
||||
class SwitchModelRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
AdjustMaskOperate = Literal["expand", "shrink"]
|
||||
|
||||
|
||||
class AdjustMaskRequest(BaseModel):
|
||||
mask: str = Field(..., description="base64 encoded mask. 255 means area to do inpaint")
|
||||
operate: AdjustMaskOperate = Field(..., description="expand or shrink")
|
||||
kernel_size: int = Field(5, description="Kernel size for expanding mask")
|
||||
|
||||
15
lama_cleaner/tests/test_adjust_mask.py
Normal file
15
lama_cleaner/tests/test_adjust_mask.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import cv2
|
||||
from lama_cleaner.helper import adjust_mask
|
||||
from lama_cleaner.tests.utils import current_dir, save_dir
|
||||
|
||||
mask_p = current_dir / "overture-creations-5sI6fQgYIuo_mask.png"
|
||||
|
||||
|
||||
def test_adjust_mask():
|
||||
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
||||
res_mask = adjust_mask(mask, 0, "expand")
|
||||
cv2.imwrite(str(save_dir / "adjust_mask_original.png"), res_mask)
|
||||
res_mask = adjust_mask(mask, 40, "expand")
|
||||
cv2.imwrite(str(save_dir / "adjust_mask_expand.png"), res_mask)
|
||||
res_mask = adjust_mask(mask, 20, "shrink")
|
||||
cv2.imwrite(str(save_dir / "adjust_mask_shrink.png"), res_mask)
|
||||
@@ -31,8 +31,6 @@ def assert_equal(
|
||||
mask_p=current_dir / "mask.png",
|
||||
):
|
||||
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||
config.image = encode_pil_to_base64(Image.fromarray(img), 95, {})[0]
|
||||
config.mask = encode_pil_to_base64(Image.fromarray(mask), 95, {})[0]
|
||||
print(f"Input image shape: {img.shape}")
|
||||
res = model(img, mask, config)
|
||||
ok = cv2.imwrite(
|
||||
|
||||
Reference in New Issue
Block a user