Add custom training files

main
Jeba Kolega 2024-02-14 00:02:09 +01:00
parent 4a8283517f
commit d40e423a42
17 changed files with 1577 additions and 6 deletions

View File

@ -0,0 +1,69 @@
data:
name: drone2sat
batch_size: 4
num_workers: 24
val_size: 500
train_size: -1
photometric:
name: lg
geo_dataset:
uav_dataset_dir: /mnt/drive/uav_dataset
satellite_dataset_dir: /mnt/drive/tiles
misslabeled_images_path: /mnt/drive/misslabeled.txt
sat_zoom_level: 17
uav_patch_width: 400
uav_patch_height: 400
sat_patch_width: 400
sat_patch_height: 400
test_from_train_ratio: 0.0
transform_mean:
- 0.485
- 0.456
- 0.406
transform_std:
- 0.229
- 0.224
- 0.225
sat_availaible_years:
- "2023"
- "2021"
- "2019"
- "2016"
max_rotation_angle: 0
uav_image_scale: 2.0
use_heatmap: false
model:
name: two_view_pipeline
extractor:
name: gluefactory_nonfree.superpoint
max_num_keypoints: 2048
detection_threshold: 0.0
nms_radius: 3
ground_truth:
name: matchers.homography_matcher
th_positive: 3
th_negative: 3
matcher:
name: matchers.lightglue
features: superpoint
depth_confidence: -1
width_confidence: -1
filter_threshold: 0.1
flash: true
train:
seed: 0
epochs: 40
log_every_iter: 100
eval_every_iter: 1000 # 350000 / 4
lr: 1e-3
lr_schedule:
start: 20
type: exp
on_epoch: true
exp_div_10: 10
benchmarks:
hpatches:
eval:
estimator: opencv
ransac_th: 0.5

View File

@ -0,0 +1,48 @@
data:
name: satellites
data_dir: /mnt/drive/tiles
metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt
train_size: null
val_size: null
batch_size: 128
num_workers: 24
homography:
difficulty: 0.9
max_angle: 359
photometric:
name: lg
model:
name: two_view_pipeline
extractor:
name: gluefactory_nonfree.superpoint
max_num_keypoints: 2048
detection_threshold: 0.0
force_num_keypoints: True
nms_radius: 3
trainable: False
ground_truth:
name: matchers.homography_matcher
th_positive: 3
th_negative: 3
matcher:
name: matchers.lightglue
filter_threshold: 0.1
flash: true
checkpointed: true
weights: superpoint
train:
seed: 0
epochs: 5
log_every_iter: 100
eval_every_iter: 500
lr: 1e-7
lr_schedule:
start: 20
type: exp
on_epoch: true
exp_div_10: 10
benchmarks:
hpatches:
eval:
estimator: opencv
ransac_th: 0.5

View File

@ -0,0 +1,60 @@
data:
name: satellites
data_dir: /mnt/drive/tiles
metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt
train_size: null
val_size: null
batch_size: 128
num_workers: 24
homography:
difficulty: 0.5
max_angle: 359
photometric:
name: lg
model:
name: two_view_pipeline
extractor:
name: gluefactory_nonfree.superpoint
max_num_keypoints: 2048
detection_threshold: 0.0
nms_radius: 3
#extractor:
# name: gluefactory_nonfree.superpoint
# max_num_keypoints: 512
# force_num_keypoints: True
# detection_threshold: 0.0
# nms_radius: 3
# trainable: False
ground_truth:
name: matchers.homography_matcher
th_positive: 3
th_negative: 3
matcher:
name: matchers.lightglue
features: superpoint
depth_confidence: -1
width_confidence: -1
filter_threshold: 0.1
flash: true
#matcher:
# name: matchers.lightglue_pretrained
# features: superpoint
# depth_confidence: -1
# width_confidence: -1
# filter_threshold: 0.1
train:
seed: 0
epochs: 5
log_every_iter: 100
eval_every_iter: 500
lr: 1e-4
lr_schedule:
start: 20
type: exp
on_epoch: true
exp_div_10: 10
benchmarks:
hpatches:
eval:
estimator: opencv
ransac_th: 0.5

View File

@ -0,0 +1,47 @@
data:
name: satellites
data_dir: /mnt/drive/tiles
metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt
train_size: null
val_size: null
batch_size: 128
num_workers: 24
homography:
difficulty: 0.5
max_angle: 180
photometric:
name: lg
model:
name: two_view_pipeline
extractor:
name: gluefactory_nonfree.superpoint
max_num_keypoints: 2048
detection_threshold: 0.0
nms_radius: 3
ground_truth:
name: matchers.homography_matcher
th_positive: 3
th_negative: 3
matcher:
name: matchers.lightglue
features: superpoint
depth_confidence: -1
width_confidence: -1
filter_threshold: 0.1
flash: true
train:
seed: 0
epochs: 40
log_every_iter: 100
eval_every_iter: 500
lr: 1e-4
lr_schedule:
start: 20
type: exp
on_epoch: true
exp_div_10: 10
benchmarks:
hpatches:
eval:
estimator: opencv
ransac_th: 0.5

View File

@ -33,6 +33,7 @@ train:
epochs: 40
log_every_iter: 100
eval_every_iter: 500
profile: true
lr: 1e-4
lr_schedule:
start: 20

View File

