186 lines
5.8 KiB
Python
186 lines
5.8 KiB
Python
import torch
|
|
from pathlib import Path
|
|
from omegaconf import OmegaConf
|
|
from pprint import pprint
|
|
import matplotlib.pyplot as plt
|
|
from collections import defaultdict
|
|
from collections.abc import Iterable
|
|
from tqdm import tqdm
|
|
import zipfile
|
|
import numpy as np
|
|
from ..visualization.viz2d import plot_cumulative
|
|
from .io import (
|
|
parse_eval_args,
|
|
load_model,
|
|
get_eval_parser,
|
|
)
|
|
from ..utils.export_predictions import export_predictions
|
|
from ..settings import EVAL_PATH, DATA_PATH
|
|
from ..models.cache_loader import CacheLoader
|
|
from ..datasets import get_dataset
|
|
from .eval_pipeline import EvalPipeline
|
|
|
|
from .utils import eval_relative_pose_robust, eval_poses, eval_matches_epipolar
|
|
|
|
|
|
class MegaDepth1500Pipeline(EvalPipeline):
|
|
default_conf = {
|
|
"data": {
|
|
"name": "image_pairs",
|
|
"pairs": "megadepth1500/pairs_calibrated.txt",
|
|
"root": "megadepth1500/images/",
|
|
"extra_data": "relative_pose",
|
|
"preprocessing": {
|
|
"side": "long",
|
|
},
|
|
},
|
|
"model": {
|
|
"ground_truth": {
|
|
"name": None, # remove gt matches
|
|
}
|
|
},
|
|
"eval": {
|
|
"estimator": "poselib",
|
|
"ransac_th": 1.0, # -1 runs a bunch of thresholds and selects the best
|
|
},
|
|
}
|
|
|
|
export_keys = [
|
|
"keypoints0",
|
|
"keypoints1",
|
|
"keypoint_scores0",
|
|
"keypoint_scores1",
|
|
"matches0",
|
|
"matches1",
|
|
"matching_scores0",
|
|
"matching_scores1",
|
|
]
|
|
optional_export_keys = []
|
|
|
|
def _init(self, conf):
|
|
if not (DATA_PATH / "megadepth1500").exists():
|
|
url = "https://cvg-data.inf.ethz.ch/megadepth/megadepth1500.zip"
|
|
zip_path = DATA_PATH / url.rsplit("/", 1)[-1]
|
|
torch.hub.download_url_to_file(url, zip_path)
|
|
with zipfile.ZipFile(zip_path) as zip:
|
|
zip.extractall(DATA_PATH)
|
|
zip_path.unlink()
|
|
|
|
@classmethod
|
|
def get_dataloader(self, data_conf=None):
|
|
"""Returns a data loader with samples for each eval datapoint"""
|
|
data_conf = data_conf if data_conf else self.default_conf["data"]
|
|
dataset = get_dataset(data_conf["name"])(data_conf)
|
|
return dataset.get_data_loader("test")
|
|
|
|
def get_predictions(self, experiment_dir, model=None, overwrite=False):
|
|
"""Export a prediction file for each eval datapoint"""
|
|
pred_file = experiment_dir / "predictions.h5"
|
|
if not pred_file.exists() or overwrite:
|
|
if model is None:
|
|
model = load_model(self.conf.model, self.conf.checkpoint)
|
|
export_predictions(
|
|
self.get_dataloader(self.conf.data),
|
|
model,
|
|
pred_file,
|
|
keys=self.export_keys,
|
|
optional_keys=self.optional_export_keys,
|
|
)
|
|
return pred_file
|
|
|
|
def run_eval(self, loader, pred_file):
|
|
"""Run the eval on cached predictions"""
|
|
conf = self.conf.eval
|
|
results = defaultdict(list)
|
|
test_thresholds = (
|
|
([conf.ransac_th] if conf.ransac_th > 0 else [0.5, 1.0, 1.5, 2.0, 2.5, 3.0])
|
|
if not isinstance(conf.ransac_th, Iterable)
|
|
else conf.ransac_th
|
|
)
|
|
pose_results = defaultdict(lambda: defaultdict(list))
|
|
cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
|
|
for i, data in enumerate(tqdm(loader)):
|
|
pred = cache_loader(data)
|
|
# add custom evaluations here
|
|
results_i = eval_matches_epipolar(data, pred)
|
|
for th in test_thresholds:
|
|
pose_results_i = eval_relative_pose_robust(
|
|
data,
|
|
pred,
|
|
{"estimator": conf.estimator, "ransac_th": th},
|
|
)
|
|
[pose_results[th][k].append(v) for k, v in pose_results_i.items()]
|
|
|
|
# we also store the names for later reference
|
|
results_i["names"] = data["name"][0]
|
|
if "scene" in data.keys():
|
|
results_i["scenes"] = data["scene"][0]
|
|
|
|
for k, v in results_i.items():
|
|
results[k].append(v)
|
|
|
|
# summarize results as a dict[str, float]
|
|
# you can also add your custom evaluations here
|
|
summaries = {}
|
|
for k, v in results.items():
|
|
arr = np.array(v)
|
|
if not np.issubdtype(np.array(v).dtype, np.number):
|
|
continue
|
|
summaries[f"m{k}"] = round(np.mean(arr), 3)
|
|
|
|
best_pose_results, best_th = eval_poses(
|
|
pose_results, auc_ths=[5, 10, 20], key="rel_pose_error"
|
|
)
|
|
results = {**results, **pose_results[best_th]}
|
|
summaries = {
|
|
**summaries,
|
|
**best_pose_results,
|
|
}
|
|
|
|
figures = {
|
|
"pose_recall": plot_cumulative(
|
|
{self.conf.eval.estimator: results["rel_pose_error"]},
|
|
[0, 30],
|
|
unit="°",
|
|
title="Pose ",
|
|
)
|
|
}
|
|
|
|
return summaries, figures, results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
dataset_name = Path(__file__).stem
|
|
parser = get_eval_parser()
|
|
args = parser.parse_intermixed_args()
|
|
|
|
default_conf = OmegaConf.create(MegaDepth1500Pipeline.default_conf)
|
|
|
|
# mingle paths
|
|
output_dir = Path(EVAL_PATH, dataset_name)
|
|
output_dir.mkdir(exist_ok=True, parents=True)
|
|
|
|
name, conf = parse_eval_args(
|
|
dataset_name,
|
|
args,
|
|
"configs/",
|
|
default_conf,
|
|
)
|
|
|
|
experiment_dir = output_dir / name
|
|
experiment_dir.mkdir(exist_ok=True)
|
|
|
|
pipeline = MegaDepth1500Pipeline(conf)
|
|
s, f, r = pipeline.run(
|
|
experiment_dir,
|
|
overwrite=args.overwrite,
|
|
overwrite_eval=args.overwrite_eval,
|
|
)
|
|
|
|
pprint(s)
|
|
|
|
if args.plot:
|
|
for name, fig in f.items():
|
|
fig.canvas.manager.set_window_title(name)
|
|
plt.show()
|