diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index 16fa64e..a6b67b8 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -1,9 +1,11 @@ +import io import os import sys from typing import List, Optional from urllib.parse import urlparse import cv2 +from PIL import Image, ImageOps import numpy as np import torch from loguru import logger @@ -85,16 +87,23 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: def load_img(img_bytes, gray: bool = False): alpha_channel = None - nparr = np.frombuffer(img_bytes, np.uint8) + image = Image.open(io.BytesIO(img_bytes)) + try: + image = ImageOps.exif_transpose(image) + except: + pass + if gray: - np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE) + image = image.convert('L') + np_img = np.array(image) else: - np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED) - if len(np_img.shape) == 3 and np_img.shape[2] == 4: + if image.mode == 'RGBA': + np_img = np.array(image) alpha_channel = np_img[:, :, -1] - np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB) + np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB) else: - np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) + image = image.convert('RGB') + np_img = np.array(image) return np_img, alpha_channel diff --git a/lama_cleaner/tests/test_load_img.py b/lama_cleaner/tests/test_load_img.py new file mode 100644 index 0000000..6028a60 --- /dev/null +++ b/lama_cleaner/tests/test_load_img.py @@ -0,0 +1,21 @@ +from pathlib import Path + +from lama_cleaner.helper import load_img + +current_dir = Path(__file__).parent.absolute().resolve() +png_img_p = current_dir / "image.png" +jpg_img_p = current_dir / "bunny.jpeg" + + +def test_load_png_image(): + with open(png_img_p, "rb") as f: + np_img, alpha_channel = load_img(f.read()) + assert np_img.shape == (256, 256, 3) + assert alpha_channel.shape == (256, 256) + + +def test_load_jpg_image(): + with open(jpg_img_p, "rb") as f: + np_img, alpha_channel = load_img(f.read()) + assert np_img.shape == (394, 448, 3) + assert alpha_channel is None