@ -168,6 +168,11 @@ class BaseDataset(metaclass=ABCMeta):
sampler = None
if shuffle is None:
shuffle = split == "train" and self.conf.shuffle_training
shuffle = split == "val"
shuffle = True
print("Shuffle", shuffle)
return DataLoader(
dataset,
batch_size=batch_size,

View File

@ -0,0 +1,694 @@
"""
Simply load images from a folder or nested folders (does not have any split),
and apply homographic adaptations to it. Yields an image pair without border
artifacts.
"""
import argparse
import logging
import shutil
import tarfile
from pathlib import Path
import os
import json
import gc
import cv2
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import torch
from omegaconf import OmegaConf
from tqdm import tqdm
from PIL import Image
from typing import List, Literal
from torchvision import transforms
import random
from torchvision.transforms import functional as F
from ..geometry.homography import (
compute_homography,
sample_homography_corners,
warp_points,
)
from ..models.cache_loader import CacheLoader, pad_local_features
from ..settings import DATA_PATH
from ..utils.image import read_image
from ..utils.tools import fork_rng
from ..visualization.viz2d import plot_image_grid
from .augmentations import IdentityAugmentation, augmentations
from .base_dataset import BaseDataset
from .s_utils import get_random_tiff_patch
logger = logging.getLogger(__name__)
class Drone2SatDataset(BaseDataset):
default_conf = {
# image search
"data_dir": "revisitop1m", # the top-level directory
"image_dir": "jpg/", # the subdirectory with the images
"image_list": "revisitop1m.txt", # optional: list or filename of list
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
"metadata_dir": None,
"uav_dataset_dir": None,
"satellite_dataset_dir": None,
# splits
"train_size": 100,
"val_size": 10,
"shuffle_seed": 0, # or None to skip
# image loading
"grayscale": False,
"triplet": False,
"right_only": False, # image0 is orig (rescaled), image1 is right
"reseed": False,
"homography": {
"difficulty": 0.8,
"translation": 1.0,
"max_angle": 60,
"n_angles": 10,
"patch_shape": [640, 480],
"min_convexity": 0.05,
},
"photometric": {
"name": "dark",
"p": 0.75,
# 'difficulty': 1.0, # currently unused
},
# feature loading
"load_features": {
"do": False,
**CacheLoader.default_conf,
"collate": False,
"thresh": 0.0,
"max_num_keypoints": -1,
"force_num_keypoints": False,
},
# Other geolocaliztion parameters
"geo_dataset" :{
"uav_dataset_dir": None,
"satellite_dataset_dir": None,
"misslabeled_images_path": None,
"sat_zoom_level": 17,
"uav_patch_width": 400,
"uav_patch_height": 400,
"sat_patch_width": 400,
"sat_patch_height": 400,
"heatmap_kernel_size": 33,
"test_from_train_ratio": 0.0,
"transform_mean": [0.485, 0.456, 0.406],
"transform_std": [0.229, 0.224, 0.225],
"sat_availaible_years": ["2023", "2021", "2019", "2016"],
"max_rotation_angle": 10,
"uav_image_scale": 1,
"use_heatmap": False,
}
}
def _init(self, conf):
self.images = {"train": "train", "val": "val"}
def get_dataset(self, split):
if split == "val":
return GeoLocalizationDataset(
uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir,
satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir,
misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path,
dataset="test",
sat_zoom_level=self.conf.geo_dataset.sat_zoom_level,
uav_patch_width=self.conf.geo_dataset.uav_patch_width,
uav_patch_height=self.conf.geo_dataset.uav_patch_height,
sat_patch_width=self.conf.geo_dataset.sat_patch_width,
sat_patch_height=self.conf.geo_dataset.sat_patch_height,
heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size,
test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio,
transform_mean=self.conf.geo_dataset.transform_mean,
transform_std=self.conf.geo_dataset.transform_std,
sat_available_years=self.conf.geo_dataset.sat_availaible_years,
max_rotation_angle=self.conf.geo_dataset.max_rotation_angle,
uav_image_scale=self.conf.geo_dataset.uav_image_scale,
use_heatmap=self.conf.geo_dataset.use_heatmap,
subset_size=self.conf.val_size,
)
elif split == "train":
return GeoLocalizationDataset(
uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir,
satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir,
misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path,
dataset="train",
sat_zoom_level=self.conf.geo_dataset.sat_zoom_level,
uav_patch_width=self.conf.geo_dataset.uav_patch_width,
uav_patch_height=self.conf.geo_dataset.uav_patch_height,
sat_patch_width=self.conf.geo_dataset.sat_patch_width,
sat_patch_height=self.conf.geo_dataset.sat_patch_height,
heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size,
test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio,
transform_mean=self.conf.geo_dataset.transform_mean,
transform_std=self.conf.geo_dataset.transform_std,
sat_available_years=self.conf.geo_dataset.sat_availaible_years,
max_rotation_angle=self.conf.geo_dataset.max_rotation_angle,
uav_image_scale=self.conf.geo_dataset.uav_image_scale,
use_heatmap=self.conf.geo_dataset.use_heatmap,
subset_size=self.conf.train_size,
)
class GeoLocalizationDataset(torch.utils.data.Dataset):
def __init__(
self,
uav_dataset_dir: str,
satellite_dataset_dir: str,
misslabeled_images_path: str,
dataset: Literal["train", "test"],
sat_zoom_level: int = 16,
uav_patch_width: int = 128,
uav_patch_height: int = 128,
sat_patch_width: int = 400,
sat_patch_height: int = 400,
heatmap_kernel_size: int = 33,
test_from_train_ratio: float = 0.0,
transform_mean: List[float] = [0.485, 0.456, 0.406],
transform_std: List[float] = [0.229, 0.224, 0.225],
sat_available_years: List[str] = ["2023", "2021", "2019", "2016"],
max_rotation_angle: int = 10,
uav_image_scale: float = 1,
use_heatmap: bool = True,
subset_size: int = None,
):
self.uav_dataset_dir = uav_dataset_dir
self.satellite_dataset_dir = satellite_dataset_dir
self.dataset = dataset
self.sat_zoom_level = sat_zoom_level
self.uav_patch_width = uav_patch_width
self.uav_patch_height = uav_patch_height
self.heatmap_kernel_size = heatmap_kernel_size
self.test_from_train_ratio = test_from_train_ratio
self.transform_mean = transform_mean
self.transform_std = transform_std
self.misslabeled_images_path = misslabeled_images_path
self.metadata_dict = {}
self.max_rotation_angle = max_rotation_angle
self.total_uav_samples = self.count_total_uav_samples()
self.misslabelled_images = self.read_misslabelled_images(self.misslabeled_images_path)
self.entry_paths = self.get_entry_paths(self.uav_dataset_dir)
self.cleanup_misslabelled_images()
self.transforms = transforms.Compose(
[
transforms.ToTensor(),
#transforms.Normalize(self.transform_mean, self.transform_std),
]
)
self.sat_available_years = sat_available_years
self.uav_image_scale = uav_image_scale
self.use_heatmap = use_heatmap
self.sat_patch_width = sat_patch_width
self.sat_patch_height = sat_patch_height
self.subset_size = subset_size
self.inverse_transforms = transforms.Compose(
[
transforms.Normalize(
mean=[
-m / s for m, s in zip(self.transform_mean, self.transform_std)
],
std=[1 / s for s in self.transform_std],
),
transforms.ToPILImage(),
]
)
if self.subset_size != -1:
ssize = int(self.subset_size / len(self.sat_available_years))
ssize = min(ssize, len(self.entry_paths))
self.entry_paths = self.entry_paths[:ssize]
def __len__(self) -> int:
return (
len(self.entry_paths)
* len(self.sat_available_years)
)
def read_misslabelled_images(
self, path: str = "misslabels/misslabeled.txt"
) -> List[str]:
with open(path, "r") as f:
lines = f.readlines()
return [line.strip() for line in lines]
def cleanup_misslabelled_images(self) -> None:
indices_to_delete = []
for image in self.misslabelled_images:
for image_path in self.entry_paths:
if image in image_path:
index = self.entry_paths.index(image_path)
indices_to_delete.append(index)
break
sorted_tuples = sorted(indices_to_delete, reverse=True)
for index in sorted_tuples:
self.entry_paths.pop(index)
def __getitem__(self, idx) -> dict:
"""
Retrieves a sample given its index, returning the preprocessed UAV
and satellite images, along with their associated heatmap and metadata.
"""
image_path_index = idx // (
len(self.sat_available_years)
)
sat_year = self.sat_available_years[idx % len(self.sat_available_years)]
rot_angle = random.randint(-self.max_rotation_angle, self.max_rotation_angle)
image_path = self.entry_paths[image_path_index]
uav_image = Image.open(image_path).convert("RGB") # Ensure 3-channel image
original_uav_image_width = uav_image.width
original_uav_image_height = uav_image.height
lookup_str, file_number = self.extract_info_from_filename(image_path)
img_info = self.metadata_dict[lookup_str][file_number]
lat, lon = (
img_info["coordinate"]["latitude"],
img_info["coordinate"]["longitude"],
)
fov_vertical = img_info["fovVertical"]
try:
agl_altitude = float(image_path.split("/")[-1].split("_")[2].split("m")[0])
except IndexError:
agl_altitude = 150.0
warnings.warn(
"Could not extract AGL altitude from filename, using default value of 150m."
)
(
satellite_patch,
x_sat,
y_sat,
x_offset,
y_offset,
patch_transform,
) = get_random_tiff_patch(
lat, lon, self.sat_patch_width, self.sat_patch_height, sat_year, self.satellite_dataset_dir
)
# Rotate crop center and transform image
h = np.ceil(uav_image.height // self.uav_image_scale).astype(int)
w = np.ceil(uav_image.width // self.uav_image_scale).astype(int)
uav_image = F.rotate(uav_image, rot_angle)
uav_image = F.resize(uav_image, [h, w])
uav_image = F.center_crop(
uav_image, (self.uav_patch_height, self.uav_patch_width)
)
uav_image = self.transforms(uav_image)
satellite_patch = satellite_patch.transpose(1, 2, 0)
satellite_patch_pytorch = self.transforms(satellite_patch)
del satellite_patch
if self.use_heatmap:
heatmap = self.get_heatmap_gt(
x_sat,
y_sat,
satellite_patch.shape[1],
satellite_patch.shape[2],
self.heatmap_kernel_size,
)
cropped_uav_image_width = self.calculate_cropped_uav_image_width(
fov_vertical,
original_uav_image_width,
original_uav_image_height,
self.uav_patch_width,
self.uav_patch_height,
agl_altitude,
)
satellite_tile_width = self.calculate_cropped_sat_image_width(
lat, self.sat_patch_width, patch_transform
)
scale_factor = cropped_uav_image_width / satellite_tile_width
scale_factor *= 10
homography_matrix = self.compute_homography(
rot_angle,
x_sat,
y_sat,
self.uav_patch_width,
self.uav_patch_height,
scale_factor,
)
if not self.use_heatmap:
# Sample four points from the UAV image
points = self.sample_four_points(
self.uav_patch_width, self.uav_patch_height
)
# Transform the points
warped_points = self.warp_points(points, homography_matrix)
#img_info["warped_points_sat"] = warped_points
#img_info["warped_points_uav"] = points
# img_info["cropped_uav_image_width"] = cropped_uav_image_width
# img_info["satellite_tile_width"] = satellite_tile_width
# img_info["scale_factor"] = scale_factor
# img_info["filename"] = image_path
# img_info["rot_angle"] = rot_angle
# img_info["x_sat"] = x_sat
# img_info["y_sat"] = y_sat
# img_info["x_offset"] = x_offset
# img_info["y_offset"] = y_offset
# img_info["patch_transform"] = patch_transform
# img_info["uav_image_scale"] = self.uav_image_scale
# img_info["homography_matrix_uav_to_sat"] = homography_matrix
# img_info["homography_matrix_sat_to_uav"] = np.linalg.inv(homography_matrix)
# img_info["agl_altitude"] = agl_altitude
# img_info["original_uav_image_width"] = original_uav_image_width
# img_info["original_drone_image_height"] = original_uav_image_height
# img_info["fov_vertical"] = fov_vertical
#
inverse_homography_matrix = np.linalg.inv(homography_matrix)
uav_image_data = {
"image": uav_image,
"H_": homography_matrix,
"coords": points,
"image_size": np.array([self.uav_patch_width, self.uav_patch_height]),
}
satellite_patch_data = {
"image": satellite_patch_pytorch,
"H_": inverse_homography_matrix,
"coords": warped_points,
"image_size": np.array([self.sat_patch_width, self.sat_patch_height]),
}
#hm = self.compute_homography_points(
# warped_points, points, [self.uav_patch_width, self.uav_patch_height]
#)
hm = self.compute_homography_points(
points, warped_points, [1.0 , 1.0]
)
if np.array_equal(hm, np.eye(3)):
print("Singular matrix")
return self.__getitem__(random.randint(0, len(self) - 1))
data = {
"name": f"{image_path}_{rot_angle}_{sat_year}",
"original_image_size": np.array(
[original_uav_image_width, original_uav_image_height]
),
"H_0to1": hm.astype(np.float32),
"view0": uav_image_data,
"view1": satellite_patch_data,
"idx": idx
}
del img_info
gc.collect()
return data
def calculate_cropped_sat_image_width(self, latitude, patch_width, patch_transform):
"""
Computes the width of the satellite image in world coordinate units
"""
length_of_degree = 111320 * np.cos(np.radians(latitude))
scale_x_meters = patch_transform[0] * length_of_degree
satellite_tile_width = patch_width * scale_x_meters
return satellite_tile_width
def calculate_cropped_uav_image_width(
self,
fov_vertical,
orig_width,
orig_height,
crop_width,
crop_height,
altitude=150.0,
):
"""
Computes the width of the UAV image in world coordinate units
"""
# Convert fov from degrees to radians
fov_rad = np.radians(fov_vertical)
# Calculate the full width of the UAV image
full_width = 2 * (altitude * np.tan(fov_rad / 2))
# Determine the cropping ratio
crop_ratio_width = crop_width / orig_width
crop_ratio_height = crop_height / orig_height
# Calculate the adjusted horizontal fov
fov_horizontal = 2 * np.arctan(np.tan(fov_rad / 2) * (orig_width / orig_height))
adjusted_fov_horizontal = 2 * np.arctan(
np.tan(fov_horizontal / 2) * crop_ratio_width
)
# Calculate the new full width using the adjusted horizontal fov
full_width = 2 * (altitude * np.tan(adjusted_fov_horizontal / 2))
# Adjust the width according to the crop ratio
cropped_width = full_width * crop_ratio_width
return cropped_width
def compute_homography_points(self, pts1_, pts2_, shape):
"""Compute the homography matrix from 4 point correspondences"""
# Rescale to actual size
shape = np.array(shape[::-1], dtype=np.float32) # different convention [y, x]
pts1 = pts1_ * np.expand_dims(shape, axis=0)
pts2 = pts2_ * np.expand_dims(shape, axis=0)
def ax(p, q):
return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
def ay(p, q):
return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
def flat2mat(H):
return np.reshape(np.concatenate([H, np.ones_like(H[:, :1])], axis=1), [3, 3])
a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0)
p_mat = np.transpose(
np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0)
)
try:
homography = np.transpose(np.linalg.solve(a_mat, p_mat))
except np.linalg.LinAlgError:
print("Singular matrix")
return np.eye(3)
return flat2mat(homography)
def compute_homography(
self, rot_angle, x_sat, y_sat, uav_width, uav_height, scale_factor
):
# Adjust rot_angle if it's greater than 180 degrees
if rot_angle > 180:
rot_angle -= 360
# Convert rotation angle to radians
theta = np.radians(rot_angle)
# Rotation matrix
R = np.array(
[
[np.cos(theta), -np.sin(theta), 0],
[np.sin(theta), np.cos(theta), 0],
[0, 0, 1],
]
)
# Scale matrix
S = np.array([[scale_factor, 0, 0], [0, scale_factor, 0], [0, 0, 1]])
# Translation matrix to center the UAV image
T_uav = np.array([[1, 0, -uav_width / 2], [0, 1, -uav_height / 2], [0, 0, 1]])
# Translation matrix to move to the satellite image position
T_sat = np.array([[1, 0, x_sat], [0, 1, y_sat], [0, 0, 1]])
# Compute the combined homography matrix
H = np.dot(T_sat, np.dot(R, np.dot(S, T_uav)))
return H
def sample_four_points(self, width: int, height: int) -> np.ndarray:
"""
Samples four points from the UAV image.
"""
PADDING = 50
CENTER_PADDING = 10
points = np.array(
[
[random.randint(CENTER_PADDING, width - PADDING), random.randint(CENTER_PADDING, height - PADDING)]
for _ in range(4)
]
)
return points
def warp_points(
self, points: np.ndarray, homography_matrix: np.ndarray
) -> np.ndarray:
"""
Warps the given points using the given homography matrix.
"""
points = np.array(points)
points = np.concatenate([points, np.ones((4, 1))], axis=1)
points = np.dot(homography_matrix, points.T).T
points = points[:, :2] / points[:, 2:]
return points
def count_total_uav_samples(self) -> int:
"""
Count the total number of uav image samples in the dataset
(train + test)
"""
total_samples = 0
for dirpath, dirnames, filenames in os.walk(self.uav_dataset_dir):
# Skip the test folder
for filename in filenames:
if filename.endswith(".jpeg"):
total_samples += 1
return total_samples
def get_number_of_city_samples(self) -> int:
"""
TODO: Count the total number of city samples in the dataset
"""
return 11
def get_entry_paths(self, directory: str) -> List[str]:
"""
Recursively retrieves paths to image and metadata files in the given directory.
"""
entry_paths = []
entries = os.listdir(directory)
images_to_take_per_folder = int(
self.total_uav_samples
* self.test_from_train_ratio
/ self.get_number_of_city_samples()
)
for entry in entries:
entry_path = os.path.join(directory, entry)
# If it's a directory, recurse into it
if os.path.isdir(entry_path):
entry_paths += self.get_entry_paths(entry_path)
# Handle train dataset
elif (self.dataset == "train" and "Train" in entry_path) or (
self.dataset == "train"
and self.test_from_train_ratio > 0
and "Test" in entry_path
):
if entry_path.endswith(".jpeg"):
_, number = self.extract_info_from_filename(entry_path)
else:
number = None
if entry_path.endswith(".json"):
self.get_metadata(entry_path)
if number is None:
continue
if (
number >= images_to_take_per_folder
): # Only include images beyond the ones taken for test
if entry_path.endswith(".jpeg"):
entry_paths.append(entry_path)
# Handle test dataset
elif self.dataset == "test":
if entry_path.endswith(".jpeg"):
_, number = self.extract_info_from_filename(entry_path)
else:
number = None
if entry_path.endswith(".json"):
self.get_metadata(entry_path)
if number is None:
continue
if (
("Test" in entry_path and number < images_to_take_per_folder)
or (number < images_to_take_per_folder and "Train" in entry_path)
or (self.test_from_train_ratio == 0.0 and "Test" in entry_path)
):
if entry_path.endswith(".jpeg"):
entry_paths.append(entry_path)
return sorted(entry_paths, key=self.extract_info_from_filename)
def get_metadata(self, path: str) -> None:
"""
Extracts metadata from a JSON file and stores it in the metadata dictionary.
"""
with open(path, newline="") as jsonfile:
json_dict = json.load(jsonfile)
path = path.split("/")[-1]
path = path.replace(".json", "")
self.metadata_dict[path] = json_dict["cameraFrames"]
def extract_info_from_filename(self, filename: str) -> (str, int):
"""
Extracts information from the filename.
"""
filename_without_ext = filename.replace(".jpeg", "")
segments = filename_without_ext.split("/")
info = segments[-1]
try:
number = int(info.split("_")[-1])
except ValueError:
print("Could not extract number from filename: ", filename)
return None, None
info = "_".join(info.split("_")[:-1])
return info, number
def get_heatmap_gt(
self, x: int, y: int, height: int, width: int, square_size: int = 33
) -> torch.Tensor:
"""
Returns 2D heatmap ground truth for the given x and y coordinates,
with the given square size.
"""
x_map, y_map = x, y
heatmap = torch.zeros((height, width))
half_size = square_size // 2
# Calculate the valid range for the square
start_x = max(0, x_map - half_size)
end_x = min(
width, x_map + half_size + 1
) # +1 to include the end_x in the square
start_y = max(0, y_map - half_size)
end_y = min(
height, y_map + half_size + 1
) # +1 to include the end_y in the square
heatmap[start_y:end_y, start_x:end_x] = 1
return heatmap
if __name__ == "__main__":
from .. import logger # overwrite the logger

