Qing
2025-03-17 16:49:36 +08:00
parent 96d944a40c
commit c61b3b1149
3 changed files with 12 additions and 15 deletions

View File

@@ -269,8 +269,9 @@ class Api:
return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt) return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
def api_inpaint(self, req: InpaintRequest): def api_inpaint(self, req: InpaintRequest):
image, alpha_channel, infos = decode_base64_to_image(req.image) image, alpha_channel, infos, ext = decode_base64_to_image(req.image)
mask, _, _ = decode_base64_to_image(req.mask, gray=True) mask, _, _, _ = decode_base64_to_image(req.mask, gray=True)
logger.info(f"image ext: {ext}")
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
if image.shape[:2] != mask.shape[:2]: if image.shape[:2] != mask.shape[:2]:
@@ -279,11 +280,6 @@ class Api:
detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.", detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
) )
if req.paint_by_example_example_image:
paint_by_example_image, _, _ = decode_base64_to_image(
req.paint_by_example_example_image
)
start = time.time() start = time.time()
rgb_np_img = self.model_manager(image, mask, req) rgb_np_img = self.model_manager(image, mask, req)
logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms") logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
@@ -292,7 +288,6 @@ class Api:
rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB) rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel) rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
ext = "png"
res_img_bytes = pil_to_bytes( res_img_bytes = pil_to_bytes(
Image.fromarray(rgb_res), Image.fromarray(rgb_res),
ext=ext, ext=ext,
@@ -316,7 +311,7 @@ class Api:
raise HTTPException( raise HTTPException(
status_code=422, detail="Plugin does not support output image" status_code=422, detail="Plugin does not support output image"
) )
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) rgb_np_img, alpha_channel, infos, _ = decode_base64_to_image(req.image)
bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req) bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
torch_gc() torch_gc()
@@ -343,7 +338,7 @@ class Api:
raise HTTPException( raise HTTPException(
status_code=422, detail="Plugin does not support output image" status_code=422, detail="Plugin does not support output image"
) )
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) rgb_np_img, _, _, _ = decode_base64_to_image(req.image)
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req) bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
torch_gc() torch_gc()
res_mask = gen_frontend_mask(bgr_or_gray_mask) res_mask = gen_frontend_mask(bgr_or_gray_mask)
@@ -356,7 +351,7 @@ class Api:
return [member.value for member in SDSampler.__members__.values()] return [member.value for member in SDSampler.__members__.values()]
def api_adjust_mask(self, req: AdjustMaskRequest): def api_adjust_mask(self, req: AdjustMaskRequest):
mask, _, _ = decode_base64_to_image(req.mask, gray=True) mask, _, _, _ = decode_base64_to_image(req.mask, gray=True)
mask = adjust_mask(mask, req.kernel_size, req.operate) mask = adjust_mask(mask, req.kernel_size, req.operate)
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png") return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")

View File

@@ -306,12 +306,14 @@ def get_image_ext(img_bytes):
def decode_base64_to_image( def decode_base64_to_image(
encoding: str, gray=False encoding: str, gray=False
) -> Tuple[np.array, Optional[np.array], Dict]: ) -> Tuple[np.array, Optional[np.array], Dict, str]:
if encoding.startswith("data:image/") or encoding.startswith( if encoding.startswith("data:image/") or encoding.startswith(
"data:application/octet-stream;base64," "data:application/octet-stream;base64,"
): ):
encoding = encoding.split(";")[1].split(",")[1] encoding = encoding.split(";")[1].split(",")[1]
image = Image.open(io.BytesIO(base64.b64decode(encoding))) image_bytes = base64.b64decode(encoding)
ext = get_image_ext(image_bytes)
image = Image.open(io.BytesIO(image_bytes))
alpha_channel = None alpha_channel = None
try: try:
@@ -333,7 +335,7 @@ def decode_base64_to_image(
image = image.convert("RGB") image = image.convert("RGB")
np_img = np.array(image) np_img = np.array(image)
return np_img, alpha_channel, infos return np_img, alpha_channel, infos, ext
def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes: def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes:

View File

@@ -49,7 +49,7 @@ class PaintByExample(DiffusionInpaintModel):
""" """
if config.paint_by_example_example_image is None: if config.paint_by_example_example_image is None:
raise ValueError("paint_by_example_example_image is required") raise ValueError("paint_by_example_example_image is required")
example_image, _, _ = decode_base64_to_image( example_image, _, _, _ = decode_base64_to_image(
config.paint_by_example_example_image config.paint_by_example_example_image
) )
output = self.model( output = self.model(