From d40e423a425b0839b7e2e7b0bfee329620aefc69 Mon Sep 17 00:00:00 2001 From: Jeba Kolega Date: Wed, 14 Feb 2024 00:02:09 +0100 Subject: [PATCH] Add custom training files --- gluefactory/configs/drone2sat_v1.yaml | 69 ++ gluefactory/configs/satellites.yaml | 48 ++ gluefactory/configs/satellites2.yaml | 60 ++ gluefactory/configs/satellites3.yaml | 47 ++ .../superpoint+lightglue_homography.yaml | 1 + gluefactory/datasets/base_dataset.py | 5 + gluefactory/datasets/drone2sat.py | 694 ++++++++++++++++++ gluefactory/datasets/s_utils.py | 163 ++++ gluefactory/datasets/sat_utils.py | 109 +++ gluefactory/datasets/satellites.py | 297 ++++++++ gluefactory/geometry/homography.py | 5 +- gluefactory/train.py | 30 +- gluefactory/utils/tensor.py | 9 + gluefactory/utils/tools.py | 4 + gluefactory_nonfree/superpoint.py | 8 + run.sh | 12 + stats.sh | 22 + 17 files changed, 1577 insertions(+), 6 deletions(-) create mode 100644 gluefactory/configs/drone2sat_v1.yaml create mode 100644 gluefactory/configs/satellites.yaml create mode 100644 gluefactory/configs/satellites2.yaml create mode 100644 gluefactory/configs/satellites3.yaml create mode 100644 gluefactory/datasets/drone2sat.py create mode 100644 gluefactory/datasets/s_utils.py create mode 100644 gluefactory/datasets/sat_utils.py create mode 100644 gluefactory/datasets/satellites.py create mode 100644 run.sh create mode 100644 stats.sh diff --git a/gluefactory/configs/drone2sat_v1.yaml b/gluefactory/configs/drone2sat_v1.yaml new file mode 100644 index 0000000..7bb776f --- /dev/null +++ b/gluefactory/configs/drone2sat_v1.yaml @@ -0,0 +1,69 @@ +data: + name: drone2sat + batch_size: 4 + num_workers: 24 + val_size: 500 + train_size: -1 + photometric: + name: lg + geo_dataset: + uav_dataset_dir: /mnt/drive/uav_dataset + satellite_dataset_dir: /mnt/drive/tiles + misslabeled_images_path: /mnt/drive/misslabeled.txt + sat_zoom_level: 17 + uav_patch_width: 400 + uav_patch_height: 400 + sat_patch_width: 400 + sat_patch_height: 400 + test_from_train_ratio: 0.0 + transform_mean: + - 0.485 + - 0.456 + - 0.406 + transform_std: + - 0.229 + - 0.224 + - 0.225 + sat_availaible_years: + - "2023" + - "2021" + - "2019" + - "2016" + max_rotation_angle: 0 + uav_image_scale: 2.0 + use_heatmap: false + +model: + name: two_view_pipeline + extractor: + name: gluefactory_nonfree.superpoint + max_num_keypoints: 2048 + detection_threshold: 0.0 + nms_radius: 3 + ground_truth: + name: matchers.homography_matcher + th_positive: 3 + th_negative: 3 + matcher: + name: matchers.lightglue + features: superpoint + depth_confidence: -1 + width_confidence: -1 + filter_threshold: 0.1 + flash: true +train: + seed: 0 + epochs: 40 + log_every_iter: 100 + eval_every_iter: 1000 # 350000 / 4 + lr: 1e-3 + lr_schedule: + start: 20 + type: exp + on_epoch: true + exp_div_10: 10 +benchmarks: + hpatches: + eval: + estimator: opencv + ransac_th: 0.5 diff --git a/gluefactory/configs/satellites.yaml b/gluefactory/configs/satellites.yaml new file mode 100644 index 0000000..ce82382 --- /dev/null +++ b/gluefactory/configs/satellites.yaml @@ -0,0 +1,48 @@ +data: + name: satellites + data_dir: /mnt/drive/tiles + metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt + train_size: null + val_size: null + batch_size: 128 + num_workers: 24 + homography: + difficulty: 0.9 + max_angle: 359 + photometric: + name: lg +model: + name: two_view_pipeline + extractor: + name: gluefactory_nonfree.superpoint + max_num_keypoints: 2048 + detection_threshold: 0.0 + force_num_keypoints: True + nms_radius: 3 + trainable: False + ground_truth: + name: matchers.homography_matcher + th_positive: 3 + th_negative: 3 + matcher: + name: matchers.lightglue + filter_threshold: 0.1 + flash: true + checkpointed: true + weights: superpoint +train: + seed: 0 + epochs: 5 + log_every_iter: 100 + eval_every_iter: 500 + lr: 1e-7 + lr_schedule: + start: 20 + type: exp + on_epoch: true + exp_div_10: 10 +benchmarks: + hpatches: + eval: + estimator: opencv + ransac_th: 0.5 diff --git a/gluefactory/configs/satellites2.yaml b/gluefactory/configs/satellites2.yaml new file mode 100644 index 0000000..a47b41c --- /dev/null +++ b/gluefactory/configs/satellites2.yaml @@ -0,0 +1,60 @@ +data: + name: satellites + data_dir: /mnt/drive/tiles + metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt + train_size: null + val_size: null + batch_size: 128 + num_workers: 24 + homography: + difficulty: 0.5 + max_angle: 359 + photometric: + name: lg +model: + name: two_view_pipeline + extractor: + name: gluefactory_nonfree.superpoint + max_num_keypoints: 2048 + detection_threshold: 0.0 + nms_radius: 3 + #extractor: + # name: gluefactory_nonfree.superpoint + # max_num_keypoints: 512 + # force_num_keypoints: True + # detection_threshold: 0.0 + # nms_radius: 3 + # trainable: False + ground_truth: + name: matchers.homography_matcher + th_positive: 3 + th_negative: 3 + matcher: + name: matchers.lightglue + features: superpoint + depth_confidence: -1 + width_confidence: -1 + filter_threshold: 0.1 + flash: true + #matcher: + # name: matchers.lightglue_pretrained + # features: superpoint + # depth_confidence: -1 + # width_confidence: -1 + # filter_threshold: 0.1 +train: + seed: 0 + epochs: 5 + log_every_iter: 100 + eval_every_iter: 500 + lr: 1e-4 + lr_schedule: + start: 20 + type: exp + on_epoch: true + exp_div_10: 10 +benchmarks: + hpatches: + eval: + estimator: opencv + ransac_th: 0.5 diff --git a/gluefactory/configs/satellites3.yaml b/gluefactory/configs/satellites3.yaml new file mode 100644 index 0000000..6bf4916 --- /dev/null +++ b/gluefactory/configs/satellites3.yaml @@ -0,0 +1,47 @@ +data: + name: satellites + data_dir: /mnt/drive/tiles + metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt + train_size: null + val_size: null + batch_size: 128 + num_workers: 24 + homography: + difficulty: 0.5 + max_angle: 180 + photometric: + name: lg +model: + name: two_view_pipeline + extractor: + name: gluefactory_nonfree.superpoint + max_num_keypoints: 2048 + detection_threshold: 0.0 + nms_radius: 3 + ground_truth: + name: matchers.homography_matcher + th_positive: 3 + th_negative: 3 + matcher: + name: matchers.lightglue + features: superpoint + depth_confidence: -1 + width_confidence: -1 + filter_threshold: 0.1 + flash: true +train: + seed: 0 + epochs: 40 + log_every_iter: 100 + eval_every_iter: 500 + lr: 1e-4 + lr_schedule: + start: 20 + type: exp + on_epoch: true + exp_div_10: 10 +benchmarks: + hpatches: + eval: + estimator: opencv + ransac_th: 0.5 diff --git a/gluefactory/configs/superpoint+lightglue_homography.yaml b/gluefactory/configs/superpoint+lightglue_homography.yaml index 1f353b3..5ba322c 100644 --- a/gluefactory/configs/superpoint+lightglue_homography.yaml +++ b/gluefactory/configs/superpoint+lightglue_homography.yaml @@ -33,6 +33,7 @@ train: epochs: 40 log_every_iter: 100 eval_every_iter: 500 + profile: true lr: 1e-4 lr_schedule: start: 20 diff --git a/gluefactory/datasets/base_dataset.py b/gluefactory/datasets/base_dataset.py index ef622cb..591f7b0 100644 --- a/gluefactory/datasets/base_dataset.py +++ b/gluefactory/datasets/base_dataset.py @@ -168,6 +168,11 @@ class BaseDataset(metaclass=ABCMeta): sampler = None if shuffle is None: shuffle = split == "train" and self.conf.shuffle_training + shuffle = split == "val" + + shuffle = True + + print("Shuffle", shuffle) return DataLoader( dataset, batch_size=batch_size, diff --git a/gluefactory/datasets/drone2sat.py b/gluefactory/datasets/drone2sat.py new file mode 100644 index 0000000..cd61ce9 --- /dev/null +++ b/gluefactory/datasets/drone2sat.py @@ -0,0 +1,694 @@ +""" +Simply load images from a folder or nested folders (does not have any split), +and apply homographic adaptations to it. Yields an image pair without border +artifacts. +""" + +import argparse +import logging +import shutil +import tarfile +from pathlib import Path +import os +import json +import gc + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import omegaconf +import torch +from omegaconf import OmegaConf +from tqdm import tqdm +from PIL import Image +from typing import List, Literal +from torchvision import transforms +import random +from torchvision.transforms import functional as F + + + + +from ..geometry.homography import ( + compute_homography, + sample_homography_corners, + warp_points, +) +from ..models.cache_loader import CacheLoader, pad_local_features +from ..settings import DATA_PATH +from ..utils.image import read_image +from ..utils.tools import fork_rng +from ..visualization.viz2d import plot_image_grid +from .augmentations import IdentityAugmentation, augmentations +from .base_dataset import BaseDataset +from .s_utils import get_random_tiff_patch + +logger = logging.getLogger(__name__) + +class Drone2SatDataset(BaseDataset): + default_conf = { + # image search + "data_dir": "revisitop1m", # the top-level directory + "image_dir": "jpg/", # the subdirectory with the images + "image_list": "revisitop1m.txt", # optional: list or filename of list + "glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"], + "metadata_dir": None, + "uav_dataset_dir": None, + "satellite_dataset_dir": None, + # splits + "train_size": 100, + "val_size": 10, + "shuffle_seed": 0, # or None to skip + # image loading + "grayscale": False, + "triplet": False, + "right_only": False, # image0 is orig (rescaled), image1 is right + "reseed": False, + "homography": { + "difficulty": 0.8, + "translation": 1.0, + "max_angle": 60, + "n_angles": 10, + "patch_shape": [640, 480], + "min_convexity": 0.05, + }, + "photometric": { + "name": "dark", + "p": 0.75, + # 'difficulty': 1.0, # currently unused + }, + # feature loading + "load_features": { + "do": False, + **CacheLoader.default_conf, + "collate": False, + "thresh": 0.0, + "max_num_keypoints": -1, + "force_num_keypoints": False, + }, + # Other geolocaliztion parameters + "geo_dataset" :{ + "uav_dataset_dir": None, + "satellite_dataset_dir": None, + "misslabeled_images_path": None, + "sat_zoom_level": 17, + "uav_patch_width": 400, + "uav_patch_height": 400, + "sat_patch_width": 400, + "sat_patch_height": 400, + "heatmap_kernel_size": 33, + "test_from_train_ratio": 0.0, + "transform_mean": [0.485, 0.456, 0.406], + "transform_std": [0.229, 0.224, 0.225], + "sat_availaible_years": ["2023", "2021", "2019", "2016"], + "max_rotation_angle": 10, + "uav_image_scale": 1, + "use_heatmap": False, + } + } + + def _init(self, conf): + self.images = {"train": "train", "val": "val"} + + def get_dataset(self, split): + if split == "val": + return GeoLocalizationDataset( + uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir, + satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir, + misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path, + dataset="test", + sat_zoom_level=self.conf.geo_dataset.sat_zoom_level, + uav_patch_width=self.conf.geo_dataset.uav_patch_width, + uav_patch_height=self.conf.geo_dataset.uav_patch_height, + sat_patch_width=self.conf.geo_dataset.sat_patch_width, + sat_patch_height=self.conf.geo_dataset.sat_patch_height, + heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size, + test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio, + transform_mean=self.conf.geo_dataset.transform_mean, + transform_std=self.conf.geo_dataset.transform_std, + sat_available_years=self.conf.geo_dataset.sat_availaible_years, + max_rotation_angle=self.conf.geo_dataset.max_rotation_angle, + uav_image_scale=self.conf.geo_dataset.uav_image_scale, + use_heatmap=self.conf.geo_dataset.use_heatmap, + subset_size=self.conf.val_size, + ) + elif split == "train": + return GeoLocalizationDataset( + uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir, + satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir, + misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path, + dataset="train", + sat_zoom_level=self.conf.geo_dataset.sat_zoom_level, + uav_patch_width=self.conf.geo_dataset.uav_patch_width, + uav_patch_height=self.conf.geo_dataset.uav_patch_height, + sat_patch_width=self.conf.geo_dataset.sat_patch_width, + sat_patch_height=self.conf.geo_dataset.sat_patch_height, + heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size, + test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio, + transform_mean=self.conf.geo_dataset.transform_mean, + transform_std=self.conf.geo_dataset.transform_std, + sat_available_years=self.conf.geo_dataset.sat_availaible_years, + max_rotation_angle=self.conf.geo_dataset.max_rotation_angle, + uav_image_scale=self.conf.geo_dataset.uav_image_scale, + use_heatmap=self.conf.geo_dataset.use_heatmap, + subset_size=self.conf.train_size, + ) + + +class GeoLocalizationDataset(torch.utils.data.Dataset): + def __init__( + self, + uav_dataset_dir: str, + satellite_dataset_dir: str, + misslabeled_images_path: str, + dataset: Literal["train", "test"], + sat_zoom_level: int = 16, + uav_patch_width: int = 128, + uav_patch_height: int = 128, + sat_patch_width: int = 400, + sat_patch_height: int = 400, + heatmap_kernel_size: int = 33, + test_from_train_ratio: float = 0.0, + transform_mean: List[float] = [0.485, 0.456, 0.406], + transform_std: List[float] = [0.229, 0.224, 0.225], + sat_available_years: List[str] = ["2023", "2021", "2019", "2016"], + max_rotation_angle: int = 10, + uav_image_scale: float = 1, + use_heatmap: bool = True, + subset_size: int = None, + ): + self.uav_dataset_dir = uav_dataset_dir + self.satellite_dataset_dir = satellite_dataset_dir + self.dataset = dataset + self.sat_zoom_level = sat_zoom_level + self.uav_patch_width = uav_patch_width + self.uav_patch_height = uav_patch_height + self.heatmap_kernel_size = heatmap_kernel_size + self.test_from_train_ratio = test_from_train_ratio + self.transform_mean = transform_mean + self.transform_std = transform_std + self.misslabeled_images_path = misslabeled_images_path + self.metadata_dict = {} + self.max_rotation_angle = max_rotation_angle + self.total_uav_samples = self.count_total_uav_samples() + self.misslabelled_images = self.read_misslabelled_images(self.misslabeled_images_path) + self.entry_paths = self.get_entry_paths(self.uav_dataset_dir) + self.cleanup_misslabelled_images() + self.transforms = transforms.Compose( + [ + transforms.ToTensor(), + #transforms.Normalize(self.transform_mean, self.transform_std), + ] + ) + self.sat_available_years = sat_available_years + self.uav_image_scale = uav_image_scale + self.use_heatmap = use_heatmap + self.sat_patch_width = sat_patch_width + self.sat_patch_height = sat_patch_height + self.subset_size = subset_size + + self.inverse_transforms = transforms.Compose( + [ + transforms.Normalize( + mean=[ + -m / s for m, s in zip(self.transform_mean, self.transform_std) + ], + std=[1 / s for s in self.transform_std], + ), + transforms.ToPILImage(), + ] + ) + + if self.subset_size != -1: + ssize = int(self.subset_size / len(self.sat_available_years)) + ssize = min(ssize, len(self.entry_paths)) + self.entry_paths = self.entry_paths[:ssize] + + + def __len__(self) -> int: + return ( + len(self.entry_paths) + * len(self.sat_available_years) + ) + + def read_misslabelled_images( + self, path: str = "misslabels/misslabeled.txt" + ) -> List[str]: + with open(path, "r") as f: + lines = f.readlines() + return [line.strip() for line in lines] + + def cleanup_misslabelled_images(self) -> None: + indices_to_delete = [] + + for image in self.misslabelled_images: + for image_path in self.entry_paths: + if image in image_path: + index = self.entry_paths.index(image_path) + indices_to_delete.append(index) + break + + sorted_tuples = sorted(indices_to_delete, reverse=True) + + for index in sorted_tuples: + self.entry_paths.pop(index) + + def __getitem__(self, idx) -> dict: + """ + Retrieves a sample given its index, returning the preprocessed UAV + and satellite images, along with their associated heatmap and metadata. + """ + + image_path_index = idx // ( + len(self.sat_available_years) + ) + + sat_year = self.sat_available_years[idx % len(self.sat_available_years)] + rot_angle = random.randint(-self.max_rotation_angle, self.max_rotation_angle) + + image_path = self.entry_paths[image_path_index] + uav_image = Image.open(image_path).convert("RGB") # Ensure 3-channel image + + original_uav_image_width = uav_image.width + original_uav_image_height = uav_image.height + + lookup_str, file_number = self.extract_info_from_filename(image_path) + img_info = self.metadata_dict[lookup_str][file_number] + + lat, lon = ( + img_info["coordinate"]["latitude"], + img_info["coordinate"]["longitude"], + ) + + fov_vertical = img_info["fovVertical"] + + try: + agl_altitude = float(image_path.split("/")[-1].split("_")[2].split("m")[0]) + except IndexError: + agl_altitude = 150.0 + warnings.warn( + "Could not extract AGL altitude from filename, using default value of 150m." + ) + + ( + satellite_patch, + x_sat, + y_sat, + x_offset, + y_offset, + patch_transform, + ) = get_random_tiff_patch( + lat, lon, self.sat_patch_width, self.sat_patch_height, sat_year, self.satellite_dataset_dir + ) + + # Rotate crop center and transform image + h = np.ceil(uav_image.height // self.uav_image_scale).astype(int) + w = np.ceil(uav_image.width // self.uav_image_scale).astype(int) + + uav_image = F.rotate(uav_image, rot_angle) + uav_image = F.resize(uav_image, [h, w]) + uav_image = F.center_crop( + uav_image, (self.uav_patch_height, self.uav_patch_width) + ) + uav_image = self.transforms(uav_image) + + satellite_patch = satellite_patch.transpose(1, 2, 0) + satellite_patch_pytorch = self.transforms(satellite_patch) + del satellite_patch + + if self.use_heatmap: + heatmap = self.get_heatmap_gt( + x_sat, + y_sat, + satellite_patch.shape[1], + satellite_patch.shape[2], + self.heatmap_kernel_size, + ) + + cropped_uav_image_width = self.calculate_cropped_uav_image_width( + fov_vertical, + original_uav_image_width, + original_uav_image_height, + self.uav_patch_width, + self.uav_patch_height, + agl_altitude, + ) + + satellite_tile_width = self.calculate_cropped_sat_image_width( + lat, self.sat_patch_width, patch_transform + ) + + scale_factor = cropped_uav_image_width / satellite_tile_width + scale_factor *= 10 + + homography_matrix = self.compute_homography( + rot_angle, + x_sat, + y_sat, + self.uav_patch_width, + self.uav_patch_height, + scale_factor, + ) + + if not self.use_heatmap: + # Sample four points from the UAV image + points = self.sample_four_points( + self.uav_patch_width, self.uav_patch_height + ) + + # Transform the points + warped_points = self.warp_points(points, homography_matrix) + #img_info["warped_points_sat"] = warped_points + #img_info["warped_points_uav"] = points + +# img_info["cropped_uav_image_width"] = cropped_uav_image_width +# img_info["satellite_tile_width"] = satellite_tile_width +# img_info["scale_factor"] = scale_factor +# img_info["filename"] = image_path +# img_info["rot_angle"] = rot_angle +# img_info["x_sat"] = x_sat +# img_info["y_sat"] = y_sat +# img_info["x_offset"] = x_offset +# img_info["y_offset"] = y_offset +# img_info["patch_transform"] = patch_transform +# img_info["uav_image_scale"] = self.uav_image_scale +# img_info["homography_matrix_uav_to_sat"] = homography_matrix +# img_info["homography_matrix_sat_to_uav"] = np.linalg.inv(homography_matrix) +# img_info["agl_altitude"] = agl_altitude +# img_info["original_uav_image_width"] = original_uav_image_width +# img_info["original_drone_image_height"] = original_uav_image_height +# img_info["fov_vertical"] = fov_vertical +# + + inverse_homography_matrix = np.linalg.inv(homography_matrix) + uav_image_data = { + "image": uav_image, + "H_": homography_matrix, + "coords": points, + "image_size": np.array([self.uav_patch_width, self.uav_patch_height]), + } + + satellite_patch_data = { + "image": satellite_patch_pytorch, + "H_": inverse_homography_matrix, + "coords": warped_points, + "image_size": np.array([self.sat_patch_width, self.sat_patch_height]), + } + + #hm = self.compute_homography_points( + # warped_points, points, [self.uav_patch_width, self.uav_patch_height] + #) + + hm = self.compute_homography_points( + points, warped_points, [1.0 , 1.0] + ) + + if np.array_equal(hm, np.eye(3)): + print("Singular matrix") + return self.__getitem__(random.randint(0, len(self) - 1)) + + data = { + "name": f"{image_path}_{rot_angle}_{sat_year}", + "original_image_size": np.array( + [original_uav_image_width, original_uav_image_height] + ), + "H_0to1": hm.astype(np.float32), + "view0": uav_image_data, + "view1": satellite_patch_data, + "idx": idx + } + del img_info + gc.collect() + + return data + + + def calculate_cropped_sat_image_width(self, latitude, patch_width, patch_transform): + """ + Computes the width of the satellite image in world coordinate units + """ + length_of_degree = 111320 * np.cos(np.radians(latitude)) + scale_x_meters = patch_transform[0] * length_of_degree + satellite_tile_width = patch_width * scale_x_meters + return satellite_tile_width + + def calculate_cropped_uav_image_width( + self, + fov_vertical, + orig_width, + orig_height, + crop_width, + crop_height, + altitude=150.0, + ): + """ + Computes the width of the UAV image in world coordinate units + """ + # Convert fov from degrees to radians + fov_rad = np.radians(fov_vertical) + + # Calculate the full width of the UAV image + full_width = 2 * (altitude * np.tan(fov_rad / 2)) + + # Determine the cropping ratio + crop_ratio_width = crop_width / orig_width + crop_ratio_height = crop_height / orig_height + + # Calculate the adjusted horizontal fov + fov_horizontal = 2 * np.arctan(np.tan(fov_rad / 2) * (orig_width / orig_height)) + adjusted_fov_horizontal = 2 * np.arctan( + np.tan(fov_horizontal / 2) * crop_ratio_width + ) + + # Calculate the new full width using the adjusted horizontal fov + full_width = 2 * (altitude * np.tan(adjusted_fov_horizontal / 2)) + + # Adjust the width according to the crop ratio + cropped_width = full_width * crop_ratio_width + + return cropped_width + + def compute_homography_points(self, pts1_, pts2_, shape): + """Compute the homography matrix from 4 point correspondences""" + # Rescale to actual size + shape = np.array(shape[::-1], dtype=np.float32) # different convention [y, x] + pts1 = pts1_ * np.expand_dims(shape, axis=0) + pts2 = pts2_ * np.expand_dims(shape, axis=0) + + def ax(p, q): + return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]] + + def ay(p, q): + return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]] + + def flat2mat(H): + return np.reshape(np.concatenate([H, np.ones_like(H[:, :1])], axis=1), [3, 3]) + + a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0) + p_mat = np.transpose( + np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0) + ) + + try: + homography = np.transpose(np.linalg.solve(a_mat, p_mat)) + except np.linalg.LinAlgError: + print("Singular matrix") + return np.eye(3) + return flat2mat(homography) + + + def compute_homography( + self, rot_angle, x_sat, y_sat, uav_width, uav_height, scale_factor + ): + # Adjust rot_angle if it's greater than 180 degrees + if rot_angle > 180: + rot_angle -= 360 + # Convert rotation angle to radians + theta = np.radians(rot_angle) + + # Rotation matrix + R = np.array( + [ + [np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1], + ] + ) + + # Scale matrix + S = np.array([[scale_factor, 0, 0], [0, scale_factor, 0], [0, 0, 1]]) + + # Translation matrix to center the UAV image + T_uav = np.array([[1, 0, -uav_width / 2], [0, 1, -uav_height / 2], [0, 0, 1]]) + + # Translation matrix to move to the satellite image position + T_sat = np.array([[1, 0, x_sat], [0, 1, y_sat], [0, 0, 1]]) + + # Compute the combined homography matrix + H = np.dot(T_sat, np.dot(R, np.dot(S, T_uav))) + + return H + + def sample_four_points(self, width: int, height: int) -> np.ndarray: + """ + Samples four points from the UAV image. + """ + PADDING = 50 + CENTER_PADDING = 10 + points = np.array( + [ + [random.randint(CENTER_PADDING, width - PADDING), random.randint(CENTER_PADDING, height - PADDING)] + for _ in range(4) + ] + ) + return points + + def warp_points( + self, points: np.ndarray, homography_matrix: np.ndarray + ) -> np.ndarray: + """ + Warps the given points using the given homography matrix. + """ + points = np.array(points) + points = np.concatenate([points, np.ones((4, 1))], axis=1) + points = np.dot(homography_matrix, points.T).T + points = points[:, :2] / points[:, 2:] + return points + + def count_total_uav_samples(self) -> int: + """ + Count the total number of uav image samples in the dataset + (train + test) + """ + total_samples = 0 + + for dirpath, dirnames, filenames in os.walk(self.uav_dataset_dir): + # Skip the test folder + for filename in filenames: + if filename.endswith(".jpeg"): + total_samples += 1 + return total_samples + + def get_number_of_city_samples(self) -> int: + """ + TODO: Count the total number of city samples in the dataset + """ + return 11 + + def get_entry_paths(self, directory: str) -> List[str]: + """ + Recursively retrieves paths to image and metadata files in the given directory. + """ + entry_paths = [] + entries = os.listdir(directory) + + images_to_take_per_folder = int( + self.total_uav_samples + * self.test_from_train_ratio + / self.get_number_of_city_samples() + ) + + for entry in entries: + entry_path = os.path.join(directory, entry) + + # If it's a directory, recurse into it + if os.path.isdir(entry_path): + entry_paths += self.get_entry_paths(entry_path) + + # Handle train dataset + elif (self.dataset == "train" and "Train" in entry_path) or ( + self.dataset == "train" + and self.test_from_train_ratio > 0 + and "Test" in entry_path + ): + if entry_path.endswith(".jpeg"): + _, number = self.extract_info_from_filename(entry_path) + else: + number = None + if entry_path.endswith(".json"): + self.get_metadata(entry_path) + if number is None: + continue + if ( + number >= images_to_take_per_folder + ): # Only include images beyond the ones taken for test + if entry_path.endswith(".jpeg"): + entry_paths.append(entry_path) + + # Handle test dataset + elif self.dataset == "test": + if entry_path.endswith(".jpeg"): + _, number = self.extract_info_from_filename(entry_path) + else: + number = None + if entry_path.endswith(".json"): + self.get_metadata(entry_path) + + if number is None: + continue + if ( + ("Test" in entry_path and number < images_to_take_per_folder) + or (number < images_to_take_per_folder and "Train" in entry_path) + or (self.test_from_train_ratio == 0.0 and "Test" in entry_path) + ): + if entry_path.endswith(".jpeg"): + entry_paths.append(entry_path) + + return sorted(entry_paths, key=self.extract_info_from_filename) + + def get_metadata(self, path: str) -> None: + """ + Extracts metadata from a JSON file and stores it in the metadata dictionary. + """ + with open(path, newline="") as jsonfile: + json_dict = json.load(jsonfile) + path = path.split("/")[-1] + path = path.replace(".json", "") + self.metadata_dict[path] = json_dict["cameraFrames"] + + def extract_info_from_filename(self, filename: str) -> (str, int): + """ + Extracts information from the filename. + """ + filename_without_ext = filename.replace(".jpeg", "") + segments = filename_without_ext.split("/") + info = segments[-1] + try: + number = int(info.split("_")[-1]) + except ValueError: + print("Could not extract number from filename: ", filename) + return None, None + + info = "_".join(info.split("_")[:-1]) + + return info, number + + def get_heatmap_gt( + self, x: int, y: int, height: int, width: int, square_size: int = 33 + ) -> torch.Tensor: + """ + Returns 2D heatmap ground truth for the given x and y coordinates, + with the given square size. + """ + x_map, y_map = x, y + + heatmap = torch.zeros((height, width)) + + half_size = square_size // 2 + + # Calculate the valid range for the square + start_x = max(0, x_map - half_size) + end_x = min( + width, x_map + half_size + 1 + ) # +1 to include the end_x in the square + start_y = max(0, y_map - half_size) + end_y = min( + height, y_map + half_size + 1 + ) # +1 to include the end_y in the square + + heatmap[start_y:end_y, start_x:end_x] = 1 + + return heatmap + +if __name__ == "__main__": + from .. import logger # overwrite the logger \ No newline at end of file diff --git a/gluefactory/datasets/s_utils.py b/gluefactory/datasets/s_utils.py new file mode 100644 index 0000000..f830815 --- /dev/null +++ b/gluefactory/datasets/s_utils.py @@ -0,0 +1,163 @@ +#! /usr/bin/env python3 + +import os +import mercantile +import rasterio +import numpy as np +import random +import warnings +from rasterio.errors import NotGeoreferencedWarning +from rasterio.io import MemoryFile +from rasterio.transform import from_bounds +from rasterio.merge import merge +from PIL import Image +import gc +import random +from osgeo import gdal, osr +from affine import Affine + + +def get_5x5_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]: + neighbors = [] + for main_neighbour in mercantile.neighbors(tile): + for sub_neighbour in mercantile.neighbors(main_neighbour): + if sub_neighbour not in neighbors: + neighbors.append(sub_neighbour) + return neighbors + +def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir: str) -> (np.ndarray, dict): + """ + Returns a TIFF map of the given tile using GDAL. + """ + tile_data = [] + neighbors = get_5x5_neighbors(tile) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + for neighbor in neighbors: + west, south, east, north = mercantile.bounds(neighbor) + tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg" + if not os.path.exists(tile_path): + raise FileNotFoundError(f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found.") + + img = Image.open(tile_path) + img_array = np.array(img) + + # Create an in-memory GDAL dataset + mem_driver = gdal.GetDriverByName('MEM') + dataset = mem_driver.Create('', img_array.shape[1], img_array.shape[0], 3, gdal.GDT_Byte) + for i in range(3): + dataset.GetRasterBand(i + 1).WriteArray(img_array[:, :, i]) + + # Set GeoTransform and Projection + geotransform = (west, (east - west) / img_array.shape[1], 0, north, 0, -(north - south) / img_array.shape[0]) + dataset.SetGeoTransform(geotransform) + srs = osr.SpatialReference() + srs.ImportFromEPSG(3857) + dataset.SetProjection(srs.ExportToWkt()) + + tile_data.append(dataset) + + # Merge tiles using GDAL + vrt_options = gdal.BuildVRTOptions() + vrt = gdal.BuildVRT('', [td for td in tile_data], options=vrt_options) + mosaic = vrt.ReadAsArray() + + # Get metadata + out_trans = vrt.GetGeoTransform() + out_crs = vrt.GetProjection() + out_trans = Affine.from_gdal(*out_trans) + out_meta = { + "driver": "GTiff", + "height": mosaic.shape[1], + "width": mosaic.shape[2], + "transform": out_trans, + "crs": out_crs, + } + + # Clean up + for td in tile_data: + td.FlushCache() + vrt = None + gc.collect() + + return mosaic, out_meta + +def get_random_tiff_patch( + lat: float, + lon: float, + patch_width: int, + patch_height: int, + sat_year: str, + satellite_dataset_dir: str = "/mnt/drive/satellite_dataset", +) -> (np.ndarray, int, int, int, int, rasterio.transform.Affine): + """ + Returns a random patch from the satellite image. + """ + + tile = get_tile_from_coord(lat, lon, 17) + + mosaic, out_meta = get_tiff_map(tile, sat_year, satellite_dataset_dir) + + + transform = out_meta["transform"] + del out_meta + + x_pixel, y_pixel = geo_to_pixel_coordinates(lat, lon, transform) + + # TODO + # Temporal constant, replace with a better solution + KS = 120 + + x_offset_range = [ + x_pixel - patch_width + KS + 1, + x_pixel - KS - 1, + ] + y_offset_range = [ + y_pixel - patch_height + KS + 1, + y_pixel - KS - 1, + ] + + # Randomly select an offset within the valid range + x_offset = random.randint(*x_offset_range) + y_offset = random.randint(*y_offset_range) + + x_offset = np.clip(x_offset, 0, mosaic.shape[-1] - patch_width) + y_offset = np.clip(y_offset, 0, mosaic.shape[-2] - patch_height) + + # Update x, y to reflect the clamping of x_offset and y_offset + x, y = x_pixel - x_offset, y_pixel - y_offset + patch = mosaic[ + :, y_offset : y_offset + patch_height, x_offset : x_offset + patch_width + ] + + patch_transform = rasterio.transform.Affine( + transform.a, + transform.b, + transform.c + x_offset * transform.a + y_offset * transform.b, + transform.d, + transform.e, + transform.f + x_offset * transform.d + y_offset * transform.e, + ) + + gc.collect() + + return patch, x, y, x_offset, y_offset, patch_transform + +def get_tile_from_coord( + lat: float, lng: float, zoom_level: int +) -> mercantile.Tile: + """ + Returns the tile containing the given coordinates. + """ + tile = mercantile.tile(lng, lat, zoom_level) + return tile + +def geo_to_pixel_coordinates( + lat: float, lon: float, transform: rasterio.transform.Affine +) -> (int, int): + """ + Converts a pair of (lat, lon) coordinates to pixel coordinates. + """ + x_pixel, y_pixel = ~transform * (lon, lat) + return round(x_pixel), round(y_pixel) diff --git a/gluefactory/datasets/sat_utils.py b/gluefactory/datasets/sat_utils.py new file mode 100644 index 0000000..261c529 --- /dev/null +++ b/gluefactory/datasets/sat_utils.py @@ -0,0 +1,109 @@ +#! /usr/bin/env python3 + +import os +import mercantile +import rasterio +import numpy as np +import random +import warnings +from rasterio.errors import NotGeoreferencedWarning +from rasterio.io import MemoryFile +from rasterio.transform import from_bounds +from rasterio.merge import merge +from PIL import Image +import gc + + +def get_3x3_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]: + neighbors = [] + for neighbour in mercantile.neighbors(tile): + if neighbour not in neighbors: + neighbors.append(neighbour) + + neighbors.append(tile) + return neighbors + +def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir:str) -> (np.ndarray, dict): + """ + Returns a TIFF map of the given tile. + """ + tile_data = [] + neighbors = get_3x3_neighbors(tile) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=NotGeoreferencedWarning) + for neighbor in neighbors: + west, south, east, north = mercantile.bounds(neighbor) + tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg" + if not os.path.exists(tile_path): + raise FileNotFoundError( + f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found." + ) + + with Image.open(tile_path) as img: + width, height = img.size + memfile = MemoryFile() + with memfile.open( + driver="GTiff", + height=height, + width=width, + count=3, + dtype="uint8", + crs="EPSG:3857", + transform=from_bounds(west, south, east, north, width, height), + ) as dataset: + data = rasterio.open(tile_path).read() + dataset.write(data) + tile_data.append(memfile.open()) + memfile.close() + + mosaic, out_trans = merge(tile_data) + + out_meta = tile_data[0].meta.copy() + out_meta.update( + { + "driver": "GTiff", + "height": mosaic.shape[1], + "width": mosaic.shape[2], + "transform": out_trans, + "crs": "EPSG:3857", + } + ) + + # Clean up MemoryFile instances to free up memory + for td in tile_data: + td.close() + + del neighbors + del tile_data + gc.collect() + + return mosaic, out_meta + +def get_random_tiff_patch( + lat: float, + lon: float, + satellite_dataset_dir: str, +) -> (np.ndarray): + """ + Returns a random patch from the satellite image. + """ + + tile = get_tile_from_coord(lat, lon, 17) + sat_years = ["2023", "2021", "2019", "2016"] + + # Randomly select a satellite year + sat_year = random.choice(sat_years) + + mosaic, _ = get_tiff_map(tile, sat_year, satellite_dataset_dir) + + return mosaic + +def get_tile_from_coord( + lat: float, lng: float, zoom_level: int +) -> mercantile.Tile: + """ + Returns the tile containing the given coordinates. + """ + tile = mercantile.tile(lng, lat, zoom_level) + return tile diff --git a/gluefactory/datasets/satellites.py b/gluefactory/datasets/satellites.py new file mode 100644 index 0000000..fc25ca9 --- /dev/null +++ b/gluefactory/datasets/satellites.py @@ -0,0 +1,297 @@ +""" +Simply load images from a folder or nested folders (does not have any split), +and apply homographic adaptations to it. Yields an image pair without border +artifacts. +""" + +import argparse +import logging +import shutil +import tarfile +from pathlib import Path +import mercantile +import rasterio + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import omegaconf +import torch +from omegaconf import OmegaConf +from tqdm import tqdm +from PIL import Image + +from ..geometry.homography import ( + compute_homography, + sample_homography_corners, + warp_points, +) +from ..models.cache_loader import CacheLoader, pad_local_features +from ..settings import DATA_PATH +from ..utils.image import read_image +from ..utils.tools import fork_rng +from ..visualization.viz2d import plot_image_grid +from .augmentations import IdentityAugmentation, augmentations +from .base_dataset import BaseDataset +from .sat_utils import get_random_tiff_patch + +logger = logging.getLogger(__name__) + + +def sample_homography(img, conf: dict, size: list): + data = {} + H, _, coords, _ = sample_homography_corners(img.shape[:2][::-1], **conf) + data["image"] = cv2.warpPerspective(img, H, tuple(size)) + data["H_"] = H.astype(np.float32) + data["coords"] = coords.astype(np.float32) + data["image_size"] = np.array(size, dtype=np.float32) + return data + + +class SatelliteDataset(BaseDataset): + default_conf = { + # image search + "data_dir": "revisitop1m", # the top-level directory + "image_dir": "jpg/", # the subdirectory with the images + "image_list": "revisitop1m.txt", # optional: list or filename of list + "glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"], + "metadata_dir": None, + # splits + "train_size": 100, + "val_size": 10, + "shuffle_seed": 0, # or None to skip + # image loading + "grayscale": False, + "triplet": False, + "right_only": False, # image0 is orig (rescaled), image1 is right + "reseed": False, + "homography": { + "difficulty": 0.8, + "translation": 1.0, + "max_angle": 60, + "n_angles": 10, + "patch_shape": [640, 480], + "min_convexity": 0.05, + }, + "photometric": { + "name": "dark", + "p": 0.75, + # 'difficulty': 1.0, # currently unused + }, + # feature loading + "load_features": { + "do": False, + **CacheLoader.default_conf, + "collate": False, + "thresh": 0.0, + "max_num_keypoints": -1, + "force_num_keypoints": False, + }, + } + + def _init(self, conf): + data_dir = conf.data_dir + coordinates_file = conf.metadata_dir + + with open(coordinates_file, "r") as cf: + coordinates = cf.readlines() + + parsed_coordinates = [] + + for coordinate in coordinates: + lat_part, lon_part = coordinate.split(',') + lat = float(lat_part.split(':')[-1].strip()) + lon = float(lon_part.split(':')[-1].strip()) + parsed_coordinates.append((lat, lon)) + + # Split into train and val 20% val + + train_images = parsed_coordinates[:int(len(parsed_coordinates) * 0.8)] + val_images = parsed_coordinates[int(len(parsed_coordinates) * 0.8):] + + self.images = {"train": train_images, "val": val_images} + + def get_dataset(self, split): + return _Dataset(self.conf, self.images[split], split) + + +class _Dataset(torch.utils.data.Dataset): + def __init__(self, conf, image_names, split): + self.conf = conf + self.split = split + self.image_names = np.array(image_names) + self.image_dir = DATA_PATH / conf.data_dir / conf.image_dir + + aug_conf = conf.photometric + aug_name = aug_conf.name + assert ( + aug_name in augmentations.keys() + ), f'{aug_name} not in {" ".join(augmentations.keys())}' + #self.left_augment = ( + # IdentityAugmentation() if conf.right_only else self.photo_augment + #) + self.photo_augment = augmentations[aug_name](aug_conf) + self.left_augment = augmentations[aug_name](aug_conf) + self.img_to_tensor = IdentityAugmentation() + + if conf.load_features.do: + self.feature_loader = CacheLoader(conf.load_features) + + def _transform_keypoints(self, features, data): + """Transform keypoints by a homography, threshold them, + and potentially keep only the best ones.""" + # Warp points + features["keypoints"] = warp_points( + features["keypoints"], data["H_"], inverse=False + ) + h, w = data["image"].shape[1:3] + valid = ( + (features["keypoints"][:, 0] >= 0) + & (features["keypoints"][:, 0] <= w - 1) + & (features["keypoints"][:, 1] >= 0) + & (features["keypoints"][:, 1] <= h - 1) + ) + features["keypoints"] = features["keypoints"][valid] + + # Threshold + if self.conf.load_features.thresh > 0: + valid = features["keypoint_scores"] >= self.conf.load_features.thresh + features = {k: v[valid] for k, v in features.items()} + + # Get the top keypoints and pad + n = self.conf.load_features.max_num_keypoints + if n > -1: + inds = np.argsort(-features["keypoint_scores"]) + features = {k: v[inds[:n]] for k, v in features.items()} + + if self.conf.load_features.force_num_keypoints: + features = pad_local_features( + features, self.conf.load_features.max_num_keypoints + ) + + return features + + def __getitem__(self, idx): + if self.conf.reseed: + with fork_rng(self.conf.seed + idx, False): + return self.getitem(idx) + else: + return self.getitem(idx) + + def _read_view(self, img, H_conf, ps, left=False): + data = sample_homography(img, H_conf, ps) + if left: + data["image"] = self.left_augment(data["image"], return_tensor=True) + else: + data["image"] = self.photo_augment(data["image"], return_tensor=True) + + gs = data["image"].new_tensor([0.299, 0.587, 0.114]).view(3, 1, 1) + if self.conf.grayscale: + data["image"] = (data["image"] * gs).sum(0, keepdim=True) + + if self.conf.load_features.do: + features = self.feature_loader({k: [v] for k, v in data.items()}) + features = self._transform_keypoints(features, data) + data["cache"] = features + + return data + + def getitem(self, idx): + # Generate a list of coordinates, based on the coordinate do the split + lat, lon = self.image_names[idx] + img = get_random_tiff_patch(lat, lon, self.conf.data_dir) + + if img is None: + logging.warning("Image %f %f could not be read.", lat, lon) + img = np.zeros((1024, 1024) + (() if self.conf.grayscale else (3,))) + + img = img.transpose(1,2,0) + img = img.astype(np.float32) / 255.0 + #write_status = cv2.imwrite(f"/mnt/drive/{str(idx)}.jpg", img) + #if write_status == True: + # print("Writing success") + #else: + # print("fuck") + size = img.shape[:2][::-1] + ps = self.conf.homography.patch_shape + + left_conf = omegaconf.OmegaConf.to_container(self.conf.homography) + if self.conf.right_only: + left_conf["difficulty"] = 0.0 + + data0 = self._read_view(img, left_conf, ps, left=True) + data1 = self._read_view(img, self.conf.homography, ps, left=False) + + print("Data0 homography_matrix", data0["H_"]) + print("Data1 homography matrix", data1["H_"]) + +# image_1 = data0["image"] +# image_1 = image_1.numpy() +# image_1 = image_1.transpose(1,2,0) +# image_1 = np.uint8(image_1 * 255) +# +# image_2 = data1["image"] +# image_2 = image_2.numpy() +# image_2 = image_2.transpose(1,2,0) +# image_2 = np.uint8(image_2 * 255) +# +# PIL_IMAGE = Image.fromarray(image_1) +# PIL_IMAGE.save(f"/mnt/drive/{str(idx)}.jpg") +# PIL_IMAGE = Image.fromarray(image_2) +# PIL_IMAGE.save(f"/mnt/drive/{str(idx)}-s.jpg") +# +# exit() +# + H = compute_homography(data0["coords"], data1["coords"], [1, 1]) + + print("COmputed homography", H) + + name = f"lat{lat}_lon{lon}" + + data = { + "name": name, + "original_image_size": np.array(size), + "H_0to1": H.astype(np.float32), + "idx": idx, + "view0": data0, + "view1": data1, + } + + return data + + def __len__(self): + return len(self.image_names) + + +def visualize(args): + conf = { + "batch_size": 1, + "num_workers": 1, + "prefetch_factor": 1, + } + conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist)) + dataset = SatelliteDataset(conf) + loader = dataset.get_data_loader("train") + logger.info("The dataset has %d elements.", len(loader)) + + with fork_rng(seed=dataset.conf.seed): + images = [] + for _, data in zip(range(args.num_items), loader): + images.append( + (data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2)) + ) + plot_image_grid(images, dpi=args.dpi) + plt.tight_layout() + plt.imsave("implot.png") + #plt.show() + + +if __name__ == "__main__": + from .. import logger # overwrite the logger + + parser = argparse.ArgumentParser() + parser.add_argument("--num_items", type=int, default=8) + parser.add_argument("--dpi", type=int, default=100) + parser.add_argument("dotlist", nargs="*") + args = parser.parse_intermixed_args() + visualize(args) diff --git a/gluefactory/geometry/homography.py b/gluefactory/geometry/homography.py index f87b9f9..2585500 100644 --- a/gluefactory/geometry/homography.py +++ b/gluefactory/geometry/homography.py @@ -53,6 +53,7 @@ def sample_homography_corners( min_pts1 = create_center_patch(shape, (pwidth, pheight)) full = create_center_patch(shape) pts2 = create_center_patch(patch_shape) + scale = min_pts1 - full found_valid = False cnt = -1 @@ -68,7 +69,9 @@ def sample_homography_corners( # Rotation if n_angles > 0 and difficulty > 0: - angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles) + #angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles) + # In our case the difficulty parameter should not affect the rotation + angles = np.linspace(-max_angle, max_angle, n_angles) rng.shuffle(angles) rng.shuffle(angles) angles = np.concatenate([[0.0], angles], axis=0) diff --git a/gluefactory/train.py b/gluefactory/train.py index debf212..1cddb26 100644 --- a/gluefactory/train.py +++ b/gluefactory/train.py @@ -12,6 +12,7 @@ import signal from collections import defaultdict from pathlib import Path from pydoc import locate +import gc import numpy as np import torch @@ -89,6 +90,7 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True): for i, data in enumerate( tqdm(loader, desc="Evaluation", ascii=True, disable=not pbar) ): + data = batch_to_device(data, device, non_blocking=True) with torch.no_grad(): pred = model(data) @@ -102,7 +104,6 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True): pred[v["predictions"]], mask=pred[v["mask"]] if "mask" in v.keys() else None, ) - del pred, data numbers = {**metrics, **{"loss/" + k: v for k, v in losses.items()}} for k, v in numbers.items(): if k not in results: @@ -118,7 +119,9 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True): if k in conf.recall_metrics.keys(): q = conf.recall_metrics[k] results[k + f"_recall{int(q)}"].update(v) - del numbers + + del pred, data, losses, metrics + gc.collect() results = {k: results[k].compute() for k in results} return results, {k: v.compute() for k, v in pr_metrics.items()}, figures @@ -257,6 +260,7 @@ def training(rank, conf, output_dir, args): dataset = get_dataset(data_conf.name)(data_conf) + # Optionally load a different validation dataset than the training one val_data_conf = conf.get("data_val", None) if val_data_conf is None: @@ -408,7 +412,9 @@ def training(rank, conf, output_dir, args): getattr(loader.dataset, conf.train.dataset_callback_fn)( conf.train.seed + epoch ) + for it, data in enumerate(train_loader): + tot_it = (len(train_loader) * epoch + it) * ( args.n_gpus if args.distributed else 1 ) @@ -427,10 +433,17 @@ def training(rank, conf, output_dir, args): loss = torch.mean(losses["total"]) if torch.isnan(loss).any(): print(f"Detected NAN, skipping iteration {it}") + print("name", data["name"]) + print("loss", loss) + print("losses", losses) + print("data", data) + print("pred", pred) + raise RuntimeError("Detected NAN in training.") del pred, data, loss, losses continue do_backward = loss.requires_grad + if args.distributed: do_backward = torch.tensor(do_backward).float().to(device) torch.distributed.all_reduce( @@ -469,7 +482,6 @@ def training(rank, conf, output_dir, args): else: if rank == 0: logger.warning(f"Skip iteration {it} due to detach.") - if args.profile: prof.step() @@ -508,8 +520,11 @@ def training(rank, conf, output_dir, args): norm = torch.norm(param.grad.detach(), 2) grad_txt += f"{name} {norm.item():.3f} \n" writer.add_text("grad/summary", grad_txt, tot_n_samples) - del pred, data, loss, losses + pred.clear() + data.clear() + del pred, data, loss, losses + gc.collect() # Run validation if ( ( @@ -529,6 +544,7 @@ def training(rank, conf, output_dir, args): pbar=(rank == -1), ) + if rank == 0: str_results = [ f"{k} {v:.3E}" @@ -569,6 +585,10 @@ def training(rank, conf, output_dir, args): f"figures/{i}_{name}", fig, tot_n_samples ) torch.cuda.empty_cache() # should be cleared at the first iter + str_results.clear() + pr_metrics.clear() + del str_results, figures, pr_metrics + if (tot_it % conf.train.save_every_iter == 0 and tot_it > 0) and rank == 0: if results is None: @@ -622,7 +642,7 @@ def training(rank, conf, output_dir, args): writer.close() -def main_worker(rank, conf, output_dir, args): +def main_worker(rank, conf, output_dir, aprgs): if rank == 0: with capture_outputs(output_dir / "log.txt"): training(rank, conf, output_dir, args) diff --git a/gluefactory/utils/tensor.py b/gluefactory/utils/tensor.py index d0a8ca5..b382d4e 100644 --- a/gluefactory/utils/tensor.py +++ b/gluefactory/utils/tensor.py @@ -33,6 +33,15 @@ def batch_to_device(batch, device, non_blocking=True): return map_tensor(batch, _func) +def detach_tensors(batch): + """ + Detach all tensors in a batch recursively. + This is useful for detaching tensors from the computational graph to free up memory. + """ + def _detach(tensor): + return tensor.detach() + + return map_tensor(batch, _detach) def rbd(data: dict) -> dict: """Remove batch dimension from elements in data""" diff --git a/gluefactory/utils/tools.py b/gluefactory/utils/tools.py index 6a27f4a..6f0e4d9 100644 --- a/gluefactory/utils/tools.py +++ b/gluefactory/utils/tools.py @@ -67,8 +67,12 @@ class MedianMetric: else: return np.nanmedian(self._elements) + def __len__(self): + return len(self._elements) + class PRMetric: + def __init__(self): self.labels = [] self.predictions = [] diff --git a/gluefactory_nonfree/superpoint.py b/gluefactory_nonfree/superpoint.py index 682d392..e25a6b4 100644 --- a/gluefactory_nonfree/superpoint.py +++ b/gluefactory_nonfree/superpoint.py @@ -338,6 +338,14 @@ class SuperPoint(BaseModel): for k, d in zip(keypoints, dense_desc) ] + if isinstance(desc, list): + if all(isinstance(d, torch.Tensor) for d in desc): + # If desc is a list of tensors + desc = torch.stack(desc) + else: + # If desc is a list of non-tensor elements + desc = torch.stack([torch.tensor(d) for d in desc]) + pred = { "keypoints": keypoints + 0.5, "keypoint_scores": scores, diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..e92a1dc --- /dev/null +++ b/run.sh @@ -0,0 +1,12 @@ +#python -m gluefactory.train sp+lg_homography \ +# --conf gluefactory/configs/superpoint+lightglue_homography.yaml \ +# data.batch_size=32 # for 1x 1080 GPU + +#python -m gluefactory.train satellites_aug_both_40 \ +# --conf gluefactory/configs/satellites3.yaml \ +# data.batch_size=4 # for 1x 1080 GPU + + +python -m gluefactory.train drone2sat_v1 \ + --conf gluefactory/configs/drone2sat_v1.yaml \ + data.batch_size=6 # for 1x 1080 GPU diff --git a/stats.sh b/stats.sh new file mode 100644 index 0000000..c3ef45d --- /dev/null +++ b/stats.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +LOGFILE="/var/log/h/system_stats.log" + +echo "Logging CPU, Memory, Swap, Disk, and Network usage to $LOGFILE" + +while true; do + echo "----- $(date) -----" >> $LOGFILE + echo "CPU Usage:" >> $LOGFILE + mpstat -P ALL 1 1 >> $LOGFILE + + echo "Disk Usage:" >> $LOGFILE + iostat >> $LOGFILE + + echo "Memory and Swap Usage:" >> $LOGFILE + free -m >> $LOGFILE + + echo "Network Usage:" >> $LOGFILE + ifstat 1 1 >> $LOGFILE + + sleep 5 +done