View File

@ -0,0 +1,163 @@
#! /usr/bin/env python3
import os
import mercantile
import rasterio
import numpy as np
import random
import warnings
from rasterio.errors import NotGeoreferencedWarning
from rasterio.io import MemoryFile
from rasterio.transform import from_bounds
from rasterio.merge import merge
from PIL import Image
import gc
import random
from osgeo import gdal, osr
from affine import Affine
def get_5x5_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]:
neighbors = []
for main_neighbour in mercantile.neighbors(tile):
for sub_neighbour in mercantile.neighbors(main_neighbour):
if sub_neighbour not in neighbors:
neighbors.append(sub_neighbour)
return neighbors
def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir: str) -> (np.ndarray, dict):
"""
Returns a TIFF map of the given tile using GDAL.
"""
tile_data = []
neighbors = get_5x5_neighbors(tile)
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
for neighbor in neighbors:
west, south, east, north = mercantile.bounds(neighbor)
tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg"
if not os.path.exists(tile_path):
raise FileNotFoundError(f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found.")
img = Image.open(tile_path)
img_array = np.array(img)
# Create an in-memory GDAL dataset
mem_driver = gdal.GetDriverByName('MEM')
dataset = mem_driver.Create('', img_array.shape[1], img_array.shape[0], 3, gdal.GDT_Byte)
for i in range(3):
dataset.GetRasterBand(i + 1).WriteArray(img_array[:, :, i])
# Set GeoTransform and Projection
geotransform = (west, (east - west) / img_array.shape[1], 0, north, 0, -(north - south) / img_array.shape[0])
dataset.SetGeoTransform(geotransform)
srs = osr.SpatialReference()
srs.ImportFromEPSG(3857)
dataset.SetProjection(srs.ExportToWkt())
tile_data.append(dataset)
# Merge tiles using GDAL
vrt_options = gdal.BuildVRTOptions()
vrt = gdal.BuildVRT('', [td for td in tile_data], options=vrt_options)
mosaic = vrt.ReadAsArray()
# Get metadata
out_trans = vrt.GetGeoTransform()
out_crs = vrt.GetProjection()
out_trans = Affine.from_gdal(*out_trans)
out_meta = {
"driver": "GTiff",
"height": mosaic.shape[1],
"width": mosaic.shape[2],
"transform": out_trans,
"crs": out_crs,
}
# Clean up
for td in tile_data:
td.FlushCache()
vrt = None
gc.collect()
return mosaic, out_meta
def get_random_tiff_patch(
lat: float,
lon: float,
patch_width: int,
patch_height: int,
sat_year: str,
satellite_dataset_dir: str = "/mnt/drive/satellite_dataset",
) -> (np.ndarray, int, int, int, int, rasterio.transform.Affine):
"""
Returns a random patch from the satellite image.
"""
tile = get_tile_from_coord(lat, lon, 17)
mosaic, out_meta = get_tiff_map(tile, sat_year, satellite_dataset_dir)
transform = out_meta["transform"]
del out_meta
x_pixel, y_pixel = geo_to_pixel_coordinates(lat, lon, transform)
# TODO
# Temporal constant, replace with a better solution
KS = 120
x_offset_range = [
x_pixel - patch_width + KS + 1,
x_pixel - KS - 1,
]
y_offset_range = [
y_pixel - patch_height + KS + 1,
y_pixel - KS - 1,
]
# Randomly select an offset within the valid range
x_offset = random.randint(*x_offset_range)
y_offset = random.randint(*y_offset_range)
x_offset = np.clip(x_offset, 0, mosaic.shape[-1] - patch_width)
y_offset = np.clip(y_offset, 0, mosaic.shape[-2] - patch_height)
# Update x, y to reflect the clamping of x_offset and y_offset
x, y = x_pixel - x_offset, y_pixel - y_offset
patch = mosaic[
:, y_offset : y_offset + patch_height, x_offset : x_offset + patch_width
]
patch_transform = rasterio.transform.Affine(
transform.a,
transform.b,
transform.c + x_offset * transform.a + y_offset * transform.b,
transform.d,
transform.e,
transform.f + x_offset * transform.d + y_offset * transform.e,
)
gc.collect()
return patch, x, y, x_offset, y_offset, patch_transform
def get_tile_from_coord(
lat: float, lng: float, zoom_level: int
) -> mercantile.Tile:
"""
Returns the tile containing the given coordinates.
"""
tile = mercantile.tile(lng, lat, zoom_level)
return tile
def geo_to_pixel_coordinates(
lat: float, lon: float, transform: rasterio.transform.Affine
) -> (int, int):
"""
Converts a pair of (lat, lon) coordinates to pixel coordinates.
"""
x_pixel, y_pixel = ~transform * (lon, lat)
return round(x_pixel), round(y_pixel)

View File

@ -0,0 +1,109 @@
#! /usr/bin/env python3
import os
import mercantile
import rasterio
import numpy as np
import random
import warnings
from rasterio.errors import NotGeoreferencedWarning
from rasterio.io import MemoryFile
from rasterio.transform import from_bounds
from rasterio.merge import merge
from PIL import Image
import gc
def get_3x3_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]:
neighbors = []
for neighbour in mercantile.neighbors(tile):
if neighbour not in neighbors:
neighbors.append(neighbour)
neighbors.append(tile)
return neighbors
def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir:str) -> (np.ndarray, dict):
"""
Returns a TIFF map of the given tile.
"""
tile_data = []
neighbors = get_3x3_neighbors(tile)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)
for neighbor in neighbors:
west, south, east, north = mercantile.bounds(neighbor)
tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg"
if not os.path.exists(tile_path):
raise FileNotFoundError(
f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found."
)
with Image.open(tile_path) as img:
width, height = img.size
memfile = MemoryFile()
with memfile.open(
driver="GTiff",
height=height,
width=width,
count=3,
dtype="uint8",
crs="EPSG:3857",
transform=from_bounds(west, south, east, north, width, height),
) as dataset:
data = rasterio.open(tile_path).read()
dataset.write(data)
tile_data.append(memfile.open())
memfile.close()
mosaic, out_trans = merge(tile_data)
out_meta = tile_data[0].meta.copy()
out_meta.update(
{
"driver": "GTiff",
"height": mosaic.shape[1],
"width": mosaic.shape[2],
"transform": out_trans,
"crs": "EPSG:3857",
}
)
# Clean up MemoryFile instances to free up memory
for td in tile_data:
td.close()
del neighbors
del tile_data
gc.collect()
return mosaic, out_meta
def get_random_tiff_patch(
lat: float,
lon: float,
satellite_dataset_dir: str,
) -> (np.ndarray):
"""
Returns a random patch from the satellite image.
"""
tile = get_tile_from_coord(lat, lon, 17)
sat_years = ["2023", "2021", "2019", "2016"]
# Randomly select a satellite year
sat_year = random.choice(sat_years)
mosaic, _ = get_tiff_map(tile, sat_year, satellite_dataset_dir)
return mosaic
def get_tile_from_coord(
lat: float, lng: float, zoom_level: int
) -> mercantile.Tile:
"""
Returns the tile containing the given coordinates.
"""
tile = mercantile.tile(lng, lat, zoom_level)
return tile

