51 lines
1.8 KiB
Python
51 lines
1.8 KiB
Python
|
import torch
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def matcher_metrics(pred, data, prefix="", prefix_gt=None):
|
||
|
def recall(m, gt_m):
|
||
|
mask = (gt_m > -1).float()
|
||
|
return ((m == gt_m) * mask).sum(1) / (1e-8 + mask.sum(1))
|
||
|
|
||
|
def accuracy(m, gt_m):
|
||
|
mask = (gt_m >= -1).float()
|
||
|
return ((m == gt_m) * mask).sum(1) / (1e-8 + mask.sum(1))
|
||
|
|
||
|
def precision(m, gt_m):
|
||
|
mask = ((m > -1) & (gt_m >= -1)).float()
|
||
|
return ((m == gt_m) * mask).sum(1) / (1e-8 + mask.sum(1))
|
||
|
|
||
|
def ranking_ap(m, gt_m, scores):
|
||
|
p_mask = ((m > -1) & (gt_m >= -1)).float()
|
||
|
r_mask = (gt_m > -1).float()
|
||
|
sort_ind = torch.argsort(-scores)
|
||
|
sorted_p_mask = torch.gather(p_mask, -1, sort_ind)
|
||
|
sorted_r_mask = torch.gather(r_mask, -1, sort_ind)
|
||
|
sorted_tp = torch.gather(m == gt_m, -1, sort_ind)
|
||
|
p_pts = torch.cumsum(sorted_tp * sorted_p_mask, -1) / (
|
||
|
1e-8 + torch.cumsum(sorted_p_mask, -1)
|
||
|
)
|
||
|
r_pts = torch.cumsum(sorted_tp * sorted_r_mask, -1) / (
|
||
|
1e-8 + sorted_r_mask.sum(-1)[:, None]
|
||
|
)
|
||
|
r_pts_diff = r_pts[..., 1:] - r_pts[..., :-1]
|
||
|
return torch.sum(r_pts_diff * p_pts[:, None, -1], dim=-1)
|
||
|
|
||
|
if prefix_gt is None:
|
||
|
prefix_gt = prefix
|
||
|
rec = recall(pred[f"{prefix}matches0"], data[f"gt_{prefix_gt}matches0"])
|
||
|
prec = precision(pred[f"{prefix}matches0"], data[f"gt_{prefix_gt}matches0"])
|
||
|
acc = accuracy(pred[f"{prefix}matches0"], data[f"gt_{prefix_gt}matches0"])
|
||
|
ap = ranking_ap(
|
||
|
pred[f"{prefix}matches0"],
|
||
|
data[f"gt_{prefix_gt}matches0"],
|
||
|
pred[f"{prefix}matching_scores0"],
|
||
|
)
|
||
|
metrics = {
|
||
|
f"{prefix}match_recall": rec,
|
||
|
f"{prefix}match_precision": prec,
|
||
|
f"{prefix}accuracy": acc,
|
||
|
f"{prefix}average_precision": ap,
|
||
|
}
|
||
|
return metrics
|