add adjust mask feature

This commit is contained in:
Qing
2024-01-05 14:57:30 +08:00
parent 2996544e75
commit e889e527ab
18 changed files with 507 additions and 76 deletions

View File

@@ -29,6 +29,7 @@ from lama_cleaner.helper import (
numpy_to_bytes,
concat_alpha_channel,
gen_frontend_mask,
adjust_mask,
)
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_info import ModelInfo
@@ -44,6 +45,7 @@ from lama_cleaner.schema import (
RunPluginRequest,
SDSampler,
PluginInfo,
AdjustMaskRequest,
)
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
@@ -150,6 +152,7 @@ class Api:
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"])
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
# fmt: on
@@ -294,6 +297,13 @@ class Api:
def api_samplers(self) -> List[str]:
return [member.value for member in SDSampler.__members__.values()]
def api_adjust_mask(self, req: AdjustMaskRequest):
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
cv2.imwrite("tmp_adjust_mask_input.png", mask)
mask = adjust_mask(mask, req.kernel_size, req.operate)
cv2.imwrite("tmp_adjust_mask.png", mask)
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
def launch(self):
self.app.include_router(self.router)
uvicorn.run(

View File

@@ -358,16 +358,19 @@ def adjust_mask(mask: np.ndarray, kernel_size: int, operate):
mask[mask >= 127] = 255
mask[mask < 127] = 0
# fronted brush color "ffcc00bb"
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
)
if operate == "expand":
mask = cv2.dilate(
mask,
np.ones((kernel_size, kernel_size), np.uint8),
kernel,
iterations=1,
)
else:
mask = cv2.erode(
mask,
np.ones((kernel_size, kernel_size), np.uint8),
kernel,
iterations=1,
)
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)

View File

@@ -110,8 +110,8 @@ class ApiConfig(BaseModel):
class InpaintRequest(BaseModel):
image: Optional[str] = Field(..., description="base64 encoded image")
mask: Optional[str] = Field(..., description="base64 encoded mask")
image: Optional[str] = Field(None, description="base64 encoded image")
mask: Optional[str] = Field(None, description="base64 encoded mask")
ldm_steps: int = Field(20, description="Steps for ldm model.")
ldm_sampler: str = Field(LDMSampler.plms, discription="Sampler for ldm model.")
@@ -289,3 +289,12 @@ class ServerConfigResponse(BaseModel):
class SwitchModelRequest(BaseModel):
name: str
AdjustMaskOperate = Literal["expand", "shrink"]
class AdjustMaskRequest(BaseModel):
mask: str = Field(..., description="base64 encoded mask. 255 means area to do inpaint")
operate: AdjustMaskOperate = Field(..., description="expand or shrink")
kernel_size: int = Field(5, description="Kernel size for expanding mask")

View File

@@ -0,0 +1,15 @@
import cv2
from lama_cleaner.helper import adjust_mask
from lama_cleaner.tests.utils import current_dir, save_dir
mask_p = current_dir / "overture-creations-5sI6fQgYIuo_mask.png"
def test_adjust_mask():
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
res_mask = adjust_mask(mask, 0, "expand")
cv2.imwrite(str(save_dir / "adjust_mask_original.png"), res_mask)
res_mask = adjust_mask(mask, 40, "expand")
cv2.imwrite(str(save_dir / "adjust_mask_expand.png"), res_mask)
res_mask = adjust_mask(mask, 20, "shrink")
cv2.imwrite(str(save_dir / "adjust_mask_shrink.png"), res_mask)

View File

@@ -31,8 +31,6 @@ def assert_equal(
mask_p=current_dir / "mask.png",
):
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
config.image = encode_pil_to_base64(Image.fromarray(img), 95, {})[0]
config.mask = encode_pil_to_base64(Image.fromarray(mask), 95, {})[0]
print(f"Input image shape: {img.shape}")
res = model(img, mask, config)
ok = cv2.imwrite(