""" Simply load images from a folder or nested folders (does not have any split), and apply homographic adaptations to it. Yields an image pair without border artifacts. """ import argparse import logging import shutil import tarfile from pathlib import Path import os import json import gc import cv2 import matplotlib.pyplot as plt import numpy as np import omegaconf import torch from omegaconf import OmegaConf from tqdm import tqdm from PIL import Image from typing import List, Literal from torchvision import transforms import random from torchvision.transforms import functional as F from ..geometry.homography import ( compute_homography, sample_homography_corners, warp_points, ) from ..models.cache_loader import CacheLoader, pad_local_features from ..settings import DATA_PATH from ..utils.image import read_image from ..utils.tools import fork_rng from ..visualization.viz2d import plot_image_grid from .augmentations import IdentityAugmentation, augmentations from .base_dataset import BaseDataset from .s_utils import get_random_tiff_patch logger = logging.getLogger(__name__) class Drone2SatDataset(BaseDataset): default_conf = { # image search "data_dir": "revisitop1m", # the top-level directory "image_dir": "jpg/", # the subdirectory with the images "image_list": "revisitop1m.txt", # optional: list or filename of list "glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"], "metadata_dir": None, "uav_dataset_dir": None, "satellite_dataset_dir": None, # splits "train_size": 100, "val_size": 10, "shuffle_seed": 0, # or None to skip # image loading "grayscale": False, "triplet": False, "right_only": False, # image0 is orig (rescaled), image1 is right "reseed": False, "homography": { "difficulty": 0.8, "translation": 1.0, "max_angle": 60, "n_angles": 10, "patch_shape": [640, 480], "min_convexity": 0.05, }, "photometric": { "name": "dark", "p": 0.75, # 'difficulty': 1.0, # currently unused }, # feature loading "load_features": { "do": False, **CacheLoader.default_conf, "collate": False, "thresh": 0.0, "max_num_keypoints": -1, "force_num_keypoints": False, }, # Other geolocaliztion parameters "geo_dataset" :{ "uav_dataset_dir": None, "satellite_dataset_dir": None, "misslabeled_images_path": None, "sat_zoom_level": 17, "uav_patch_width": 400, "uav_patch_height": 400, "sat_patch_width": 400, "sat_patch_height": 400, "heatmap_kernel_size": 33, "test_from_train_ratio": 0.0, "transform_mean": [0.485, 0.456, 0.406], "transform_std": [0.229, 0.224, 0.225], "sat_availaible_years": ["2023", "2021", "2019", "2016"], "max_rotation_angle": 10, "uav_image_scale": 1, "use_heatmap": False, } } def _init(self, conf): self.images = {"train": "train", "val": "val"} def get_dataset(self, split): if split == "val": return GeoLocalizationDataset( uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir, satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir, misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path, dataset="test", sat_zoom_level=self.conf.geo_dataset.sat_zoom_level, uav_patch_width=self.conf.geo_dataset.uav_patch_width, uav_patch_height=self.conf.geo_dataset.uav_patch_height, sat_patch_width=self.conf.geo_dataset.sat_patch_width, sat_patch_height=self.conf.geo_dataset.sat_patch_height, heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size, test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio, transform_mean=self.conf.geo_dataset.transform_mean, transform_std=self.conf.geo_dataset.transform_std, sat_available_years=self.conf.geo_dataset.sat_availaible_years, max_rotation_angle=self.conf.geo_dataset.max_rotation_angle, uav_image_scale=self.conf.geo_dataset.uav_image_scale, use_heatmap=self.conf.geo_dataset.use_heatmap, subset_size=self.conf.val_size, ) elif split == "train": return GeoLocalizationDataset( uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir, satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir, misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path, dataset="train", sat_zoom_level=self.conf.geo_dataset.sat_zoom_level, uav_patch_width=self.conf.geo_dataset.uav_patch_width, uav_patch_height=self.conf.geo_dataset.uav_patch_height, sat_patch_width=self.conf.geo_dataset.sat_patch_width, sat_patch_height=self.conf.geo_dataset.sat_patch_height, heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size, test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio, transform_mean=self.conf.geo_dataset.transform_mean, transform_std=self.conf.geo_dataset.transform_std, sat_available_years=self.conf.geo_dataset.sat_availaible_years, max_rotation_angle=self.conf.geo_dataset.max_rotation_angle, uav_image_scale=self.conf.geo_dataset.uav_image_scale, use_heatmap=self.conf.geo_dataset.use_heatmap, subset_size=self.conf.train_size, ) class GeoLocalizationDataset(torch.utils.data.Dataset): def __init__( self, uav_dataset_dir: str, satellite_dataset_dir: str, misslabeled_images_path: str, dataset: Literal["train", "test"], sat_zoom_level: int = 16, uav_patch_width: int = 128, uav_patch_height: int = 128, sat_patch_width: int = 400, sat_patch_height: int = 400, heatmap_kernel_size: int = 33, test_from_train_ratio: float = 0.0, transform_mean: List[float] = [0.485, 0.456, 0.406], transform_std: List[float] = [0.229, 0.224, 0.225], sat_available_years: List[str] = ["2023", "2021", "2019", "2016"], max_rotation_angle: int = 10, uav_image_scale: float = 1, use_heatmap: bool = True, subset_size: int = None, ): self.uav_dataset_dir = uav_dataset_dir self.satellite_dataset_dir = satellite_dataset_dir self.dataset = dataset self.sat_zoom_level = sat_zoom_level self.uav_patch_width = uav_patch_width self.uav_patch_height = uav_patch_height self.heatmap_kernel_size = heatmap_kernel_size self.test_from_train_ratio = test_from_train_ratio self.transform_mean = transform_mean self.transform_std = transform_std self.misslabeled_images_path = misslabeled_images_path self.metadata_dict = {} self.max_rotation_angle = max_rotation_angle self.total_uav_samples = self.count_total_uav_samples() self.misslabelled_images = self.read_misslabelled_images(self.misslabeled_images_path) self.entry_paths = self.get_entry_paths(self.uav_dataset_dir) self.cleanup_misslabelled_images() self.transforms = transforms.Compose( [ transforms.ToTensor(), #transforms.Normalize(self.transform_mean, self.transform_std), ] ) self.sat_available_years = sat_available_years self.uav_image_scale = uav_image_scale self.use_heatmap = use_heatmap self.sat_patch_width = sat_patch_width self.sat_patch_height = sat_patch_height self.subset_size = subset_size self.inverse_transforms = transforms.Compose( [ transforms.Normalize( mean=[ -m / s for m, s in zip(self.transform_mean, self.transform_std) ], std=[1 / s for s in self.transform_std], ), transforms.ToPILImage(), ] ) if self.subset_size != -1: ssize = int(self.subset_size / len(self.sat_available_years)) ssize = min(ssize, len(self.entry_paths)) self.entry_paths = self.entry_paths[:ssize] def __len__(self) -> int: return ( len(self.entry_paths) * len(self.sat_available_years) ) def read_misslabelled_images( self, path: str = "misslabels/misslabeled.txt" ) -> List[str]: with open(path, "r") as f: lines = f.readlines() return [line.strip() for line in lines] def cleanup_misslabelled_images(self) -> None: indices_to_delete = [] for image in self.misslabelled_images: for image_path in self.entry_paths: if image in image_path: index = self.entry_paths.index(image_path) indices_to_delete.append(index) break sorted_tuples = sorted(indices_to_delete, reverse=True) for index in sorted_tuples: self.entry_paths.pop(index) def __getitem__(self, idx) -> dict: """ Retrieves a sample given its index, returning the preprocessed UAV and satellite images, along with their associated heatmap and metadata. """ image_path_index = idx // ( len(self.sat_available_years) ) sat_year = self.sat_available_years[idx % len(self.sat_available_years)] rot_angle = random.randint(-self.max_rotation_angle, self.max_rotation_angle) image_path = self.entry_paths[image_path_index] uav_image = Image.open(image_path).convert("RGB") # Ensure 3-channel image original_uav_image_width = uav_image.width original_uav_image_height = uav_image.height lookup_str, file_number = self.extract_info_from_filename(image_path) img_info = self.metadata_dict[lookup_str][file_number] lat, lon = ( img_info["coordinate"]["latitude"], img_info["coordinate"]["longitude"], ) fov_vertical = img_info["fovVertical"] try: agl_altitude = float(image_path.split("/")[-1].split("_")[2].split("m")[0]) except IndexError: agl_altitude = 150.0 warnings.warn( "Could not extract AGL altitude from filename, using default value of 150m." ) ( satellite_patch, x_sat, y_sat, x_offset, y_offset, patch_transform, ) = get_random_tiff_patch( lat, lon, self.sat_patch_width, self.sat_patch_height, sat_year, self.satellite_dataset_dir ) # Rotate crop center and transform image h = np.ceil(uav_image.height // self.uav_image_scale).astype(int) w = np.ceil(uav_image.width // self.uav_image_scale).astype(int) uav_image = F.rotate(uav_image, rot_angle) uav_image = F.resize(uav_image, [h, w]) uav_image = F.center_crop( uav_image, (self.uav_patch_height, self.uav_patch_width) ) uav_image = self.transforms(uav_image) satellite_patch = satellite_patch.transpose(1, 2, 0) satellite_patch_pytorch = self.transforms(satellite_patch) del satellite_patch if self.use_heatmap: heatmap = self.get_heatmap_gt( x_sat, y_sat, satellite_patch.shape[1], satellite_patch.shape[2], self.heatmap_kernel_size, ) cropped_uav_image_width = self.calculate_cropped_uav_image_width( fov_vertical, original_uav_image_width, original_uav_image_height, self.uav_patch_width, self.uav_patch_height, agl_altitude, ) satellite_tile_width = self.calculate_cropped_sat_image_width( lat, self.sat_patch_width, patch_transform ) scale_factor = cropped_uav_image_width / satellite_tile_width scale_factor *= 10 homography_matrix = self.compute_homography( rot_angle, x_sat, y_sat, self.uav_patch_width, self.uav_patch_height, scale_factor, ) if not self.use_heatmap: # Sample four points from the UAV image points = self.sample_four_points( self.uav_patch_width, self.uav_patch_height ) # Transform the points warped_points = self.warp_points(points, homography_matrix) #img_info["warped_points_sat"] = warped_points #img_info["warped_points_uav"] = points # img_info["cropped_uav_image_width"] = cropped_uav_image_width # img_info["satellite_tile_width"] = satellite_tile_width # img_info["scale_factor"] = scale_factor # img_info["filename"] = image_path # img_info["rot_angle"] = rot_angle # img_info["x_sat"] = x_sat # img_info["y_sat"] = y_sat # img_info["x_offset"] = x_offset # img_info["y_offset"] = y_offset # img_info["patch_transform"] = patch_transform # img_info["uav_image_scale"] = self.uav_image_scale # img_info["homography_matrix_uav_to_sat"] = homography_matrix # img_info["homography_matrix_sat_to_uav"] = np.linalg.inv(homography_matrix) # img_info["agl_altitude"] = agl_altitude # img_info["original_uav_image_width"] = original_uav_image_width # img_info["original_drone_image_height"] = original_uav_image_height # img_info["fov_vertical"] = fov_vertical # inverse_homography_matrix = np.linalg.inv(homography_matrix) uav_image_data = { "image": uav_image, "H_": homography_matrix, "coords": points, "image_size": np.array([self.uav_patch_width, self.uav_patch_height]), } satellite_patch_data = { "image": satellite_patch_pytorch, "H_": inverse_homography_matrix, "coords": warped_points, "image_size": np.array([self.sat_patch_width, self.sat_patch_height]), } #hm = self.compute_homography_points( # warped_points, points, [self.uav_patch_width, self.uav_patch_height] #) hm = self.compute_homography_points( points, warped_points, [1.0 , 1.0] ) if np.array_equal(hm, np.eye(3)): print("Singular matrix") return self.__getitem__(random.randint(0, len(self) - 1)) data = { "name": f"{image_path}_{rot_angle}_{sat_year}", "original_image_size": np.array( [original_uav_image_width, original_uav_image_height] ), "H_0to1": hm.astype(np.float32), "view0": uav_image_data, "view1": satellite_patch_data, "idx": idx } del img_info gc.collect() return data def calculate_cropped_sat_image_width(self, latitude, patch_width, patch_transform): """ Computes the width of the satellite image in world coordinate units """ length_of_degree = 111320 * np.cos(np.radians(latitude)) scale_x_meters = patch_transform[0] * length_of_degree satellite_tile_width = patch_width * scale_x_meters return satellite_tile_width def calculate_cropped_uav_image_width( self, fov_vertical, orig_width, orig_height, crop_width, crop_height, altitude=150.0, ): """ Computes the width of the UAV image in world coordinate units """ # Convert fov from degrees to radians fov_rad = np.radians(fov_vertical) # Calculate the full width of the UAV image full_width = 2 * (altitude * np.tan(fov_rad / 2)) # Determine the cropping ratio crop_ratio_width = crop_width / orig_width crop_ratio_height = crop_height / orig_height # Calculate the adjusted horizontal fov fov_horizontal = 2 * np.arctan(np.tan(fov_rad / 2) * (orig_width / orig_height)) adjusted_fov_horizontal = 2 * np.arctan( np.tan(fov_horizontal / 2) * crop_ratio_width ) # Calculate the new full width using the adjusted horizontal fov full_width = 2 * (altitude * np.tan(adjusted_fov_horizontal / 2)) # Adjust the width according to the crop ratio cropped_width = full_width * crop_ratio_width return cropped_width def compute_homography_points(self, pts1_, pts2_, shape): """Compute the homography matrix from 4 point correspondences""" # Rescale to actual size shape = np.array(shape[::-1], dtype=np.float32) # different convention [y, x] pts1 = pts1_ * np.expand_dims(shape, axis=0) pts2 = pts2_ * np.expand_dims(shape, axis=0) def ax(p, q): return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]] def ay(p, q): return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]] def flat2mat(H): return np.reshape(np.concatenate([H, np.ones_like(H[:, :1])], axis=1), [3, 3]) a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0) p_mat = np.transpose( np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0) ) try: homography = np.transpose(np.linalg.solve(a_mat, p_mat)) except np.linalg.LinAlgError: print("Singular matrix") return np.eye(3) return flat2mat(homography) def compute_homography( self, rot_angle, x_sat, y_sat, uav_width, uav_height, scale_factor ): # Adjust rot_angle if it's greater than 180 degrees if rot_angle > 180: rot_angle -= 360 # Convert rotation angle to radians theta = np.radians(rot_angle) # Rotation matrix R = np.array( [ [np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1], ] ) # Scale matrix S = np.array([[scale_factor, 0, 0], [0, scale_factor, 0], [0, 0, 1]]) # Translation matrix to center the UAV image T_uav = np.array([[1, 0, -uav_width / 2], [0, 1, -uav_height / 2], [0, 0, 1]]) # Translation matrix to move to the satellite image position T_sat = np.array([[1, 0, x_sat], [0, 1, y_sat], [0, 0, 1]]) # Compute the combined homography matrix H = np.dot(T_sat, np.dot(R, np.dot(S, T_uav))) return H def sample_four_points(self, width: int, height: int) -> np.ndarray: """ Samples four points from the UAV image. """ PADDING = 50 CENTER_PADDING = 10 points = np.array( [ [random.randint(CENTER_PADDING, width - PADDING), random.randint(CENTER_PADDING, height - PADDING)] for _ in range(4) ] ) return points def warp_points( self, points: np.ndarray, homography_matrix: np.ndarray ) -> np.ndarray: """ Warps the given points using the given homography matrix. """ points = np.array(points) points = np.concatenate([points, np.ones((4, 1))], axis=1) points = np.dot(homography_matrix, points.T).T points = points[:, :2] / points[:, 2:] return points def count_total_uav_samples(self) -> int: """ Count the total number of uav image samples in the dataset (train + test) """ total_samples = 0 for dirpath, dirnames, filenames in os.walk(self.uav_dataset_dir): # Skip the test folder for filename in filenames: if filename.endswith(".jpeg"): total_samples += 1 return total_samples def get_number_of_city_samples(self) -> int: """ TODO: Count the total number of city samples in the dataset """ return 11 def get_entry_paths(self, directory: str) -> List[str]: """ Recursively retrieves paths to image and metadata files in the given directory. """ entry_paths = [] entries = os.listdir(directory) images_to_take_per_folder = int( self.total_uav_samples * self.test_from_train_ratio / self.get_number_of_city_samples() ) for entry in entries: entry_path = os.path.join(directory, entry) # If it's a directory, recurse into it if os.path.isdir(entry_path): entry_paths += self.get_entry_paths(entry_path) # Handle train dataset elif (self.dataset == "train" and "Train" in entry_path) or ( self.dataset == "train" and self.test_from_train_ratio > 0 and "Test" in entry_path ): if entry_path.endswith(".jpeg"): _, number = self.extract_info_from_filename(entry_path) else: number = None if entry_path.endswith(".json"): self.get_metadata(entry_path) if number is None: continue if ( number >= images_to_take_per_folder ): # Only include images beyond the ones taken for test if entry_path.endswith(".jpeg"): entry_paths.append(entry_path) # Handle test dataset elif self.dataset == "test": if entry_path.endswith(".jpeg"): _, number = self.extract_info_from_filename(entry_path) else: number = None if entry_path.endswith(".json"): self.get_metadata(entry_path) if number is None: continue if ( ("Test" in entry_path and number < images_to_take_per_folder) or (number < images_to_take_per_folder and "Train" in entry_path) or (self.test_from_train_ratio == 0.0 and "Test" in entry_path) ): if entry_path.endswith(".jpeg"): entry_paths.append(entry_path) return sorted(entry_paths, key=self.extract_info_from_filename) def get_metadata(self, path: str) -> None: """ Extracts metadata from a JSON file and stores it in the metadata dictionary. """ with open(path, newline="") as jsonfile: json_dict = json.load(jsonfile) path = path.split("/")[-1] path = path.replace(".json", "") self.metadata_dict[path] = json_dict["cameraFrames"] def extract_info_from_filename(self, filename: str) -> (str, int): """ Extracts information from the filename. """ filename_without_ext = filename.replace(".jpeg", "") segments = filename_without_ext.split("/") info = segments[-1] try: number = int(info.split("_")[-1]) except ValueError: print("Could not extract number from filename: ", filename) return None, None info = "_".join(info.split("_")[:-1]) return info, number def get_heatmap_gt( self, x: int, y: int, height: int, width: int, square_size: int = 33 ) -> torch.Tensor: """ Returns 2D heatmap ground truth for the given x and y coordinates, with the given square size. """ x_map, y_map = x, y heatmap = torch.zeros((height, width)) half_size = square_size // 2 # Calculate the valid range for the square start_x = max(0, x_map - half_size) end_x = min( width, x_map + half_size + 1 ) # +1 to include the end_x in the square start_y = max(0, y_map - half_size) end_y = min( height, y_map + half_size + 1 ) # +1 to include the end_y in the square heatmap[start_y:end_y, start_x:end_x] = 1 return heatmap if __name__ == "__main__": from .. import logger # overwrite the logger