update
This commit is contained in:
@@ -62,7 +62,7 @@ def load_model(model: torch.nn.Module, url_or_path, device):
|
||||
model_path = download_model(url_or_path)
|
||||
|
||||
try:
|
||||
state_dict = torch.load(model_path, map_location='cpu')
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
model.to(device)
|
||||
logger.info(f"Load model from: {model_path}")
|
||||
@@ -85,26 +85,43 @@ def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
|
||||
return image_bytes
|
||||
|
||||
|
||||
def load_img(img_bytes, gray: bool = False):
|
||||
def pil_to_bytes(pil_img, ext: str, exif=None) -> bytes:
|
||||
with io.BytesIO() as output:
|
||||
pil_img.save(output, format=ext, exif=exif)
|
||||
image_bytes = output.getvalue()
|
||||
return image_bytes
|
||||
|
||||
|
||||
def load_img(img_bytes, gray: bool = False, return_exif: bool = False):
|
||||
alpha_channel = None
|
||||
image = Image.open(io.BytesIO(img_bytes))
|
||||
|
||||
try:
|
||||
if return_exif:
|
||||
exif = image.getexif()
|
||||
except:
|
||||
exif = None
|
||||
logger.error("Failed to extract exif from image")
|
||||
|
||||
try:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
except:
|
||||
pass
|
||||
|
||||
if gray:
|
||||
image = image.convert('L')
|
||||
image = image.convert("L")
|
||||
np_img = np.array(image)
|
||||
else:
|
||||
if image.mode == 'RGBA':
|
||||
if image.mode == "RGBA":
|
||||
np_img = np.array(image)
|
||||
alpha_channel = np_img[:, :, -1]
|
||||
np_img = cv2.cvtColor(np_img, cv2.COLOR_RGBA2RGB)
|
||||
else:
|
||||
image = image.convert('RGB')
|
||||
image = image.convert("RGB")
|
||||
np_img = np.array(image)
|
||||
|
||||
if return_exif:
|
||||
return np_img, alpha_channel, exif
|
||||
return np_img, alpha_channel
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user