This commit is contained in:
@@ -269,8 +269,9 @@ class Api:
|
||||
return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
|
||||
|
||||
def api_inpaint(self, req: InpaintRequest):
|
||||
image, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
||||
image, alpha_channel, infos, ext = decode_base64_to_image(req.image)
|
||||
mask, _, _, _ = decode_base64_to_image(req.mask, gray=True)
|
||||
logger.info(f"image ext: {ext}")
|
||||
|
||||
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
||||
if image.shape[:2] != mask.shape[:2]:
|
||||
@@ -279,11 +280,6 @@ class Api:
|
||||
detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
|
||||
)
|
||||
|
||||
if req.paint_by_example_example_image:
|
||||
paint_by_example_image, _, _ = decode_base64_to_image(
|
||||
req.paint_by_example_example_image
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
rgb_np_img = self.model_manager(image, mask, req)
|
||||
logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
|
||||
@@ -292,7 +288,6 @@ class Api:
|
||||
rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
||||
rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
|
||||
|
||||
ext = "png"
|
||||
res_img_bytes = pil_to_bytes(
|
||||
Image.fromarray(rgb_res),
|
||||
ext=ext,
|
||||
@@ -316,7 +311,7 @@ class Api:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Plugin does not support output image"
|
||||
)
|
||||
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||
rgb_np_img, alpha_channel, infos, _ = decode_base64_to_image(req.image)
|
||||
bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
|
||||
torch_gc()
|
||||
|
||||
@@ -343,7 +338,7 @@ class Api:
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Plugin does not support output image"
|
||||
)
|
||||
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
||||
rgb_np_img, _, _, _ = decode_base64_to_image(req.image)
|
||||
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
|
||||
torch_gc()
|
||||
res_mask = gen_frontend_mask(bgr_or_gray_mask)
|
||||
@@ -356,7 +351,7 @@ class Api:
|
||||
return [member.value for member in SDSampler.__members__.values()]
|
||||
|
||||
def api_adjust_mask(self, req: AdjustMaskRequest):
|
||||
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
||||
mask, _, _, _ = decode_base64_to_image(req.mask, gray=True)
|
||||
mask = adjust_mask(mask, req.kernel_size, req.operate)
|
||||
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
|
||||
|
||||
|
||||
@@ -306,12 +306,14 @@ def get_image_ext(img_bytes):
|
||||
|
||||
def decode_base64_to_image(
|
||||
encoding: str, gray=False
|
||||
) -> Tuple[np.array, Optional[np.array], Dict]:
|
||||
) -> Tuple[np.array, Optional[np.array], Dict, str]:
|
||||
if encoding.startswith("data:image/") or encoding.startswith(
|
||||
"data:application/octet-stream;base64,"
|
||||
):
|
||||
encoding = encoding.split(";")[1].split(",")[1]
|
||||
image = Image.open(io.BytesIO(base64.b64decode(encoding)))
|
||||
image_bytes = base64.b64decode(encoding)
|
||||
ext = get_image_ext(image_bytes)
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
|
||||
alpha_channel = None
|
||||
try:
|
||||
@@ -333,7 +335,7 @@ def decode_base64_to_image(
|
||||
image = image.convert("RGB")
|
||||
np_img = np.array(image)
|
||||
|
||||
return np_img, alpha_channel, infos
|
||||
return np_img, alpha_channel, infos, ext
|
||||
|
||||
|
||||
def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes:
|
||||
|
||||
@@ -49,7 +49,7 @@ class PaintByExample(DiffusionInpaintModel):
|
||||
"""
|
||||
if config.paint_by_example_example_image is None:
|
||||
raise ValueError("paint_by_example_example_image is required")
|
||||
example_image, _, _ = decode_base64_to_image(
|
||||
example_image, _, _, _ = decode_base64_to_image(
|
||||
config.paint_by_example_example_image
|
||||
)
|
||||
output = self.model(
|
||||
|
||||
Reference in New Issue
Block a user