245 lines
7.8 KiB
Python
245 lines
7.8 KiB
Python
from typing import Union
|
|
|
|
import albumentations as A
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from albumentations.pytorch.transforms import ToTensorV2
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
class IdentityTransform(A.ImageOnlyTransform):
|
|
def apply(self, img, **params):
|
|
return img
|
|
|
|
def get_transform_init_args_names(self):
|
|
return ()
|
|
|
|
|
|
class RandomAdditiveShade(A.ImageOnlyTransform):
|
|
def __init__(
|
|
self,
|
|
nb_ellipses=10,
|
|
transparency_limit=[-0.5, 0.8],
|
|
kernel_size_limit=[150, 350],
|
|
always_apply=False,
|
|
p=0.5,
|
|
):
|
|
super().__init__(always_apply, p)
|
|
self.nb_ellipses = nb_ellipses
|
|
self.transparency_limit = transparency_limit
|
|
self.kernel_size_limit = kernel_size_limit
|
|
|
|
def apply(self, img, **params):
|
|
if img.dtype == np.float32:
|
|
shaded = self._py_additive_shade(img * 255.0)
|
|
shaded /= 255.0
|
|
elif img.dtype == np.uint8:
|
|
shaded = self._py_additive_shade(img.astype(np.float32))
|
|
shaded = shaded.astype(np.uint8)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Data augmentation not available for type: {img.dtype}"
|
|
)
|
|
return shaded
|
|
|
|
def _py_additive_shade(self, img):
|
|
grayscale = len(img.shape) == 2
|
|
if grayscale:
|
|
img = img[None]
|
|
min_dim = min(img.shape[:2]) / 4
|
|
mask = np.zeros(img.shape[:2], img.dtype)
|
|
for i in range(self.nb_ellipses):
|
|
ax = int(max(np.random.rand() * min_dim, min_dim / 5))
|
|
ay = int(max(np.random.rand() * min_dim, min_dim / 5))
|
|
max_rad = max(ax, ay)
|
|
x = np.random.randint(max_rad, img.shape[1] - max_rad) # center
|
|
y = np.random.randint(max_rad, img.shape[0] - max_rad)
|
|
angle = np.random.rand() * 90
|
|
cv2.ellipse(mask, (x, y), (ax, ay), angle, 0, 360, 255, -1)
|
|
|
|
transparency = np.random.uniform(*self.transparency_limit)
|
|
ks = np.random.randint(*self.kernel_size_limit)
|
|
if (ks % 2) == 0: # kernel_size has to be odd
|
|
ks += 1
|
|
mask = cv2.GaussianBlur(mask.astype(np.float32), (ks, ks), 0)
|
|
shaded = img * (1 - transparency * mask[..., np.newaxis] / 255.0)
|
|
out = np.clip(shaded, 0, 255)
|
|
if grayscale:
|
|
out = out.squeeze(0)
|
|
return out
|
|
|
|
def get_transform_init_args_names(self):
|
|
return "transparency_limit", "kernel_size_limit", "nb_ellipses"
|
|
|
|
|
|
def kw(entry: Union[float, dict], n=None, **default):
|
|
if not isinstance(entry, dict):
|
|
entry = {"p": entry}
|
|
entry = OmegaConf.create(entry)
|
|
if n is not None:
|
|
entry = default.get(n, entry)
|
|
return OmegaConf.merge(default, entry)
|
|
|
|
|
|
def kwi(entry: Union[float, dict], n=None, **default):
|
|
conf = kw(entry, n=n, **default)
|
|
return {k: conf[k] for k in set(default.keys()).union(set(["p"]))}
|
|
|
|
|
|
def replay_str(transforms, s="Replay:\n", log_inactive=True):
|
|
for t in transforms:
|
|
if "transforms" in t.keys():
|
|
s = replay_str(t["transforms"], s=s)
|
|
elif t["applied"] or log_inactive:
|
|
s += t["__class_fullname__"] + " " + str(t["applied"]) + "\n"
|
|
return s
|
|
|
|
|
|
class BaseAugmentation(object):
|
|
base_default_conf = {
|
|
"name": "???",
|
|
"shuffle": False,
|
|
"p": 1.0,
|
|
"verbose": False,
|
|
"dtype": "uint8", # (byte, float)
|
|
}
|
|
|
|
default_conf = {}
|
|
|
|
def __init__(self, conf={}):
|
|
"""Perform some logic and call the _init method of the child model."""
|
|
default_conf = OmegaConf.merge(
|
|
OmegaConf.create(self.base_default_conf),
|
|
OmegaConf.create(self.default_conf),
|
|
)
|
|
OmegaConf.set_struct(default_conf, True)
|
|
if isinstance(conf, dict):
|
|
conf = OmegaConf.create(conf)
|
|
self.conf = OmegaConf.merge(default_conf, conf)
|
|
OmegaConf.set_readonly(self.conf, True)
|
|
self._init(self.conf)
|
|
|
|
self.conf = OmegaConf.merge(self.conf, conf)
|
|
if self.conf.verbose:
|
|
self.compose = A.ReplayCompose
|
|
else:
|
|
self.compose = A.Compose
|
|
if self.conf.dtype == "uint8":
|
|
self.dtype = np.uint8
|
|
self.preprocess = A.FromFloat(always_apply=True, dtype="uint8")
|
|
self.postprocess = A.ToFloat(always_apply=True)
|
|
elif self.conf.dtype == "float32":
|
|
self.dtype = np.float32
|
|
self.preprocess = A.ToFloat(always_apply=True)
|
|
self.postprocess = IdentityTransform()
|
|
else:
|
|
raise ValueError(f"Unsupported dtype {self.conf.dtype}")
|
|
self.to_tensor = ToTensorV2()
|
|
|
|
def _init(self, conf):
|
|
"""Child class overwrites this, setting up a list of transforms"""
|
|
self.transforms = []
|
|
|
|
def __call__(self, image, return_tensor=False):
|
|
"""image as HW or HWC"""
|
|
if isinstance(image, torch.Tensor):
|
|
image = image.cpu().detach().numpy()
|
|
data = {"image": image}
|
|
if image.dtype != self.dtype:
|
|
data = self.preprocess(**data)
|
|
transforms = self.transforms
|
|
if self.conf.shuffle:
|
|
order = [i for i, _ in enumerate(transforms)]
|
|
np.random.shuffle(order)
|
|
transforms = [transforms[i] for i in order]
|
|
transformed = self.compose(transforms, p=self.conf.p)(**data)
|
|
if self.conf.verbose:
|
|
print(replay_str(transformed["replay"]["transforms"]))
|
|
transformed = self.postprocess(**transformed)
|
|
if return_tensor:
|
|
return self.to_tensor(**transformed)["image"]
|
|
else:
|
|
return transformed["image"]
|
|
|
|
|
|
class IdentityAugmentation(BaseAugmentation):
|
|
default_conf = {}
|
|
|
|
def _init(self, conf):
|
|
self.transforms = [IdentityTransform(p=1.0)]
|
|
|
|
|
|
class DarkAugmentation(BaseAugmentation):
|
|
default_conf = {"p": 0.75}
|
|
|
|
def _init(self, conf):
|
|
bright_contr = 0.5
|
|
blur = 0.1
|
|
random_gamma = 0.1
|
|
hue = 0.1
|
|
self.transforms = [
|
|
A.RandomRain(p=0.2),
|
|
A.RandomBrightnessContrast(
|
|
**kw(
|
|
bright_contr,
|
|
brightness_limit=(-0.4, 0.0),
|
|
contrast_limit=(-0.3, 0.0),
|
|
)
|
|
),
|
|
A.OneOf(
|
|
[
|
|
A.Blur(**kwi(blur, p=0.1, blur_limit=(3, 9), n="blur")),
|
|
A.MotionBlur(
|
|
**kwi(blur, p=0.2, blur_limit=(3, 25), n="motion_blur")
|
|
),
|
|
A.ISONoise(),
|
|
A.ImageCompression(),
|
|
],
|
|
**kwi(blur, p=0.1),
|
|
),
|
|
A.RandomGamma(**kw(random_gamma, gamma_limit=(15, 65))),
|
|
A.OneOf(
|
|
[
|
|
A.Equalize(),
|
|
A.CLAHE(p=0.2),
|
|
A.ToGray(),
|
|
A.ToSepia(p=0.1),
|
|
A.HueSaturationValue(**kw(hue, val_shift_limit=(-100, -40))),
|
|
],
|
|
p=0.5,
|
|
),
|
|
]
|
|
|
|
|
|
class LGAugmentation(BaseAugmentation):
|
|
default_conf = {"p": 0.95}
|
|
|
|
def _init(self, conf):
|
|
self.transforms = [
|
|
A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
|
|
A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)),
|
|
A.OneOf(
|
|
[
|
|
A.Blur(blur_limit=(3, 9)),
|
|
A.MotionBlur(blur_limit=(3, 25)),
|
|
A.ISONoise(),
|
|
A.ImageCompression(),
|
|
],
|
|
p=0.1,
|
|
),
|
|
A.Blur(p=0.1, blur_limit=(3, 9)),
|
|
A.MotionBlur(p=0.1, blur_limit=(3, 25)),
|
|
A.RandomBrightnessContrast(
|
|
p=0.5, brightness_limit=(-0.4, 0.0), contrast_limit=(-0.3, 0.0)
|
|
),
|
|
A.CLAHE(p=0.2),
|
|
]
|
|
|
|
|
|
augmentations = {
|
|
"dark": DarkAugmentation,
|
|
"lg": LGAugmentation,
|
|
"identity": IdentityAugmentation,
|
|
}
|