2023-10-09 08:32:43 +02:00
|
|
|
import collections.abc as collections
|
2023-10-05 16:53:51 +02:00
|
|
|
from pathlib import Path
|
2023-10-09 08:32:43 +02:00
|
|
|
from typing import Optional, Tuple
|
|
|
|
|
2023-10-05 16:53:51 +02:00
|
|
|
import cv2
|
2023-10-09 08:32:43 +02:00
|
|
|
import kornia
|
2023-10-05 16:53:51 +02:00
|
|
|
import numpy as np
|
2023-10-09 08:32:43 +02:00
|
|
|
import torch
|
2023-10-05 16:53:51 +02:00
|
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
|
|
|
|
|
|
class ImagePreprocessor:
|
|
|
|
default_conf = {
|
|
|
|
"resize": None, # target edge length, None for no resizing
|
|
|
|
"edge_divisible_by": None,
|
|
|
|
"side": "long",
|
|
|
|
"interpolation": "bilinear",
|
|
|
|
"align_corners": None,
|
|
|
|
"antialias": True,
|
|
|
|
"square_pad": False,
|
|
|
|
"add_padding_mask": False,
|
|
|
|
}
|
|
|
|
|
|
|
|
def __init__(self, conf) -> None:
|
|
|
|
super().__init__()
|
|
|
|
default_conf = OmegaConf.create(self.default_conf)
|
|
|
|
OmegaConf.set_struct(default_conf, True)
|
|
|
|
self.conf = OmegaConf.merge(default_conf, conf)
|
|
|
|
|
|
|
|
def __call__(self, img: torch.Tensor, interpolation: Optional[str] = None) -> dict:
|
|
|
|
"""Resize and preprocess an image, return image and resize scale"""
|
|
|
|
h, w = img.shape[-2:]
|
|
|
|
size = h, w
|
|
|
|
if self.conf.resize is not None:
|
|
|
|
if interpolation is None:
|
|
|
|
interpolation = self.conf.interpolation
|
|
|
|
size = self.get_new_image_size(h, w)
|
|
|
|
img = kornia.geometry.transform.resize(
|
|
|
|
img,
|
|
|
|
size,
|
|
|
|
side=self.conf.side,
|
|
|
|
antialias=self.conf.antialias,
|
|
|
|
align_corners=self.conf.align_corners,
|
|
|
|
interpolation=interpolation,
|
|
|
|
)
|
|
|
|
scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
|
|
|
|
T = np.diag([scale[0], scale[1], 1])
|
|
|
|
|
|
|
|
data = {
|
|
|
|
"scales": scale,
|
|
|
|
"image_size": np.array(size[::-1]),
|
|
|
|
"transform": T,
|
|
|
|
"original_image_size": np.array([w, h]),
|
|
|
|
}
|
|
|
|
if self.conf.square_pad:
|
|
|
|
sl = max(img.shape[-2:])
|
|
|
|
data["image"] = torch.zeros(
|
|
|
|
*img.shape[:-2], sl, sl, device=img.device, dtype=img.dtype
|
|
|
|
)
|
|
|
|
data["image"][:, : img.shape[-2], : img.shape[-1]] = img
|
|
|
|
if self.conf.add_padding_mask:
|
|
|
|
data["padding_mask"] = torch.zeros(
|
|
|
|
*img.shape[:-3], 1, sl, sl, device=img.device, dtype=torch.bool
|
|
|
|
)
|
|
|
|
data["padding_mask"][:, : img.shape[-2], : img.shape[-1]] = True
|
|
|
|
|
|
|
|
else:
|
|
|
|
data["image"] = img
|
|
|
|
return data
|
|
|
|
|
|
|
|
def load_image(self, image_path: Path) -> dict:
|
|
|
|
return self(load_image(image_path))
|
|
|
|
|
|
|
|
def get_new_image_size(
|
|
|
|
self,
|
|
|
|
h: int,
|
|
|
|
w: int,
|
|
|
|
) -> Tuple[int, int]:
|
|
|
|
side = self.conf.side
|
|
|
|
if isinstance(self.conf.resize, collections.Iterable):
|
|
|
|
assert len(self.conf.resize) == 2
|
|
|
|
return tuple(self.conf.resize)
|
|
|
|
side_size = self.conf.resize
|
|
|
|
aspect_ratio = w / h
|
|
|
|
if side not in ("short", "long", "vert", "horz"):
|
|
|
|
raise ValueError(
|
|
|
|
f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{side}'"
|
|
|
|
)
|
|
|
|
if side == "vert":
|
|
|
|
size = side_size, int(side_size * aspect_ratio)
|
|
|
|
elif side == "horz":
|
|
|
|
size = int(side_size / aspect_ratio), side_size
|
|
|
|
elif (side == "short") ^ (aspect_ratio < 1.0):
|
|
|
|
size = side_size, int(side_size * aspect_ratio)
|
|
|
|
else:
|
|
|
|
size = int(side_size / aspect_ratio), side_size
|
|
|
|
|
|
|
|
if self.conf.edge_divisible_by is not None:
|
|
|
|
df = self.conf.edge_divisible_by
|
|
|
|
size = list(map(lambda x: int(x // df * df), size))
|
|
|
|
return size
|
|
|
|
|
|
|
|
|
|
|
|
def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
|
|
|
|
"""Read an image from path as RGB or grayscale"""
|
|
|
|
if not Path(path).exists():
|
|
|
|
raise FileNotFoundError(f"No image at path {path}.")
|
|
|
|
mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
|
|
|
|
image = cv2.imread(str(path), mode)
|
|
|
|
if image is None:
|
|
|
|
raise IOError(f"Could not read image at {path}.")
|
|
|
|
if not grayscale:
|
|
|
|
image = image[..., ::-1]
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
|
|
|
|
"""Normalize the image tensor and reorder the dimensions."""
|
|
|
|
if image.ndim == 3:
|
|
|
|
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
|
|
|
|
elif image.ndim == 2:
|
|
|
|
image = image[None] # add channel axis
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Not an image: {image.shape}")
|
|
|
|
return torch.tensor(image / 255.0, dtype=torch.float)
|
|
|
|
|
|
|
|
|
|
|
|
def load_image(path: Path, grayscale=False) -> torch.Tensor:
|
|
|
|
image = read_image(path, grayscale=grayscale)
|
|
|
|
return numpy_image_to_torch(image)
|