694 lines
25 KiB
Python
694 lines
25 KiB
Python
"""
|
|
Simply load images from a folder or nested folders (does not have any split),
|
|
and apply homographic adaptations to it. Yields an image pair without border
|
|
artifacts.
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import shutil
|
|
import tarfile
|
|
from pathlib import Path
|
|
import os
|
|
import json
|
|
import gc
|
|
|
|
import cv2
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import omegaconf
|
|
import torch
|
|
from omegaconf import OmegaConf
|
|
from tqdm import tqdm
|
|
from PIL import Image
|
|
from typing import List, Literal
|
|
from torchvision import transforms
|
|
import random
|
|
from torchvision.transforms import functional as F
|
|
|
|
|
|
|
|
|
|
from ..geometry.homography import (
|
|
compute_homography,
|
|
sample_homography_corners,
|
|
warp_points,
|
|
)
|
|
from ..models.cache_loader import CacheLoader, pad_local_features
|
|
from ..settings import DATA_PATH
|
|
from ..utils.image import read_image
|
|
from ..utils.tools import fork_rng
|
|
from ..visualization.viz2d import plot_image_grid
|
|
from .augmentations import IdentityAugmentation, augmentations
|
|
from .base_dataset import BaseDataset
|
|
from .s_utils import get_random_tiff_patch
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class Drone2SatDataset(BaseDataset):
|
|
default_conf = {
|
|
# image search
|
|
"data_dir": "revisitop1m", # the top-level directory
|
|
"image_dir": "jpg/", # the subdirectory with the images
|
|
"image_list": "revisitop1m.txt", # optional: list or filename of list
|
|
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
|
|
"metadata_dir": None,
|
|
"uav_dataset_dir": None,
|
|
"satellite_dataset_dir": None,
|
|
# splits
|
|
"train_size": 100,
|
|
"val_size": 10,
|
|
"shuffle_seed": 0, # or None to skip
|
|
# image loading
|
|
"grayscale": False,
|
|
"triplet": False,
|
|
"right_only": False, # image0 is orig (rescaled), image1 is right
|
|
"reseed": False,
|
|
"homography": {
|
|
"difficulty": 0.8,
|
|
"translation": 1.0,
|
|
"max_angle": 60,
|
|
"n_angles": 10,
|
|
"patch_shape": [640, 480],
|
|
"min_convexity": 0.05,
|
|
},
|
|
"photometric": {
|
|
"name": "dark",
|
|
"p": 0.75,
|
|
# 'difficulty': 1.0, # currently unused
|
|
},
|
|
# feature loading
|
|
"load_features": {
|
|
"do": False,
|
|
**CacheLoader.default_conf,
|
|
"collate": False,
|
|
"thresh": 0.0,
|
|
"max_num_keypoints": -1,
|
|
"force_num_keypoints": False,
|
|
},
|
|
# Other geolocaliztion parameters
|
|
"geo_dataset" :{
|
|
"uav_dataset_dir": None,
|
|
"satellite_dataset_dir": None,
|
|
"misslabeled_images_path": None,
|
|
"sat_zoom_level": 17,
|
|
"uav_patch_width": 400,
|
|
"uav_patch_height": 400,
|
|
"sat_patch_width": 400,
|
|
"sat_patch_height": 400,
|
|
"heatmap_kernel_size": 33,
|
|
"test_from_train_ratio": 0.0,
|
|
"transform_mean": [0.485, 0.456, 0.406],
|
|
"transform_std": [0.229, 0.224, 0.225],
|
|
"sat_availaible_years": ["2023", "2021", "2019", "2016"],
|
|
"max_rotation_angle": 10,
|
|
"uav_image_scale": 1,
|
|
"use_heatmap": False,
|
|
}
|
|
}
|
|
|
|
def _init(self, conf):
|
|
self.images = {"train": "train", "val": "val"}
|
|
|
|
def get_dataset(self, split):
|
|
if split == "val":
|
|
return GeoLocalizationDataset(
|
|
uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir,
|
|
satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir,
|
|
misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path,
|
|
dataset="test",
|
|
sat_zoom_level=self.conf.geo_dataset.sat_zoom_level,
|
|
uav_patch_width=self.conf.geo_dataset.uav_patch_width,
|
|
uav_patch_height=self.conf.geo_dataset.uav_patch_height,
|
|
sat_patch_width=self.conf.geo_dataset.sat_patch_width,
|
|
sat_patch_height=self.conf.geo_dataset.sat_patch_height,
|
|
heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size,
|
|
test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio,
|
|
transform_mean=self.conf.geo_dataset.transform_mean,
|
|
transform_std=self.conf.geo_dataset.transform_std,
|
|
sat_available_years=self.conf.geo_dataset.sat_availaible_years,
|
|
max_rotation_angle=self.conf.geo_dataset.max_rotation_angle,
|
|
uav_image_scale=self.conf.geo_dataset.uav_image_scale,
|
|
use_heatmap=self.conf.geo_dataset.use_heatmap,
|
|
subset_size=self.conf.val_size,
|
|
)
|
|
elif split == "train":
|
|
return GeoLocalizationDataset(
|
|
uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir,
|
|
satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir,
|
|
misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path,
|
|
dataset="train",
|
|
sat_zoom_level=self.conf.geo_dataset.sat_zoom_level,
|
|
uav_patch_width=self.conf.geo_dataset.uav_patch_width,
|
|
uav_patch_height=self.conf.geo_dataset.uav_patch_height,
|
|
sat_patch_width=self.conf.geo_dataset.sat_patch_width,
|
|
sat_patch_height=self.conf.geo_dataset.sat_patch_height,
|
|
heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size,
|
|
test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio,
|
|
transform_mean=self.conf.geo_dataset.transform_mean,
|
|
transform_std=self.conf.geo_dataset.transform_std,
|
|
sat_available_years=self.conf.geo_dataset.sat_availaible_years,
|
|
max_rotation_angle=self.conf.geo_dataset.max_rotation_angle,
|
|
uav_image_scale=self.conf.geo_dataset.uav_image_scale,
|
|
use_heatmap=self.conf.geo_dataset.use_heatmap,
|
|
subset_size=self.conf.train_size,
|
|
)
|
|
|
|
|
|
class GeoLocalizationDataset(torch.utils.data.Dataset):
|
|
def __init__(
|
|
self,
|
|
uav_dataset_dir: str,
|
|
satellite_dataset_dir: str,
|
|
misslabeled_images_path: str,
|
|
dataset: Literal["train", "test"],
|
|
sat_zoom_level: int = 16,
|
|
uav_patch_width: int = 128,
|
|
uav_patch_height: int = 128,
|
|
sat_patch_width: int = 400,
|
|
sat_patch_height: int = 400,
|
|
heatmap_kernel_size: int = 33,
|
|
test_from_train_ratio: float = 0.0,
|
|
transform_mean: List[float] = [0.485, 0.456, 0.406],
|
|
transform_std: List[float] = [0.229, 0.224, 0.225],
|
|
sat_available_years: List[str] = ["2023", "2021", "2019", "2016"],
|
|
max_rotation_angle: int = 10,
|
|
uav_image_scale: float = 1,
|
|
use_heatmap: bool = True,
|
|
subset_size: int = None,
|
|
):
|
|
self.uav_dataset_dir = uav_dataset_dir
|
|
self.satellite_dataset_dir = satellite_dataset_dir
|
|
self.dataset = dataset
|
|
self.sat_zoom_level = sat_zoom_level
|
|
self.uav_patch_width = uav_patch_width
|
|
self.uav_patch_height = uav_patch_height
|
|
self.heatmap_kernel_size = heatmap_kernel_size
|
|
self.test_from_train_ratio = test_from_train_ratio
|
|
self.transform_mean = transform_mean
|
|
self.transform_std = transform_std
|
|
self.misslabeled_images_path = misslabeled_images_path
|
|
self.metadata_dict = {}
|
|
self.max_rotation_angle = max_rotation_angle
|
|
self.total_uav_samples = self.count_total_uav_samples()
|
|
self.misslabelled_images = self.read_misslabelled_images(self.misslabeled_images_path)
|
|
self.entry_paths = self.get_entry_paths(self.uav_dataset_dir)
|
|
self.cleanup_misslabelled_images()
|
|
self.transforms = transforms.Compose(
|
|
[
|
|
transforms.ToTensor(),
|
|
#transforms.Normalize(self.transform_mean, self.transform_std),
|
|
]
|
|
)
|
|
self.sat_available_years = sat_available_years
|
|
self.uav_image_scale = uav_image_scale
|
|
self.use_heatmap = use_heatmap
|
|
self.sat_patch_width = sat_patch_width
|
|
self.sat_patch_height = sat_patch_height
|
|
self.subset_size = subset_size
|
|
|
|
self.inverse_transforms = transforms.Compose(
|
|
[
|
|
transforms.Normalize(
|
|
mean=[
|
|
-m / s for m, s in zip(self.transform_mean, self.transform_std)
|
|
],
|
|
std=[1 / s for s in self.transform_std],
|
|
),
|
|
transforms.ToPILImage(),
|
|
]
|
|
)
|
|
|
|
if self.subset_size != -1:
|
|
ssize = int(self.subset_size / len(self.sat_available_years))
|
|
ssize = min(ssize, len(self.entry_paths))
|
|
self.entry_paths = self.entry_paths[:ssize]
|
|
|
|
|
|
def __len__(self) -> int:
|
|
return (
|
|
len(self.entry_paths)
|
|
* len(self.sat_available_years)
|
|
)
|
|
|
|
def read_misslabelled_images(
|
|
self, path: str = "misslabels/misslabeled.txt"
|
|
) -> List[str]:
|
|
with open(path, "r") as f:
|
|
lines = f.readlines()
|
|
return [line.strip() for line in lines]
|
|
|
|
def cleanup_misslabelled_images(self) -> None:
|
|
indices_to_delete = []
|
|
|
|
for image in self.misslabelled_images:
|
|
for image_path in self.entry_paths:
|
|
if image in image_path:
|
|
index = self.entry_paths.index(image_path)
|
|
indices_to_delete.append(index)
|
|
break
|
|
|
|
sorted_tuples = sorted(indices_to_delete, reverse=True)
|
|
|
|
for index in sorted_tuples:
|
|
self.entry_paths.pop(index)
|
|
|
|
def __getitem__(self, idx) -> dict:
|
|
"""
|
|
Retrieves a sample given its index, returning the preprocessed UAV
|
|
and satellite images, along with their associated heatmap and metadata.
|
|
"""
|
|
|
|
image_path_index = idx // (
|
|
len(self.sat_available_years)
|
|
)
|
|
|
|
sat_year = self.sat_available_years[idx % len(self.sat_available_years)]
|
|
rot_angle = random.randint(-self.max_rotation_angle, self.max_rotation_angle)
|
|
|
|
image_path = self.entry_paths[image_path_index]
|
|
uav_image = Image.open(image_path).convert("RGB") # Ensure 3-channel image
|
|
|
|
original_uav_image_width = uav_image.width
|
|
original_uav_image_height = uav_image.height
|
|
|
|
lookup_str, file_number = self.extract_info_from_filename(image_path)
|
|
img_info = self.metadata_dict[lookup_str][file_number]
|
|
|
|
lat, lon = (
|
|
img_info["coordinate"]["latitude"],
|
|
img_info["coordinate"]["longitude"],
|
|
)
|
|
|
|
fov_vertical = img_info["fovVertical"]
|
|
|
|
try:
|
|
agl_altitude = float(image_path.split("/")[-1].split("_")[2].split("m")[0])
|
|
except IndexError:
|
|
agl_altitude = 150.0
|
|
warnings.warn(
|
|
"Could not extract AGL altitude from filename, using default value of 150m."
|
|
)
|
|
|
|
(
|
|
satellite_patch,
|
|
x_sat,
|
|
y_sat,
|
|
x_offset,
|
|
y_offset,
|
|
patch_transform,
|
|
) = get_random_tiff_patch(
|
|
lat, lon, self.sat_patch_width, self.sat_patch_height, sat_year, self.satellite_dataset_dir
|
|
)
|
|
|
|
# Rotate crop center and transform image
|
|
h = np.ceil(uav_image.height // self.uav_image_scale).astype(int)
|
|
w = np.ceil(uav_image.width // self.uav_image_scale).astype(int)
|
|
|
|
uav_image = F.rotate(uav_image, rot_angle)
|
|
uav_image = F.resize(uav_image, [h, w])
|
|
uav_image = F.center_crop(
|
|
uav_image, (self.uav_patch_height, self.uav_patch_width)
|
|
)
|
|
uav_image = self.transforms(uav_image)
|
|
|
|
satellite_patch = satellite_patch.transpose(1, 2, 0)
|
|
satellite_patch_pytorch = self.transforms(satellite_patch)
|
|
del satellite_patch
|
|
|
|
if self.use_heatmap:
|
|
heatmap = self.get_heatmap_gt(
|
|
x_sat,
|
|
y_sat,
|
|
satellite_patch.shape[1],
|
|
satellite_patch.shape[2],
|
|
self.heatmap_kernel_size,
|
|
)
|
|
|
|
cropped_uav_image_width = self.calculate_cropped_uav_image_width(
|
|
fov_vertical,
|
|
original_uav_image_width,
|
|
original_uav_image_height,
|
|
self.uav_patch_width,
|
|
self.uav_patch_height,
|
|
agl_altitude,
|
|
)
|
|
|
|
satellite_tile_width = self.calculate_cropped_sat_image_width(
|
|
lat, self.sat_patch_width, patch_transform
|
|
)
|
|
|
|
scale_factor = cropped_uav_image_width / satellite_tile_width
|
|
scale_factor *= 10
|
|
|
|
homography_matrix = self.compute_homography(
|
|
rot_angle,
|
|
x_sat,
|
|
y_sat,
|
|
self.uav_patch_width,
|
|
self.uav_patch_height,
|
|
scale_factor,
|
|
)
|
|
|
|
if not self.use_heatmap:
|
|
# Sample four points from the UAV image
|
|
points = self.sample_four_points(
|
|
self.uav_patch_width, self.uav_patch_height
|
|
)
|
|
|
|
# Transform the points
|
|
warped_points = self.warp_points(points, homography_matrix)
|
|
#img_info["warped_points_sat"] = warped_points
|
|
#img_info["warped_points_uav"] = points
|
|
|
|
# img_info["cropped_uav_image_width"] = cropped_uav_image_width
|
|
# img_info["satellite_tile_width"] = satellite_tile_width
|
|
# img_info["scale_factor"] = scale_factor
|
|
# img_info["filename"] = image_path
|
|
# img_info["rot_angle"] = rot_angle
|
|
# img_info["x_sat"] = x_sat
|
|
# img_info["y_sat"] = y_sat
|
|
# img_info["x_offset"] = x_offset
|
|
# img_info["y_offset"] = y_offset
|
|
# img_info["patch_transform"] = patch_transform
|
|
# img_info["uav_image_scale"] = self.uav_image_scale
|
|
# img_info["homography_matrix_uav_to_sat"] = homography_matrix
|
|
# img_info["homography_matrix_sat_to_uav"] = np.linalg.inv(homography_matrix)
|
|
# img_info["agl_altitude"] = agl_altitude
|
|
# img_info["original_uav_image_width"] = original_uav_image_width
|
|
# img_info["original_drone_image_height"] = original_uav_image_height
|
|
# img_info["fov_vertical"] = fov_vertical
|
|
#
|
|
|
|
inverse_homography_matrix = np.linalg.inv(homography_matrix)
|
|
uav_image_data = {
|
|
"image": uav_image,
|
|
"H_": homography_matrix,
|
|
"coords": points,
|
|
"image_size": np.array([self.uav_patch_width, self.uav_patch_height]),
|
|
}
|
|
|
|
satellite_patch_data = {
|
|
"image": satellite_patch_pytorch,
|
|
"H_": inverse_homography_matrix,
|
|
"coords": warped_points,
|
|
"image_size": np.array([self.sat_patch_width, self.sat_patch_height]),
|
|
}
|
|
|
|
#hm = self.compute_homography_points(
|
|
# warped_points, points, [self.uav_patch_width, self.uav_patch_height]
|
|
#)
|
|
|
|
hm = self.compute_homography_points(
|
|
points, warped_points, [1.0 , 1.0]
|
|
)
|
|
|
|
if np.array_equal(hm, np.eye(3)):
|
|
print("Singular matrix")
|
|
return self.__getitem__(random.randint(0, len(self) - 1))
|
|
|
|
data = {
|
|
"name": f"{image_path}_{rot_angle}_{sat_year}",
|
|
"original_image_size": np.array(
|
|
[original_uav_image_width, original_uav_image_height]
|
|
),
|
|
"H_0to1": hm.astype(np.float32),
|
|
"view0": uav_image_data,
|
|
"view1": satellite_patch_data,
|
|
"idx": idx
|
|
}
|
|
del img_info
|
|
gc.collect()
|
|
|
|
return data
|
|
|
|
|
|
def calculate_cropped_sat_image_width(self, latitude, patch_width, patch_transform):
|
|
"""
|
|
Computes the width of the satellite image in world coordinate units
|
|
"""
|
|
length_of_degree = 111320 * np.cos(np.radians(latitude))
|
|
scale_x_meters = patch_transform[0] * length_of_degree
|
|
satellite_tile_width = patch_width * scale_x_meters
|
|
return satellite_tile_width
|
|
|
|
def calculate_cropped_uav_image_width(
|
|
self,
|
|
fov_vertical,
|
|
orig_width,
|
|
orig_height,
|
|
crop_width,
|
|
crop_height,
|
|
altitude=150.0,
|
|
):
|
|
"""
|
|
Computes the width of the UAV image in world coordinate units
|
|
"""
|
|
# Convert fov from degrees to radians
|
|
fov_rad = np.radians(fov_vertical)
|
|
|
|
# Calculate the full width of the UAV image
|
|
full_width = 2 * (altitude * np.tan(fov_rad / 2))
|
|
|
|
# Determine the cropping ratio
|
|
crop_ratio_width = crop_width / orig_width
|
|
crop_ratio_height = crop_height / orig_height
|
|
|
|
# Calculate the adjusted horizontal fov
|
|
fov_horizontal = 2 * np.arctan(np.tan(fov_rad / 2) * (orig_width / orig_height))
|
|
adjusted_fov_horizontal = 2 * np.arctan(
|
|
np.tan(fov_horizontal / 2) * crop_ratio_width
|
|
)
|
|
|
|
# Calculate the new full width using the adjusted horizontal fov
|
|
full_width = 2 * (altitude * np.tan(adjusted_fov_horizontal / 2))
|
|
|
|
# Adjust the width according to the crop ratio
|
|
cropped_width = full_width * crop_ratio_width
|
|
|
|
return cropped_width
|
|
|
|
def compute_homography_points(self, pts1_, pts2_, shape):
|
|
"""Compute the homography matrix from 4 point correspondences"""
|
|
# Rescale to actual size
|
|
shape = np.array(shape[::-1], dtype=np.float32) # different convention [y, x]
|
|
pts1 = pts1_ * np.expand_dims(shape, axis=0)
|
|
pts2 = pts2_ * np.expand_dims(shape, axis=0)
|
|
|
|
def ax(p, q):
|
|
return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
|
|
|
|
def ay(p, q):
|
|
return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
|
|
|
|
def flat2mat(H):
|
|
return np.reshape(np.concatenate([H, np.ones_like(H[:, :1])], axis=1), [3, 3])
|
|
|
|
a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0)
|
|
p_mat = np.transpose(
|
|
np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0)
|
|
)
|
|
|
|
try:
|
|
homography = np.transpose(np.linalg.solve(a_mat, p_mat))
|
|
except np.linalg.LinAlgError:
|
|
print("Singular matrix")
|
|
return np.eye(3)
|
|
return flat2mat(homography)
|
|
|
|
|
|
def compute_homography(
|
|
self, rot_angle, x_sat, y_sat, uav_width, uav_height, scale_factor
|
|
):
|
|
# Adjust rot_angle if it's greater than 180 degrees
|
|
if rot_angle > 180:
|
|
rot_angle -= 360
|
|
# Convert rotation angle to radians
|
|
theta = np.radians(rot_angle)
|
|
|
|
# Rotation matrix
|
|
R = np.array(
|
|
[
|
|
[np.cos(theta), -np.sin(theta), 0],
|
|
[np.sin(theta), np.cos(theta), 0],
|
|
[0, 0, 1],
|
|
]
|
|
)
|
|
|
|
# Scale matrix
|
|
S = np.array([[scale_factor, 0, 0], [0, scale_factor, 0], [0, 0, 1]])
|
|
|
|
# Translation matrix to center the UAV image
|
|
T_uav = np.array([[1, 0, -uav_width / 2], [0, 1, -uav_height / 2], [0, 0, 1]])
|
|
|
|
# Translation matrix to move to the satellite image position
|
|
T_sat = np.array([[1, 0, x_sat], [0, 1, y_sat], [0, 0, 1]])
|
|
|
|
# Compute the combined homography matrix
|
|
H = np.dot(T_sat, np.dot(R, np.dot(S, T_uav)))
|
|
|
|
return H
|
|
|
|
def sample_four_points(self, width: int, height: int) -> np.ndarray:
|
|
"""
|
|
Samples four points from the UAV image.
|
|
"""
|
|
PADDING = 50
|
|
CENTER_PADDING = 10
|
|
points = np.array(
|
|
[
|
|
[random.randint(CENTER_PADDING, width - PADDING), random.randint(CENTER_PADDING, height - PADDING)]
|
|
for _ in range(4)
|
|
]
|
|
)
|
|
return points
|
|
|
|
def warp_points(
|
|
self, points: np.ndarray, homography_matrix: np.ndarray
|
|
) -> np.ndarray:
|
|
"""
|
|
Warps the given points using the given homography matrix.
|
|
"""
|
|
points = np.array(points)
|
|
points = np.concatenate([points, np.ones((4, 1))], axis=1)
|
|
points = np.dot(homography_matrix, points.T).T
|
|
points = points[:, :2] / points[:, 2:]
|
|
return points
|
|
|
|
def count_total_uav_samples(self) -> int:
|
|
"""
|
|
Count the total number of uav image samples in the dataset
|
|
(train + test)
|
|
"""
|
|
total_samples = 0
|
|
|
|
for dirpath, dirnames, filenames in os.walk(self.uav_dataset_dir):
|
|
# Skip the test folder
|
|
for filename in filenames:
|
|
if filename.endswith(".jpeg"):
|
|
total_samples += 1
|
|
return total_samples
|
|
|
|
def get_number_of_city_samples(self) -> int:
|
|
"""
|
|
TODO: Count the total number of city samples in the dataset
|
|
"""
|
|
return 11
|
|
|
|
def get_entry_paths(self, directory: str) -> List[str]:
|
|
"""
|
|
Recursively retrieves paths to image and metadata files in the given directory.
|
|
"""
|
|
entry_paths = []
|
|
entries = os.listdir(directory)
|
|
|
|
images_to_take_per_folder = int(
|
|
self.total_uav_samples
|
|
* self.test_from_train_ratio
|
|
/ self.get_number_of_city_samples()
|
|
)
|
|
|
|
for entry in entries:
|
|
entry_path = os.path.join(directory, entry)
|
|
|
|
# If it's a directory, recurse into it
|
|
if os.path.isdir(entry_path):
|
|
entry_paths += self.get_entry_paths(entry_path)
|
|
|
|
# Handle train dataset
|
|
elif (self.dataset == "train" and "Train" in entry_path) or (
|
|
self.dataset == "train"
|
|
and self.test_from_train_ratio > 0
|
|
and "Test" in entry_path
|
|
):
|
|
if entry_path.endswith(".jpeg"):
|
|
_, number = self.extract_info_from_filename(entry_path)
|
|
else:
|
|
number = None
|
|
if entry_path.endswith(".json"):
|
|
self.get_metadata(entry_path)
|
|
if number is None:
|
|
continue
|
|
if (
|
|
number >= images_to_take_per_folder
|
|
): # Only include images beyond the ones taken for test
|
|
if entry_path.endswith(".jpeg"):
|
|
entry_paths.append(entry_path)
|
|
|
|
# Handle test dataset
|
|
elif self.dataset == "test":
|
|
if entry_path.endswith(".jpeg"):
|
|
_, number = self.extract_info_from_filename(entry_path)
|
|
else:
|
|
number = None
|
|
if entry_path.endswith(".json"):
|
|
self.get_metadata(entry_path)
|
|
|
|
if number is None:
|
|
continue
|
|
if (
|
|
("Test" in entry_path and number < images_to_take_per_folder)
|
|
or (number < images_to_take_per_folder and "Train" in entry_path)
|
|
or (self.test_from_train_ratio == 0.0 and "Test" in entry_path)
|
|
):
|
|
if entry_path.endswith(".jpeg"):
|
|
entry_paths.append(entry_path)
|
|
|
|
return sorted(entry_paths, key=self.extract_info_from_filename)
|
|
|
|
def get_metadata(self, path: str) -> None:
|
|
"""
|
|
Extracts metadata from a JSON file and stores it in the metadata dictionary.
|
|
"""
|
|
with open(path, newline="") as jsonfile:
|
|
json_dict = json.load(jsonfile)
|
|
path = path.split("/")[-1]
|
|
path = path.replace(".json", "")
|
|
self.metadata_dict[path] = json_dict["cameraFrames"]
|
|
|
|
def extract_info_from_filename(self, filename: str) -> (str, int):
|
|
"""
|
|
Extracts information from the filename.
|
|
"""
|
|
filename_without_ext = filename.replace(".jpeg", "")
|
|
segments = filename_without_ext.split("/")
|
|
info = segments[-1]
|
|
try:
|
|
number = int(info.split("_")[-1])
|
|
except ValueError:
|
|
print("Could not extract number from filename: ", filename)
|
|
return None, None
|
|
|
|
info = "_".join(info.split("_")[:-1])
|
|
|
|
return info, number
|
|
|
|
def get_heatmap_gt(
|
|
self, x: int, y: int, height: int, width: int, square_size: int = 33
|
|
) -> torch.Tensor:
|
|
"""
|
|
Returns 2D heatmap ground truth for the given x and y coordinates,
|
|
with the given square size.
|
|
"""
|
|
x_map, y_map = x, y
|
|
|
|
heatmap = torch.zeros((height, width))
|
|
|
|
half_size = square_size // 2
|
|
|
|
# Calculate the valid range for the square
|
|
start_x = max(0, x_map - half_size)
|
|
end_x = min(
|
|
width, x_map + half_size + 1
|
|
) # +1 to include the end_x in the square
|
|
start_y = max(0, y_map - half_size)
|
|
end_y = min(
|
|
height, y_map + half_size + 1
|
|
) # +1 to include the end_y in the square
|
|
|
|
heatmap[start_y:end_y, start_x:end_x] = 1
|
|
|
|
return heatmap
|
|
|
|
if __name__ == "__main__":
|
|
from .. import logger # overwrite the logger |