Add custom training files
parent
4a8283517f
commit
d40e423a42
|
@ -0,0 +1,69 @@
|
||||||
|
data:
|
||||||
|
name: drone2sat
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 24
|
||||||
|
val_size: 500
|
||||||
|
train_size: -1
|
||||||
|
photometric:
|
||||||
|
name: lg
|
||||||
|
geo_dataset:
|
||||||
|
uav_dataset_dir: /mnt/drive/uav_dataset
|
||||||
|
satellite_dataset_dir: /mnt/drive/tiles
|
||||||
|
misslabeled_images_path: /mnt/drive/misslabeled.txt
|
||||||
|
sat_zoom_level: 17
|
||||||
|
uav_patch_width: 400
|
||||||
|
uav_patch_height: 400
|
||||||
|
sat_patch_width: 400
|
||||||
|
sat_patch_height: 400
|
||||||
|
test_from_train_ratio: 0.0
|
||||||
|
transform_mean:
|
||||||
|
- 0.485
|
||||||
|
- 0.456
|
||||||
|
- 0.406
|
||||||
|
transform_std:
|
||||||
|
- 0.229
|
||||||
|
- 0.224
|
||||||
|
- 0.225
|
||||||
|
sat_availaible_years:
|
||||||
|
- "2023"
|
||||||
|
- "2021"
|
||||||
|
- "2019"
|
||||||
|
- "2016"
|
||||||
|
max_rotation_angle: 0
|
||||||
|
uav_image_scale: 2.0
|
||||||
|
use_heatmap: false
|
||||||
|
|
||||||
|
model:
|
||||||
|
name: two_view_pipeline
|
||||||
|
extractor:
|
||||||
|
name: gluefactory_nonfree.superpoint
|
||||||
|
max_num_keypoints: 2048
|
||||||
|
detection_threshold: 0.0
|
||||||
|
nms_radius: 3
|
||||||
|
ground_truth:
|
||||||
|
name: matchers.homography_matcher
|
||||||
|
th_positive: 3
|
||||||
|
th_negative: 3
|
||||||
|
matcher:
|
||||||
|
name: matchers.lightglue
|
||||||
|
features: superpoint
|
||||||
|
depth_confidence: -1
|
||||||
|
width_confidence: -1
|
||||||
|
filter_threshold: 0.1
|
||||||
|
flash: true
|
||||||
|
train:
|
||||||
|
seed: 0
|
||||||
|
epochs: 40
|
||||||
|
log_every_iter: 100
|
||||||
|
eval_every_iter: 1000 # 350000 / 4
|
||||||
|
lr: 1e-3
|
||||||
|
lr_schedule:
|
||||||
|
start: 20
|
||||||
|
type: exp
|
||||||
|
on_epoch: true
|
||||||
|
exp_div_10: 10
|
||||||
|
benchmarks:
|
||||||
|
hpatches:
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
|
@ -0,0 +1,48 @@
|
||||||
|
data:
|
||||||
|
name: satellites
|
||||||
|
data_dir: /mnt/drive/tiles
|
||||||
|
metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt
|
||||||
|
train_size: null
|
||||||
|
val_size: null
|
||||||
|
batch_size: 128
|
||||||
|
num_workers: 24
|
||||||
|
homography:
|
||||||
|
difficulty: 0.9
|
||||||
|
max_angle: 359
|
||||||
|
photometric:
|
||||||
|
name: lg
|
||||||
|
model:
|
||||||
|
name: two_view_pipeline
|
||||||
|
extractor:
|
||||||
|
name: gluefactory_nonfree.superpoint
|
||||||
|
max_num_keypoints: 2048
|
||||||
|
detection_threshold: 0.0
|
||||||
|
force_num_keypoints: True
|
||||||
|
nms_radius: 3
|
||||||
|
trainable: False
|
||||||
|
ground_truth:
|
||||||
|
name: matchers.homography_matcher
|
||||||
|
th_positive: 3
|
||||||
|
th_negative: 3
|
||||||
|
matcher:
|
||||||
|
name: matchers.lightglue
|
||||||
|
filter_threshold: 0.1
|
||||||
|
flash: true
|
||||||
|
checkpointed: true
|
||||||
|
weights: superpoint
|
||||||
|
train:
|
||||||
|
seed: 0
|
||||||
|
epochs: 5
|
||||||
|
log_every_iter: 100
|
||||||
|
eval_every_iter: 500
|
||||||
|
lr: 1e-7
|
||||||
|
lr_schedule:
|
||||||
|
start: 20
|
||||||
|
type: exp
|
||||||
|
on_epoch: true
|
||||||
|
exp_div_10: 10
|
||||||
|
benchmarks:
|
||||||
|
hpatches:
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
|
@ -0,0 +1,60 @@
|
||||||
|
data:
|
||||||
|
name: satellites
|
||||||
|
data_dir: /mnt/drive/tiles
|
||||||
|
metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt
|
||||||
|
train_size: null
|
||||||
|
val_size: null
|
||||||
|
batch_size: 128
|
||||||
|
num_workers: 24
|
||||||
|
homography:
|
||||||
|
difficulty: 0.5
|
||||||
|
max_angle: 359
|
||||||
|
photometric:
|
||||||
|
name: lg
|
||||||
|
model:
|
||||||
|
name: two_view_pipeline
|
||||||
|
extractor:
|
||||||
|
name: gluefactory_nonfree.superpoint
|
||||||
|
max_num_keypoints: 2048
|
||||||
|
detection_threshold: 0.0
|
||||||
|
nms_radius: 3
|
||||||
|
#extractor:
|
||||||
|
# name: gluefactory_nonfree.superpoint
|
||||||
|
# max_num_keypoints: 512
|
||||||
|
# force_num_keypoints: True
|
||||||
|
# detection_threshold: 0.0
|
||||||
|
# nms_radius: 3
|
||||||
|
# trainable: False
|
||||||
|
ground_truth:
|
||||||
|
name: matchers.homography_matcher
|
||||||
|
th_positive: 3
|
||||||
|
th_negative: 3
|
||||||
|
matcher:
|
||||||
|
name: matchers.lightglue
|
||||||
|
features: superpoint
|
||||||
|
depth_confidence: -1
|
||||||
|
width_confidence: -1
|
||||||
|
filter_threshold: 0.1
|
||||||
|
flash: true
|
||||||
|
#matcher:
|
||||||
|
# name: matchers.lightglue_pretrained
|
||||||
|
# features: superpoint
|
||||||
|
# depth_confidence: -1
|
||||||
|
# width_confidence: -1
|
||||||
|
# filter_threshold: 0.1
|
||||||
|
train:
|
||||||
|
seed: 0
|
||||||
|
epochs: 5
|
||||||
|
log_every_iter: 100
|
||||||
|
eval_every_iter: 500
|
||||||
|
lr: 1e-4
|
||||||
|
lr_schedule:
|
||||||
|
start: 20
|
||||||
|
type: exp
|
||||||
|
on_epoch: true
|
||||||
|
exp_div_10: 10
|
||||||
|
benchmarks:
|
||||||
|
hpatches:
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
|
@ -0,0 +1,47 @@
|
||||||
|
data:
|
||||||
|
name: satellites
|
||||||
|
data_dir: /mnt/drive/tiles
|
||||||
|
metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt
|
||||||
|
train_size: null
|
||||||
|
val_size: null
|
||||||
|
batch_size: 128
|
||||||
|
num_workers: 24
|
||||||
|
homography:
|
||||||
|
difficulty: 0.5
|
||||||
|
max_angle: 180
|
||||||
|
photometric:
|
||||||
|
name: lg
|
||||||
|
model:
|
||||||
|
name: two_view_pipeline
|
||||||
|
extractor:
|
||||||
|
name: gluefactory_nonfree.superpoint
|
||||||
|
max_num_keypoints: 2048
|
||||||
|
detection_threshold: 0.0
|
||||||
|
nms_radius: 3
|
||||||
|
ground_truth:
|
||||||
|
name: matchers.homography_matcher
|
||||||
|
th_positive: 3
|
||||||
|
th_negative: 3
|
||||||
|
matcher:
|
||||||
|
name: matchers.lightglue
|
||||||
|
features: superpoint
|
||||||
|
depth_confidence: -1
|
||||||
|
width_confidence: -1
|
||||||
|
filter_threshold: 0.1
|
||||||
|
flash: true
|
||||||
|
train:
|
||||||
|
seed: 0
|
||||||
|
epochs: 40
|
||||||
|
log_every_iter: 100
|
||||||
|
eval_every_iter: 500
|
||||||
|
lr: 1e-4
|
||||||
|
lr_schedule:
|
||||||
|
start: 20
|
||||||
|
type: exp
|
||||||
|
on_epoch: true
|
||||||
|
exp_div_10: 10
|
||||||
|
benchmarks:
|
||||||
|
hpatches:
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
|
@ -33,6 +33,7 @@ train:
|
||||||
epochs: 40
|
epochs: 40
|
||||||
log_every_iter: 100
|
log_every_iter: 100
|
||||||
eval_every_iter: 500
|
eval_every_iter: 500
|
||||||
|
profile: true
|
||||||
lr: 1e-4
|
lr: 1e-4
|
||||||
lr_schedule:
|
lr_schedule:
|
||||||
start: 20
|
start: 20
|
||||||
|
|
|
@ -168,6 +168,11 @@ class BaseDataset(metaclass=ABCMeta):
|
||||||
sampler = None
|
sampler = None
|
||||||
if shuffle is None:
|
if shuffle is None:
|
||||||
shuffle = split == "train" and self.conf.shuffle_training
|
shuffle = split == "train" and self.conf.shuffle_training
|
||||||
|
shuffle = split == "val"
|
||||||
|
|
||||||
|
shuffle = True
|
||||||
|
|
||||||
|
print("Shuffle", shuffle)
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
|
|
@ -0,0 +1,694 @@
|
||||||
|
"""
|
||||||
|
Simply load images from a folder or nested folders (does not have any split),
|
||||||
|
and apply homographic adaptations to it. Yields an image pair without border
|
||||||
|
artifacts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import tarfile
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import gc
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import omegaconf
|
||||||
|
import torch
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from tqdm import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
from typing import List, Literal
|
||||||
|
from torchvision import transforms
|
||||||
|
import random
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
from ..geometry.homography import (
|
||||||
|
compute_homography,
|
||||||
|
sample_homography_corners,
|
||||||
|
warp_points,
|
||||||
|
)
|
||||||
|
from ..models.cache_loader import CacheLoader, pad_local_features
|
||||||
|
from ..settings import DATA_PATH
|
||||||
|
from ..utils.image import read_image
|
||||||
|
from ..utils.tools import fork_rng
|
||||||
|
from ..visualization.viz2d import plot_image_grid
|
||||||
|
from .augmentations import IdentityAugmentation, augmentations
|
||||||
|
from .base_dataset import BaseDataset
|
||||||
|
from .s_utils import get_random_tiff_patch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class Drone2SatDataset(BaseDataset):
|
||||||
|
default_conf = {
|
||||||
|
# image search
|
||||||
|
"data_dir": "revisitop1m", # the top-level directory
|
||||||
|
"image_dir": "jpg/", # the subdirectory with the images
|
||||||
|
"image_list": "revisitop1m.txt", # optional: list or filename of list
|
||||||
|
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
|
||||||
|
"metadata_dir": None,
|
||||||
|
"uav_dataset_dir": None,
|
||||||
|
"satellite_dataset_dir": None,
|
||||||
|
# splits
|
||||||
|
"train_size": 100,
|
||||||
|
"val_size": 10,
|
||||||
|
"shuffle_seed": 0, # or None to skip
|
||||||
|
# image loading
|
||||||
|
"grayscale": False,
|
||||||
|
"triplet": False,
|
||||||
|
"right_only": False, # image0 is orig (rescaled), image1 is right
|
||||||
|
"reseed": False,
|
||||||
|
"homography": {
|
||||||
|
"difficulty": 0.8,
|
||||||
|
"translation": 1.0,
|
||||||
|
"max_angle": 60,
|
||||||
|
"n_angles": 10,
|
||||||
|
"patch_shape": [640, 480],
|
||||||
|
"min_convexity": 0.05,
|
||||||
|
},
|
||||||
|
"photometric": {
|
||||||
|
"name": "dark",
|
||||||
|
"p": 0.75,
|
||||||
|
# 'difficulty': 1.0, # currently unused
|
||||||
|
},
|
||||||
|
# feature loading
|
||||||
|
"load_features": {
|
||||||
|
"do": False,
|
||||||
|
**CacheLoader.default_conf,
|
||||||
|
"collate": False,
|
||||||
|
"thresh": 0.0,
|
||||||
|
"max_num_keypoints": -1,
|
||||||
|
"force_num_keypoints": False,
|
||||||
|
},
|
||||||
|
# Other geolocaliztion parameters
|
||||||
|
"geo_dataset" :{
|
||||||
|
"uav_dataset_dir": None,
|
||||||
|
"satellite_dataset_dir": None,
|
||||||
|
"misslabeled_images_path": None,
|
||||||
|
"sat_zoom_level": 17,
|
||||||
|
"uav_patch_width": 400,
|
||||||
|
"uav_patch_height": 400,
|
||||||
|
"sat_patch_width": 400,
|
||||||
|
"sat_patch_height": 400,
|
||||||
|
"heatmap_kernel_size": 33,
|
||||||
|
"test_from_train_ratio": 0.0,
|
||||||
|
"transform_mean": [0.485, 0.456, 0.406],
|
||||||
|
"transform_std": [0.229, 0.224, 0.225],
|
||||||
|
"sat_availaible_years": ["2023", "2021", "2019", "2016"],
|
||||||
|
"max_rotation_angle": 10,
|
||||||
|
"uav_image_scale": 1,
|
||||||
|
"use_heatmap": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _init(self, conf):
|
||||||
|
self.images = {"train": "train", "val": "val"}
|
||||||
|
|
||||||
|
def get_dataset(self, split):
|
||||||
|
if split == "val":
|
||||||
|
return GeoLocalizationDataset(
|
||||||
|
uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir,
|
||||||
|
satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir,
|
||||||
|
misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path,
|
||||||
|
dataset="test",
|
||||||
|
sat_zoom_level=self.conf.geo_dataset.sat_zoom_level,
|
||||||
|
uav_patch_width=self.conf.geo_dataset.uav_patch_width,
|
||||||
|
uav_patch_height=self.conf.geo_dataset.uav_patch_height,
|
||||||
|
sat_patch_width=self.conf.geo_dataset.sat_patch_width,
|
||||||
|
sat_patch_height=self.conf.geo_dataset.sat_patch_height,
|
||||||
|
heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size,
|
||||||
|
test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio,
|
||||||
|
transform_mean=self.conf.geo_dataset.transform_mean,
|
||||||
|
transform_std=self.conf.geo_dataset.transform_std,
|
||||||
|
sat_available_years=self.conf.geo_dataset.sat_availaible_years,
|
||||||
|
max_rotation_angle=self.conf.geo_dataset.max_rotation_angle,
|
||||||
|
uav_image_scale=self.conf.geo_dataset.uav_image_scale,
|
||||||
|
use_heatmap=self.conf.geo_dataset.use_heatmap,
|
||||||
|
subset_size=self.conf.val_size,
|
||||||
|
)
|
||||||
|
elif split == "train":
|
||||||
|
return GeoLocalizationDataset(
|
||||||
|
uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir,
|
||||||
|
satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir,
|
||||||
|
misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path,
|
||||||
|
dataset="train",
|
||||||
|
sat_zoom_level=self.conf.geo_dataset.sat_zoom_level,
|
||||||
|
uav_patch_width=self.conf.geo_dataset.uav_patch_width,
|
||||||
|
uav_patch_height=self.conf.geo_dataset.uav_patch_height,
|
||||||
|
sat_patch_width=self.conf.geo_dataset.sat_patch_width,
|
||||||
|
sat_patch_height=self.conf.geo_dataset.sat_patch_height,
|
||||||
|
heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size,
|
||||||
|
test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio,
|
||||||
|
transform_mean=self.conf.geo_dataset.transform_mean,
|
||||||
|
transform_std=self.conf.geo_dataset.transform_std,
|
||||||
|
sat_available_years=self.conf.geo_dataset.sat_availaible_years,
|
||||||
|
max_rotation_angle=self.conf.geo_dataset.max_rotation_angle,
|
||||||
|
uav_image_scale=self.conf.geo_dataset.uav_image_scale,
|
||||||
|
use_heatmap=self.conf.geo_dataset.use_heatmap,
|
||||||
|
subset_size=self.conf.train_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GeoLocalizationDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
uav_dataset_dir: str,
|
||||||
|
satellite_dataset_dir: str,
|
||||||
|
misslabeled_images_path: str,
|
||||||
|
dataset: Literal["train", "test"],
|
||||||
|
sat_zoom_level: int = 16,
|
||||||
|
uav_patch_width: int = 128,
|
||||||
|
uav_patch_height: int = 128,
|
||||||
|
sat_patch_width: int = 400,
|
||||||
|
sat_patch_height: int = 400,
|
||||||
|
heatmap_kernel_size: int = 33,
|
||||||
|
test_from_train_ratio: float = 0.0,
|
||||||
|
transform_mean: List[float] = [0.485, 0.456, 0.406],
|
||||||
|
transform_std: List[float] = [0.229, 0.224, 0.225],
|
||||||
|
sat_available_years: List[str] = ["2023", "2021", "2019", "2016"],
|
||||||
|
max_rotation_angle: int = 10,
|
||||||
|
uav_image_scale: float = 1,
|
||||||
|
use_heatmap: bool = True,
|
||||||
|
subset_size: int = None,
|
||||||
|
):
|
||||||
|
self.uav_dataset_dir = uav_dataset_dir
|
||||||
|
self.satellite_dataset_dir = satellite_dataset_dir
|
||||||
|
self.dataset = dataset
|
||||||
|
self.sat_zoom_level = sat_zoom_level
|
||||||
|
self.uav_patch_width = uav_patch_width
|
||||||
|
self.uav_patch_height = uav_patch_height
|
||||||
|
self.heatmap_kernel_size = heatmap_kernel_size
|
||||||
|
self.test_from_train_ratio = test_from_train_ratio
|
||||||
|
self.transform_mean = transform_mean
|
||||||
|
self.transform_std = transform_std
|
||||||
|
self.misslabeled_images_path = misslabeled_images_path
|
||||||
|
self.metadata_dict = {}
|
||||||
|
self.max_rotation_angle = max_rotation_angle
|
||||||
|
self.total_uav_samples = self.count_total_uav_samples()
|
||||||
|
self.misslabelled_images = self.read_misslabelled_images(self.misslabeled_images_path)
|
||||||
|
self.entry_paths = self.get_entry_paths(self.uav_dataset_dir)
|
||||||
|
self.cleanup_misslabelled_images()
|
||||||
|
self.transforms = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
#transforms.Normalize(self.transform_mean, self.transform_std),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.sat_available_years = sat_available_years
|
||||||
|
self.uav_image_scale = uav_image_scale
|
||||||
|
self.use_heatmap = use_heatmap
|
||||||
|
self.sat_patch_width = sat_patch_width
|
||||||
|
self.sat_patch_height = sat_patch_height
|
||||||
|
self.subset_size = subset_size
|
||||||
|
|
||||||
|
self.inverse_transforms = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[
|
||||||
|
-m / s for m, s in zip(self.transform_mean, self.transform_std)
|
||||||
|
],
|
||||||
|
std=[1 / s for s in self.transform_std],
|
||||||
|
),
|
||||||
|
transforms.ToPILImage(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.subset_size != -1:
|
||||||
|
ssize = int(self.subset_size / len(self.sat_available_years))
|
||||||
|
ssize = min(ssize, len(self.entry_paths))
|
||||||
|
self.entry_paths = self.entry_paths[:ssize]
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return (
|
||||||
|
len(self.entry_paths)
|
||||||
|
* len(self.sat_available_years)
|
||||||
|
)
|
||||||
|
|
||||||
|
def read_misslabelled_images(
|
||||||
|
self, path: str = "misslabels/misslabeled.txt"
|
||||||
|
) -> List[str]:
|
||||||
|
with open(path, "r") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
return [line.strip() for line in lines]
|
||||||
|
|
||||||
|
def cleanup_misslabelled_images(self) -> None:
|
||||||
|
indices_to_delete = []
|
||||||
|
|
||||||
|
for image in self.misslabelled_images:
|
||||||
|
for image_path in self.entry_paths:
|
||||||
|
if image in image_path:
|
||||||
|
index = self.entry_paths.index(image_path)
|
||||||
|
indices_to_delete.append(index)
|
||||||
|
break
|
||||||
|
|
||||||
|
sorted_tuples = sorted(indices_to_delete, reverse=True)
|
||||||
|
|
||||||
|
for index in sorted_tuples:
|
||||||
|
self.entry_paths.pop(index)
|
||||||
|
|
||||||
|
def __getitem__(self, idx) -> dict:
|
||||||
|
"""
|
||||||
|
Retrieves a sample given its index, returning the preprocessed UAV
|
||||||
|
and satellite images, along with their associated heatmap and metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_path_index = idx // (
|
||||||
|
len(self.sat_available_years)
|
||||||
|
)
|
||||||
|
|
||||||
|
sat_year = self.sat_available_years[idx % len(self.sat_available_years)]
|
||||||
|
rot_angle = random.randint(-self.max_rotation_angle, self.max_rotation_angle)
|
||||||
|
|
||||||
|
image_path = self.entry_paths[image_path_index]
|
||||||
|
uav_image = Image.open(image_path).convert("RGB") # Ensure 3-channel image
|
||||||
|
|
||||||
|
original_uav_image_width = uav_image.width
|
||||||
|
original_uav_image_height = uav_image.height
|
||||||
|
|
||||||
|
lookup_str, file_number = self.extract_info_from_filename(image_path)
|
||||||
|
img_info = self.metadata_dict[lookup_str][file_number]
|
||||||
|
|
||||||
|
lat, lon = (
|
||||||
|
img_info["coordinate"]["latitude"],
|
||||||
|
img_info["coordinate"]["longitude"],
|
||||||
|
)
|
||||||
|
|
||||||
|
fov_vertical = img_info["fovVertical"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
agl_altitude = float(image_path.split("/")[-1].split("_")[2].split("m")[0])
|
||||||
|
except IndexError:
|
||||||
|
agl_altitude = 150.0
|
||||||
|
warnings.warn(
|
||||||
|
"Could not extract AGL altitude from filename, using default value of 150m."
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
satellite_patch,
|
||||||
|
x_sat,
|
||||||
|
y_sat,
|
||||||
|
x_offset,
|
||||||
|
y_offset,
|
||||||
|
patch_transform,
|
||||||
|
) = get_random_tiff_patch(
|
||||||
|
lat, lon, self.sat_patch_width, self.sat_patch_height, sat_year, self.satellite_dataset_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rotate crop center and transform image
|
||||||
|
h = np.ceil(uav_image.height // self.uav_image_scale).astype(int)
|
||||||
|
w = np.ceil(uav_image.width // self.uav_image_scale).astype(int)
|
||||||
|
|
||||||
|
uav_image = F.rotate(uav_image, rot_angle)
|
||||||
|
uav_image = F.resize(uav_image, [h, w])
|
||||||
|
uav_image = F.center_crop(
|
||||||
|
uav_image, (self.uav_patch_height, self.uav_patch_width)
|
||||||
|
)
|
||||||
|
uav_image = self.transforms(uav_image)
|
||||||
|
|
||||||
|
satellite_patch = satellite_patch.transpose(1, 2, 0)
|
||||||
|
satellite_patch_pytorch = self.transforms(satellite_patch)
|
||||||
|
del satellite_patch
|
||||||
|
|
||||||
|
if self.use_heatmap:
|
||||||
|
heatmap = self.get_heatmap_gt(
|
||||||
|
x_sat,
|
||||||
|
y_sat,
|
||||||
|
satellite_patch.shape[1],
|
||||||
|
satellite_patch.shape[2],
|
||||||
|
self.heatmap_kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
cropped_uav_image_width = self.calculate_cropped_uav_image_width(
|
||||||
|
fov_vertical,
|
||||||
|
original_uav_image_width,
|
||||||
|
original_uav_image_height,
|
||||||
|
self.uav_patch_width,
|
||||||
|
self.uav_patch_height,
|
||||||
|
agl_altitude,
|
||||||
|
)
|
||||||
|
|
||||||
|
satellite_tile_width = self.calculate_cropped_sat_image_width(
|
||||||
|
lat, self.sat_patch_width, patch_transform
|
||||||
|
)
|
||||||
|
|
||||||
|
scale_factor = cropped_uav_image_width / satellite_tile_width
|
||||||
|
scale_factor *= 10
|
||||||
|
|
||||||
|
homography_matrix = self.compute_homography(
|
||||||
|
rot_angle,
|
||||||
|
x_sat,
|
||||||
|
y_sat,
|
||||||
|
self.uav_patch_width,
|
||||||
|
self.uav_patch_height,
|
||||||
|
scale_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.use_heatmap:
|
||||||
|
# Sample four points from the UAV image
|
||||||
|
points = self.sample_four_points(
|
||||||
|
self.uav_patch_width, self.uav_patch_height
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transform the points
|
||||||
|
warped_points = self.warp_points(points, homography_matrix)
|
||||||
|
#img_info["warped_points_sat"] = warped_points
|
||||||
|
#img_info["warped_points_uav"] = points
|
||||||
|
|
||||||
|
# img_info["cropped_uav_image_width"] = cropped_uav_image_width
|
||||||
|
# img_info["satellite_tile_width"] = satellite_tile_width
|
||||||
|
# img_info["scale_factor"] = scale_factor
|
||||||
|
# img_info["filename"] = image_path
|
||||||
|
# img_info["rot_angle"] = rot_angle
|
||||||
|
# img_info["x_sat"] = x_sat
|
||||||
|
# img_info["y_sat"] = y_sat
|
||||||
|
# img_info["x_offset"] = x_offset
|
||||||
|
# img_info["y_offset"] = y_offset
|
||||||
|
# img_info["patch_transform"] = patch_transform
|
||||||
|
# img_info["uav_image_scale"] = self.uav_image_scale
|
||||||
|
# img_info["homography_matrix_uav_to_sat"] = homography_matrix
|
||||||
|
# img_info["homography_matrix_sat_to_uav"] = np.linalg.inv(homography_matrix)
|
||||||
|
# img_info["agl_altitude"] = agl_altitude
|
||||||
|
# img_info["original_uav_image_width"] = original_uav_image_width
|
||||||
|
# img_info["original_drone_image_height"] = original_uav_image_height
|
||||||
|
# img_info["fov_vertical"] = fov_vertical
|
||||||
|
#
|
||||||
|
|
||||||
|
inverse_homography_matrix = np.linalg.inv(homography_matrix)
|
||||||
|
uav_image_data = {
|
||||||
|
"image": uav_image,
|
||||||
|
"H_": homography_matrix,
|
||||||
|
"coords": points,
|
||||||
|
"image_size": np.array([self.uav_patch_width, self.uav_patch_height]),
|
||||||
|
}
|
||||||
|
|
||||||
|
satellite_patch_data = {
|
||||||
|
"image": satellite_patch_pytorch,
|
||||||
|
"H_": inverse_homography_matrix,
|
||||||
|
"coords": warped_points,
|
||||||
|
"image_size": np.array([self.sat_patch_width, self.sat_patch_height]),
|
||||||
|
}
|
||||||
|
|
||||||
|
#hm = self.compute_homography_points(
|
||||||
|
# warped_points, points, [self.uav_patch_width, self.uav_patch_height]
|
||||||
|
#)
|
||||||
|
|
||||||
|
hm = self.compute_homography_points(
|
||||||
|
points, warped_points, [1.0 , 1.0]
|
||||||
|
)
|
||||||
|
|
||||||
|
if np.array_equal(hm, np.eye(3)):
|
||||||
|
print("Singular matrix")
|
||||||
|
return self.__getitem__(random.randint(0, len(self) - 1))
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"name": f"{image_path}_{rot_angle}_{sat_year}",
|
||||||
|
"original_image_size": np.array(
|
||||||
|
[original_uav_image_width, original_uav_image_height]
|
||||||
|
),
|
||||||
|
"H_0to1": hm.astype(np.float32),
|
||||||
|
"view0": uav_image_data,
|
||||||
|
"view1": satellite_patch_data,
|
||||||
|
"idx": idx
|
||||||
|
}
|
||||||
|
del img_info
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_cropped_sat_image_width(self, latitude, patch_width, patch_transform):
|
||||||
|
"""
|
||||||
|
Computes the width of the satellite image in world coordinate units
|
||||||
|
"""
|
||||||
|
length_of_degree = 111320 * np.cos(np.radians(latitude))
|
||||||
|
scale_x_meters = patch_transform[0] * length_of_degree
|
||||||
|
satellite_tile_width = patch_width * scale_x_meters
|
||||||
|
return satellite_tile_width
|
||||||
|
|
||||||
|
def calculate_cropped_uav_image_width(
|
||||||
|
self,
|
||||||
|
fov_vertical,
|
||||||
|
orig_width,
|
||||||
|
orig_height,
|
||||||
|
crop_width,
|
||||||
|
crop_height,
|
||||||
|
altitude=150.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Computes the width of the UAV image in world coordinate units
|
||||||
|
"""
|
||||||
|
# Convert fov from degrees to radians
|
||||||
|
fov_rad = np.radians(fov_vertical)
|
||||||
|
|
||||||
|
# Calculate the full width of the UAV image
|
||||||
|
full_width = 2 * (altitude * np.tan(fov_rad / 2))
|
||||||
|
|
||||||
|
# Determine the cropping ratio
|
||||||
|
crop_ratio_width = crop_width / orig_width
|
||||||
|
crop_ratio_height = crop_height / orig_height
|
||||||
|
|
||||||
|
# Calculate the adjusted horizontal fov
|
||||||
|
fov_horizontal = 2 * np.arctan(np.tan(fov_rad / 2) * (orig_width / orig_height))
|
||||||
|
adjusted_fov_horizontal = 2 * np.arctan(
|
||||||
|
np.tan(fov_horizontal / 2) * crop_ratio_width
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate the new full width using the adjusted horizontal fov
|
||||||
|
full_width = 2 * (altitude * np.tan(adjusted_fov_horizontal / 2))
|
||||||
|
|
||||||
|
# Adjust the width according to the crop ratio
|
||||||
|
cropped_width = full_width * crop_ratio_width
|
||||||
|
|
||||||
|
return cropped_width
|
||||||
|
|
||||||
|
def compute_homography_points(self, pts1_, pts2_, shape):
|
||||||
|
"""Compute the homography matrix from 4 point correspondences"""
|
||||||
|
# Rescale to actual size
|
||||||
|
shape = np.array(shape[::-1], dtype=np.float32) # different convention [y, x]
|
||||||
|
pts1 = pts1_ * np.expand_dims(shape, axis=0)
|
||||||
|
pts2 = pts2_ * np.expand_dims(shape, axis=0)
|
||||||
|
|
||||||
|
def ax(p, q):
|
||||||
|
return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
|
||||||
|
|
||||||
|
def ay(p, q):
|
||||||
|
return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
|
||||||
|
|
||||||
|
def flat2mat(H):
|
||||||
|
return np.reshape(np.concatenate([H, np.ones_like(H[:, :1])], axis=1), [3, 3])
|
||||||
|
|
||||||
|
a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0)
|
||||||
|
p_mat = np.transpose(
|
||||||
|
np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
homography = np.transpose(np.linalg.solve(a_mat, p_mat))
|
||||||
|
except np.linalg.LinAlgError:
|
||||||
|
print("Singular matrix")
|
||||||
|
return np.eye(3)
|
||||||
|
return flat2mat(homography)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_homography(
|
||||||
|
self, rot_angle, x_sat, y_sat, uav_width, uav_height, scale_factor
|
||||||
|
):
|
||||||
|
# Adjust rot_angle if it's greater than 180 degrees
|
||||||
|
if rot_angle > 180:
|
||||||
|
rot_angle -= 360
|
||||||
|
# Convert rotation angle to radians
|
||||||
|
theta = np.radians(rot_angle)
|
||||||
|
|
||||||
|
# Rotation matrix
|
||||||
|
R = np.array(
|
||||||
|
[
|
||||||
|
[np.cos(theta), -np.sin(theta), 0],
|
||||||
|
[np.sin(theta), np.cos(theta), 0],
|
||||||
|
[0, 0, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scale matrix
|
||||||
|
S = np.array([[scale_factor, 0, 0], [0, scale_factor, 0], [0, 0, 1]])
|
||||||
|
|
||||||
|
# Translation matrix to center the UAV image
|
||||||
|
T_uav = np.array([[1, 0, -uav_width / 2], [0, 1, -uav_height / 2], [0, 0, 1]])
|
||||||
|
|
||||||
|
# Translation matrix to move to the satellite image position
|
||||||
|
T_sat = np.array([[1, 0, x_sat], [0, 1, y_sat], [0, 0, 1]])
|
||||||
|
|
||||||
|
# Compute the combined homography matrix
|
||||||
|
H = np.dot(T_sat, np.dot(R, np.dot(S, T_uav)))
|
||||||
|
|
||||||
|
return H
|
||||||
|
|
||||||
|
def sample_four_points(self, width: int, height: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Samples four points from the UAV image.
|
||||||
|
"""
|
||||||
|
PADDING = 50
|
||||||
|
CENTER_PADDING = 10
|
||||||
|
points = np.array(
|
||||||
|
[
|
||||||
|
[random.randint(CENTER_PADDING, width - PADDING), random.randint(CENTER_PADDING, height - PADDING)]
|
||||||
|
for _ in range(4)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return points
|
||||||
|
|
||||||
|
def warp_points(
|
||||||
|
self, points: np.ndarray, homography_matrix: np.ndarray
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Warps the given points using the given homography matrix.
|
||||||
|
"""
|
||||||
|
points = np.array(points)
|
||||||
|
points = np.concatenate([points, np.ones((4, 1))], axis=1)
|
||||||
|
points = np.dot(homography_matrix, points.T).T
|
||||||
|
points = points[:, :2] / points[:, 2:]
|
||||||
|
return points
|
||||||
|
|
||||||
|
def count_total_uav_samples(self) -> int:
|
||||||
|
"""
|
||||||
|
Count the total number of uav image samples in the dataset
|
||||||
|
(train + test)
|
||||||
|
"""
|
||||||
|
total_samples = 0
|
||||||
|
|
||||||
|
for dirpath, dirnames, filenames in os.walk(self.uav_dataset_dir):
|
||||||
|
# Skip the test folder
|
||||||
|
for filename in filenames:
|
||||||
|
if filename.endswith(".jpeg"):
|
||||||
|
total_samples += 1
|
||||||
|
return total_samples
|
||||||
|
|
||||||
|
def get_number_of_city_samples(self) -> int:
|
||||||
|
"""
|
||||||
|
TODO: Count the total number of city samples in the dataset
|
||||||
|
"""
|
||||||
|
return 11
|
||||||
|
|
||||||
|
def get_entry_paths(self, directory: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Recursively retrieves paths to image and metadata files in the given directory.
|
||||||
|
"""
|
||||||
|
entry_paths = []
|
||||||
|
entries = os.listdir(directory)
|
||||||
|
|
||||||
|
images_to_take_per_folder = int(
|
||||||
|
self.total_uav_samples
|
||||||
|
* self.test_from_train_ratio
|
||||||
|
/ self.get_number_of_city_samples()
|
||||||
|
)
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
entry_path = os.path.join(directory, entry)
|
||||||
|
|
||||||
|
# If it's a directory, recurse into it
|
||||||
|
if os.path.isdir(entry_path):
|
||||||
|
entry_paths += self.get_entry_paths(entry_path)
|
||||||
|
|
||||||
|
# Handle train dataset
|
||||||
|
elif (self.dataset == "train" and "Train" in entry_path) or (
|
||||||
|
self.dataset == "train"
|
||||||
|
and self.test_from_train_ratio > 0
|
||||||
|
and "Test" in entry_path
|
||||||
|
):
|
||||||
|
if entry_path.endswith(".jpeg"):
|
||||||
|
_, number = self.extract_info_from_filename(entry_path)
|
||||||
|
else:
|
||||||
|
number = None
|
||||||
|
if entry_path.endswith(".json"):
|
||||||
|
self.get_metadata(entry_path)
|
||||||
|
if number is None:
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
number >= images_to_take_per_folder
|
||||||
|
): # Only include images beyond the ones taken for test
|
||||||
|
if entry_path.endswith(".jpeg"):
|
||||||
|
entry_paths.append(entry_path)
|
||||||
|
|
||||||
|
# Handle test dataset
|
||||||
|
elif self.dataset == "test":
|
||||||
|
if entry_path.endswith(".jpeg"):
|
||||||
|
_, number = self.extract_info_from_filename(entry_path)
|
||||||
|
else:
|
||||||
|
number = None
|
||||||
|
if entry_path.endswith(".json"):
|
||||||
|
self.get_metadata(entry_path)
|
||||||
|
|
||||||
|
if number is None:
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
("Test" in entry_path and number < images_to_take_per_folder)
|
||||||
|
or (number < images_to_take_per_folder and "Train" in entry_path)
|
||||||
|
or (self.test_from_train_ratio == 0.0 and "Test" in entry_path)
|
||||||
|
):
|
||||||
|
if entry_path.endswith(".jpeg"):
|
||||||
|
entry_paths.append(entry_path)
|
||||||
|
|
||||||
|
return sorted(entry_paths, key=self.extract_info_from_filename)
|
||||||
|
|
||||||
|
def get_metadata(self, path: str) -> None:
|
||||||
|
"""
|
||||||
|
Extracts metadata from a JSON file and stores it in the metadata dictionary.
|
||||||
|
"""
|
||||||
|
with open(path, newline="") as jsonfile:
|
||||||
|
json_dict = json.load(jsonfile)
|
||||||
|
path = path.split("/")[-1]
|
||||||
|
path = path.replace(".json", "")
|
||||||
|
self.metadata_dict[path] = json_dict["cameraFrames"]
|
||||||
|
|
||||||
|
def extract_info_from_filename(self, filename: str) -> (str, int):
|
||||||
|
"""
|
||||||
|
Extracts information from the filename.
|
||||||
|
"""
|
||||||
|
filename_without_ext = filename.replace(".jpeg", "")
|
||||||
|
segments = filename_without_ext.split("/")
|
||||||
|
info = segments[-1]
|
||||||
|
try:
|
||||||
|
number = int(info.split("_")[-1])
|
||||||
|
except ValueError:
|
||||||
|
print("Could not extract number from filename: ", filename)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
info = "_".join(info.split("_")[:-1])
|
||||||
|
|
||||||
|
return info, number
|
||||||
|
|
||||||
|
def get_heatmap_gt(
|
||||||
|
self, x: int, y: int, height: int, width: int, square_size: int = 33
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Returns 2D heatmap ground truth for the given x and y coordinates,
|
||||||
|
with the given square size.
|
||||||
|
"""
|
||||||
|
x_map, y_map = x, y
|
||||||
|
|
||||||
|
heatmap = torch.zeros((height, width))
|
||||||
|
|
||||||
|
half_size = square_size // 2
|
||||||
|
|
||||||
|
# Calculate the valid range for the square
|
||||||
|
start_x = max(0, x_map - half_size)
|
||||||
|
end_x = min(
|
||||||
|
width, x_map + half_size + 1
|
||||||
|
) # +1 to include the end_x in the square
|
||||||
|
start_y = max(0, y_map - half_size)
|
||||||
|
end_y = min(
|
||||||
|
height, y_map + half_size + 1
|
||||||
|
) # +1 to include the end_y in the square
|
||||||
|
|
||||||
|
heatmap[start_y:end_y, start_x:end_x] = 1
|
||||||
|
|
||||||
|
return heatmap
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from .. import logger # overwrite the logger
|
|
@ -0,0 +1,163 @@
|
||||||
|
#! /usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
import mercantile
|
||||||
|
import rasterio
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
from rasterio.errors import NotGeoreferencedWarning
|
||||||
|
from rasterio.io import MemoryFile
|
||||||
|
from rasterio.transform import from_bounds
|
||||||
|
from rasterio.merge import merge
|
||||||
|
from PIL import Image
|
||||||
|
import gc
|
||||||
|
import random
|
||||||
|
from osgeo import gdal, osr
|
||||||
|
from affine import Affine
|
||||||
|
|
||||||
|
|
||||||
|
def get_5x5_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]:
|
||||||
|
neighbors = []
|
||||||
|
for main_neighbour in mercantile.neighbors(tile):
|
||||||
|
for sub_neighbour in mercantile.neighbors(main_neighbour):
|
||||||
|
if sub_neighbour not in neighbors:
|
||||||
|
neighbors.append(sub_neighbour)
|
||||||
|
return neighbors
|
||||||
|
|
||||||
|
def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir: str) -> (np.ndarray, dict):
|
||||||
|
"""
|
||||||
|
Returns a TIFF map of the given tile using GDAL.
|
||||||
|
"""
|
||||||
|
tile_data = []
|
||||||
|
neighbors = get_5x5_neighbors(tile)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
for neighbor in neighbors:
|
||||||
|
west, south, east, north = mercantile.bounds(neighbor)
|
||||||
|
tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg"
|
||||||
|
if not os.path.exists(tile_path):
|
||||||
|
raise FileNotFoundError(f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found.")
|
||||||
|
|
||||||
|
img = Image.open(tile_path)
|
||||||
|
img_array = np.array(img)
|
||||||
|
|
||||||
|
# Create an in-memory GDAL dataset
|
||||||
|
mem_driver = gdal.GetDriverByName('MEM')
|
||||||
|
dataset = mem_driver.Create('', img_array.shape[1], img_array.shape[0], 3, gdal.GDT_Byte)
|
||||||
|
for i in range(3):
|
||||||
|
dataset.GetRasterBand(i + 1).WriteArray(img_array[:, :, i])
|
||||||
|
|
||||||
|
# Set GeoTransform and Projection
|
||||||
|
geotransform = (west, (east - west) / img_array.shape[1], 0, north, 0, -(north - south) / img_array.shape[0])
|
||||||
|
dataset.SetGeoTransform(geotransform)
|
||||||
|
srs = osr.SpatialReference()
|
||||||
|
srs.ImportFromEPSG(3857)
|
||||||
|
dataset.SetProjection(srs.ExportToWkt())
|
||||||
|
|
||||||
|
tile_data.append(dataset)
|
||||||
|
|
||||||
|
# Merge tiles using GDAL
|
||||||
|
vrt_options = gdal.BuildVRTOptions()
|
||||||
|
vrt = gdal.BuildVRT('', [td for td in tile_data], options=vrt_options)
|
||||||
|
mosaic = vrt.ReadAsArray()
|
||||||
|
|
||||||
|
# Get metadata
|
||||||
|
out_trans = vrt.GetGeoTransform()
|
||||||
|
out_crs = vrt.GetProjection()
|
||||||
|
out_trans = Affine.from_gdal(*out_trans)
|
||||||
|
out_meta = {
|
||||||
|
"driver": "GTiff",
|
||||||
|
"height": mosaic.shape[1],
|
||||||
|
"width": mosaic.shape[2],
|
||||||
|
"transform": out_trans,
|
||||||
|
"crs": out_crs,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
for td in tile_data:
|
||||||
|
td.FlushCache()
|
||||||
|
vrt = None
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return mosaic, out_meta
|
||||||
|
|
||||||
|
def get_random_tiff_patch(
|
||||||
|
lat: float,
|
||||||
|
lon: float,
|
||||||
|
patch_width: int,
|
||||||
|
patch_height: int,
|
||||||
|
sat_year: str,
|
||||||
|
satellite_dataset_dir: str = "/mnt/drive/satellite_dataset",
|
||||||
|
) -> (np.ndarray, int, int, int, int, rasterio.transform.Affine):
|
||||||
|
"""
|
||||||
|
Returns a random patch from the satellite image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tile = get_tile_from_coord(lat, lon, 17)
|
||||||
|
|
||||||
|
mosaic, out_meta = get_tiff_map(tile, sat_year, satellite_dataset_dir)
|
||||||
|
|
||||||
|
|
||||||
|
transform = out_meta["transform"]
|
||||||
|
del out_meta
|
||||||
|
|
||||||
|
x_pixel, y_pixel = geo_to_pixel_coordinates(lat, lon, transform)
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# Temporal constant, replace with a better solution
|
||||||
|
KS = 120
|
||||||
|
|
||||||
|
x_offset_range = [
|
||||||
|
x_pixel - patch_width + KS + 1,
|
||||||
|
x_pixel - KS - 1,
|
||||||
|
]
|
||||||
|
y_offset_range = [
|
||||||
|
y_pixel - patch_height + KS + 1,
|
||||||
|
y_pixel - KS - 1,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Randomly select an offset within the valid range
|
||||||
|
x_offset = random.randint(*x_offset_range)
|
||||||
|
y_offset = random.randint(*y_offset_range)
|
||||||
|
|
||||||
|
x_offset = np.clip(x_offset, 0, mosaic.shape[-1] - patch_width)
|
||||||
|
y_offset = np.clip(y_offset, 0, mosaic.shape[-2] - patch_height)
|
||||||
|
|
||||||
|
# Update x, y to reflect the clamping of x_offset and y_offset
|
||||||
|
x, y = x_pixel - x_offset, y_pixel - y_offset
|
||||||
|
patch = mosaic[
|
||||||
|
:, y_offset : y_offset + patch_height, x_offset : x_offset + patch_width
|
||||||
|
]
|
||||||
|
|
||||||
|
patch_transform = rasterio.transform.Affine(
|
||||||
|
transform.a,
|
||||||
|
transform.b,
|
||||||
|
transform.c + x_offset * transform.a + y_offset * transform.b,
|
||||||
|
transform.d,
|
||||||
|
transform.e,
|
||||||
|
transform.f + x_offset * transform.d + y_offset * transform.e,
|
||||||
|
)
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return patch, x, y, x_offset, y_offset, patch_transform
|
||||||
|
|
||||||
|
def get_tile_from_coord(
|
||||||
|
lat: float, lng: float, zoom_level: int
|
||||||
|
) -> mercantile.Tile:
|
||||||
|
"""
|
||||||
|
Returns the tile containing the given coordinates.
|
||||||
|
"""
|
||||||
|
tile = mercantile.tile(lng, lat, zoom_level)
|
||||||
|
return tile
|
||||||
|
|
||||||
|
def geo_to_pixel_coordinates(
|
||||||
|
lat: float, lon: float, transform: rasterio.transform.Affine
|
||||||
|
) -> (int, int):
|
||||||
|
"""
|
||||||
|
Converts a pair of (lat, lon) coordinates to pixel coordinates.
|
||||||
|
"""
|
||||||
|
x_pixel, y_pixel = ~transform * (lon, lat)
|
||||||
|
return round(x_pixel), round(y_pixel)
|
|
@ -0,0 +1,109 @@
|
||||||
|
#! /usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
import mercantile
|
||||||
|
import rasterio
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
from rasterio.errors import NotGeoreferencedWarning
|
||||||
|
from rasterio.io import MemoryFile
|
||||||
|
from rasterio.transform import from_bounds
|
||||||
|
from rasterio.merge import merge
|
||||||
|
from PIL import Image
|
||||||
|
import gc
|
||||||
|
|
||||||
|
|
||||||
|
def get_3x3_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]:
|
||||||
|
neighbors = []
|
||||||
|
for neighbour in mercantile.neighbors(tile):
|
||||||
|
if neighbour not in neighbors:
|
||||||
|
neighbors.append(neighbour)
|
||||||
|
|
||||||
|
neighbors.append(tile)
|
||||||
|
return neighbors
|
||||||
|
|
||||||
|
def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir:str) -> (np.ndarray, dict):
|
||||||
|
"""
|
||||||
|
Returns a TIFF map of the given tile.
|
||||||
|
"""
|
||||||
|
tile_data = []
|
||||||
|
neighbors = get_3x3_neighbors(tile)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)
|
||||||
|
for neighbor in neighbors:
|
||||||
|
west, south, east, north = mercantile.bounds(neighbor)
|
||||||
|
tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg"
|
||||||
|
if not os.path.exists(tile_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found."
|
||||||
|
)
|
||||||
|
|
||||||
|
with Image.open(tile_path) as img:
|
||||||
|
width, height = img.size
|
||||||
|
memfile = MemoryFile()
|
||||||
|
with memfile.open(
|
||||||
|
driver="GTiff",
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
count=3,
|
||||||
|
dtype="uint8",
|
||||||
|
crs="EPSG:3857",
|
||||||
|
transform=from_bounds(west, south, east, north, width, height),
|
||||||
|
) as dataset:
|
||||||
|
data = rasterio.open(tile_path).read()
|
||||||
|
dataset.write(data)
|
||||||
|
tile_data.append(memfile.open())
|
||||||
|
memfile.close()
|
||||||
|
|
||||||
|
mosaic, out_trans = merge(tile_data)
|
||||||
|
|
||||||
|
out_meta = tile_data[0].meta.copy()
|
||||||
|
out_meta.update(
|
||||||
|
{
|
||||||
|
"driver": "GTiff",
|
||||||
|
"height": mosaic.shape[1],
|
||||||
|
"width": mosaic.shape[2],
|
||||||
|
"transform": out_trans,
|
||||||
|
"crs": "EPSG:3857",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up MemoryFile instances to free up memory
|
||||||
|
for td in tile_data:
|
||||||
|
td.close()
|
||||||
|
|
||||||
|
del neighbors
|
||||||
|
del tile_data
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return mosaic, out_meta
|
||||||
|
|
||||||
|
def get_random_tiff_patch(
|
||||||
|
lat: float,
|
||||||
|
lon: float,
|
||||||
|
satellite_dataset_dir: str,
|
||||||
|
) -> (np.ndarray):
|
||||||
|
"""
|
||||||
|
Returns a random patch from the satellite image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tile = get_tile_from_coord(lat, lon, 17)
|
||||||
|
sat_years = ["2023", "2021", "2019", "2016"]
|
||||||
|
|
||||||
|
# Randomly select a satellite year
|
||||||
|
sat_year = random.choice(sat_years)
|
||||||
|
|
||||||
|
mosaic, _ = get_tiff_map(tile, sat_year, satellite_dataset_dir)
|
||||||
|
|
||||||
|
return mosaic
|
||||||
|
|
||||||
|
def get_tile_from_coord(
|
||||||
|
lat: float, lng: float, zoom_level: int
|
||||||
|
) -> mercantile.Tile:
|
||||||
|
"""
|
||||||
|
Returns the tile containing the given coordinates.
|
||||||
|
"""
|
||||||
|
tile = mercantile.tile(lng, lat, zoom_level)
|
||||||
|
return tile
|
|
@ -0,0 +1,297 @@
|
||||||
|
"""
|
||||||
|
Simply load images from a folder or nested folders (does not have any split),
|
||||||
|
and apply homographic adaptations to it. Yields an image pair without border
|
||||||
|
artifacts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import tarfile
|
||||||
|
from pathlib import Path
|
||||||
|
import mercantile
|
||||||
|
import rasterio
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import omegaconf
|
||||||
|
import torch
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from tqdm import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..geometry.homography import (
|
||||||
|
compute_homography,
|
||||||
|
sample_homography_corners,
|
||||||
|
warp_points,
|
||||||
|
)
|
||||||
|
from ..models.cache_loader import CacheLoader, pad_local_features
|
||||||
|
from ..settings import DATA_PATH
|
||||||
|
from ..utils.image import read_image
|
||||||
|
from ..utils.tools import fork_rng
|
||||||
|
from ..visualization.viz2d import plot_image_grid
|
||||||
|
from .augmentations import IdentityAugmentation, augmentations
|
||||||
|
from .base_dataset import BaseDataset
|
||||||
|
from .sat_utils import get_random_tiff_patch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_homography(img, conf: dict, size: list):
|
||||||
|
data = {}
|
||||||
|
H, _, coords, _ = sample_homography_corners(img.shape[:2][::-1], **conf)
|
||||||
|
data["image"] = cv2.warpPerspective(img, H, tuple(size))
|
||||||
|
data["H_"] = H.astype(np.float32)
|
||||||
|
data["coords"] = coords.astype(np.float32)
|
||||||
|
data["image_size"] = np.array(size, dtype=np.float32)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class SatelliteDataset(BaseDataset):
|
||||||
|
default_conf = {
|
||||||
|
# image search
|
||||||
|
"data_dir": "revisitop1m", # the top-level directory
|
||||||
|
"image_dir": "jpg/", # the subdirectory with the images
|
||||||
|
"image_list": "revisitop1m.txt", # optional: list or filename of list
|
||||||
|
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
|
||||||
|
"metadata_dir": None,
|
||||||
|
# splits
|
||||||
|
"train_size": 100,
|
||||||
|
"val_size": 10,
|
||||||
|
"shuffle_seed": 0, # or None to skip
|
||||||
|
# image loading
|
||||||
|
"grayscale": False,
|
||||||
|
"triplet": False,
|
||||||
|
"right_only": False, # image0 is orig (rescaled), image1 is right
|
||||||
|
"reseed": False,
|
||||||
|
"homography": {
|
||||||
|
"difficulty": 0.8,
|
||||||
|
"translation": 1.0,
|
||||||
|
"max_angle": 60,
|
||||||
|
"n_angles": 10,
|
||||||
|
"patch_shape": [640, 480],
|
||||||
|
"min_convexity": 0.05,
|
||||||
|
},
|
||||||
|
"photometric": {
|
||||||
|
"name": "dark",
|
||||||
|
"p": 0.75,
|
||||||
|
# 'difficulty': 1.0, # currently unused
|
||||||
|
},
|
||||||
|
# feature loading
|
||||||
|
"load_features": {
|
||||||
|
"do": False,
|
||||||
|
**CacheLoader.default_conf,
|
||||||
|
"collate": False,
|
||||||
|
"thresh": 0.0,
|
||||||
|
"max_num_keypoints": -1,
|
||||||
|
"force_num_keypoints": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _init(self, conf):
|
||||||
|
data_dir = conf.data_dir
|
||||||
|
coordinates_file = conf.metadata_dir
|
||||||
|
|
||||||
|
with open(coordinates_file, "r") as cf:
|
||||||
|
coordinates = cf.readlines()
|
||||||
|
|
||||||
|
parsed_coordinates = []
|
||||||
|
|
||||||
|
for coordinate in coordinates:
|
||||||
|
lat_part, lon_part = coordinate.split(',')
|
||||||
|
lat = float(lat_part.split(':')[-1].strip())
|
||||||
|
lon = float(lon_part.split(':')[-1].strip())
|
||||||
|
parsed_coordinates.append((lat, lon))
|
||||||
|
|
||||||
|
# Split into train and val 20% val
|
||||||
|
|
||||||
|
train_images = parsed_coordinates[:int(len(parsed_coordinates) * 0.8)]
|
||||||
|
val_images = parsed_coordinates[int(len(parsed_coordinates) * 0.8):]
|
||||||
|
|
||||||
|
self.images = {"train": train_images, "val": val_images}
|
||||||
|
|
||||||
|
def get_dataset(self, split):
|
||||||
|
return _Dataset(self.conf, self.images[split], split)
|
||||||
|
|
||||||
|
|
||||||
|
class _Dataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, conf, image_names, split):
|
||||||
|
self.conf = conf
|
||||||
|
self.split = split
|
||||||
|
self.image_names = np.array(image_names)
|
||||||
|
self.image_dir = DATA_PATH / conf.data_dir / conf.image_dir
|
||||||
|
|
||||||
|
aug_conf = conf.photometric
|
||||||
|
aug_name = aug_conf.name
|
||||||
|
assert (
|
||||||
|
aug_name in augmentations.keys()
|
||||||
|
), f'{aug_name} not in {" ".join(augmentations.keys())}'
|
||||||
|
#self.left_augment = (
|
||||||
|
# IdentityAugmentation() if conf.right_only else self.photo_augment
|
||||||
|
#)
|
||||||
|
self.photo_augment = augmentations[aug_name](aug_conf)
|
||||||
|
self.left_augment = augmentations[aug_name](aug_conf)
|
||||||
|
self.img_to_tensor = IdentityAugmentation()
|
||||||
|
|
||||||
|
if conf.load_features.do:
|
||||||
|
self.feature_loader = CacheLoader(conf.load_features)
|
||||||
|
|
||||||
|
def _transform_keypoints(self, features, data):
|
||||||
|
"""Transform keypoints by a homography, threshold them,
|
||||||
|
and potentially keep only the best ones."""
|
||||||
|
# Warp points
|
||||||
|
features["keypoints"] = warp_points(
|
||||||
|
features["keypoints"], data["H_"], inverse=False
|
||||||
|
)
|
||||||
|
h, w = data["image"].shape[1:3]
|
||||||
|
valid = (
|
||||||
|
(features["keypoints"][:, 0] >= 0)
|
||||||
|
& (features["keypoints"][:, 0] <= w - 1)
|
||||||
|
& (features["keypoints"][:, 1] >= 0)
|
||||||
|
& (features["keypoints"][:, 1] <= h - 1)
|
||||||
|
)
|
||||||
|
features["keypoints"] = features["keypoints"][valid]
|
||||||
|
|
||||||
|
# Threshold
|
||||||
|
if self.conf.load_features.thresh > 0:
|
||||||
|
valid = features["keypoint_scores"] >= self.conf.load_features.thresh
|
||||||
|
features = {k: v[valid] for k, v in features.items()}
|
||||||
|
|
||||||
|
# Get the top keypoints and pad
|
||||||
|
n = self.conf.load_features.max_num_keypoints
|
||||||
|
if n > -1:
|
||||||
|
inds = np.argsort(-features["keypoint_scores"])
|
||||||
|
features = {k: v[inds[:n]] for k, v in features.items()}
|
||||||
|
|
||||||
|
if self.conf.load_features.force_num_keypoints:
|
||||||
|
features = pad_local_features(
|
||||||
|
features, self.conf.load_features.max_num_keypoints
|
||||||
|
)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if self.conf.reseed:
|
||||||
|
with fork_rng(self.conf.seed + idx, False):
|
||||||
|
return self.getitem(idx)
|
||||||
|
else:
|
||||||
|
return self.getitem(idx)
|
||||||
|
|
||||||
|
def _read_view(self, img, H_conf, ps, left=False):
|
||||||
|
data = sample_homography(img, H_conf, ps)
|
||||||
|
if left:
|
||||||
|
data["image"] = self.left_augment(data["image"], return_tensor=True)
|
||||||
|
else:
|
||||||
|
data["image"] = self.photo_augment(data["image"], return_tensor=True)
|
||||||
|
|
||||||
|
gs = data["image"].new_tensor([0.299, 0.587, 0.114]).view(3, 1, 1)
|
||||||
|
if self.conf.grayscale:
|
||||||
|
data["image"] = (data["image"] * gs).sum(0, keepdim=True)
|
||||||
|
|
||||||
|
if self.conf.load_features.do:
|
||||||
|
features = self.feature_loader({k: [v] for k, v in data.items()})
|
||||||
|
features = self._transform_keypoints(features, data)
|
||||||
|
data["cache"] = features
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def getitem(self, idx):
|
||||||
|
# Generate a list of coordinates, based on the coordinate do the split
|
||||||
|
lat, lon = self.image_names[idx]
|
||||||
|
img = get_random_tiff_patch(lat, lon, self.conf.data_dir)
|
||||||
|
|
||||||
|
if img is None:
|
||||||
|
logging.warning("Image %f %f could not be read.", lat, lon)
|
||||||
|
img = np.zeros((1024, 1024) + (() if self.conf.grayscale else (3,)))
|
||||||
|
|
||||||
|
img = img.transpose(1,2,0)
|
||||||
|
img = img.astype(np.float32) / 255.0
|
||||||
|
#write_status = cv2.imwrite(f"/mnt/drive/{str(idx)}.jpg", img)
|
||||||
|
#if write_status == True:
|
||||||
|
# print("Writing success")
|
||||||
|
#else:
|
||||||
|
# print("fuck")
|
||||||
|
size = img.shape[:2][::-1]
|
||||||
|
ps = self.conf.homography.patch_shape
|
||||||
|
|
||||||
|
left_conf = omegaconf.OmegaConf.to_container(self.conf.homography)
|
||||||
|
if self.conf.right_only:
|
||||||
|
left_conf["difficulty"] = 0.0
|
||||||
|
|
||||||
|
data0 = self._read_view(img, left_conf, ps, left=True)
|
||||||
|
data1 = self._read_view(img, self.conf.homography, ps, left=False)
|
||||||
|
|
||||||
|
print("Data0 homography_matrix", data0["H_"])
|
||||||
|
print("Data1 homography matrix", data1["H_"])
|
||||||
|
|
||||||
|
# image_1 = data0["image"]
|
||||||
|
# image_1 = image_1.numpy()
|
||||||
|
# image_1 = image_1.transpose(1,2,0)
|
||||||
|
# image_1 = np.uint8(image_1 * 255)
|
||||||
|
#
|
||||||
|
# image_2 = data1["image"]
|
||||||
|
# image_2 = image_2.numpy()
|
||||||
|
# image_2 = image_2.transpose(1,2,0)
|
||||||
|
# image_2 = np.uint8(image_2 * 255)
|
||||||
|
#
|
||||||
|
# PIL_IMAGE = Image.fromarray(image_1)
|
||||||
|
# PIL_IMAGE.save(f"/mnt/drive/{str(idx)}.jpg")
|
||||||
|
# PIL_IMAGE = Image.fromarray(image_2)
|
||||||
|
# PIL_IMAGE.save(f"/mnt/drive/{str(idx)}-s.jpg")
|
||||||
|
#
|
||||||
|
# exit()
|
||||||
|
#
|
||||||
|
H = compute_homography(data0["coords"], data1["coords"], [1, 1])
|
||||||
|
|
||||||
|
print("COmputed homography", H)
|
||||||
|
|
||||||
|
name = f"lat{lat}_lon{lon}"
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"name": name,
|
||||||
|
"original_image_size": np.array(size),
|
||||||
|
"H_0to1": H.astype(np.float32),
|
||||||
|
"idx": idx,
|
||||||
|
"view0": data0,
|
||||||
|
"view1": data1,
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_names)
|
||||||
|
|
||||||
|
|
||||||
|
def visualize(args):
|
||||||
|
conf = {
|
||||||
|
"batch_size": 1,
|
||||||
|
"num_workers": 1,
|
||||||
|
"prefetch_factor": 1,
|
||||||
|
}
|
||||||
|
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
|
||||||
|
dataset = SatelliteDataset(conf)
|
||||||
|
loader = dataset.get_data_loader("train")
|
||||||
|
logger.info("The dataset has %d elements.", len(loader))
|
||||||
|
|
||||||
|
with fork_rng(seed=dataset.conf.seed):
|
||||||
|
images = []
|
||||||
|
for _, data in zip(range(args.num_items), loader):
|
||||||
|
images.append(
|
||||||
|
(data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2))
|
||||||
|
)
|
||||||
|
plot_image_grid(images, dpi=args.dpi)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.imsave("implot.png")
|
||||||
|
#plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from .. import logger # overwrite the logger
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--num_items", type=int, default=8)
|
||||||
|
parser.add_argument("--dpi", type=int, default=100)
|
||||||
|
parser.add_argument("dotlist", nargs="*")
|
||||||
|
args = parser.parse_intermixed_args()
|
||||||
|
visualize(args)
|
|
@ -53,6 +53,7 @@ def sample_homography_corners(
|
||||||
min_pts1 = create_center_patch(shape, (pwidth, pheight))
|
min_pts1 = create_center_patch(shape, (pwidth, pheight))
|
||||||
full = create_center_patch(shape)
|
full = create_center_patch(shape)
|
||||||
pts2 = create_center_patch(patch_shape)
|
pts2 = create_center_patch(patch_shape)
|
||||||
|
|
||||||
scale = min_pts1 - full
|
scale = min_pts1 - full
|
||||||
found_valid = False
|
found_valid = False
|
||||||
cnt = -1
|
cnt = -1
|
||||||
|
@ -68,7 +69,9 @@ def sample_homography_corners(
|
||||||
|
|
||||||
# Rotation
|
# Rotation
|
||||||
if n_angles > 0 and difficulty > 0:
|
if n_angles > 0 and difficulty > 0:
|
||||||
angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles)
|
#angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles)
|
||||||
|
# In our case the difficulty parameter should not affect the rotation
|
||||||
|
angles = np.linspace(-max_angle, max_angle, n_angles)
|
||||||
rng.shuffle(angles)
|
rng.shuffle(angles)
|
||||||
rng.shuffle(angles)
|
rng.shuffle(angles)
|
||||||
angles = np.concatenate([[0.0], angles], axis=0)
|
angles = np.concatenate([[0.0], angles], axis=0)
|
||||||
|
|
|
@ -12,6 +12,7 @@ import signal
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydoc import locate
|
from pydoc import locate
|
||||||
|
import gc
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -89,6 +90,7 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
|
||||||
for i, data in enumerate(
|
for i, data in enumerate(
|
||||||
tqdm(loader, desc="Evaluation", ascii=True, disable=not pbar)
|
tqdm(loader, desc="Evaluation", ascii=True, disable=not pbar)
|
||||||
):
|
):
|
||||||
|
|
||||||
data = batch_to_device(data, device, non_blocking=True)
|
data = batch_to_device(data, device, non_blocking=True)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pred = model(data)
|
pred = model(data)
|
||||||
|
@ -102,7 +104,6 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
|
||||||
pred[v["predictions"]],
|
pred[v["predictions"]],
|
||||||
mask=pred[v["mask"]] if "mask" in v.keys() else None,
|
mask=pred[v["mask"]] if "mask" in v.keys() else None,
|
||||||
)
|
)
|
||||||
del pred, data
|
|
||||||
numbers = {**metrics, **{"loss/" + k: v for k, v in losses.items()}}
|
numbers = {**metrics, **{"loss/" + k: v for k, v in losses.items()}}
|
||||||
for k, v in numbers.items():
|
for k, v in numbers.items():
|
||||||
if k not in results:
|
if k not in results:
|
||||||
|
@ -118,7 +119,9 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
|
||||||
if k in conf.recall_metrics.keys():
|
if k in conf.recall_metrics.keys():
|
||||||
q = conf.recall_metrics[k]
|
q = conf.recall_metrics[k]
|
||||||
results[k + f"_recall{int(q)}"].update(v)
|
results[k + f"_recall{int(q)}"].update(v)
|
||||||
del numbers
|
|
||||||
|
del pred, data, losses, metrics
|
||||||
|
gc.collect()
|
||||||
results = {k: results[k].compute() for k in results}
|
results = {k: results[k].compute() for k in results}
|
||||||
return results, {k: v.compute() for k, v in pr_metrics.items()}, figures
|
return results, {k: v.compute() for k, v in pr_metrics.items()}, figures
|
||||||
|
|
||||||
|
@ -257,6 +260,7 @@ def training(rank, conf, output_dir, args):
|
||||||
|
|
||||||
dataset = get_dataset(data_conf.name)(data_conf)
|
dataset = get_dataset(data_conf.name)(data_conf)
|
||||||
|
|
||||||
|
|
||||||
# Optionally load a different validation dataset than the training one
|
# Optionally load a different validation dataset than the training one
|
||||||
val_data_conf = conf.get("data_val", None)
|
val_data_conf = conf.get("data_val", None)
|
||||||
if val_data_conf is None:
|
if val_data_conf is None:
|
||||||
|
@ -408,7 +412,9 @@ def training(rank, conf, output_dir, args):
|
||||||
getattr(loader.dataset, conf.train.dataset_callback_fn)(
|
getattr(loader.dataset, conf.train.dataset_callback_fn)(
|
||||||
conf.train.seed + epoch
|
conf.train.seed + epoch
|
||||||
)
|
)
|
||||||
|
|
||||||
for it, data in enumerate(train_loader):
|
for it, data in enumerate(train_loader):
|
||||||
|
|
||||||
tot_it = (len(train_loader) * epoch + it) * (
|
tot_it = (len(train_loader) * epoch + it) * (
|
||||||
args.n_gpus if args.distributed else 1
|
args.n_gpus if args.distributed else 1
|
||||||
)
|
)
|
||||||
|
@ -427,10 +433,17 @@ def training(rank, conf, output_dir, args):
|
||||||
loss = torch.mean(losses["total"])
|
loss = torch.mean(losses["total"])
|
||||||
if torch.isnan(loss).any():
|
if torch.isnan(loss).any():
|
||||||
print(f"Detected NAN, skipping iteration {it}")
|
print(f"Detected NAN, skipping iteration {it}")
|
||||||
|
print("name", data["name"])
|
||||||
|
print("loss", loss)
|
||||||
|
print("losses", losses)
|
||||||
|
print("data", data)
|
||||||
|
print("pred", pred)
|
||||||
|
raise RuntimeError("Detected NAN in training.")
|
||||||
del pred, data, loss, losses
|
del pred, data, loss, losses
|
||||||
continue
|
continue
|
||||||
|
|
||||||
do_backward = loss.requires_grad
|
do_backward = loss.requires_grad
|
||||||
|
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
do_backward = torch.tensor(do_backward).float().to(device)
|
do_backward = torch.tensor(do_backward).float().to(device)
|
||||||
torch.distributed.all_reduce(
|
torch.distributed.all_reduce(
|
||||||
|
@ -469,7 +482,6 @@ def training(rank, conf, output_dir, args):
|
||||||
else:
|
else:
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.warning(f"Skip iteration {it} due to detach.")
|
logger.warning(f"Skip iteration {it} due to detach.")
|
||||||
|
|
||||||
if args.profile:
|
if args.profile:
|
||||||
prof.step()
|
prof.step()
|
||||||
|
|
||||||
|
@ -508,8 +520,11 @@ def training(rank, conf, output_dir, args):
|
||||||
norm = torch.norm(param.grad.detach(), 2)
|
norm = torch.norm(param.grad.detach(), 2)
|
||||||
grad_txt += f"{name} {norm.item():.3f} \n"
|
grad_txt += f"{name} {norm.item():.3f} \n"
|
||||||
writer.add_text("grad/summary", grad_txt, tot_n_samples)
|
writer.add_text("grad/summary", grad_txt, tot_n_samples)
|
||||||
del pred, data, loss, losses
|
|
||||||
|
|
||||||
|
pred.clear()
|
||||||
|
data.clear()
|
||||||
|
del pred, data, loss, losses
|
||||||
|
gc.collect()
|
||||||
# Run validation
|
# Run validation
|
||||||
if (
|
if (
|
||||||
(
|
(
|
||||||
|
@ -529,6 +544,7 @@ def training(rank, conf, output_dir, args):
|
||||||
pbar=(rank == -1),
|
pbar=(rank == -1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
str_results = [
|
str_results = [
|
||||||
f"{k} {v:.3E}"
|
f"{k} {v:.3E}"
|
||||||
|
@ -569,6 +585,10 @@ def training(rank, conf, output_dir, args):
|
||||||
f"figures/{i}_{name}", fig, tot_n_samples
|
f"figures/{i}_{name}", fig, tot_n_samples
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache() # should be cleared at the first iter
|
torch.cuda.empty_cache() # should be cleared at the first iter
|
||||||
|
str_results.clear()
|
||||||
|
pr_metrics.clear()
|
||||||
|
del str_results, figures, pr_metrics
|
||||||
|
|
||||||
|
|
||||||
if (tot_it % conf.train.save_every_iter == 0 and tot_it > 0) and rank == 0:
|
if (tot_it % conf.train.save_every_iter == 0 and tot_it > 0) and rank == 0:
|
||||||
if results is None:
|
if results is None:
|
||||||
|
@ -622,7 +642,7 @@ def training(rank, conf, output_dir, args):
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
|
|
||||||
def main_worker(rank, conf, output_dir, args):
|
def main_worker(rank, conf, output_dir, aprgs):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
with capture_outputs(output_dir / "log.txt"):
|
with capture_outputs(output_dir / "log.txt"):
|
||||||
training(rank, conf, output_dir, args)
|
training(rank, conf, output_dir, args)
|
||||||
|
|
|
@ -33,6 +33,15 @@ def batch_to_device(batch, device, non_blocking=True):
|
||||||
|
|
||||||
return map_tensor(batch, _func)
|
return map_tensor(batch, _func)
|
||||||
|
|
||||||
|
def detach_tensors(batch):
|
||||||
|
"""
|
||||||
|
Detach all tensors in a batch recursively.
|
||||||
|
This is useful for detaching tensors from the computational graph to free up memory.
|
||||||
|
"""
|
||||||
|
def _detach(tensor):
|
||||||
|
return tensor.detach()
|
||||||
|
|
||||||
|
return map_tensor(batch, _detach)
|
||||||
|
|
||||||
def rbd(data: dict) -> dict:
|
def rbd(data: dict) -> dict:
|
||||||
"""Remove batch dimension from elements in data"""
|
"""Remove batch dimension from elements in data"""
|
||||||
|
|
|
@ -67,8 +67,12 @@ class MedianMetric:
|
||||||
else:
|
else:
|
||||||
return np.nanmedian(self._elements)
|
return np.nanmedian(self._elements)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._elements)
|
||||||
|
|
||||||
|
|
||||||
class PRMetric:
|
class PRMetric:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.labels = []
|
self.labels = []
|
||||||
self.predictions = []
|
self.predictions = []
|
||||||
|
|
|
@ -338,6 +338,14 @@ class SuperPoint(BaseModel):
|
||||||
for k, d in zip(keypoints, dense_desc)
|
for k, d in zip(keypoints, dense_desc)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if isinstance(desc, list):
|
||||||
|
if all(isinstance(d, torch.Tensor) for d in desc):
|
||||||
|
# If desc is a list of tensors
|
||||||
|
desc = torch.stack(desc)
|
||||||
|
else:
|
||||||
|
# If desc is a list of non-tensor elements
|
||||||
|
desc = torch.stack([torch.tensor(d) for d in desc])
|
||||||
|
|
||||||
pred = {
|
pred = {
|
||||||
"keypoints": keypoints + 0.5,
|
"keypoints": keypoints + 0.5,
|
||||||
"keypoint_scores": scores,
|
"keypoint_scores": scores,
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
#python -m gluefactory.train sp+lg_homography \
|
||||||
|
# --conf gluefactory/configs/superpoint+lightglue_homography.yaml \
|
||||||
|
# data.batch_size=32 # for 1x 1080 GPU
|
||||||
|
|
||||||
|
#python -m gluefactory.train satellites_aug_both_40 \
|
||||||
|
# --conf gluefactory/configs/satellites3.yaml \
|
||||||
|
# data.batch_size=4 # for 1x 1080 GPU
|
||||||
|
|
||||||
|
|
||||||
|
python -m gluefactory.train drone2sat_v1 \
|
||||||
|
--conf gluefactory/configs/drone2sat_v1.yaml \
|
||||||
|
data.batch_size=6 # for 1x 1080 GPU
|
|
@ -0,0 +1,22 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
LOGFILE="/var/log/h/system_stats.log"
|
||||||
|
|
||||||
|
echo "Logging CPU, Memory, Swap, Disk, and Network usage to $LOGFILE"
|
||||||
|
|
||||||
|
while true; do
|
||||||
|
echo "----- $(date) -----" >> $LOGFILE
|
||||||
|
echo "CPU Usage:" >> $LOGFILE
|
||||||
|
mpstat -P ALL 1 1 >> $LOGFILE
|
||||||
|
|
||||||
|
echo "Disk Usage:" >> $LOGFILE
|
||||||
|
iostat >> $LOGFILE
|
||||||
|
|
||||||
|
echo "Memory and Swap Usage:" >> $LOGFILE
|
||||||
|
free -m >> $LOGFILE
|
||||||
|
|
||||||
|
echo "Network Usage:" >> $LOGFILE
|
||||||
|
ifstat 1 1 >> $LOGFILE
|
||||||
|
|
||||||
|
sleep 5
|
||||||
|
done
|
Loading…
Reference in New Issue