203 lines
6.3 KiB
Python
203 lines
6.3 KiB
Python
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from omegaconf import OmegaConf
|
|
from tqdm import tqdm
|
|
|
|
from ..datasets import get_dataset
|
|
from ..models.cache_loader import CacheLoader
|
|
from ..settings import EVAL_PATH
|
|
from ..utils.export_predictions import export_predictions
|
|
from .eval_pipeline import EvalPipeline, load_eval
|
|
from .io import get_eval_parser, load_model, parse_eval_args
|
|
from .utils import aggregate_pr_results, get_tp_fp_pts
|
|
|
|
|
|
def eval_dataset(loader, pred_file, suffix=""):
|
|
results = defaultdict(list)
|
|
results["num_pos" + suffix] = 0
|
|
cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
|
|
for data in tqdm(loader):
|
|
pred = cache_loader(data)
|
|
|
|
if suffix == "":
|
|
scores = pred["matching_scores0"].numpy()
|
|
sort_indices = np.argsort(scores)[::-1]
|
|
gt_matches = pred["gt_matches0"].numpy()[sort_indices]
|
|
pred_matches = pred["matches0"].numpy()[sort_indices]
|
|
else:
|
|
scores = pred["line_matching_scores0"].numpy()
|
|
sort_indices = np.argsort(scores)[::-1]
|
|
gt_matches = pred["gt_line_matches0"].numpy()[sort_indices]
|
|
pred_matches = pred["line_matches0"].numpy()[sort_indices]
|
|
scores = scores[sort_indices]
|
|
|
|
tp, fp, scores, num_pos = get_tp_fp_pts(pred_matches, gt_matches, scores)
|
|
results["tp" + suffix].append(tp)
|
|
results["fp" + suffix].append(fp)
|
|
results["scores" + suffix].append(scores)
|
|
results["num_pos" + suffix] += num_pos
|
|
|
|
# Aggregate the results
|
|
return aggregate_pr_results(results, suffix=suffix)
|
|
|
|
|
|
class ETH3DPipeline(EvalPipeline):
|
|
default_conf = {
|
|
"data": {
|
|
"name": "eth3d",
|
|
"batch_size": 1,
|
|
"train_batch_size": 1,
|
|
"val_batch_size": 1,
|
|
"test_batch_size": 1,
|
|
"num_workers": 16,
|
|
},
|
|
"model": {
|
|
"name": "gluefactory.models.two_view_pipeline",
|
|
"ground_truth": {
|
|
"name": "gluefactory.models.matchers.depth_matcher",
|
|
"use_lines": False,
|
|
},
|
|
"run_gt_in_forward": True,
|
|
},
|
|
"eval": {"plot_methods": [], "plot_line_methods": [], "eval_lines": False},
|
|
}
|
|
|
|
export_keys = [
|
|
"gt_matches0",
|
|
"matches0",
|
|
"matching_scores0",
|
|
]
|
|
|
|
optional_export_keys = [
|
|
"gt_line_matches0",
|
|
"line_matches0",
|
|
"line_matching_scores0",
|
|
]
|
|
|
|
def get_dataloader(self, data_conf=None):
|
|
data_conf = data_conf if data_conf is not None else self.default_conf["data"]
|
|
dataset = get_dataset("eth3d")(data_conf)
|
|
return dataset.get_data_loader("test")
|
|
|
|
def get_predictions(self, experiment_dir, model=None, overwrite=False):
|
|
pred_file = experiment_dir / "predictions.h5"
|
|
if not pred_file.exists() or overwrite:
|
|
if model is None:
|
|
model = load_model(self.conf.model, self.conf.checkpoint)
|
|
export_predictions(
|
|
self.get_dataloader(self.conf.data),
|
|
model,
|
|
pred_file,
|
|
keys=self.export_keys,
|
|
optional_keys=self.optional_export_keys,
|
|
)
|
|
return pred_file
|
|
|
|
def run_eval(self, loader, pred_file):
|
|
eval_conf = self.conf.eval
|
|
r = eval_dataset(loader, pred_file)
|
|
if self.conf.eval.eval_lines:
|
|
r.update(eval_dataset(loader, pred_file, conf=eval_conf, suffix="_lines"))
|
|
s = {}
|
|
|
|
return s, {}, r
|
|
|
|
|
|
def plot_pr_curve(
|
|
models_name, results, dst_file="eth3d_pr_curve.pdf", title=None, suffix=""
|
|
):
|
|
plt.figure()
|
|
f_scores = np.linspace(0.2, 0.9, num=8)
|
|
for f_score in f_scores:
|
|
x = np.linspace(0.01, 1)
|
|
y = f_score * x / (2 * x - f_score)
|
|
plt.plot(x[y >= 0], y[y >= 0], color=[0, 0.5, 0], alpha=0.3)
|
|
plt.annotate(
|
|
"f={0:0.1}".format(f_score),
|
|
xy=(0.9, y[45] + 0.02),
|
|
alpha=0.4,
|
|
fontsize=14,
|
|
)
|
|
|
|
plt.rcParams.update({"font.size": 12})
|
|
# plt.rc('legend', fontsize=10)
|
|
plt.grid(True)
|
|
plt.axis([0.0, 1.0, 0.0, 1.0])
|
|
plt.xticks(np.arange(0, 1.05, step=0.1), fontsize=16)
|
|
plt.xlabel("Recall", fontsize=18)
|
|
plt.ylabel("Precision", fontsize=18)
|
|
plt.yticks(np.arange(0, 1.05, step=0.1), fontsize=16)
|
|
plt.ylim([0.3, 1.0])
|
|
prop_cycle = plt.rcParams["axes.prop_cycle"]
|
|
colors = prop_cycle.by_key()["color"]
|
|
for m, c in zip(models_name, colors):
|
|
sAP_string = f'{m}: {results[m]["AP" + suffix]:.1f}'
|
|
plt.plot(
|
|
results[m]["curve_recall" + suffix],
|
|
results[m]["curve_precision" + suffix],
|
|
label=sAP_string,
|
|
color=c,
|
|
)
|
|
|
|
plt.legend(fontsize=16, loc="lower right")
|
|
if title:
|
|
plt.title(title)
|
|
|
|
plt.tight_layout(pad=0.5)
|
|
print(f"Saving plot to: {dst_file}")
|
|
plt.savefig(dst_file)
|
|
plt.show()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
dataset_name = Path(__file__).stem
|
|
parser = get_eval_parser()
|
|
args = parser.parse_intermixed_args()
|
|
|
|
default_conf = OmegaConf.create(ETH3DPipeline.default_conf)
|
|
|
|
# mingle paths
|
|
output_dir = Path(EVAL_PATH, dataset_name)
|
|
output_dir.mkdir(exist_ok=True, parents=True)
|
|
|
|
name, conf = parse_eval_args(
|
|
dataset_name,
|
|
args,
|
|
"configs/",
|
|
default_conf,
|
|
)
|
|
|
|
experiment_dir = output_dir / name
|
|
experiment_dir.mkdir(exist_ok=True)
|
|
|
|
pipeline = ETH3DPipeline(conf)
|
|
s, f, r = pipeline.run(
|
|
experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval
|
|
)
|
|
|
|
# print results
|
|
for k, v in r.items():
|
|
if k.startswith("AP"):
|
|
print(f"{k}: {v:.2f}")
|
|
|
|
if args.plot:
|
|
results = {}
|
|
for m in conf.eval.plot_methods:
|
|
exp_dir = output_dir / m
|
|
results[m] = load_eval(exp_dir)[1]
|
|
|
|
plot_pr_curve(conf.eval.plot_methods, results, dst_file="eth3d_pr_curve.pdf")
|
|
if conf.eval.eval_lines:
|
|
for m in conf.eval.plot_line_methods:
|
|
exp_dir = output_dir / m
|
|
results[m] = load_eval(exp_dir)[1]
|
|
plot_pr_curve(
|
|
conf.eval.plot_line_methods,
|
|
results,
|
|
dst_file="eth3d_pr_curve_lines.pdf",
|
|
suffix="_lines",
|
|
)
|