From c61b3b1149f0c0a1d38588dfe44ad4a283cfffc0 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 17 Mar 2025 16:49:36 +0800 Subject: [PATCH] fix https://github.com/Sanster/IOPaint/issues/617 --- iopaint/api.py | 17 ++++++----------- iopaint/helper.py | 8 +++++--- iopaint/model/paint_by_example.py | 2 +- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/iopaint/api.py b/iopaint/api.py index 363d23d..f95022e 100644 --- a/iopaint/api.py +++ b/iopaint/api.py @@ -269,8 +269,9 @@ class Api: return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt) def api_inpaint(self, req: InpaintRequest): - image, alpha_channel, infos = decode_base64_to_image(req.image) - mask, _, _ = decode_base64_to_image(req.mask, gray=True) + image, alpha_channel, infos, ext = decode_base64_to_image(req.image) + mask, _, _, _ = decode_base64_to_image(req.mask, gray=True) + logger.info(f"image ext: {ext}") mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] if image.shape[:2] != mask.shape[:2]: @@ -279,11 +280,6 @@ class Api: detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.", ) - if req.paint_by_example_example_image: - paint_by_example_image, _, _ = decode_base64_to_image( - req.paint_by_example_example_image - ) - start = time.time() rgb_np_img = self.model_manager(image, mask, req) logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms") @@ -292,7 +288,6 @@ class Api: rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB) rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel) - ext = "png" res_img_bytes = pil_to_bytes( Image.fromarray(rgb_res), ext=ext, @@ -316,7 +311,7 @@ class Api: raise HTTPException( status_code=422, detail="Plugin does not support output image" ) - rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) + rgb_np_img, alpha_channel, infos, _ = decode_base64_to_image(req.image) bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req) torch_gc() @@ -343,7 +338,7 @@ class Api: raise HTTPException( status_code=422, detail="Plugin does not support output image" ) - rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) + rgb_np_img, _, _, _ = decode_base64_to_image(req.image) bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req) torch_gc() res_mask = gen_frontend_mask(bgr_or_gray_mask) @@ -356,7 +351,7 @@ class Api: return [member.value for member in SDSampler.__members__.values()] def api_adjust_mask(self, req: AdjustMaskRequest): - mask, _, _ = decode_base64_to_image(req.mask, gray=True) + mask, _, _, _ = decode_base64_to_image(req.mask, gray=True) mask = adjust_mask(mask, req.kernel_size, req.operate) return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png") diff --git a/iopaint/helper.py b/iopaint/helper.py index 1c99dcf..2fa7e46 100644 --- a/iopaint/helper.py +++ b/iopaint/helper.py @@ -306,12 +306,14 @@ def get_image_ext(img_bytes): def decode_base64_to_image( encoding: str, gray=False -) -> Tuple[np.array, Optional[np.array], Dict]: +) -> Tuple[np.array, Optional[np.array], Dict, str]: if encoding.startswith("data:image/") or encoding.startswith( "data:application/octet-stream;base64," ): encoding = encoding.split(";")[1].split(",")[1] - image = Image.open(io.BytesIO(base64.b64decode(encoding))) + image_bytes = base64.b64decode(encoding) + ext = get_image_ext(image_bytes) + image = Image.open(io.BytesIO(image_bytes)) alpha_channel = None try: @@ -333,7 +335,7 @@ def decode_base64_to_image( image = image.convert("RGB") np_img = np.array(image) - return np_img, alpha_channel, infos + return np_img, alpha_channel, infos, ext def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes: diff --git a/iopaint/model/paint_by_example.py b/iopaint/model/paint_by_example.py index bf1e5b7..7d37a6b 100644 --- a/iopaint/model/paint_by_example.py +++ b/iopaint/model/paint_by_example.py @@ -49,7 +49,7 @@ class PaintByExample(DiffusionInpaintModel): """ if config.paint_by_example_example_image is None: raise ValueError("paint_by_example_example_image is required") - example_image, _, _ = decode_base64_to_image( + example_image, _, _, _ = decode_base64_to_image( config.paint_by_example_example_image ) output = self.model(