View File

@ -0,0 +1,297 @@
"""
Simply load images from a folder or nested folders (does not have any split),
and apply homographic adaptations to it. Yields an image pair without border
artifacts.
"""
import argparse
import logging
import shutil
import tarfile
from pathlib import Path
import mercantile
import rasterio
import cv2
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import torch
from omegaconf import OmegaConf
from tqdm import tqdm
from PIL import Image
from ..geometry.homography import (
compute_homography,
sample_homography_corners,
warp_points,
)
from ..models.cache_loader import CacheLoader, pad_local_features
from ..settings import DATA_PATH
from ..utils.image import read_image
from ..utils.tools import fork_rng
from ..visualization.viz2d import plot_image_grid
from .augmentations import IdentityAugmentation, augmentations
from .base_dataset import BaseDataset
from .sat_utils import get_random_tiff_patch
logger = logging.getLogger(__name__)
def sample_homography(img, conf: dict, size: list):
data = {}
H, _, coords, _ = sample_homography_corners(img.shape[:2][::-1], **conf)
data["image"] = cv2.warpPerspective(img, H, tuple(size))
data["H_"] = H.astype(np.float32)
data["coords"] = coords.astype(np.float32)
data["image_size"] = np.array(size, dtype=np.float32)
return data
class SatelliteDataset(BaseDataset):
default_conf = {
# image search
"data_dir": "revisitop1m", # the top-level directory
"image_dir": "jpg/", # the subdirectory with the images
"image_list": "revisitop1m.txt", # optional: list or filename of list
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
"metadata_dir": None,
# splits
"train_size": 100,
"val_size": 10,
"shuffle_seed": 0, # or None to skip
# image loading
"grayscale": False,
"triplet": False,
"right_only": False, # image0 is orig (rescaled), image1 is right
"reseed": False,
"homography": {
"difficulty": 0.8,
"translation": 1.0,
"max_angle": 60,
"n_angles": 10,
"patch_shape": [640, 480],
"min_convexity": 0.05,
},
"photometric": {
"name": "dark",
"p": 0.75,
# 'difficulty': 1.0, # currently unused
},
# feature loading
"load_features": {
"do": False,
**CacheLoader.default_conf,
"collate": False,
"thresh": 0.0,
"max_num_keypoints": -1,
"force_num_keypoints": False,
},
}
def _init(self, conf):
data_dir = conf.data_dir
coordinates_file = conf.metadata_dir
with open(coordinates_file, "r") as cf:
coordinates = cf.readlines()
parsed_coordinates = []
for coordinate in coordinates:
lat_part, lon_part = coordinate.split(',')
lat = float(lat_part.split(':')[-1].strip())
lon = float(lon_part.split(':')[-1].strip())
parsed_coordinates.append((lat, lon))
# Split into train and val 20% val
train_images = parsed_coordinates[:int(len(parsed_coordinates) * 0.8)]
val_images = parsed_coordinates[int(len(parsed_coordinates) * 0.8):]
self.images = {"train": train_images, "val": val_images}
def get_dataset(self, split):
return _Dataset(self.conf, self.images[split], split)
class _Dataset(torch.utils.data.Dataset):
def __init__(self, conf, image_names, split):
self.conf = conf
self.split = split
self.image_names = np.array(image_names)
self.image_dir = DATA_PATH / conf.data_dir / conf.image_dir
aug_conf = conf.photometric
aug_name = aug_conf.name
assert (
aug_name in augmentations.keys()
), f'{aug_name} not in {" ".join(augmentations.keys())}'
#self.left_augment = (
# IdentityAugmentation() if conf.right_only else self.photo_augment
#)
self.photo_augment = augmentations[aug_name](aug_conf)
self.left_augment = augmentations[aug_name](aug_conf)
self.img_to_tensor = IdentityAugmentation()
if conf.load_features.do:
self.feature_loader = CacheLoader(conf.load_features)
def _transform_keypoints(self, features, data):
"""Transform keypoints by a homography, threshold them,
and potentially keep only the best ones."""
# Warp points
features["keypoints"] = warp_points(
features["keypoints"], data["H_"], inverse=False
)
h, w = data["image"].shape[1:3]
valid = (
(features["keypoints"][:, 0] >= 0)
& (features["keypoints"][:, 0] <= w - 1)
& (features["keypoints"][:, 1] >= 0)
& (features["keypoints"][:, 1] <= h - 1)
)
features["keypoints"] = features["keypoints"][valid]
# Threshold
if self.conf.load_features.thresh > 0:
valid = features["keypoint_scores"] >= self.conf.load_features.thresh
features = {k: v[valid] for k, v in features.items()}
# Get the top keypoints and pad
n = self.conf.load_features.max_num_keypoints
if n > -1:
inds = np.argsort(-features["keypoint_scores"])
features = {k: v[inds[:n]] for k, v in features.items()}
if self.conf.load_features.force_num_keypoints:
features = pad_local_features(
features, self.conf.load_features.max_num_keypoints
)
return features
def __getitem__(self, idx):
if self.conf.reseed:
with fork_rng(self.conf.seed + idx, False):
return self.getitem(idx)
else:
return self.getitem(idx)
def _read_view(self, img, H_conf, ps, left=False):
data = sample_homography(img, H_conf, ps)
if left:
data["image"] = self.left_augment(data["image"], return_tensor=True)
else:
data["image"] = self.photo_augment(data["image"], return_tensor=True)
gs = data["image"].new_tensor([0.299, 0.587, 0.114]).view(3, 1, 1)
if self.conf.grayscale:
data["image"] = (data["image"] * gs).sum(0, keepdim=True)
if self.conf.load_features.do:
features = self.feature_loader({k: [v] for k, v in data.items()})
features = self._transform_keypoints(features, data)
data["cache"] = features
return data
def getitem(self, idx):
# Generate a list of coordinates, based on the coordinate do the split
lat, lon = self.image_names[idx]
img = get_random_tiff_patch(lat, lon, self.conf.data_dir)
if img is None:
logging.warning("Image %f %f could not be read.", lat, lon)
img = np.zeros((1024, 1024) + (() if self.conf.grayscale else (3,)))
img = img.transpose(1,2,0)
img = img.astype(np.float32) / 255.0
#write_status = cv2.imwrite(f"/mnt/drive/{str(idx)}.jpg", img)
#if write_status == True:
# print("Writing success")
#else:
# print("fuck")
size = img.shape[:2][::-1]
ps = self.conf.homography.patch_shape
left_conf = omegaconf.OmegaConf.to_container(self.conf.homography)
if self.conf.right_only:
left_conf["difficulty"] = 0.0
data0 = self._read_view(img, left_conf, ps, left=True)
data1 = self._read_view(img, self.conf.homography, ps, left=False)
print("Data0 homography_matrix", data0["H_"])
print("Data1 homography matrix", data1["H_"])
# image_1 = data0["image"]
# image_1 = image_1.numpy()
# image_1 = image_1.transpose(1,2,0)
# image_1 = np.uint8(image_1 * 255)
#
# image_2 = data1["image"]
# image_2 = image_2.numpy()
# image_2 = image_2.transpose(1,2,0)
# image_2 = np.uint8(image_2 * 255)
#
# PIL_IMAGE = Image.fromarray(image_1)
# PIL_IMAGE.save(f"/mnt/drive/{str(idx)}.jpg")
# PIL_IMAGE = Image.fromarray(image_2)
# PIL_IMAGE.save(f"/mnt/drive/{str(idx)}-s.jpg")
#
# exit()
#
H = compute_homography(data0["coords"], data1["coords"], [1, 1])
print("COmputed homography", H)
name = f"lat{lat}_lon{lon}"
data = {
"name": name,
"original_image_size": np.array(size),
"H_0to1": H.astype(np.float32),
"idx": idx,
"view0": data0,
"view1": data1,
}
return data
def __len__(self):
return len(self.image_names)
def visualize(args):
conf = {
"batch_size": 1,
"num_workers": 1,
"prefetch_factor": 1,
}
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
dataset = SatelliteDataset(conf)
loader = dataset.get_data_loader("train")
logger.info("The dataset has %d elements.", len(loader))
with fork_rng(seed=dataset.conf.seed):
images = []
for _, data in zip(range(args.num_items), loader):
images.append(
(data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2))
)
plot_image_grid(images, dpi=args.dpi)
plt.tight_layout()
plt.imsave("implot.png")
#plt.show()
if __name__ == "__main__":
from .. import logger # overwrite the logger
parser = argparse.ArgumentParser()
parser.add_argument("--num_items", type=int, default=8)
parser.add_argument("--dpi", type=int, default=100)
parser.add_argument("dotlist", nargs="*")
args = parser.parse_intermixed_args()
visualize(args)

