import torch from pathlib import Path from omegaconf import OmegaConf from pprint import pprint import matplotlib.pyplot as plt from collections import defaultdict from collections.abc import Iterable from tqdm import tqdm import zipfile import numpy as np from ..visualization.viz2d import plot_cumulative from .io import ( parse_eval_args, load_model, get_eval_parser, ) from ..utils.export_predictions import export_predictions from ..settings import EVAL_PATH, DATA_PATH from ..models.cache_loader import CacheLoader from ..datasets import get_dataset from .eval_pipeline import EvalPipeline from .utils import eval_relative_pose_robust, eval_poses, eval_matches_epipolar class MegaDepth1500Pipeline(EvalPipeline): default_conf = { "data": { "name": "image_pairs", "pairs": "megadepth1500/pairs_calibrated.txt", "root": "megadepth1500/images/", "extra_data": "relative_pose", "preprocessing": { "side": "long", }, }, "model": { "ground_truth": { "name": None, # remove gt matches } }, "eval": { "estimator": "poselib", "ransac_th": 1.0, # -1 runs a bunch of thresholds and selects the best }, } export_keys = [ "keypoints0", "keypoints1", "keypoint_scores0", "keypoint_scores1", "matches0", "matches1", "matching_scores0", "matching_scores1", ] optional_export_keys = [] def _init(self, conf): if not (DATA_PATH / "megadepth1500").exists(): url = "https://cvg-data.inf.ethz.ch/megadepth/megadepth1500.zip" zip_path = DATA_PATH / url.rsplit("/", 1)[-1] torch.hub.download_url_to_file(url, zip_path) with zipfile.ZipFile(zip_path) as zip: zip.extractall(DATA_PATH) zip_path.unlink() @classmethod def get_dataloader(self, data_conf=None): """Returns a data loader with samples for each eval datapoint""" data_conf = data_conf if data_conf else self.default_conf["data"] dataset = get_dataset(data_conf["name"])(data_conf) return dataset.get_data_loader("test") def get_predictions(self, experiment_dir, model=None, overwrite=False): """Export a prediction file for each eval datapoint""" pred_file = experiment_dir / "predictions.h5" if not pred_file.exists() or overwrite: if model is None: model = load_model(self.conf.model, self.conf.checkpoint) export_predictions( self.get_dataloader(self.conf.data), model, pred_file, keys=self.export_keys, optional_keys=self.optional_export_keys, ) return pred_file def run_eval(self, loader, pred_file): """Run the eval on cached predictions""" conf = self.conf.eval results = defaultdict(list) test_thresholds = ( ([conf.ransac_th] if conf.ransac_th > 0 else [0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) if not isinstance(conf.ransac_th, Iterable) else conf.ransac_th ) pose_results = defaultdict(lambda: defaultdict(list)) cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval() for i, data in enumerate(tqdm(loader)): pred = cache_loader(data) # add custom evaluations here results_i = eval_matches_epipolar(data, pred) for th in test_thresholds: pose_results_i = eval_relative_pose_robust( data, pred, {"estimator": conf.estimator, "ransac_th": th}, ) [pose_results[th][k].append(v) for k, v in pose_results_i.items()] # we also store the names for later reference results_i["names"] = data["name"][0] if "scene" in data.keys(): results_i["scenes"] = data["scene"][0] for k, v in results_i.items(): results[k].append(v) # summarize results as a dict[str, float] # you can also add your custom evaluations here summaries = {} for k, v in results.items(): arr = np.array(v) if not np.issubdtype(np.array(v).dtype, np.number): continue summaries[f"m{k}"] = round(np.mean(arr), 3) best_pose_results, best_th = eval_poses( pose_results, auc_ths=[5, 10, 20], key="rel_pose_error" ) results = {**results, **pose_results[best_th]} summaries = { **summaries, **best_pose_results, } figures = { "pose_recall": plot_cumulative( {self.conf.eval.estimator: results["rel_pose_error"]}, [0, 30], unit="°", title="Pose ", ) } return summaries, figures, results if __name__ == "__main__": dataset_name = Path(__file__).stem parser = get_eval_parser() args = parser.parse_intermixed_args() default_conf = OmegaConf.create(MegaDepth1500Pipeline.default_conf) # mingle paths output_dir = Path(EVAL_PATH, dataset_name) output_dir.mkdir(exist_ok=True, parents=True) name, conf = parse_eval_args( dataset_name, args, "configs/", default_conf, ) experiment_dir = output_dir / name experiment_dir.mkdir(exist_ok=True) pipeline = MegaDepth1500Pipeline(conf) s, f, r = pipeline.run( experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval, ) pprint(s) if args.plot: for name, fig in f.items(): fig.canvas.manager.set_window_title(name) plt.show()