glue-factory-custom/gluefactory/datasets/hpatches.py

146 lines
4.5 KiB
Python
Raw Permalink Normal View History

"""
Simply load images from a folder or nested folders (does not have any split).
"""
import argparse
import logging
import tarfile
import matplotlib.pyplot as plt
import numpy as np
import torch
from omegaconf import OmegaConf
from ..settings import DATA_PATH
from ..utils.image import ImagePreprocessor, load_image
from ..utils.tools import fork_rng
from ..visualization.viz2d import plot_image_grid
from .base_dataset import BaseDataset
logger = logging.getLogger(__name__)
def read_homography(path):
with open(path) as f:
result = []
for line in f.readlines():
while " " in line: # Remove double spaces
line = line.replace(" ", " ")
line = line.replace(" \n", "").replace("\n", "")
# Split and discard empty strings
elements = list(filter(lambda s: s, line.split(" ")))
if elements:
result.append(elements)
return np.array(result).astype(float)
class HPatches(BaseDataset, torch.utils.data.Dataset):
default_conf = {
"preprocessing": ImagePreprocessor.default_conf,
"data_dir": "hpatches-sequences-release",
"subset": None,
"ignore_large_images": True,
"grayscale": False,
}
# Large images that were ignored in previous papers
ignored_scenes = (
"i_contruction",
"i_crownnight",
"i_dc",
"i_pencils",
"i_whitebuilding",
"v_artisans",
"v_astronautis",
"v_talent",
)
url = "http://icvl.ee.ic.ac.uk/vbalnt/hpatches/hpatches-sequences-release.tar.gz"
def _init(self, conf):
assert conf.batch_size == 1
self.preprocessor = ImagePreprocessor(conf.preprocessing)
self.root = DATA_PATH / conf.data_dir
if not self.root.exists():
logger.info("Downloading the HPatches dataset.")
self.download()
self.sequences = sorted([x.name for x in self.root.iterdir()])
if not self.sequences:
raise ValueError("No image found!")
self.items = [] # (seq, q_idx, is_illu)
for seq in self.sequences:
if conf.ignore_large_images and seq in self.ignored_scenes:
continue
if conf.subset is not None and conf.subset != seq[0]:
continue
for i in range(2, 7):
self.items.append((seq, i, seq[0] == "i"))
def download(self):
data_dir = self.root.parent
data_dir.mkdir(exist_ok=True, parents=True)
tar_path = data_dir / self.url.rsplit("/", 1)[-1]
torch.hub.download_url_to_file(self.url, tar_path)
with tarfile.open(tar_path) as tar:
tar.extractall(data_dir)
tar_path.unlink()
def get_dataset(self, split):
assert split in ["val", "test"]
return self
def _read_image(self, seq: str, idx: int) -> dict:
img = load_image(self.root / seq / f"{idx}.ppm", self.conf.grayscale)
return self.preprocessor(img)
def __getitem__(self, idx):
seq, q_idx, is_illu = self.items[idx]
data0 = self._read_image(seq, 1)
data1 = self._read_image(seq, q_idx)
H = read_homography(self.root / seq / f"H_1_{q_idx}")
H = data1["transform"] @ H @ np.linalg.inv(data0["transform"])
return {
"H_0to1": H.astype(np.float32),
"scene": seq,
"idx": idx,
"is_illu": is_illu,
"name": f"{seq}/{idx}.ppm",
"view0": data0,
"view1": data1,
}
def __len__(self):
return len(self.items)
def visualize(args):
conf = {
"batch_size": 1,
"num_workers": 8,
"prefetch_factor": 1,
}
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
dataset = HPatches(conf)
loader = dataset.get_data_loader("test")
logger.info("The dataset has %d elements.", len(loader))
with fork_rng(seed=dataset.conf.seed):
images = []
for _, data in zip(range(args.num_items), loader):
images.append(
(data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2))
)
plot_image_grid(images, dpi=args.dpi)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
from .. import logger # overwrite the logger
parser = argparse.ArgumentParser()
parser.add_argument("--num_items", type=int, default=8)
parser.add_argument("--dpi", type=int, default=100)
parser.add_argument("dotlist", nargs="*")
args = parser.parse_intermixed_args()
visualize(args)