This commit is contained in:
Qing
2023-02-06 22:00:47 +08:00
parent 24bff09534
commit 3f6bc8fada
9 changed files with 307 additions and 91 deletions

View File

@@ -62,7 +62,7 @@ def load_model(model: torch.nn.Module, url_or_path, device):
model_path = download_model(url_or_path)
try:
state_dict = torch.load(model_path, map_location='cpu')
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model.to(device)
logger.info(f"Load model from: {model_path}")
@@ -85,26 +85,43 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
return image_bytes
def load_img(img_bytes, gray: bool = False):
def pil_to_bytes(pil_img, ext: str, exif=None) -> bytes:
with io.BytesIO() as output:
pil_img.save(output, format=ext, exif=exif)
image_bytes = output.getvalue()
return image_bytes
def load_img(img_bytes, gray: bool = False, return_exif: bool = False):
alpha_channel = None
image = Image.open(io.BytesIO(img_bytes))
try:
if return_exif:
exif = image.getexif()
except:
exif = None
logger.error("Failed to extract exif from image")
try:
image = ImageOps.exif_transpose(image)
except:
pass
if gray:
image = image.convert('L')
image = image.convert("L")
np_img = np.array(image)
else:
if image.mode == 'RGBA':
if image.mode == "RGBA":
np_img = np.array(image)
alpha_channel = np_img[:, :, -1]
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
else:
image = image.convert('RGB')
image = image.convert("RGB")
np_img = np.array(image)
if return_exif:
return np_img, alpha_channel, exif
return np_img, alpha_channel