from pathlib import Path from omegaconf import OmegaConf from pprint import pprint import matplotlib.pyplot as plt from collections import defaultdict from collections.abc import Iterable from tqdm import tqdm import numpy as np from ..visualization.viz2d import plot_cumulative from .io import ( parse_eval_args, load_model, get_eval_parser, ) from ..utils.export_predictions import export_predictions from ..settings import EVAL_PATH from ..models.cache_loader import CacheLoader from ..datasets import get_dataset from .utils import ( eval_homography_robust, eval_poses, eval_matches_homography, eval_homography_dlt, ) from ..utils.tools import AUCMetric from .eval_pipeline import EvalPipeline class HPatchesPipeline(EvalPipeline): default_conf = { "data": { "batch_size": 1, "name": "hpatches", "num_workers": 16, "preprocessing": { "resize": 480, # we also resize during eval to have comparable metrics "side": "short", }, }, "model": { "ground_truth": { "name": None, # remove gt matches } }, "eval": { "estimator": "poselib", "ransac_th": 1.0, # -1 runs a bunch of thresholds and selects the best }, } export_keys = [ "keypoints0", "keypoints1", "keypoint_scores0", "keypoint_scores1", "matches0", "matches1", "matching_scores0", "matching_scores1", ] optional_export_keys = [ "lines0", "lines1", "orig_lines0", "orig_lines1", "line_matches0", "line_matches1", "line_matching_scores0", "line_matching_scores1", ] def _init(self, conf): pass @classmethod def get_dataloader(self, data_conf=None): data_conf = data_conf if data_conf else self.default_conf["data"] dataset = get_dataset("hpatches")(data_conf) return dataset.get_data_loader("test") def get_predictions(self, experiment_dir, model=None, overwrite=False): pred_file = experiment_dir / "predictions.h5" if not pred_file.exists() or overwrite: if model is None: model = load_model(self.conf.model, self.conf.checkpoint) export_predictions( self.get_dataloader(self.conf.data), model, pred_file, keys=self.export_keys, optional_keys=self.optional_export_keys, ) return pred_file def run_eval(self, loader, pred_file): assert pred_file.exists() results = defaultdict(list) conf = self.conf.eval test_thresholds = ( ([conf.ransac_th] if conf.ransac_th > 0 else [0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) if not isinstance(conf.ransac_th, Iterable) else conf.ransac_th ) pose_results = defaultdict(lambda: defaultdict(list)) cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval() for i, data in enumerate(tqdm(loader)): pred = cache_loader(data) # add custom evaluations here if "keypoints0" in pred: results_i = eval_matches_homography(data, pred, {}) results_i = {**results_i, **eval_homography_dlt(data, pred)} else: results_i = {} for th in test_thresholds: pose_results_i = eval_homography_robust( data, pred, {"estimator": conf.estimator, "ransac_th": th}, ) [pose_results[th][k].append(v) for k, v in pose_results_i.items()] # we also store the names for later reference results_i["names"] = data["name"][0] results_i["scenes"] = data["scene"][0] for k, v in results_i.items(): results[k].append(v) # summarize results as a dict[str, float] # you can also add your custom evaluations here summaries = {} for k, v in results.items(): arr = np.array(v) if not np.issubdtype(np.array(v).dtype, np.number): continue summaries[f"m{k}"] = round(np.median(arr), 3) auc_ths = [1, 3, 5] best_pose_results, best_th = eval_poses( pose_results, auc_ths=auc_ths, key="H_error_ransac", unit="px" ) if "H_error_dlt" in results.keys(): dlt_aucs = AUCMetric(auc_ths, results["H_error_dlt"]).compute() for i, ath in enumerate(auc_ths): summaries[f"H_error_dlt@{ath}px"] = dlt_aucs[i] results = {**results, **pose_results[best_th]} summaries = { **summaries, **best_pose_results, } figures = { "homography_recall": plot_cumulative( { "DLT": results["H_error_dlt"], self.conf.eval.estimator: results["H_error_ransac"], }, [0, 10], unit="px", title="Homography ", ) } return summaries, figures, results if __name__ == "__main__": dataset_name = Path(__file__).stem parser = get_eval_parser() args = parser.parse_intermixed_args() default_conf = OmegaConf.create(HPatchesPipeline.default_conf) # mingle paths output_dir = Path(EVAL_PATH, dataset_name) output_dir.mkdir(exist_ok=True, parents=True) name, conf = parse_eval_args( dataset_name, args, "configs/", default_conf, ) experiment_dir = output_dir / name experiment_dir.mkdir(exist_ok=True) pipeline = HPatchesPipeline(conf) s, f, r = pipeline.run( experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval ) # print results pprint(s) if args.plot: for name, fig in f.items(): fig.canvas.manager.set_window_title(name) plt.show()