Add custom training files
parent
4a8283517f
commit
d40e423a42
|
@ -0,0 +1,69 @@
|
|||
data:
|
||||
name: drone2sat
|
||||
batch_size: 4
|
||||
num_workers: 24
|
||||
val_size: 500
|
||||
train_size: -1
|
||||
photometric:
|
||||
name: lg
|
||||
geo_dataset:
|
||||
uav_dataset_dir: /mnt/drive/uav_dataset
|
||||
satellite_dataset_dir: /mnt/drive/tiles
|
||||
misslabeled_images_path: /mnt/drive/misslabeled.txt
|
||||
sat_zoom_level: 17
|
||||
uav_patch_width: 400
|
||||
uav_patch_height: 400
|
||||
sat_patch_width: 400
|
||||
sat_patch_height: 400
|
||||
test_from_train_ratio: 0.0
|
||||
transform_mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
transform_std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
sat_availaible_years:
|
||||
- "2023"
|
||||
- "2021"
|
||||
- "2019"
|
||||
- "2016"
|
||||
max_rotation_angle: 0
|
||||
uav_image_scale: 2.0
|
||||
use_heatmap: false
|
||||
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: gluefactory_nonfree.superpoint
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
nms_radius: 3
|
||||
ground_truth:
|
||||
name: matchers.homography_matcher
|
||||
th_positive: 3
|
||||
th_negative: 3
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
features: superpoint
|
||||
depth_confidence: -1
|
||||
width_confidence: -1
|
||||
filter_threshold: 0.1
|
||||
flash: true
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 40
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 1000 # 350000 / 4
|
||||
lr: 1e-3
|
||||
lr_schedule:
|
||||
start: 20
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
benchmarks:
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
|
@ -0,0 +1,48 @@
|
|||
data:
|
||||
name: satellites
|
||||
data_dir: /mnt/drive/tiles
|
||||
metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt
|
||||
train_size: null
|
||||
val_size: null
|
||||
batch_size: 128
|
||||
num_workers: 24
|
||||
homography:
|
||||
difficulty: 0.9
|
||||
max_angle: 359
|
||||
photometric:
|
||||
name: lg
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: gluefactory_nonfree.superpoint
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
force_num_keypoints: True
|
||||
nms_radius: 3
|
||||
trainable: False
|
||||
ground_truth:
|
||||
name: matchers.homography_matcher
|
||||
th_positive: 3
|
||||
th_negative: 3
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
filter_threshold: 0.1
|
||||
flash: true
|
||||
checkpointed: true
|
||||
weights: superpoint
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 5
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 500
|
||||
lr: 1e-7
|
||||
lr_schedule:
|
||||
start: 20
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
benchmarks:
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
|
@ -0,0 +1,60 @@
|
|||
data:
|
||||
name: satellites
|
||||
data_dir: /mnt/drive/tiles
|
||||
metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt
|
||||
train_size: null
|
||||
val_size: null
|
||||
batch_size: 128
|
||||
num_workers: 24
|
||||
homography:
|
||||
difficulty: 0.5
|
||||
max_angle: 359
|
||||
photometric:
|
||||
name: lg
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: gluefactory_nonfree.superpoint
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
nms_radius: 3
|
||||
#extractor:
|
||||
# name: gluefactory_nonfree.superpoint
|
||||
# max_num_keypoints: 512
|
||||
# force_num_keypoints: True
|
||||
# detection_threshold: 0.0
|
||||
# nms_radius: 3
|
||||
# trainable: False
|
||||
ground_truth:
|
||||
name: matchers.homography_matcher
|
||||
th_positive: 3
|
||||
th_negative: 3
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
features: superpoint
|
||||
depth_confidence: -1
|
||||
width_confidence: -1
|
||||
filter_threshold: 0.1
|
||||
flash: true
|
||||
#matcher:
|
||||
# name: matchers.lightglue_pretrained
|
||||
# features: superpoint
|
||||
# depth_confidence: -1
|
||||
# width_confidence: -1
|
||||
# filter_threshold: 0.1
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 5
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 500
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
start: 20
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
benchmarks:
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
|
@ -0,0 +1,47 @@
|
|||
data:
|
||||
name: satellites
|
||||
data_dir: /mnt/drive/tiles
|
||||
metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt
|
||||
train_size: null
|
||||
val_size: null
|
||||
batch_size: 128
|
||||
num_workers: 24
|
||||
homography:
|
||||
difficulty: 0.5
|
||||
max_angle: 180
|
||||
photometric:
|
||||
name: lg
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: gluefactory_nonfree.superpoint
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
nms_radius: 3
|
||||
ground_truth:
|
||||
name: matchers.homography_matcher
|
||||
th_positive: 3
|
||||
th_negative: 3
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
features: superpoint
|
||||
depth_confidence: -1
|
||||
width_confidence: -1
|
||||
filter_threshold: 0.1
|
||||
flash: true
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 40
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 500
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
start: 20
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
benchmarks:
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
|
@ -33,6 +33,7 @@ train:
|
|||
epochs: 40
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 500
|
||||
profile: true
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
start: 20
|
||||
|
|
|
@ -168,6 +168,11 @@ class BaseDataset(metaclass=ABCMeta):
|
|||
sampler = None
|
||||
if shuffle is None:
|
||||
shuffle = split == "train" and self.conf.shuffle_training
|
||||
shuffle = split == "val"
|
||||
|
||||
shuffle = True
|
||||
|
||||
print("Shuffle", shuffle)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
|
|
|
@ -0,0 +1,694 @@
|
|||
"""
|
||||
Simply load images from a folder or nested folders (does not have any split),
|
||||
and apply homographic adaptations to it. Yields an image pair without border
|
||||
artifacts.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
import os
|
||||
import json
|
||||
import gc
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import omegaconf
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from typing import List, Literal
|
||||
from torchvision import transforms
|
||||
import random
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
|
||||
|
||||
from ..geometry.homography import (
|
||||
compute_homography,
|
||||
sample_homography_corners,
|
||||
warp_points,
|
||||
)
|
||||
from ..models.cache_loader import CacheLoader, pad_local_features
|
||||
from ..settings import DATA_PATH
|
||||
from ..utils.image import read_image
|
||||
from ..utils.tools import fork_rng
|
||||
from ..visualization.viz2d import plot_image_grid
|
||||
from .augmentations import IdentityAugmentation, augmentations
|
||||
from .base_dataset import BaseDataset
|
||||
from .s_utils import get_random_tiff_patch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Drone2SatDataset(BaseDataset):
|
||||
default_conf = {
|
||||
# image search
|
||||
"data_dir": "revisitop1m", # the top-level directory
|
||||
"image_dir": "jpg/", # the subdirectory with the images
|
||||
"image_list": "revisitop1m.txt", # optional: list or filename of list
|
||||
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
|
||||
"metadata_dir": None,
|
||||
"uav_dataset_dir": None,
|
||||
"satellite_dataset_dir": None,
|
||||
# splits
|
||||
"train_size": 100,
|
||||
"val_size": 10,
|
||||
"shuffle_seed": 0, # or None to skip
|
||||
# image loading
|
||||
"grayscale": False,
|
||||
"triplet": False,
|
||||
"right_only": False, # image0 is orig (rescaled), image1 is right
|
||||
"reseed": False,
|
||||
"homography": {
|
||||
"difficulty": 0.8,
|
||||
"translation": 1.0,
|
||||
"max_angle": 60,
|
||||
"n_angles": 10,
|
||||
"patch_shape": [640, 480],
|
||||
"min_convexity": 0.05,
|
||||
},
|
||||
"photometric": {
|
||||
"name": "dark",
|
||||
"p": 0.75,
|
||||
# 'difficulty': 1.0, # currently unused
|
||||
},
|
||||
# feature loading
|
||||
"load_features": {
|
||||
"do": False,
|
||||
**CacheLoader.default_conf,
|
||||
"collate": False,
|
||||
"thresh": 0.0,
|
||||
"max_num_keypoints": -1,
|
||||
"force_num_keypoints": False,
|
||||
},
|
||||
# Other geolocaliztion parameters
|
||||
"geo_dataset" :{
|
||||
"uav_dataset_dir": None,
|
||||
"satellite_dataset_dir": None,
|
||||
"misslabeled_images_path": None,
|
||||
"sat_zoom_level": 17,
|
||||
"uav_patch_width": 400,
|
||||
"uav_patch_height": 400,
|
||||
"sat_patch_width": 400,
|
||||
"sat_patch_height": 400,
|
||||
"heatmap_kernel_size": 33,
|
||||
"test_from_train_ratio": 0.0,
|
||||
"transform_mean": [0.485, 0.456, 0.406],
|
||||
"transform_std": [0.229, 0.224, 0.225],
|
||||
"sat_availaible_years": ["2023", "2021", "2019", "2016"],
|
||||
"max_rotation_angle": 10,
|
||||
"uav_image_scale": 1,
|
||||
"use_heatmap": False,
|
||||
}
|
||||
}
|
||||
|
||||
def _init(self, conf):
|
||||
self.images = {"train": "train", "val": "val"}
|
||||
|
||||
def get_dataset(self, split):
|
||||
if split == "val":
|
||||
return GeoLocalizationDataset(
|
||||
uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir,
|
||||
satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir,
|
||||
misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path,
|
||||
dataset="test",
|
||||
sat_zoom_level=self.conf.geo_dataset.sat_zoom_level,
|
||||
uav_patch_width=self.conf.geo_dataset.uav_patch_width,
|
||||
uav_patch_height=self.conf.geo_dataset.uav_patch_height,
|
||||
sat_patch_width=self.conf.geo_dataset.sat_patch_width,
|
||||
sat_patch_height=self.conf.geo_dataset.sat_patch_height,
|
||||
heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size,
|
||||
test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio,
|
||||
transform_mean=self.conf.geo_dataset.transform_mean,
|
||||
transform_std=self.conf.geo_dataset.transform_std,
|
||||
sat_available_years=self.conf.geo_dataset.sat_availaible_years,
|
||||
max_rotation_angle=self.conf.geo_dataset.max_rotation_angle,
|
||||
uav_image_scale=self.conf.geo_dataset.uav_image_scale,
|
||||
use_heatmap=self.conf.geo_dataset.use_heatmap,
|
||||
subset_size=self.conf.val_size,
|
||||
)
|
||||
elif split == "train":
|
||||
return GeoLocalizationDataset(
|
||||
uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir,
|
||||
satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir,
|
||||
misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path,
|
||||
dataset="train",
|
||||
sat_zoom_level=self.conf.geo_dataset.sat_zoom_level,
|
||||
uav_patch_width=self.conf.geo_dataset.uav_patch_width,
|
||||
uav_patch_height=self.conf.geo_dataset.uav_patch_height,
|
||||
sat_patch_width=self.conf.geo_dataset.sat_patch_width,
|
||||
sat_patch_height=self.conf.geo_dataset.sat_patch_height,
|
||||
heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size,
|
||||
test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio,
|
||||
transform_mean=self.conf.geo_dataset.transform_mean,
|
||||
transform_std=self.conf.geo_dataset.transform_std,
|
||||
sat_available_years=self.conf.geo_dataset.sat_availaible_years,
|
||||
max_rotation_angle=self.conf.geo_dataset.max_rotation_angle,
|
||||
uav_image_scale=self.conf.geo_dataset.uav_image_scale,
|
||||
use_heatmap=self.conf.geo_dataset.use_heatmap,
|
||||
subset_size=self.conf.train_size,
|
||||
)
|
||||
|
||||
|
||||
class GeoLocalizationDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
uav_dataset_dir: str,
|
||||
satellite_dataset_dir: str,
|
||||
misslabeled_images_path: str,
|
||||
dataset: Literal["train", "test"],
|
||||
sat_zoom_level: int = 16,
|
||||
uav_patch_width: int = 128,
|
||||
uav_patch_height: int = 128,
|
||||
sat_patch_width: int = 400,
|
||||
sat_patch_height: int = 400,
|
||||
heatmap_kernel_size: int = 33,
|
||||
test_from_train_ratio: float = 0.0,
|
||||
transform_mean: List[float] = [0.485, 0.456, 0.406],
|
||||
transform_std: List[float] = [0.229, 0.224, 0.225],
|
||||
sat_available_years: List[str] = ["2023", "2021", "2019", "2016"],
|
||||
max_rotation_angle: int = 10,
|
||||
uav_image_scale: float = 1,
|
||||
use_heatmap: bool = True,
|
||||
subset_size: int = None,
|
||||
):
|
||||
self.uav_dataset_dir = uav_dataset_dir
|
||||
self.satellite_dataset_dir = satellite_dataset_dir
|
||||
self.dataset = dataset
|
||||
self.sat_zoom_level = sat_zoom_level
|
||||
self.uav_patch_width = uav_patch_width
|
||||
self.uav_patch_height = uav_patch_height
|
||||
self.heatmap_kernel_size = heatmap_kernel_size
|
||||
self.test_from_train_ratio = test_from_train_ratio
|
||||
self.transform_mean = transform_mean
|
||||
self.transform_std = transform_std
|
||||
self.misslabeled_images_path = misslabeled_images_path
|
||||
self.metadata_dict = {}
|
||||
self.max_rotation_angle = max_rotation_angle
|
||||
self.total_uav_samples = self.count_total_uav_samples()
|
||||
self.misslabelled_images = self.read_misslabelled_images(self.misslabeled_images_path)
|
||||
self.entry_paths = self.get_entry_paths(self.uav_dataset_dir)
|
||||
self.cleanup_misslabelled_images()
|
||||
self.transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
#transforms.Normalize(self.transform_mean, self.transform_std),
|
||||
]
|
||||
)
|
||||
self.sat_available_years = sat_available_years
|
||||
self.uav_image_scale = uav_image_scale
|
||||
self.use_heatmap = use_heatmap
|
||||
self.sat_patch_width = sat_patch_width
|
||||
self.sat_patch_height = sat_patch_height
|
||||
self.subset_size = subset_size
|
||||
|
||||
self.inverse_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Normalize(
|
||||
mean=[
|
||||
-m / s for m, s in zip(self.transform_mean, self.transform_std)
|
||||
],
|
||||
std=[1 / s for s in self.transform_std],
|
||||
),
|
||||
transforms.ToPILImage(),
|
||||
]
|
||||
)
|
||||
|
||||
if self.subset_size != -1:
|
||||
ssize = int(self.subset_size / len(self.sat_available_years))
|
||||
ssize = min(ssize, len(self.entry_paths))
|
||||
self.entry_paths = self.entry_paths[:ssize]
|
||||
|
||||
|
||||
def __len__(self) -> int:
|
||||
return (
|
||||
len(self.entry_paths)
|
||||
* len(self.sat_available_years)
|
||||
)
|
||||
|
||||
def read_misslabelled_images(
|
||||
self, path: str = "misslabels/misslabeled.txt"
|
||||
) -> List[str]:
|
||||
with open(path, "r") as f:
|
||||
lines = f.readlines()
|
||||
return [line.strip() for line in lines]
|
||||
|
||||
def cleanup_misslabelled_images(self) -> None:
|
||||
indices_to_delete = []
|
||||
|
||||
for image in self.misslabelled_images:
|
||||
for image_path in self.entry_paths:
|
||||
if image in image_path:
|
||||
index = self.entry_paths.index(image_path)
|
||||
indices_to_delete.append(index)
|
||||
break
|
||||
|
||||
sorted_tuples = sorted(indices_to_delete, reverse=True)
|
||||
|
||||
for index in sorted_tuples:
|
||||
self.entry_paths.pop(index)
|
||||
|
||||
def __getitem__(self, idx) -> dict:
|
||||
"""
|
||||
Retrieves a sample given its index, returning the preprocessed UAV
|
||||
and satellite images, along with their associated heatmap and metadata.
|
||||
"""
|
||||
|
||||
image_path_index = idx // (
|
||||
len(self.sat_available_years)
|
||||
)
|
||||
|
||||
sat_year = self.sat_available_years[idx % len(self.sat_available_years)]
|
||||
rot_angle = random.randint(-self.max_rotation_angle, self.max_rotation_angle)
|
||||
|
||||
image_path = self.entry_paths[image_path_index]
|
||||
uav_image = Image.open(image_path).convert("RGB") # Ensure 3-channel image
|
||||
|
||||
original_uav_image_width = uav_image.width
|
||||
original_uav_image_height = uav_image.height
|
||||
|
||||
lookup_str, file_number = self.extract_info_from_filename(image_path)
|
||||
img_info = self.metadata_dict[lookup_str][file_number]
|
||||
|
||||
lat, lon = (
|
||||
img_info["coordinate"]["latitude"],
|
||||
img_info["coordinate"]["longitude"],
|
||||
)
|
||||
|
||||
fov_vertical = img_info["fovVertical"]
|
||||
|
||||
try:
|
||||
agl_altitude = float(image_path.split("/")[-1].split("_")[2].split("m")[0])
|
||||
except IndexError:
|
||||
agl_altitude = 150.0
|
||||
warnings.warn(
|
||||
"Could not extract AGL altitude from filename, using default value of 150m."
|
||||
)
|
||||
|
||||
(
|
||||
satellite_patch,
|
||||
x_sat,
|
||||
y_sat,
|
||||
x_offset,
|
||||
y_offset,
|
||||
patch_transform,
|
||||
) = get_random_tiff_patch(
|
||||
lat, lon, self.sat_patch_width, self.sat_patch_height, sat_year, self.satellite_dataset_dir
|
||||
)
|
||||
|
||||
# Rotate crop center and transform image
|
||||
h = np.ceil(uav_image.height // self.uav_image_scale).astype(int)
|
||||
w = np.ceil(uav_image.width // self.uav_image_scale).astype(int)
|
||||
|
||||
uav_image = F.rotate(uav_image, rot_angle)
|
||||
uav_image = F.resize(uav_image, [h, w])
|
||||
uav_image = F.center_crop(
|
||||
uav_image, (self.uav_patch_height, self.uav_patch_width)
|
||||
)
|
||||
uav_image = self.transforms(uav_image)
|
||||
|
||||
satellite_patch = satellite_patch.transpose(1, 2, 0)
|
||||
satellite_patch_pytorch = self.transforms(satellite_patch)
|
||||
del satellite_patch
|
||||
|
||||
if self.use_heatmap:
|
||||
heatmap = self.get_heatmap_gt(
|
||||
x_sat,
|
||||
y_sat,
|
||||
satellite_patch.shape[1],
|
||||
satellite_patch.shape[2],
|
||||
self.heatmap_kernel_size,
|
||||
)
|
||||
|
||||
cropped_uav_image_width = self.calculate_cropped_uav_image_width(
|
||||
fov_vertical,
|
||||
original_uav_image_width,
|
||||
original_uav_image_height,
|
||||
self.uav_patch_width,
|
||||
self.uav_patch_height,
|
||||
agl_altitude,
|
||||
)
|
||||
|
||||
satellite_tile_width = self.calculate_cropped_sat_image_width(
|
||||
lat, self.sat_patch_width, patch_transform
|
||||
)
|
||||
|
||||
scale_factor = cropped_uav_image_width / satellite_tile_width
|
||||
scale_factor *= 10
|
||||
|
||||
homography_matrix = self.compute_homography(
|
||||
rot_angle,
|
||||
x_sat,
|
||||
y_sat,
|
||||
self.uav_patch_width,
|
||||
self.uav_patch_height,
|
||||
scale_factor,
|
||||
)
|
||||
|
||||
if not self.use_heatmap:
|
||||
# Sample four points from the UAV image
|
||||
points = self.sample_four_points(
|
||||
self.uav_patch_width, self.uav_patch_height
|
||||
)
|
||||
|
||||
# Transform the points
|
||||
warped_points = self.warp_points(points, homography_matrix)
|
||||
#img_info["warped_points_sat"] = warped_points
|
||||
#img_info["warped_points_uav"] = points
|
||||
|
||||
# img_info["cropped_uav_image_width"] = cropped_uav_image_width
|
||||
# img_info["satellite_tile_width"] = satellite_tile_width
|
||||
# img_info["scale_factor"] = scale_factor
|
||||
# img_info["filename"] = image_path
|
||||
# img_info["rot_angle"] = rot_angle
|
||||
# img_info["x_sat"] = x_sat
|
||||
# img_info["y_sat"] = y_sat
|
||||
# img_info["x_offset"] = x_offset
|
||||
# img_info["y_offset"] = y_offset
|
||||
# img_info["patch_transform"] = patch_transform
|
||||
# img_info["uav_image_scale"] = self.uav_image_scale
|
||||
# img_info["homography_matrix_uav_to_sat"] = homography_matrix
|
||||
# img_info["homography_matrix_sat_to_uav"] = np.linalg.inv(homography_matrix)
|
||||
# img_info["agl_altitude"] = agl_altitude
|
||||
# img_info["original_uav_image_width"] = original_uav_image_width
|
||||
# img_info["original_drone_image_height"] = original_uav_image_height
|
||||
# img_info["fov_vertical"] = fov_vertical
|
||||
#
|
||||
|
||||
inverse_homography_matrix = np.linalg.inv(homography_matrix)
|
||||
uav_image_data = {
|
||||
"image": uav_image,
|
||||
"H_": homography_matrix,
|
||||
"coords": points,
|
||||
"image_size": np.array([self.uav_patch_width, self.uav_patch_height]),
|
||||
}
|
||||
|
||||
satellite_patch_data = {
|
||||
"image": satellite_patch_pytorch,
|
||||
"H_": inverse_homography_matrix,
|
||||
"coords": warped_points,
|
||||
"image_size": np.array([self.sat_patch_width, self.sat_patch_height]),
|
||||
}
|
||||
|
||||
#hm = self.compute_homography_points(
|
||||
# warped_points, points, [self.uav_patch_width, self.uav_patch_height]
|
||||
#)
|
||||
|
||||
hm = self.compute_homography_points(
|
||||
points, warped_points, [1.0 , 1.0]
|
||||
)
|
||||
|
||||
if np.array_equal(hm, np.eye(3)):
|
||||
print("Singular matrix")
|
||||
return self.__getitem__(random.randint(0, len(self) - 1))
|
||||
|
||||
data = {
|
||||
"name": f"{image_path}_{rot_angle}_{sat_year}",
|
||||
"original_image_size": np.array(
|
||||
[original_uav_image_width, original_uav_image_height]
|
||||
),
|
||||
"H_0to1": hm.astype(np.float32),
|
||||
"view0": uav_image_data,
|
||||
"view1": satellite_patch_data,
|
||||
"idx": idx
|
||||
}
|
||||
del img_info
|
||||
gc.collect()
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def calculate_cropped_sat_image_width(self, latitude, patch_width, patch_transform):
|
||||
"""
|
||||
Computes the width of the satellite image in world coordinate units
|
||||
"""
|
||||
length_of_degree = 111320 * np.cos(np.radians(latitude))
|
||||
scale_x_meters = patch_transform[0] * length_of_degree
|
||||
satellite_tile_width = patch_width * scale_x_meters
|
||||
return satellite_tile_width
|
||||
|
||||
def calculate_cropped_uav_image_width(
|
||||
self,
|
||||
fov_vertical,
|
||||
orig_width,
|
||||
orig_height,
|
||||
crop_width,
|
||||
crop_height,
|
||||
altitude=150.0,
|
||||
):
|
||||
"""
|
||||
Computes the width of the UAV image in world coordinate units
|
||||
"""
|
||||
# Convert fov from degrees to radians
|
||||
fov_rad = np.radians(fov_vertical)
|
||||
|
||||
# Calculate the full width of the UAV image
|
||||
full_width = 2 * (altitude * np.tan(fov_rad / 2))
|
||||
|
||||
# Determine the cropping ratio
|
||||
crop_ratio_width = crop_width / orig_width
|
||||
crop_ratio_height = crop_height / orig_height
|
||||
|
||||
# Calculate the adjusted horizontal fov
|
||||
fov_horizontal = 2 * np.arctan(np.tan(fov_rad / 2) * (orig_width / orig_height))
|
||||
adjusted_fov_horizontal = 2 * np.arctan(
|
||||
np.tan(fov_horizontal / 2) * crop_ratio_width
|
||||
)
|
||||
|
||||
# Calculate the new full width using the adjusted horizontal fov
|
||||
full_width = 2 * (altitude * np.tan(adjusted_fov_horizontal / 2))
|
||||
|
||||
# Adjust the width according to the crop ratio
|
||||
cropped_width = full_width * crop_ratio_width
|
||||
|
||||
return cropped_width
|
||||
|
||||
def compute_homography_points(self, pts1_, pts2_, shape):
|
||||
"""Compute the homography matrix from 4 point correspondences"""
|
||||
# Rescale to actual size
|
||||
shape = np.array(shape[::-1], dtype=np.float32) # different convention [y, x]
|
||||
pts1 = pts1_ * np.expand_dims(shape, axis=0)
|
||||
pts2 = pts2_ * np.expand_dims(shape, axis=0)
|
||||
|
||||
def ax(p, q):
|
||||
return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
|
||||
|
||||
def ay(p, q):
|
||||
return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
|
||||
|
||||
def flat2mat(H):
|
||||
return np.reshape(np.concatenate([H, np.ones_like(H[:, :1])], axis=1), [3, 3])
|
||||
|
||||
a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0)
|
||||
p_mat = np.transpose(
|
||||
np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0)
|
||||
)
|
||||
|
||||
try:
|
||||
homography = np.transpose(np.linalg.solve(a_mat, p_mat))
|
||||
except np.linalg.LinAlgError:
|
||||
print("Singular matrix")
|
||||
return np.eye(3)
|
||||
return flat2mat(homography)
|
||||
|
||||
|
||||
def compute_homography(
|
||||
self, rot_angle, x_sat, y_sat, uav_width, uav_height, scale_factor
|
||||
):
|
||||
# Adjust rot_angle if it's greater than 180 degrees
|
||||
if rot_angle > 180:
|
||||
rot_angle -= 360
|
||||
# Convert rotation angle to radians
|
||||
theta = np.radians(rot_angle)
|
||||
|
||||
# Rotation matrix
|
||||
R = np.array(
|
||||
[
|
||||
[np.cos(theta), -np.sin(theta), 0],
|
||||
[np.sin(theta), np.cos(theta), 0],
|
||||
[0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Scale matrix
|
||||
S = np.array([[scale_factor, 0, 0], [0, scale_factor, 0], [0, 0, 1]])
|
||||
|
||||
# Translation matrix to center the UAV image
|
||||
T_uav = np.array([[1, 0, -uav_width / 2], [0, 1, -uav_height / 2], [0, 0, 1]])
|
||||
|
||||
# Translation matrix to move to the satellite image position
|
||||
T_sat = np.array([[1, 0, x_sat], [0, 1, y_sat], [0, 0, 1]])
|
||||
|
||||
# Compute the combined homography matrix
|
||||
H = np.dot(T_sat, np.dot(R, np.dot(S, T_uav)))
|
||||
|
||||
return H
|
||||
|
||||
def sample_four_points(self, width: int, height: int) -> np.ndarray:
|
||||
"""
|
||||
Samples four points from the UAV image.
|
||||
"""
|
||||
PADDING = 50
|
||||
CENTER_PADDING = 10
|
||||
points = np.array(
|
||||
[
|
||||
[random.randint(CENTER_PADDING, width - PADDING), random.randint(CENTER_PADDING, height - PADDING)]
|
||||
for _ in range(4)
|
||||
]
|
||||
)
|
||||
return points
|
||||
|
||||
def warp_points(
|
||||
self, points: np.ndarray, homography_matrix: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Warps the given points using the given homography matrix.
|
||||
"""
|
||||
points = np.array(points)
|
||||
points = np.concatenate([points, np.ones((4, 1))], axis=1)
|
||||
points = np.dot(homography_matrix, points.T).T
|
||||
points = points[:, :2] / points[:, 2:]
|
||||
return points
|
||||
|
||||
def count_total_uav_samples(self) -> int:
|
||||
"""
|
||||
Count the total number of uav image samples in the dataset
|
||||
(train + test)
|
||||
"""
|
||||
total_samples = 0
|
||||
|
||||
for dirpath, dirnames, filenames in os.walk(self.uav_dataset_dir):
|
||||
# Skip the test folder
|
||||
for filename in filenames:
|
||||
if filename.endswith(".jpeg"):
|
||||
total_samples += 1
|
||||
return total_samples
|
||||
|
||||
def get_number_of_city_samples(self) -> int:
|
||||
"""
|
||||
TODO: Count the total number of city samples in the dataset
|
||||
"""
|
||||
return 11
|
||||
|
||||
def get_entry_paths(self, directory: str) -> List[str]:
|
||||
"""
|
||||
Recursively retrieves paths to image and metadata files in the given directory.
|
||||
"""
|
||||
entry_paths = []
|
||||
entries = os.listdir(directory)
|
||||
|
||||
images_to_take_per_folder = int(
|
||||
self.total_uav_samples
|
||||
* self.test_from_train_ratio
|
||||
/ self.get_number_of_city_samples()
|
||||
)
|
||||
|
||||
for entry in entries:
|
||||
entry_path = os.path.join(directory, entry)
|
||||
|
||||
# If it's a directory, recurse into it
|
||||
if os.path.isdir(entry_path):
|
||||
entry_paths += self.get_entry_paths(entry_path)
|
||||
|
||||
# Handle train dataset
|
||||
elif (self.dataset == "train" and "Train" in entry_path) or (
|
||||
self.dataset == "train"
|
||||
and self.test_from_train_ratio > 0
|
||||
and "Test" in entry_path
|
||||
):
|
||||
if entry_path.endswith(".jpeg"):
|
||||
_, number = self.extract_info_from_filename(entry_path)
|
||||
else:
|
||||
number = None
|
||||
if entry_path.endswith(".json"):
|
||||
self.get_metadata(entry_path)
|
||||
if number is None:
|
||||
continue
|
||||
if (
|
||||
number >= images_to_take_per_folder
|
||||
): # Only include images beyond the ones taken for test
|
||||
if entry_path.endswith(".jpeg"):
|
||||
entry_paths.append(entry_path)
|
||||
|
||||
# Handle test dataset
|
||||
elif self.dataset == "test":
|
||||
if entry_path.endswith(".jpeg"):
|
||||
_, number = self.extract_info_from_filename(entry_path)
|
||||
else:
|
||||
number = None
|
||||
if entry_path.endswith(".json"):
|
||||
self.get_metadata(entry_path)
|
||||
|
||||
if number is None:
|
||||
continue
|
||||
if (
|
||||
("Test" in entry_path and number < images_to_take_per_folder)
|
||||
or (number < images_to_take_per_folder and "Train" in entry_path)
|
||||
or (self.test_from_train_ratio == 0.0 and "Test" in entry_path)
|
||||
):
|
||||
if entry_path.endswith(".jpeg"):
|
||||
entry_paths.append(entry_path)
|
||||
|
||||
return sorted(entry_paths, key=self.extract_info_from_filename)
|
||||
|
||||
def get_metadata(self, path: str) -> None:
|
||||
"""
|
||||
Extracts metadata from a JSON file and stores it in the metadata dictionary.
|
||||
"""
|
||||
with open(path, newline="") as jsonfile:
|
||||
json_dict = json.load(jsonfile)
|
||||
path = path.split("/")[-1]
|
||||
path = path.replace(".json", "")
|
||||
self.metadata_dict[path] = json_dict["cameraFrames"]
|
||||
|
||||
def extract_info_from_filename(self, filename: str) -> (str, int):
|
||||
"""
|
||||
Extracts information from the filename.
|
||||
"""
|
||||
filename_without_ext = filename.replace(".jpeg", "")
|
||||
segments = filename_without_ext.split("/")
|
||||
info = segments[-1]
|
||||
try:
|
||||
number = int(info.split("_")[-1])
|
||||
except ValueError:
|
||||
print("Could not extract number from filename: ", filename)
|
||||
return None, None
|
||||
|
||||
info = "_".join(info.split("_")[:-1])
|
||||
|
||||
return info, number
|
||||
|
||||
def get_heatmap_gt(
|
||||
self, x: int, y: int, height: int, width: int, square_size: int = 33
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns 2D heatmap ground truth for the given x and y coordinates,
|
||||
with the given square size.
|
||||
"""
|
||||
x_map, y_map = x, y
|
||||
|
||||
heatmap = torch.zeros((height, width))
|
||||
|
||||
half_size = square_size // 2
|
||||
|
||||
# Calculate the valid range for the square
|
||||
start_x = max(0, x_map - half_size)
|
||||
end_x = min(
|
||||
width, x_map + half_size + 1
|
||||
) # +1 to include the end_x in the square
|
||||
start_y = max(0, y_map - half_size)
|
||||
end_y = min(
|
||||
height, y_map + half_size + 1
|
||||
) # +1 to include the end_y in the square
|
||||
|
||||
heatmap[start_y:end_y, start_x:end_x] = 1
|
||||
|
||||
return heatmap
|
||||
|
||||
if __name__ == "__main__":
|
||||
from .. import logger # overwrite the logger
|
|
@ -0,0 +1,163 @@
|
|||
#! /usr/bin/env python3
|
||||
|
||||
import os
|
||||
import mercantile
|
||||
import rasterio
|
||||
import numpy as np
|
||||
import random
|
||||
import warnings
|
||||
from rasterio.errors import NotGeoreferencedWarning
|
||||
from rasterio.io import MemoryFile
|
||||
from rasterio.transform import from_bounds
|
||||
from rasterio.merge import merge
|
||||
from PIL import Image
|
||||
import gc
|
||||
import random
|
||||
from osgeo import gdal, osr
|
||||
from affine import Affine
|
||||
|
||||
|
||||
def get_5x5_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]:
|
||||
neighbors = []
|
||||
for main_neighbour in mercantile.neighbors(tile):
|
||||
for sub_neighbour in mercantile.neighbors(main_neighbour):
|
||||
if sub_neighbour not in neighbors:
|
||||
neighbors.append(sub_neighbour)
|
||||
return neighbors
|
||||
|
||||
def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir: str) -> (np.ndarray, dict):
|
||||
"""
|
||||
Returns a TIFF map of the given tile using GDAL.
|
||||
"""
|
||||
tile_data = []
|
||||
neighbors = get_5x5_neighbors(tile)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
for neighbor in neighbors:
|
||||
west, south, east, north = mercantile.bounds(neighbor)
|
||||
tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg"
|
||||
if not os.path.exists(tile_path):
|
||||
raise FileNotFoundError(f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found.")
|
||||
|
||||
img = Image.open(tile_path)
|
||||
img_array = np.array(img)
|
||||
|
||||
# Create an in-memory GDAL dataset
|
||||
mem_driver = gdal.GetDriverByName('MEM')
|
||||
dataset = mem_driver.Create('', img_array.shape[1], img_array.shape[0], 3, gdal.GDT_Byte)
|
||||
for i in range(3):
|
||||
dataset.GetRasterBand(i + 1).WriteArray(img_array[:, :, i])
|
||||
|
||||
# Set GeoTransform and Projection
|
||||
geotransform = (west, (east - west) / img_array.shape[1], 0, north, 0, -(north - south) / img_array.shape[0])
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
srs = osr.SpatialReference()
|
||||
srs.ImportFromEPSG(3857)
|
||||
dataset.SetProjection(srs.ExportToWkt())
|
||||
|
||||
tile_data.append(dataset)
|
||||
|
||||
# Merge tiles using GDAL
|
||||
vrt_options = gdal.BuildVRTOptions()
|
||||
vrt = gdal.BuildVRT('', [td for td in tile_data], options=vrt_options)
|
||||
mosaic = vrt.ReadAsArray()
|
||||
|
||||
# Get metadata
|
||||
out_trans = vrt.GetGeoTransform()
|
||||
out_crs = vrt.GetProjection()
|
||||
out_trans = Affine.from_gdal(*out_trans)
|
||||
out_meta = {
|
||||
"driver": "GTiff",
|
||||
"height": mosaic.shape[1],
|
||||
"width": mosaic.shape[2],
|
||||
"transform": out_trans,
|
||||
"crs": out_crs,
|
||||
}
|
||||
|
||||
# Clean up
|
||||
for td in tile_data:
|
||||
td.FlushCache()
|
||||
vrt = None
|
||||
gc.collect()
|
||||
|
||||
return mosaic, out_meta
|
||||
|
||||
def get_random_tiff_patch(
|
||||
lat: float,
|
||||
lon: float,
|
||||
patch_width: int,
|
||||
patch_height: int,
|
||||
sat_year: str,
|
||||
satellite_dataset_dir: str = "/mnt/drive/satellite_dataset",
|
||||
) -> (np.ndarray, int, int, int, int, rasterio.transform.Affine):
|
||||
"""
|
||||
Returns a random patch from the satellite image.
|
||||
"""
|
||||
|
||||
tile = get_tile_from_coord(lat, lon, 17)
|
||||
|
||||
mosaic, out_meta = get_tiff_map(tile, sat_year, satellite_dataset_dir)
|
||||
|
||||
|
||||
transform = out_meta["transform"]
|
||||
del out_meta
|
||||
|
||||
x_pixel, y_pixel = geo_to_pixel_coordinates(lat, lon, transform)
|
||||
|
||||
# TODO
|
||||
# Temporal constant, replace with a better solution
|
||||
KS = 120
|
||||
|
||||
x_offset_range = [
|
||||
x_pixel - patch_width + KS + 1,
|
||||
x_pixel - KS - 1,
|
||||
]
|
||||
y_offset_range = [
|
||||
y_pixel - patch_height + KS + 1,
|
||||
y_pixel - KS - 1,
|
||||
]
|
||||
|
||||
# Randomly select an offset within the valid range
|
||||
x_offset = random.randint(*x_offset_range)
|
||||
y_offset = random.randint(*y_offset_range)
|
||||
|
||||
x_offset = np.clip(x_offset, 0, mosaic.shape[-1] - patch_width)
|
||||
y_offset = np.clip(y_offset, 0, mosaic.shape[-2] - patch_height)
|
||||
|
||||
# Update x, y to reflect the clamping of x_offset and y_offset
|
||||
x, y = x_pixel - x_offset, y_pixel - y_offset
|
||||
patch = mosaic[
|
||||
:, y_offset : y_offset + patch_height, x_offset : x_offset + patch_width
|
||||
]
|
||||
|
||||
patch_transform = rasterio.transform.Affine(
|
||||
transform.a,
|
||||
transform.b,
|
||||
transform.c + x_offset * transform.a + y_offset * transform.b,
|
||||
transform.d,
|
||||
transform.e,
|
||||
transform.f + x_offset * transform.d + y_offset * transform.e,
|
||||
)
|
||||
|
||||
gc.collect()
|
||||
|
||||
return patch, x, y, x_offset, y_offset, patch_transform
|
||||
|
||||
def get_tile_from_coord(
|
||||
lat: float, lng: float, zoom_level: int
|
||||
) -> mercantile.Tile:
|
||||
"""
|
||||
Returns the tile containing the given coordinates.
|
||||
"""
|
||||
tile = mercantile.tile(lng, lat, zoom_level)
|
||||
return tile
|
||||
|
||||
def geo_to_pixel_coordinates(
|
||||
lat: float, lon: float, transform: rasterio.transform.Affine
|
||||
) -> (int, int):
|
||||
"""
|
||||
Converts a pair of (lat, lon) coordinates to pixel coordinates.
|
||||
"""
|
||||
x_pixel, y_pixel = ~transform * (lon, lat)
|
||||
return round(x_pixel), round(y_pixel)
|
|
@ -0,0 +1,109 @@
|
|||
#! /usr/bin/env python3
|
||||
|
||||
import os
|
||||
import mercantile
|
||||
import rasterio
|
||||
import numpy as np
|
||||
import random
|
||||
import warnings
|
||||
from rasterio.errors import NotGeoreferencedWarning
|
||||
from rasterio.io import MemoryFile
|
||||
from rasterio.transform import from_bounds
|
||||
from rasterio.merge import merge
|
||||
from PIL import Image
|
||||
import gc
|
||||
|
||||
|
||||
def get_3x3_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]:
|
||||
neighbors = []
|
||||
for neighbour in mercantile.neighbors(tile):
|
||||
if neighbour not in neighbors:
|
||||
neighbors.append(neighbour)
|
||||
|
||||
neighbors.append(tile)
|
||||
return neighbors
|
||||
|
||||
def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir:str) -> (np.ndarray, dict):
|
||||
"""
|
||||
Returns a TIFF map of the given tile.
|
||||
"""
|
||||
tile_data = []
|
||||
neighbors = get_3x3_neighbors(tile)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)
|
||||
for neighbor in neighbors:
|
||||
west, south, east, north = mercantile.bounds(neighbor)
|
||||
tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg"
|
||||
if not os.path.exists(tile_path):
|
||||
raise FileNotFoundError(
|
||||
f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found."
|
||||
)
|
||||
|
||||
with Image.open(tile_path) as img:
|
||||
width, height = img.size
|
||||
memfile = MemoryFile()
|
||||
with memfile.open(
|
||||
driver="GTiff",
|
||||
height=height,
|
||||
width=width,
|
||||
count=3,
|
||||
dtype="uint8",
|
||||
crs="EPSG:3857",
|
||||
transform=from_bounds(west, south, east, north, width, height),
|
||||
) as dataset:
|
||||
data = rasterio.open(tile_path).read()
|
||||
dataset.write(data)
|
||||
tile_data.append(memfile.open())
|
||||
memfile.close()
|
||||
|
||||
mosaic, out_trans = merge(tile_data)
|
||||
|
||||
out_meta = tile_data[0].meta.copy()
|
||||
out_meta.update(
|
||||
{
|
||||
"driver": "GTiff",
|
||||
"height": mosaic.shape[1],
|
||||
"width": mosaic.shape[2],
|
||||
"transform": out_trans,
|
||||
"crs": "EPSG:3857",
|
||||
}
|
||||
)
|
||||
|
||||
# Clean up MemoryFile instances to free up memory
|
||||
for td in tile_data:
|
||||
td.close()
|
||||
|
||||
del neighbors
|
||||
del tile_data
|
||||
gc.collect()
|
||||
|
||||
return mosaic, out_meta
|
||||
|
||||
def get_random_tiff_patch(
|
||||
lat: float,
|
||||
lon: float,
|
||||
satellite_dataset_dir: str,
|
||||
) -> (np.ndarray):
|
||||
"""
|
||||
Returns a random patch from the satellite image.
|
||||
"""
|
||||
|
||||
tile = get_tile_from_coord(lat, lon, 17)
|
||||
sat_years = ["2023", "2021", "2019", "2016"]
|
||||
|
||||
# Randomly select a satellite year
|
||||
sat_year = random.choice(sat_years)
|
||||
|
||||
mosaic, _ = get_tiff_map(tile, sat_year, satellite_dataset_dir)
|
||||
|
||||
return mosaic
|
||||
|
||||
def get_tile_from_coord(
|
||||
lat: float, lng: float, zoom_level: int
|
||||
) -> mercantile.Tile:
|
||||
"""
|
||||
Returns the tile containing the given coordinates.
|
||||
"""
|
||||
tile = mercantile.tile(lng, lat, zoom_level)
|
||||
return tile
|
|
@ -0,0 +1,297 @@
|
|||
"""
|
||||
Simply load images from a folder or nested folders (does not have any split),
|
||||
and apply homographic adaptations to it. Yields an image pair without border
|
||||
artifacts.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
import mercantile
|
||||
import rasterio
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import omegaconf
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from ..geometry.homography import (
|
||||
compute_homography,
|
||||
sample_homography_corners,
|
||||
warp_points,
|
||||
)
|
||||
from ..models.cache_loader import CacheLoader, pad_local_features
|
||||
from ..settings import DATA_PATH
|
||||
from ..utils.image import read_image
|
||||
from ..utils.tools import fork_rng
|
||||
from ..visualization.viz2d import plot_image_grid
|
||||
from .augmentations import IdentityAugmentation, augmentations
|
||||
from .base_dataset import BaseDataset
|
||||
from .sat_utils import get_random_tiff_patch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def sample_homography(img, conf: dict, size: list):
|
||||
data = {}
|
||||
H, _, coords, _ = sample_homography_corners(img.shape[:2][::-1], **conf)
|
||||
data["image"] = cv2.warpPerspective(img, H, tuple(size))
|
||||
data["H_"] = H.astype(np.float32)
|
||||
data["coords"] = coords.astype(np.float32)
|
||||
data["image_size"] = np.array(size, dtype=np.float32)
|
||||
return data
|
||||
|
||||
|
||||
class SatelliteDataset(BaseDataset):
|
||||
default_conf = {
|
||||
# image search
|
||||
"data_dir": "revisitop1m", # the top-level directory
|
||||
"image_dir": "jpg/", # the subdirectory with the images
|
||||
"image_list": "revisitop1m.txt", # optional: list or filename of list
|
||||
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
|
||||
"metadata_dir": None,
|
||||
# splits
|
||||
"train_size": 100,
|
||||
"val_size": 10,
|
||||
"shuffle_seed": 0, # or None to skip
|
||||
# image loading
|
||||
"grayscale": False,
|
||||
"triplet": False,
|
||||
"right_only": False, # image0 is orig (rescaled), image1 is right
|
||||
"reseed": False,
|
||||
"homography": {
|
||||
"difficulty": 0.8,
|
||||
"translation": 1.0,
|
||||
"max_angle": 60,
|
||||
"n_angles": 10,
|
||||
"patch_shape": [640, 480],
|
||||
"min_convexity": 0.05,
|
||||
},
|
||||
"photometric": {
|
||||
"name": "dark",
|
||||
"p": 0.75,
|
||||
# 'difficulty': 1.0, # currently unused
|
||||
},
|
||||
# feature loading
|
||||
"load_features": {
|
||||
"do": False,
|
||||
**CacheLoader.default_conf,
|
||||
"collate": False,
|
||||
"thresh": 0.0,
|
||||
"max_num_keypoints": -1,
|
||||
"force_num_keypoints": False,
|
||||
},
|
||||
}
|
||||
|
||||
def _init(self, conf):
|
||||
data_dir = conf.data_dir
|
||||
coordinates_file = conf.metadata_dir
|
||||
|
||||
with open(coordinates_file, "r") as cf:
|
||||
coordinates = cf.readlines()
|
||||
|
||||
parsed_coordinates = []
|
||||
|
||||
for coordinate in coordinates:
|
||||
lat_part, lon_part = coordinate.split(',')
|
||||
lat = float(lat_part.split(':')[-1].strip())
|
||||
lon = float(lon_part.split(':')[-1].strip())
|
||||
parsed_coordinates.append((lat, lon))
|
||||
|
||||
# Split into train and val 20% val
|
||||
|
||||
train_images = parsed_coordinates[:int(len(parsed_coordinates) * 0.8)]
|
||||
val_images = parsed_coordinates[int(len(parsed_coordinates) * 0.8):]
|
||||
|
||||
self.images = {"train": train_images, "val": val_images}
|
||||
|
||||
def get_dataset(self, split):
|
||||
return _Dataset(self.conf, self.images[split], split)
|
||||
|
||||
|
||||
class _Dataset(torch.utils.data.Dataset):
|
||||
def __init__(self, conf, image_names, split):
|
||||
self.conf = conf
|
||||
self.split = split
|
||||
self.image_names = np.array(image_names)
|
||||
self.image_dir = DATA_PATH / conf.data_dir / conf.image_dir
|
||||
|
||||
aug_conf = conf.photometric
|
||||
aug_name = aug_conf.name
|
||||
assert (
|
||||
aug_name in augmentations.keys()
|
||||
), f'{aug_name} not in {" ".join(augmentations.keys())}'
|
||||
#self.left_augment = (
|
||||
# IdentityAugmentation() if conf.right_only else self.photo_augment
|
||||
#)
|
||||
self.photo_augment = augmentations[aug_name](aug_conf)
|
||||
self.left_augment = augmentations[aug_name](aug_conf)
|
||||
self.img_to_tensor = IdentityAugmentation()
|
||||
|
||||
if conf.load_features.do:
|
||||
self.feature_loader = CacheLoader(conf.load_features)
|
||||
|
||||
def _transform_keypoints(self, features, data):
|
||||
"""Transform keypoints by a homography, threshold them,
|
||||
and potentially keep only the best ones."""
|
||||
# Warp points
|
||||
features["keypoints"] = warp_points(
|
||||
features["keypoints"], data["H_"], inverse=False
|
||||
)
|
||||
h, w = data["image"].shape[1:3]
|
||||
valid = (
|
||||
(features["keypoints"][:, 0] >= 0)
|
||||
& (features["keypoints"][:, 0] <= w - 1)
|
||||
& (features["keypoints"][:, 1] >= 0)
|
||||
& (features["keypoints"][:, 1] <= h - 1)
|
||||
)
|
||||
features["keypoints"] = features["keypoints"][valid]
|
||||
|
||||
# Threshold
|
||||
if self.conf.load_features.thresh > 0:
|
||||
valid = features["keypoint_scores"] >= self.conf.load_features.thresh
|
||||
features = {k: v[valid] for k, v in features.items()}
|
||||
|
||||
# Get the top keypoints and pad
|
||||
n = self.conf.load_features.max_num_keypoints
|
||||
if n > -1:
|
||||
inds = np.argsort(-features["keypoint_scores"])
|
||||
features = {k: v[inds[:n]] for k, v in features.items()}
|
||||
|
||||
if self.conf.load_features.force_num_keypoints:
|
||||
features = pad_local_features(
|
||||
features, self.conf.load_features.max_num_keypoints
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.conf.reseed:
|
||||
with fork_rng(self.conf.seed + idx, False):
|
||||
return self.getitem(idx)
|
||||
else:
|
||||
return self.getitem(idx)
|
||||
|
||||
def _read_view(self, img, H_conf, ps, left=False):
|
||||
data = sample_homography(img, H_conf, ps)
|
||||
if left:
|
||||
data["image"] = self.left_augment(data["image"], return_tensor=True)
|
||||
else:
|
||||
data["image"] = self.photo_augment(data["image"], return_tensor=True)
|
||||
|
||||
gs = data["image"].new_tensor([0.299, 0.587, 0.114]).view(3, 1, 1)
|
||||
if self.conf.grayscale:
|
||||
data["image"] = (data["image"] * gs).sum(0, keepdim=True)
|
||||
|
||||
if self.conf.load_features.do:
|
||||
features = self.feature_loader({k: [v] for k, v in data.items()})
|
||||
features = self._transform_keypoints(features, data)
|
||||
data["cache"] = features
|
||||
|
||||
return data
|
||||
|
||||
def getitem(self, idx):
|
||||
# Generate a list of coordinates, based on the coordinate do the split
|
||||
lat, lon = self.image_names[idx]
|
||||
img = get_random_tiff_patch(lat, lon, self.conf.data_dir)
|
||||
|
||||
if img is None:
|
||||
logging.warning("Image %f %f could not be read.", lat, lon)
|
||||
img = np.zeros((1024, 1024) + (() if self.conf.grayscale else (3,)))
|
||||
|
||||
img = img.transpose(1,2,0)
|
||||
img = img.astype(np.float32) / 255.0
|
||||
#write_status = cv2.imwrite(f"/mnt/drive/{str(idx)}.jpg", img)
|
||||
#if write_status == True:
|
||||
# print("Writing success")
|
||||
#else:
|
||||
# print("fuck")
|
||||
size = img.shape[:2][::-1]
|
||||
ps = self.conf.homography.patch_shape
|
||||
|
||||
left_conf = omegaconf.OmegaConf.to_container(self.conf.homography)
|
||||
if self.conf.right_only:
|
||||
left_conf["difficulty"] = 0.0
|
||||
|
||||
data0 = self._read_view(img, left_conf, ps, left=True)
|
||||
data1 = self._read_view(img, self.conf.homography, ps, left=False)
|
||||
|
||||
print("Data0 homography_matrix", data0["H_"])
|
||||
print("Data1 homography matrix", data1["H_"])
|
||||
|
||||
# image_1 = data0["image"]
|
||||
# image_1 = image_1.numpy()
|
||||
# image_1 = image_1.transpose(1,2,0)
|
||||
# image_1 = np.uint8(image_1 * 255)
|
||||
#
|
||||
# image_2 = data1["image"]
|
||||
# image_2 = image_2.numpy()
|
||||
# image_2 = image_2.transpose(1,2,0)
|
||||
# image_2 = np.uint8(image_2 * 255)
|
||||
#
|
||||
# PIL_IMAGE = Image.fromarray(image_1)
|
||||
# PIL_IMAGE.save(f"/mnt/drive/{str(idx)}.jpg")
|
||||
# PIL_IMAGE = Image.fromarray(image_2)
|
||||
# PIL_IMAGE.save(f"/mnt/drive/{str(idx)}-s.jpg")
|
||||
#
|
||||
# exit()
|
||||
#
|
||||
H = compute_homography(data0["coords"], data1["coords"], [1, 1])
|
||||
|
||||
print("COmputed homography", H)
|
||||
|
||||
name = f"lat{lat}_lon{lon}"
|
||||
|
||||
data = {
|
||||
"name": name,
|
||||
"original_image_size": np.array(size),
|
||||
"H_0to1": H.astype(np.float32),
|
||||
"idx": idx,
|
||||
"view0": data0,
|
||||
"view1": data1,
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_names)
|
||||
|
||||
|
||||
def visualize(args):
|
||||
conf = {
|
||||
"batch_size": 1,
|
||||
"num_workers": 1,
|
||||
"prefetch_factor": 1,
|
||||
}
|
||||
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
|
||||
dataset = SatelliteDataset(conf)
|
||||
loader = dataset.get_data_loader("train")
|
||||
logger.info("The dataset has %d elements.", len(loader))
|
||||
|
||||
with fork_rng(seed=dataset.conf.seed):
|
||||
images = []
|
||||
for _, data in zip(range(args.num_items), loader):
|
||||
images.append(
|
||||
(data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2))
|
||||
)
|
||||
plot_image_grid(images, dpi=args.dpi)
|
||||
plt.tight_layout()
|
||||
plt.imsave("implot.png")
|
||||
#plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from .. import logger # overwrite the logger
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num_items", type=int, default=8)
|
||||
parser.add_argument("--dpi", type=int, default=100)
|
||||
parser.add_argument("dotlist", nargs="*")
|
||||
args = parser.parse_intermixed_args()
|
||||
visualize(args)
|
|
@ -53,6 +53,7 @@ def sample_homography_corners(
|
|||
min_pts1 = create_center_patch(shape, (pwidth, pheight))
|
||||
full = create_center_patch(shape)
|
||||
pts2 = create_center_patch(patch_shape)
|
||||
|
||||
scale = min_pts1 - full
|
||||
found_valid = False
|
||||
cnt = -1
|
||||
|
@ -68,7 +69,9 @@ def sample_homography_corners(
|
|||
|
||||
# Rotation
|
||||
if n_angles > 0 and difficulty > 0:
|
||||
angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles)
|
||||
#angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles)
|
||||
# In our case the difficulty parameter should not affect the rotation
|
||||
angles = np.linspace(-max_angle, max_angle, n_angles)
|
||||
rng.shuffle(angles)
|
||||
rng.shuffle(angles)
|
||||
angles = np.concatenate([[0.0], angles], axis=0)
|
||||
|
|
|
@ -12,6 +12,7 @@ import signal
|
|||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from pydoc import locate
|
||||
import gc
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -89,6 +90,7 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
|
|||
for i, data in enumerate(
|
||||
tqdm(loader, desc="Evaluation", ascii=True, disable=not pbar)
|
||||
):
|
||||
|
||||
data = batch_to_device(data, device, non_blocking=True)
|
||||
with torch.no_grad():
|
||||
pred = model(data)
|
||||
|
@ -102,7 +104,6 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
|
|||
pred[v["predictions"]],
|
||||
mask=pred[v["mask"]] if "mask" in v.keys() else None,
|
||||
)
|
||||
del pred, data
|
||||
numbers = {**metrics, **{"loss/" + k: v for k, v in losses.items()}}
|
||||
for k, v in numbers.items():
|
||||
if k not in results:
|
||||
|
@ -118,7 +119,9 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
|
|||
if k in conf.recall_metrics.keys():
|
||||
q = conf.recall_metrics[k]
|
||||
results[k + f"_recall{int(q)}"].update(v)
|
||||
del numbers
|
||||
|
||||
del pred, data, losses, metrics
|
||||
gc.collect()
|
||||
results = {k: results[k].compute() for k in results}
|
||||
return results, {k: v.compute() for k, v in pr_metrics.items()}, figures
|
||||
|
||||
|
@ -257,6 +260,7 @@ def training(rank, conf, output_dir, args):
|
|||
|
||||
dataset = get_dataset(data_conf.name)(data_conf)
|
||||
|
||||
|
||||
# Optionally load a different validation dataset than the training one
|
||||
val_data_conf = conf.get("data_val", None)
|
||||
if val_data_conf is None:
|
||||
|
@ -408,7 +412,9 @@ def training(rank, conf, output_dir, args):
|
|||
getattr(loader.dataset, conf.train.dataset_callback_fn)(
|
||||
conf.train.seed + epoch
|
||||
)
|
||||
|
||||
for it, data in enumerate(train_loader):
|
||||
|
||||
tot_it = (len(train_loader) * epoch + it) * (
|
||||
args.n_gpus if args.distributed else 1
|
||||
)
|
||||
|
@ -427,10 +433,17 @@ def training(rank, conf, output_dir, args):
|
|||
loss = torch.mean(losses["total"])
|
||||
if torch.isnan(loss).any():
|
||||
print(f"Detected NAN, skipping iteration {it}")
|
||||
print("name", data["name"])
|
||||
print("loss", loss)
|
||||
print("losses", losses)
|
||||
print("data", data)
|
||||
print("pred", pred)
|
||||
raise RuntimeError("Detected NAN in training.")
|
||||
del pred, data, loss, losses
|
||||
continue
|
||||
|
||||
do_backward = loss.requires_grad
|
||||
|
||||
if args.distributed:
|
||||
do_backward = torch.tensor(do_backward).float().to(device)
|
||||
torch.distributed.all_reduce(
|
||||
|
@ -469,7 +482,6 @@ def training(rank, conf, output_dir, args):
|
|||
else:
|
||||
if rank == 0:
|
||||
logger.warning(f"Skip iteration {it} due to detach.")
|
||||
|
||||
if args.profile:
|
||||
prof.step()
|
||||
|
||||
|
@ -508,8 +520,11 @@ def training(rank, conf, output_dir, args):
|
|||
norm = torch.norm(param.grad.detach(), 2)
|
||||
grad_txt += f"{name} {norm.item():.3f} \n"
|
||||
writer.add_text("grad/summary", grad_txt, tot_n_samples)
|
||||
del pred, data, loss, losses
|
||||
|
||||
pred.clear()
|
||||
data.clear()
|
||||
del pred, data, loss, losses
|
||||
gc.collect()
|
||||
# Run validation
|
||||
if (
|
||||
(
|
||||
|
@ -529,6 +544,7 @@ def training(rank, conf, output_dir, args):
|
|||
pbar=(rank == -1),
|
||||
)
|
||||
|
||||
|
||||
if rank == 0:
|
||||
str_results = [
|
||||
f"{k} {v:.3E}"
|
||||
|
@ -569,6 +585,10 @@ def training(rank, conf, output_dir, args):
|
|||
f"figures/{i}_{name}", fig, tot_n_samples
|
||||
)
|
||||
torch.cuda.empty_cache() # should be cleared at the first iter
|
||||
str_results.clear()
|
||||
pr_metrics.clear()
|
||||
del str_results, figures, pr_metrics
|
||||
|
||||
|
||||
if (tot_it % conf.train.save_every_iter == 0 and tot_it > 0) and rank == 0:
|
||||
if results is None:
|
||||
|
@ -622,7 +642,7 @@ def training(rank, conf, output_dir, args):
|
|||
writer.close()
|
||||
|
||||
|
||||
def main_worker(rank, conf, output_dir, args):
|
||||
def main_worker(rank, conf, output_dir, aprgs):
|
||||
if rank == 0:
|
||||
with capture_outputs(output_dir / "log.txt"):
|
||||
training(rank, conf, output_dir, args)
|
||||
|
|
|
@ -33,6 +33,15 @@ def batch_to_device(batch, device, non_blocking=True):
|
|||
|
||||
return map_tensor(batch, _func)
|
||||
|
||||
def detach_tensors(batch):
|
||||
"""
|
||||
Detach all tensors in a batch recursively.
|
||||
This is useful for detaching tensors from the computational graph to free up memory.
|
||||
"""
|
||||
def _detach(tensor):
|
||||
return tensor.detach()
|
||||
|
||||
return map_tensor(batch, _detach)
|
||||
|
||||
def rbd(data: dict) -> dict:
|
||||
"""Remove batch dimension from elements in data"""
|
||||
|
|
|
@ -67,8 +67,12 @@ class MedianMetric:
|
|||
else:
|
||||
return np.nanmedian(self._elements)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._elements)
|
||||
|
||||
|
||||
class PRMetric:
|
||||
|
||||
def __init__(self):
|
||||
self.labels = []
|
||||
self.predictions = []
|
||||
|
|
|
@ -338,6 +338,14 @@ class SuperPoint(BaseModel):
|
|||
for k, d in zip(keypoints, dense_desc)
|
||||
]
|
||||
|
||||
if isinstance(desc, list):
|
||||
if all(isinstance(d, torch.Tensor) for d in desc):
|
||||
# If desc is a list of tensors
|
||||
desc = torch.stack(desc)
|
||||
else:
|
||||
# If desc is a list of non-tensor elements
|
||||
desc = torch.stack([torch.tensor(d) for d in desc])
|
||||
|
||||
pred = {
|
||||
"keypoints": keypoints + 0.5,
|
||||
"keypoint_scores": scores,
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
#python -m gluefactory.train sp+lg_homography \
|
||||
# --conf gluefactory/configs/superpoint+lightglue_homography.yaml \
|
||||
# data.batch_size=32 # for 1x 1080 GPU
|
||||
|
||||
#python -m gluefactory.train satellites_aug_both_40 \
|
||||
# --conf gluefactory/configs/satellites3.yaml \
|
||||
# data.batch_size=4 # for 1x 1080 GPU
|
||||
|
||||
|
||||
python -m gluefactory.train drone2sat_v1 \
|
||||
--conf gluefactory/configs/drone2sat_v1.yaml \
|
||||
data.batch_size=6 # for 1x 1080 GPU
|
|
@ -0,0 +1,22 @@
|
|||
#!/bin/bash
|
||||
|
||||
LOGFILE="/var/log/h/system_stats.log"
|
||||
|
||||
echo "Logging CPU, Memory, Swap, Disk, and Network usage to $LOGFILE"
|
||||
|
||||
while true; do
|
||||
echo "----- $(date) -----" >> $LOGFILE
|
||||
echo "CPU Usage:" >> $LOGFILE
|
||||
mpstat -P ALL 1 1 >> $LOGFILE
|
||||
|
||||
echo "Disk Usage:" >> $LOGFILE
|
||||
iostat >> $LOGFILE
|
||||
|
||||
echo "Memory and Swap Usage:" >> $LOGFILE
|
||||
free -m >> $LOGFILE
|
||||
|
||||
echo "Network Usage:" >> $LOGFILE
|
||||
ifstat 1 1 >> $LOGFILE
|
||||
|
||||
sleep 5
|
||||
done
|
Loading…
Reference in New Issue