145 lines
4.5 KiB
Python
145 lines
4.5 KiB
Python
|
"""
|
||
|
Simply load images from a folder or nested folders (does not have any split).
|
||
|
"""
|
||
|
import argparse
|
||
|
import logging
|
||
|
import tarfile
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from omegaconf import OmegaConf
|
||
|
|
||
|
from .base_dataset import BaseDataset
|
||
|
from ..settings import DATA_PATH
|
||
|
from ..utils.image import load_image, ImagePreprocessor
|
||
|
from ..utils.tools import fork_rng
|
||
|
from ..visualization.viz2d import plot_image_grid
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def read_homography(path):
|
||
|
with open(path) as f:
|
||
|
result = []
|
||
|
for line in f.readlines():
|
||
|
while " " in line: # Remove double spaces
|
||
|
line = line.replace(" ", " ")
|
||
|
line = line.replace(" \n", "").replace("\n", "")
|
||
|
# Split and discard empty strings
|
||
|
elements = list(filter(lambda s: s, line.split(" ")))
|
||
|
if elements:
|
||
|
result.append(elements)
|
||
|
return np.array(result).astype(float)
|
||
|
|
||
|
|
||
|
class HPatches(BaseDataset, torch.utils.data.Dataset):
|
||
|
default_conf = {
|
||
|
"preprocessing": ImagePreprocessor.default_conf,
|
||
|
"data_dir": "hpatches-sequences-release",
|
||
|
"subset": None,
|
||
|
"ignore_large_images": True,
|
||
|
"grayscale": False,
|
||
|
}
|
||
|
|
||
|
# Large images that were ignored in previous papers
|
||
|
ignored_scenes = (
|
||
|
"i_contruction",
|
||
|
"i_crownnight",
|
||
|
"i_dc",
|
||
|
"i_pencils",
|
||
|
"i_whitebuilding",
|
||
|
"v_artisans",
|
||
|
"v_astronautis",
|
||
|
"v_talent",
|
||
|
)
|
||
|
url = "http://icvl.ee.ic.ac.uk/vbalnt/hpatches/hpatches-sequences-release.tar.gz"
|
||
|
|
||
|
def _init(self, conf):
|
||
|
assert conf.batch_size == 1
|
||
|
self.preprocessor = ImagePreprocessor(conf.preprocessing)
|
||
|
|
||
|
self.root = DATA_PATH / conf.data_dir
|
||
|
if not self.root.exists():
|
||
|
logger.info("Downloading the HPatches dataset.")
|
||
|
self.download()
|
||
|
self.sequences = sorted([x.name for x in self.root.iterdir()])
|
||
|
if not self.sequences:
|
||
|
raise ValueError("No image found!")
|
||
|
self.items = [] # (seq, q_idx, is_illu)
|
||
|
for seq in self.sequences:
|
||
|
if conf.ignore_large_images and seq in self.ignored_scenes:
|
||
|
continue
|
||
|
if conf.subset is not None and conf.subset != seq[0]:
|
||
|
continue
|
||
|
for i in range(2, 7):
|
||
|
self.items.append((seq, i, seq[0] == "i"))
|
||
|
|
||
|
def download(self):
|
||
|
data_dir = self.root.parent
|
||
|
data_dir.mkdir(exist_ok=True, parents=True)
|
||
|
tar_path = data_dir / self.url.rsplit("/", 1)[-1]
|
||
|
torch.hub.download_url_to_file(self.url, tar_path)
|
||
|
with tarfile.open(tar_path) as tar:
|
||
|
tar.extractall(data_dir)
|
||
|
tar_path.unlink()
|
||
|
|
||
|
def get_dataset(self, split):
|
||
|
assert split in ["val", "test"]
|
||
|
return self
|
||
|
|
||
|
def _read_image(self, seq: str, idx: int) -> dict:
|
||
|
img = load_image(self.root / seq / f"{idx}.ppm", self.conf.grayscale)
|
||
|
return self.preprocessor(img)
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
seq, q_idx, is_illu = self.items[idx]
|
||
|
data0 = self._read_image(seq, 1)
|
||
|
data1 = self._read_image(seq, q_idx)
|
||
|
H = read_homography(self.root / seq / f"H_1_{q_idx}")
|
||
|
H = data1["transform"] @ H @ np.linalg.inv(data0["transform"])
|
||
|
return {
|
||
|
"H_0to1": H.astype(np.float32),
|
||
|
"scene": seq,
|
||
|
"idx": idx,
|
||
|
"is_illu": is_illu,
|
||
|
"name": f"{seq}/{idx}.ppm",
|
||
|
"view0": data0,
|
||
|
"view1": data1,
|
||
|
}
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.items)
|
||
|
|
||
|
|
||
|
def visualize(args):
|
||
|
conf = {
|
||
|
"batch_size": 1,
|
||
|
"num_workers": 8,
|
||
|
"prefetch_factor": 1,
|
||
|
}
|
||
|
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
|
||
|
dataset = HPatches(conf)
|
||
|
loader = dataset.get_data_loader("test")
|
||
|
logger.info("The dataset has %d elements.", len(loader))
|
||
|
|
||
|
with fork_rng(seed=dataset.conf.seed):
|
||
|
images = []
|
||
|
for _, data in zip(range(args.num_items), loader):
|
||
|
images.append(
|
||
|
(data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2))
|
||
|
)
|
||
|
plot_image_grid(images, dpi=args.dpi)
|
||
|
plt.tight_layout()
|
||
|
plt.show()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
from .. import logger # overwrite the logger
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("--num_items", type=int, default=8)
|
||
|
parser.add_argument("--dpi", type=int, default=100)
|
||
|
parser.add_argument("dotlist", nargs="*")
|
||
|
args = parser.parse_intermixed_args()
|
||
|
visualize(args)
|