298 lines
9.6 KiB
Python
298 lines
9.6 KiB
Python
|
"""
|
||
|
Simply load images from a folder or nested folders (does not have any split),
|
||
|
and apply homographic adaptations to it. Yields an image pair without border
|
||
|
artifacts.
|
||
|
"""
|
||
|
|
||
|
import argparse
|
||
|
import logging
|
||
|
import shutil
|
||
|
import tarfile
|
||
|
from pathlib import Path
|
||
|
import mercantile
|
||
|
import rasterio
|
||
|
|
||
|
import cv2
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
import omegaconf
|
||
|
import torch
|
||
|
from omegaconf import OmegaConf
|
||
|
from tqdm import tqdm
|
||
|
from PIL import Image
|
||
|
|
||
|
from ..geometry.homography import (
|
||
|
compute_homography,
|
||
|
sample_homography_corners,
|
||
|
warp_points,
|
||
|
)
|
||
|
from ..models.cache_loader import CacheLoader, pad_local_features
|
||
|
from ..settings import DATA_PATH
|
||
|
from ..utils.image import read_image
|
||
|
from ..utils.tools import fork_rng
|
||
|
from ..visualization.viz2d import plot_image_grid
|
||
|
from .augmentations import IdentityAugmentation, augmentations
|
||
|
from .base_dataset import BaseDataset
|
||
|
from .sat_utils import get_random_tiff_patch
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def sample_homography(img, conf: dict, size: list):
|
||
|
data = {}
|
||
|
H, _, coords, _ = sample_homography_corners(img.shape[:2][::-1], **conf)
|
||
|
data["image"] = cv2.warpPerspective(img, H, tuple(size))
|
||
|
data["H_"] = H.astype(np.float32)
|
||
|
data["coords"] = coords.astype(np.float32)
|
||
|
data["image_size"] = np.array(size, dtype=np.float32)
|
||
|
return data
|
||
|
|
||
|
|
||
|
class SatelliteDataset(BaseDataset):
|
||
|
default_conf = {
|
||
|
# image search
|
||
|
"data_dir": "revisitop1m", # the top-level directory
|
||
|
"image_dir": "jpg/", # the subdirectory with the images
|
||
|
"image_list": "revisitop1m.txt", # optional: list or filename of list
|
||
|
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
|
||
|
"metadata_dir": None,
|
||
|
# splits
|
||
|
"train_size": 100,
|
||
|
"val_size": 10,
|
||
|
"shuffle_seed": 0, # or None to skip
|
||
|
# image loading
|
||
|
"grayscale": False,
|
||
|
"triplet": False,
|
||
|
"right_only": False, # image0 is orig (rescaled), image1 is right
|
||
|
"reseed": False,
|
||
|
"homography": {
|
||
|
"difficulty": 0.8,
|
||
|
"translation": 1.0,
|
||
|
"max_angle": 60,
|
||
|
"n_angles": 10,
|
||
|
"patch_shape": [640, 480],
|
||
|
"min_convexity": 0.05,
|
||
|
},
|
||
|
"photometric": {
|
||
|
"name": "dark",
|
||
|
"p": 0.75,
|
||
|
# 'difficulty': 1.0, # currently unused
|
||
|
},
|
||
|
# feature loading
|
||
|
"load_features": {
|
||
|
"do": False,
|
||
|
**CacheLoader.default_conf,
|
||
|
"collate": False,
|
||
|
"thresh": 0.0,
|
||
|
"max_num_keypoints": -1,
|
||
|
"force_num_keypoints": False,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
def _init(self, conf):
|
||
|
data_dir = conf.data_dir
|
||
|
coordinates_file = conf.metadata_dir
|
||
|
|
||
|
with open(coordinates_file, "r") as cf:
|
||
|
coordinates = cf.readlines()
|
||
|
|
||
|
parsed_coordinates = []
|
||
|
|
||
|
for coordinate in coordinates:
|
||
|
lat_part, lon_part = coordinate.split(',')
|
||
|
lat = float(lat_part.split(':')[-1].strip())
|
||
|
lon = float(lon_part.split(':')[-1].strip())
|
||
|
parsed_coordinates.append((lat, lon))
|
||
|
|
||
|
# Split into train and val 20% val
|
||
|
|
||
|
train_images = parsed_coordinates[:int(len(parsed_coordinates) * 0.8)]
|
||
|
val_images = parsed_coordinates[int(len(parsed_coordinates) * 0.8):]
|
||
|
|
||
|
self.images = {"train": train_images, "val": val_images}
|
||
|
|
||
|
def get_dataset(self, split):
|
||
|
return _Dataset(self.conf, self.images[split], split)
|
||
|
|
||
|
|
||
|
class _Dataset(torch.utils.data.Dataset):
|
||
|
def __init__(self, conf, image_names, split):
|
||
|
self.conf = conf
|
||
|
self.split = split
|
||
|
self.image_names = np.array(image_names)
|
||
|
self.image_dir = DATA_PATH / conf.data_dir / conf.image_dir
|
||
|
|
||
|
aug_conf = conf.photometric
|
||
|
aug_name = aug_conf.name
|
||
|
assert (
|
||
|
aug_name in augmentations.keys()
|
||
|
), f'{aug_name} not in {" ".join(augmentations.keys())}'
|
||
|
#self.left_augment = (
|
||
|
# IdentityAugmentation() if conf.right_only else self.photo_augment
|
||
|
#)
|
||
|
self.photo_augment = augmentations[aug_name](aug_conf)
|
||
|
self.left_augment = augmentations[aug_name](aug_conf)
|
||
|
self.img_to_tensor = IdentityAugmentation()
|
||
|
|
||
|
if conf.load_features.do:
|
||
|
self.feature_loader = CacheLoader(conf.load_features)
|
||
|
|
||
|
def _transform_keypoints(self, features, data):
|
||
|
"""Transform keypoints by a homography, threshold them,
|
||
|
and potentially keep only the best ones."""
|
||
|
# Warp points
|
||
|
features["keypoints"] = warp_points(
|
||
|
features["keypoints"], data["H_"], inverse=False
|
||
|
)
|
||
|
h, w = data["image"].shape[1:3]
|
||
|
valid = (
|
||
|
(features["keypoints"][:, 0] >= 0)
|
||
|
& (features["keypoints"][:, 0] <= w - 1)
|
||
|
& (features["keypoints"][:, 1] >= 0)
|
||
|
& (features["keypoints"][:, 1] <= h - 1)
|
||
|
)
|
||
|
features["keypoints"] = features["keypoints"][valid]
|
||
|
|
||
|
# Threshold
|
||
|
if self.conf.load_features.thresh > 0:
|
||
|
valid = features["keypoint_scores"] >= self.conf.load_features.thresh
|
||
|
features = {k: v[valid] for k, v in features.items()}
|
||
|
|
||
|
# Get the top keypoints and pad
|
||
|
n = self.conf.load_features.max_num_keypoints
|
||
|
if n > -1:
|
||
|
inds = np.argsort(-features["keypoint_scores"])
|
||
|
features = {k: v[inds[:n]] for k, v in features.items()}
|
||
|
|
||
|
if self.conf.load_features.force_num_keypoints:
|
||
|
features = pad_local_features(
|
||
|
features, self.conf.load_features.max_num_keypoints
|
||
|
)
|
||
|
|
||
|
return features
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
if self.conf.reseed:
|
||
|
with fork_rng(self.conf.seed + idx, False):
|
||
|
return self.getitem(idx)
|
||
|
else:
|
||
|
return self.getitem(idx)
|
||
|
|
||
|
def _read_view(self, img, H_conf, ps, left=False):
|
||
|
data = sample_homography(img, H_conf, ps)
|
||
|
if left:
|
||
|
data["image"] = self.left_augment(data["image"], return_tensor=True)
|
||
|
else:
|
||
|
data["image"] = self.photo_augment(data["image"], return_tensor=True)
|
||
|
|
||
|
gs = data["image"].new_tensor([0.299, 0.587, 0.114]).view(3, 1, 1)
|
||
|
if self.conf.grayscale:
|
||
|
data["image"] = (data["image"] * gs).sum(0, keepdim=True)
|
||
|
|
||
|
if self.conf.load_features.do:
|
||
|
features = self.feature_loader({k: [v] for k, v in data.items()})
|
||
|
features = self._transform_keypoints(features, data)
|
||
|
data["cache"] = features
|
||
|
|
||
|
return data
|
||
|
|
||
|
def getitem(self, idx):
|
||
|
# Generate a list of coordinates, based on the coordinate do the split
|
||
|
lat, lon = self.image_names[idx]
|
||
|
img = get_random_tiff_patch(lat, lon, self.conf.data_dir)
|
||
|
|
||
|
if img is None:
|
||
|
logging.warning("Image %f %f could not be read.", lat, lon)
|
||
|
img = np.zeros((1024, 1024) + (() if self.conf.grayscale else (3,)))
|
||
|
|
||
|
img = img.transpose(1,2,0)
|
||
|
img = img.astype(np.float32) / 255.0
|
||
|
#write_status = cv2.imwrite(f"/mnt/drive/{str(idx)}.jpg", img)
|
||
|
#if write_status == True:
|
||
|
# print("Writing success")
|
||
|
#else:
|
||
|
# print("fuck")
|
||
|
size = img.shape[:2][::-1]
|
||
|
ps = self.conf.homography.patch_shape
|
||
|
|
||
|
left_conf = omegaconf.OmegaConf.to_container(self.conf.homography)
|
||
|
if self.conf.right_only:
|
||
|
left_conf["difficulty"] = 0.0
|
||
|
|
||
|
data0 = self._read_view(img, left_conf, ps, left=True)
|
||
|
data1 = self._read_view(img, self.conf.homography, ps, left=False)
|
||
|
|
||
|
print("Data0 homography_matrix", data0["H_"])
|
||
|
print("Data1 homography matrix", data1["H_"])
|
||
|
|
||
|
# image_1 = data0["image"]
|
||
|
# image_1 = image_1.numpy()
|
||
|
# image_1 = image_1.transpose(1,2,0)
|
||
|
# image_1 = np.uint8(image_1 * 255)
|
||
|
#
|
||
|
# image_2 = data1["image"]
|
||
|
# image_2 = image_2.numpy()
|
||
|
# image_2 = image_2.transpose(1,2,0)
|
||
|
# image_2 = np.uint8(image_2 * 255)
|
||
|
#
|
||
|
# PIL_IMAGE = Image.fromarray(image_1)
|
||
|
# PIL_IMAGE.save(f"/mnt/drive/{str(idx)}.jpg")
|
||
|
# PIL_IMAGE = Image.fromarray(image_2)
|
||
|
# PIL_IMAGE.save(f"/mnt/drive/{str(idx)}-s.jpg")
|
||
|
#
|
||
|
# exit()
|
||
|
#
|
||
|
H = compute_homography(data0["coords"], data1["coords"], [1, 1])
|
||
|
|
||
|
print("COmputed homography", H)
|
||
|
|
||
|
name = f"lat{lat}_lon{lon}"
|
||
|
|
||
|
data = {
|
||
|
"name": name,
|
||
|
"original_image_size": np.array(size),
|
||
|
"H_0to1": H.astype(np.float32),
|
||
|
"idx": idx,
|
||
|
"view0": data0,
|
||
|
"view1": data1,
|
||
|
}
|
||
|
|
||
|
return data
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.image_names)
|
||
|
|
||
|
|
||
|
def visualize(args):
|
||
|
conf = {
|
||
|
"batch_size": 1,
|
||
|
"num_workers": 1,
|
||
|
"prefetch_factor": 1,
|
||
|
}
|
||
|
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
|
||
|
dataset = SatelliteDataset(conf)
|
||
|
loader = dataset.get_data_loader("train")
|
||
|
logger.info("The dataset has %d elements.", len(loader))
|
||
|
|
||
|
with fork_rng(seed=dataset.conf.seed):
|
||
|
images = []
|
||
|
for _, data in zip(range(args.num_items), loader):
|
||
|
images.append(
|
||
|
(data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2))
|
||
|
)
|
||
|
plot_image_grid(images, dpi=args.dpi)
|
||
|
plt.tight_layout()
|
||
|
plt.imsave("implot.png")
|
||
|
#plt.show()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
from .. import logger # overwrite the logger
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("--num_items", type=int, default=8)
|
||
|
parser.add_argument("--dpi", type=int, default=100)
|
||
|
parser.add_argument("dotlist", nargs="*")
|
||
|
args = parser.parse_intermixed_args()
|
||
|
visualize(args)
|