glue-factory-custom/gluefactory/datasets/drone2sat.py

694 lines
25 KiB
Python

"""
Simply load images from a folder or nested folders (does not have any split),
and apply homographic adaptations to it. Yields an image pair without border
artifacts.
"""
import argparse
import logging
import shutil
import tarfile
from pathlib import Path
import os
import json
import gc
import cv2
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import torch
from omegaconf import OmegaConf
from tqdm import tqdm
from PIL import Image
from typing import List, Literal
from torchvision import transforms
import random
from torchvision.transforms import functional as F
from ..geometry.homography import (
compute_homography,
sample_homography_corners,
warp_points,
)
from ..models.cache_loader import CacheLoader, pad_local_features
from ..settings import DATA_PATH
from ..utils.image import read_image
from ..utils.tools import fork_rng
from ..visualization.viz2d import plot_image_grid
from .augmentations import IdentityAugmentation, augmentations
from .base_dataset import BaseDataset
from .s_utils import get_random_tiff_patch
logger = logging.getLogger(__name__)
class Drone2SatDataset(BaseDataset):
default_conf = {
# image search
"data_dir": "revisitop1m", # the top-level directory
"image_dir": "jpg/", # the subdirectory with the images
"image_list": "revisitop1m.txt", # optional: list or filename of list
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
"metadata_dir": None,
"uav_dataset_dir": None,
"satellite_dataset_dir": None,
# splits
"train_size": 100,
"val_size": 10,
"shuffle_seed": 0, # or None to skip
# image loading
"grayscale": False,
"triplet": False,
"right_only": False, # image0 is orig (rescaled), image1 is right
"reseed": False,
"homography": {
"difficulty": 0.8,
"translation": 1.0,
"max_angle": 60,
"n_angles": 10,
"patch_shape": [640, 480],
"min_convexity": 0.05,
},
"photometric": {
"name": "dark",
"p": 0.75,
# 'difficulty': 1.0, # currently unused
},
# feature loading
"load_features": {
"do": False,
**CacheLoader.default_conf,
"collate": False,
"thresh": 0.0,
"max_num_keypoints": -1,
"force_num_keypoints": False,
},
# Other geolocaliztion parameters
"geo_dataset" :{
"uav_dataset_dir": None,
"satellite_dataset_dir": None,
"misslabeled_images_path": None,
"sat_zoom_level": 17,
"uav_patch_width": 400,
"uav_patch_height": 400,
"sat_patch_width": 400,
"sat_patch_height": 400,
"heatmap_kernel_size": 33,
"test_from_train_ratio": 0.0,
"transform_mean": [0.485, 0.456, 0.406],
"transform_std": [0.229, 0.224, 0.225],
"sat_availaible_years": ["2023", "2021", "2019", "2016"],
"max_rotation_angle": 10,
"uav_image_scale": 1,
"use_heatmap": False,
}
}
def _init(self, conf):
self.images = {"train": "train", "val": "val"}
def get_dataset(self, split):
if split == "val":
return GeoLocalizationDataset(
uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir,
satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir,
misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path,
dataset="test",
sat_zoom_level=self.conf.geo_dataset.sat_zoom_level,
uav_patch_width=self.conf.geo_dataset.uav_patch_width,
uav_patch_height=self.conf.geo_dataset.uav_patch_height,
sat_patch_width=self.conf.geo_dataset.sat_patch_width,
sat_patch_height=self.conf.geo_dataset.sat_patch_height,
heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size,
test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio,
transform_mean=self.conf.geo_dataset.transform_mean,
transform_std=self.conf.geo_dataset.transform_std,
sat_available_years=self.conf.geo_dataset.sat_availaible_years,
max_rotation_angle=self.conf.geo_dataset.max_rotation_angle,
uav_image_scale=self.conf.geo_dataset.uav_image_scale,
use_heatmap=self.conf.geo_dataset.use_heatmap,
subset_size=self.conf.val_size,
)
elif split == "train":
return GeoLocalizationDataset(
uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir,
satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir,
misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path,
dataset="train",
sat_zoom_level=self.conf.geo_dataset.sat_zoom_level,
uav_patch_width=self.conf.geo_dataset.uav_patch_width,
uav_patch_height=self.conf.geo_dataset.uav_patch_height,
sat_patch_width=self.conf.geo_dataset.sat_patch_width,
sat_patch_height=self.conf.geo_dataset.sat_patch_height,
heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size,
test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio,
transform_mean=self.conf.geo_dataset.transform_mean,
transform_std=self.conf.geo_dataset.transform_std,
sat_available_years=self.conf.geo_dataset.sat_availaible_years,
max_rotation_angle=self.conf.geo_dataset.max_rotation_angle,
uav_image_scale=self.conf.geo_dataset.uav_image_scale,
use_heatmap=self.conf.geo_dataset.use_heatmap,
subset_size=self.conf.train_size,
)
class GeoLocalizationDataset(torch.utils.data.Dataset):
def __init__(
self,
uav_dataset_dir: str,
satellite_dataset_dir: str,
misslabeled_images_path: str,
dataset: Literal["train", "test"],
sat_zoom_level: int = 16,
uav_patch_width: int = 128,
uav_patch_height: int = 128,
sat_patch_width: int = 400,
sat_patch_height: int = 400,
heatmap_kernel_size: int = 33,
test_from_train_ratio: float = 0.0,
transform_mean: List[float] = [0.485, 0.456, 0.406],
transform_std: List[float] = [0.229, 0.224, 0.225],
sat_available_years: List[str] = ["2023", "2021", "2019", "2016"],
max_rotation_angle: int = 10,
uav_image_scale: float = 1,
use_heatmap: bool = True,
subset_size: int = None,
):
self.uav_dataset_dir = uav_dataset_dir
self.satellite_dataset_dir = satellite_dataset_dir
self.dataset = dataset
self.sat_zoom_level = sat_zoom_level
self.uav_patch_width = uav_patch_width
self.uav_patch_height = uav_patch_height
self.heatmap_kernel_size = heatmap_kernel_size
self.test_from_train_ratio = test_from_train_ratio
self.transform_mean = transform_mean
self.transform_std = transform_std
self.misslabeled_images_path = misslabeled_images_path
self.metadata_dict = {}
self.max_rotation_angle = max_rotation_angle
self.total_uav_samples = self.count_total_uav_samples()
self.misslabelled_images = self.read_misslabelled_images(self.misslabeled_images_path)
self.entry_paths = self.get_entry_paths(self.uav_dataset_dir)
self.cleanup_misslabelled_images()
self.transforms = transforms.Compose(
[
transforms.ToTensor(),
#transforms.Normalize(self.transform_mean, self.transform_std),
]
)
self.sat_available_years = sat_available_years
self.uav_image_scale = uav_image_scale
self.use_heatmap = use_heatmap
self.sat_patch_width = sat_patch_width
self.sat_patch_height = sat_patch_height
self.subset_size = subset_size
self.inverse_transforms = transforms.Compose(
[
transforms.Normalize(
mean=[
-m / s for m, s in zip(self.transform_mean, self.transform_std)
],
std=[1 / s for s in self.transform_std],
),
transforms.ToPILImage(),
]
)
if self.subset_size != -1:
ssize = int(self.subset_size / len(self.sat_available_years))
ssize = min(ssize, len(self.entry_paths))
self.entry_paths = self.entry_paths[:ssize]
def __len__(self) -> int:
return (
len(self.entry_paths)
* len(self.sat_available_years)
)
def read_misslabelled_images(
self, path: str = "misslabels/misslabeled.txt"
) -> List[str]:
with open(path, "r") as f:
lines = f.readlines()
return [line.strip() for line in lines]
def cleanup_misslabelled_images(self) -> None:
indices_to_delete = []
for image in self.misslabelled_images:
for image_path in self.entry_paths:
if image in image_path:
index = self.entry_paths.index(image_path)
indices_to_delete.append(index)
break
sorted_tuples = sorted(indices_to_delete, reverse=True)
for index in sorted_tuples:
self.entry_paths.pop(index)
def __getitem__(self, idx) -> dict:
"""
Retrieves a sample given its index, returning the preprocessed UAV
and satellite images, along with their associated heatmap and metadata.
"""
image_path_index = idx // (
len(self.sat_available_years)
)
sat_year = self.sat_available_years[idx % len(self.sat_available_years)]
rot_angle = random.randint(-self.max_rotation_angle, self.max_rotation_angle)
image_path = self.entry_paths[image_path_index]
uav_image = Image.open(image_path).convert("RGB") # Ensure 3-channel image
original_uav_image_width = uav_image.width
original_uav_image_height = uav_image.height
lookup_str, file_number = self.extract_info_from_filename(image_path)
img_info = self.metadata_dict[lookup_str][file_number]
lat, lon = (
img_info["coordinate"]["latitude"],
img_info["coordinate"]["longitude"],
)
fov_vertical = img_info["fovVertical"]
try:
agl_altitude = float(image_path.split("/")[-1].split("_")[2].split("m")[0])
except IndexError:
agl_altitude = 150.0
warnings.warn(
"Could not extract AGL altitude from filename, using default value of 150m."
)
(
satellite_patch,
x_sat,
y_sat,
x_offset,
y_offset,
patch_transform,
) = get_random_tiff_patch(
lat, lon, self.sat_patch_width, self.sat_patch_height, sat_year, self.satellite_dataset_dir
)
# Rotate crop center and transform image
h = np.ceil(uav_image.height // self.uav_image_scale).astype(int)
w = np.ceil(uav_image.width // self.uav_image_scale).astype(int)
uav_image = F.rotate(uav_image, rot_angle)
uav_image = F.resize(uav_image, [h, w])
uav_image = F.center_crop(
uav_image, (self.uav_patch_height, self.uav_patch_width)
)
uav_image = self.transforms(uav_image)
satellite_patch = satellite_patch.transpose(1, 2, 0)
satellite_patch_pytorch = self.transforms(satellite_patch)
del satellite_patch
if self.use_heatmap:
heatmap = self.get_heatmap_gt(
x_sat,
y_sat,
satellite_patch.shape[1],
satellite_patch.shape[2],
self.heatmap_kernel_size,
)
cropped_uav_image_width = self.calculate_cropped_uav_image_width(
fov_vertical,
original_uav_image_width,
original_uav_image_height,
self.uav_patch_width,
self.uav_patch_height,
agl_altitude,
)
satellite_tile_width = self.calculate_cropped_sat_image_width(
lat, self.sat_patch_width, patch_transform
)
scale_factor = cropped_uav_image_width / satellite_tile_width
scale_factor *= 10
homography_matrix = self.compute_homography(
rot_angle,
x_sat,
y_sat,
self.uav_patch_width,
self.uav_patch_height,
scale_factor,
)
if not self.use_heatmap:
# Sample four points from the UAV image
points = self.sample_four_points(
self.uav_patch_width, self.uav_patch_height
)
# Transform the points
warped_points = self.warp_points(points, homography_matrix)
#img_info["warped_points_sat"] = warped_points
#img_info["warped_points_uav"] = points
# img_info["cropped_uav_image_width"] = cropped_uav_image_width
# img_info["satellite_tile_width"] = satellite_tile_width
# img_info["scale_factor"] = scale_factor
# img_info["filename"] = image_path
# img_info["rot_angle"] = rot_angle
# img_info["x_sat"] = x_sat
# img_info["y_sat"] = y_sat
# img_info["x_offset"] = x_offset
# img_info["y_offset"] = y_offset
# img_info["patch_transform"] = patch_transform
# img_info["uav_image_scale"] = self.uav_image_scale
# img_info["homography_matrix_uav_to_sat"] = homography_matrix
# img_info["homography_matrix_sat_to_uav"] = np.linalg.inv(homography_matrix)
# img_info["agl_altitude"] = agl_altitude
# img_info["original_uav_image_width"] = original_uav_image_width
# img_info["original_drone_image_height"] = original_uav_image_height
# img_info["fov_vertical"] = fov_vertical
#
inverse_homography_matrix = np.linalg.inv(homography_matrix)
uav_image_data = {
"image": uav_image,
"H_": homography_matrix,
"coords": points,
"image_size": np.array([self.uav_patch_width, self.uav_patch_height]),
}
satellite_patch_data = {
"image": satellite_patch_pytorch,
"H_": inverse_homography_matrix,
"coords": warped_points,
"image_size": np.array([self.sat_patch_width, self.sat_patch_height]),
}
#hm = self.compute_homography_points(
# warped_points, points, [self.uav_patch_width, self.uav_patch_height]
#)
hm = self.compute_homography_points(
points, warped_points, [1.0 , 1.0]
)
if np.array_equal(hm, np.eye(3)):
print("Singular matrix")
return self.__getitem__(random.randint(0, len(self) - 1))
data = {
"name": f"{image_path}_{rot_angle}_{sat_year}",
"original_image_size": np.array(
[original_uav_image_width, original_uav_image_height]
),
"H_0to1": hm.astype(np.float32),
"view0": uav_image_data,
"view1": satellite_patch_data,
"idx": idx
}
del img_info
gc.collect()
return data
def calculate_cropped_sat_image_width(self, latitude, patch_width, patch_transform):
"""
Computes the width of the satellite image in world coordinate units
"""
length_of_degree = 111320 * np.cos(np.radians(latitude))
scale_x_meters = patch_transform[0] * length_of_degree
satellite_tile_width = patch_width * scale_x_meters
return satellite_tile_width
def calculate_cropped_uav_image_width(
self,
fov_vertical,
orig_width,
orig_height,
crop_width,
crop_height,
altitude=150.0,
):
"""
Computes the width of the UAV image in world coordinate units
"""
# Convert fov from degrees to radians
fov_rad = np.radians(fov_vertical)
# Calculate the full width of the UAV image
full_width = 2 * (altitude * np.tan(fov_rad / 2))
# Determine the cropping ratio
crop_ratio_width = crop_width / orig_width
crop_ratio_height = crop_height / orig_height
# Calculate the adjusted horizontal fov
fov_horizontal = 2 * np.arctan(np.tan(fov_rad / 2) * (orig_width / orig_height))
adjusted_fov_horizontal = 2 * np.arctan(
np.tan(fov_horizontal / 2) * crop_ratio_width
)
# Calculate the new full width using the adjusted horizontal fov
full_width = 2 * (altitude * np.tan(adjusted_fov_horizontal / 2))
# Adjust the width according to the crop ratio
cropped_width = full_width * crop_ratio_width
return cropped_width
def compute_homography_points(self, pts1_, pts2_, shape):
"""Compute the homography matrix from 4 point correspondences"""
# Rescale to actual size
shape = np.array(shape[::-1], dtype=np.float32) # different convention [y, x]
pts1 = pts1_ * np.expand_dims(shape, axis=0)
pts2 = pts2_ * np.expand_dims(shape, axis=0)
def ax(p, q):
return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
def ay(p, q):
return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
def flat2mat(H):
return np.reshape(np.concatenate([H, np.ones_like(H[:, :1])], axis=1), [3, 3])
a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0)
p_mat = np.transpose(
np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0)
)
try:
homography = np.transpose(np.linalg.solve(a_mat, p_mat))
except np.linalg.LinAlgError:
print("Singular matrix")
return np.eye(3)
return flat2mat(homography)
def compute_homography(
self, rot_angle, x_sat, y_sat, uav_width, uav_height, scale_factor
):
# Adjust rot_angle if it's greater than 180 degrees
if rot_angle > 180:
rot_angle -= 360
# Convert rotation angle to radians
theta = np.radians(rot_angle)
# Rotation matrix
R = np.array(
[
[np.cos(theta), -np.sin(theta), 0],
[np.sin(theta), np.cos(theta), 0],
[0, 0, 1],
]
)
# Scale matrix
S = np.array([[scale_factor, 0, 0], [0, scale_factor, 0], [0, 0, 1]])
# Translation matrix to center the UAV image
T_uav = np.array([[1, 0, -uav_width / 2], [0, 1, -uav_height / 2], [0, 0, 1]])
# Translation matrix to move to the satellite image position
T_sat = np.array([[1, 0, x_sat], [0, 1, y_sat], [0, 0, 1]])
# Compute the combined homography matrix
H = np.dot(T_sat, np.dot(R, np.dot(S, T_uav)))
return H
def sample_four_points(self, width: int, height: int) -> np.ndarray:
"""
Samples four points from the UAV image.
"""
PADDING = 50
CENTER_PADDING = 10
points = np.array(
[
[random.randint(CENTER_PADDING, width - PADDING), random.randint(CENTER_PADDING, height - PADDING)]
for _ in range(4)
]
)
return points
def warp_points(
self, points: np.ndarray, homography_matrix: np.ndarray
) -> np.ndarray:
"""
Warps the given points using the given homography matrix.
"""
points = np.array(points)
points = np.concatenate([points, np.ones((4, 1))], axis=1)
points = np.dot(homography_matrix, points.T).T
points = points[:, :2] / points[:, 2:]
return points
def count_total_uav_samples(self) -> int:
"""
Count the total number of uav image samples in the dataset
(train + test)
"""
total_samples = 0
for dirpath, dirnames, filenames in os.walk(self.uav_dataset_dir):
# Skip the test folder
for filename in filenames:
if filename.endswith(".jpeg"):
total_samples += 1
return total_samples
def get_number_of_city_samples(self) -> int:
"""
TODO: Count the total number of city samples in the dataset
"""
return 11
def get_entry_paths(self, directory: str) -> List[str]:
"""
Recursively retrieves paths to image and metadata files in the given directory.
"""
entry_paths = []
entries = os.listdir(directory)
images_to_take_per_folder = int(
self.total_uav_samples
* self.test_from_train_ratio
/ self.get_number_of_city_samples()
)
for entry in entries:
entry_path = os.path.join(directory, entry)
# If it's a directory, recurse into it
if os.path.isdir(entry_path):
entry_paths += self.get_entry_paths(entry_path)
# Handle train dataset
elif (self.dataset == "train" and "Train" in entry_path) or (
self.dataset == "train"
and self.test_from_train_ratio > 0
and "Test" in entry_path
):
if entry_path.endswith(".jpeg"):
_, number = self.extract_info_from_filename(entry_path)
else:
number = None
if entry_path.endswith(".json"):
self.get_metadata(entry_path)
if number is None:
continue
if (
number >= images_to_take_per_folder
): # Only include images beyond the ones taken for test
if entry_path.endswith(".jpeg"):
entry_paths.append(entry_path)
# Handle test dataset
elif self.dataset == "test":
if entry_path.endswith(".jpeg"):
_, number = self.extract_info_from_filename(entry_path)
else:
number = None
if entry_path.endswith(".json"):
self.get_metadata(entry_path)
if number is None:
continue
if (
("Test" in entry_path and number < images_to_take_per_folder)
or (number < images_to_take_per_folder and "Train" in entry_path)
or (self.test_from_train_ratio == 0.0 and "Test" in entry_path)
):
if entry_path.endswith(".jpeg"):
entry_paths.append(entry_path)
return sorted(entry_paths, key=self.extract_info_from_filename)
def get_metadata(self, path: str) -> None:
"""
Extracts metadata from a JSON file and stores it in the metadata dictionary.
"""
with open(path, newline="") as jsonfile:
json_dict = json.load(jsonfile)
path = path.split("/")[-1]
path = path.replace(".json", "")
self.metadata_dict[path] = json_dict["cameraFrames"]
def extract_info_from_filename(self, filename: str) -> (str, int):
"""
Extracts information from the filename.
"""
filename_without_ext = filename.replace(".jpeg", "")
segments = filename_without_ext.split("/")
info = segments[-1]
try:
number = int(info.split("_")[-1])
except ValueError:
print("Could not extract number from filename: ", filename)
return None, None
info = "_".join(info.split("_")[:-1])
return info, number
def get_heatmap_gt(
self, x: int, y: int, height: int, width: int, square_size: int = 33
) -> torch.Tensor:
"""
Returns 2D heatmap ground truth for the given x and y coordinates,
with the given square size.
"""
x_map, y_map = x, y
heatmap = torch.zeros((height, width))
half_size = square_size // 2
# Calculate the valid range for the square
start_x = max(0, x_map - half_size)
end_x = min(
width, x_map + half_size + 1
) # +1 to include the end_x in the square
start_y = max(0, y_map - half_size)
end_y = min(
height, y_map + half_size + 1
) # +1 to include the end_y in the square
heatmap[start_y:end_y, start_x:end_x] = 1
return heatmap
if __name__ == "__main__":
from .. import logger # overwrite the logger