This commit is contained in:
@@ -269,8 +269,9 @@ class Api:
|
|||||||
return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
|
return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt)
|
||||||
|
|
||||||
def api_inpaint(self, req: InpaintRequest):
|
def api_inpaint(self, req: InpaintRequest):
|
||||||
image, alpha_channel, infos = decode_base64_to_image(req.image)
|
image, alpha_channel, infos, ext = decode_base64_to_image(req.image)
|
||||||
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
mask, _, _, _ = decode_base64_to_image(req.mask, gray=True)
|
||||||
|
logger.info(f"image ext: {ext}")
|
||||||
|
|
||||||
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
|
||||||
if image.shape[:2] != mask.shape[:2]:
|
if image.shape[:2] != mask.shape[:2]:
|
||||||
@@ -279,11 +280,6 @@ class Api:
|
|||||||
detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
|
detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.",
|
||||||
)
|
)
|
||||||
|
|
||||||
if req.paint_by_example_example_image:
|
|
||||||
paint_by_example_image, _, _ = decode_base64_to_image(
|
|
||||||
req.paint_by_example_example_image
|
|
||||||
)
|
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
rgb_np_img = self.model_manager(image, mask, req)
|
rgb_np_img = self.model_manager(image, mask, req)
|
||||||
logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
|
logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms")
|
||||||
@@ -292,7 +288,6 @@ class Api:
|
|||||||
rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB)
|
||||||
rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
|
rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel)
|
||||||
|
|
||||||
ext = "png"
|
|
||||||
res_img_bytes = pil_to_bytes(
|
res_img_bytes = pil_to_bytes(
|
||||||
Image.fromarray(rgb_res),
|
Image.fromarray(rgb_res),
|
||||||
ext=ext,
|
ext=ext,
|
||||||
@@ -316,7 +311,7 @@ class Api:
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=422, detail="Plugin does not support output image"
|
status_code=422, detail="Plugin does not support output image"
|
||||||
)
|
)
|
||||||
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
rgb_np_img, alpha_channel, infos, _ = decode_base64_to_image(req.image)
|
||||||
bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
|
bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req)
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
@@ -343,7 +338,7 @@ class Api:
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=422, detail="Plugin does not support output image"
|
status_code=422, detail="Plugin does not support output image"
|
||||||
)
|
)
|
||||||
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image)
|
rgb_np_img, _, _, _ = decode_base64_to_image(req.image)
|
||||||
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
|
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req)
|
||||||
torch_gc()
|
torch_gc()
|
||||||
res_mask = gen_frontend_mask(bgr_or_gray_mask)
|
res_mask = gen_frontend_mask(bgr_or_gray_mask)
|
||||||
@@ -356,7 +351,7 @@ class Api:
|
|||||||
return [member.value for member in SDSampler.__members__.values()]
|
return [member.value for member in SDSampler.__members__.values()]
|
||||||
|
|
||||||
def api_adjust_mask(self, req: AdjustMaskRequest):
|
def api_adjust_mask(self, req: AdjustMaskRequest):
|
||||||
mask, _, _ = decode_base64_to_image(req.mask, gray=True)
|
mask, _, _, _ = decode_base64_to_image(req.mask, gray=True)
|
||||||
mask = adjust_mask(mask, req.kernel_size, req.operate)
|
mask = adjust_mask(mask, req.kernel_size, req.operate)
|
||||||
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
|
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png")
|
||||||
|
|
||||||
|
|||||||
@@ -306,12 +306,14 @@ def get_image_ext(img_bytes):
|
|||||||
|
|
||||||
def decode_base64_to_image(
|
def decode_base64_to_image(
|
||||||
encoding: str, gray=False
|
encoding: str, gray=False
|
||||||
) -> Tuple[np.array, Optional[np.array], Dict]:
|
) -> Tuple[np.array, Optional[np.array], Dict, str]:
|
||||||
if encoding.startswith("data:image/") or encoding.startswith(
|
if encoding.startswith("data:image/") or encoding.startswith(
|
||||||
"data:application/octet-stream;base64,"
|
"data:application/octet-stream;base64,"
|
||||||
):
|
):
|
||||||
encoding = encoding.split(";")[1].split(",")[1]
|
encoding = encoding.split(";")[1].split(",")[1]
|
||||||
image = Image.open(io.BytesIO(base64.b64decode(encoding)))
|
image_bytes = base64.b64decode(encoding)
|
||||||
|
ext = get_image_ext(image_bytes)
|
||||||
|
image = Image.open(io.BytesIO(image_bytes))
|
||||||
|
|
||||||
alpha_channel = None
|
alpha_channel = None
|
||||||
try:
|
try:
|
||||||
@@ -333,7 +335,7 @@ def decode_base64_to_image(
|
|||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
np_img = np.array(image)
|
np_img = np.array(image)
|
||||||
|
|
||||||
return np_img, alpha_channel, infos
|
return np_img, alpha_channel, infos, ext
|
||||||
|
|
||||||
|
|
||||||
def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes:
|
def encode_pil_to_base64(image: Image, quality: int, infos: Dict) -> bytes:
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ class PaintByExample(DiffusionInpaintModel):
|
|||||||
"""
|
"""
|
||||||
if config.paint_by_example_example_image is None:
|
if config.paint_by_example_example_image is None:
|
||||||
raise ValueError("paint_by_example_example_image is required")
|
raise ValueError("paint_by_example_example_image is required")
|
||||||
example_image, _, _ = decode_base64_to_image(
|
example_image, _, _, _ = decode_base64_to_image(
|
||||||
config.paint_by_example_example_image
|
config.paint_by_example_example_image
|
||||||
)
|
)
|
||||||
output = self.model(
|
output = self.model(
|
||||||
|
|||||||
Reference in New Issue
Block a user