View File

@ -53,6 +53,7 @@ def sample_homography_corners(
min_pts1 = create_center_patch(shape, (pwidth, pheight))
full = create_center_patch(shape)
pts2 = create_center_patch(patch_shape)
scale = min_pts1 - full
found_valid = False
cnt = -1
@ -68,7 +69,9 @@ def sample_homography_corners(
# Rotation
if n_angles > 0 and difficulty > 0:
angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles)
#angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles)
# In our case the difficulty parameter should not affect the rotation
angles = np.linspace(-max_angle, max_angle, n_angles)
rng.shuffle(angles)
rng.shuffle(angles)
angles = np.concatenate([[0.0], angles], axis=0)

View File

@ -12,6 +12,7 @@ import signal
from collections import defaultdict
from pathlib import Path
from pydoc import locate
import gc
import numpy as np
import torch
@ -89,6 +90,7 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
for i, data in enumerate(
tqdm(loader, desc="Evaluation", ascii=True, disable=not pbar)
):
data = batch_to_device(data, device, non_blocking=True)
with torch.no_grad():
pred = model(data)
@ -102,7 +104,6 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
pred[v["predictions"]],
mask=pred[v["mask"]] if "mask" in v.keys() else None,
)
del pred, data
numbers = {**metrics, **{"loss/" + k: v for k, v in losses.items()}}
for k, v in numbers.items():
if k not in results:
@ -118,7 +119,9 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
if k in conf.recall_metrics.keys():
q = conf.recall_metrics[k]
results[k + f"_recall{int(q)}"].update(v)
del numbers
del pred, data, losses, metrics
gc.collect()
results = {k: results[k].compute() for k in results}
return results, {k: v.compute() for k, v in pr_metrics.items()}, figures
@ -257,6 +260,7 @@ def training(rank, conf, output_dir, args):
dataset = get_dataset(data_conf.name)(data_conf)
# Optionally load a different validation dataset than the training one
val_data_conf = conf.get("data_val", None)
if val_data_conf is None:
@ -408,7 +412,9 @@ def training(rank, conf, output_dir, args):
getattr(loader.dataset, conf.train.dataset_callback_fn)(
conf.train.seed + epoch
)
for it, data in enumerate(train_loader):
tot_it = (len(train_loader) * epoch + it) * (
args.n_gpus if args.distributed else 1
)
@ -427,10 +433,17 @@ def training(rank, conf, output_dir, args):
loss = torch.mean(losses["total"])
if torch.isnan(loss).any():
print(f"Detected NAN, skipping iteration {it}")
print("name", data["name"])
print("loss", loss)
print("losses", losses)
print("data", data)
print("pred", pred)
raise RuntimeError("Detected NAN in training.")
del pred, data, loss, losses
continue
do_backward = loss.requires_grad
if args.distributed:
do_backward = torch.tensor(do_backward).float().to(device)
torch.distributed.all_reduce(
@ -469,7 +482,6 @@ def training(rank, conf, output_dir, args):
else:
if rank == 0:
logger.warning(f"Skip iteration {it} due to detach.")
if args.profile:
prof.step()
@ -508,8 +520,11 @@ def training(rank, conf, output_dir, args):
norm = torch.norm(param.grad.detach(), 2)
grad_txt += f"{name} {norm.item():.3f} \n"
writer.add_text("grad/summary", grad_txt, tot_n_samples)
del pred, data, loss, losses
pred.clear()
data.clear()
del pred, data, loss, losses
gc.collect()
# Run validation
if (
(
@ -529,6 +544,7 @@ def training(rank, conf, output_dir, args):
pbar=(rank == -1),
)
if rank == 0:
str_results = [
f"{k} {v:.3E}"
@ -569,6 +585,10 @@ def training(rank, conf, output_dir, args):
f"figures/{i}_{name}", fig, tot_n_samples
)
torch.cuda.empty_cache() # should be cleared at the first iter
str_results.clear()
pr_metrics.clear()
del str_results, figures, pr_metrics
if (tot_it % conf.train.save_every_iter == 0 and tot_it > 0) and rank == 0:
if results is None:
@ -622,7 +642,7 @@ def training(rank, conf, output_dir, args):
writer.close()
def main_worker(rank, conf, output_dir, args):
def main_worker(rank, conf, output_dir, aprgs):
if rank == 0:
with capture_outputs(output_dir / "log.txt"):
training(rank, conf, output_dir, args)

View File

@ -33,6 +33,15 @@ def batch_to_device(batch, device, non_blocking=True):
return map_tensor(batch, _func)
def detach_tensors(batch):
"""
Detach all tensors in a batch recursively.
This is useful for detaching tensors from the computational graph to free up memory.
"""
def _detach(tensor):
return tensor.detach()
return map_tensor(batch, _detach)
def rbd(data: dict) -> dict:
"""Remove batch dimension from elements in data"""

View File

@ -67,8 +67,12 @@ class MedianMetric:
else:
return np.nanmedian(self._elements)
def __len__(self):
return len(self._elements)
class PRMetric:
def __init__(self):
self.labels = []
self.predictions = []

View File

@ -338,6 +338,14 @@ class SuperPoint(BaseModel):
for k, d in zip(keypoints, dense_desc)
]
if isinstance(desc, list):
if all(isinstance(d, torch.Tensor) for d in desc):
# If desc is a list of tensors
desc = torch.stack(desc)
else:
# If desc is a list of non-tensor elements
desc = torch.stack([torch.tensor(d) for d in desc])
pred = {
"keypoints": keypoints + 0.5,
"keypoint_scores": scores,

12
run.sh Normal file
View File

@ -0,0 +1,12 @@
#python -m gluefactory.train sp+lg_homography \
# --conf gluefactory/configs/superpoint+lightglue_homography.yaml \
# data.batch_size=32 # for 1x 1080 GPU
#python -m gluefactory.train satellites_aug_both_40 \
# --conf gluefactory/configs/satellites3.yaml \
# data.batch_size=4 # for 1x 1080 GPU
python -m gluefactory.train drone2sat_v1 \
--conf gluefactory/configs/drone2sat_v1.yaml \
data.batch_size=6 # for 1x 1080 GPU

22
stats.sh Normal file
View File

@ -0,0 +1,22 @@
#!/bin/bash
LOGFILE="/var/log/h/system_stats.log"
echo "Logging CPU, Memory, Swap, Disk, and Network usage to $LOGFILE"
while true; do
echo "----- $(date) -----" >> $LOGFILE
echo "CPU Usage:" >> $LOGFILE
mpstat -P ALL 1 1 >> $LOGFILE
echo "Disk Usage:" >> $LOGFILE
iostat >> $LOGFILE
echo "Memory and Swap Usage:" >> $LOGFILE
free -m >> $LOGFILE
echo "Network Usage:" >> $LOGFILE
ifstat 1 1 >> $LOGFILE
sleep 5
done