glue-factory-custom/tests/test_eval_utils.py

89 lines
3.2 KiB
Python
Raw Permalink Normal View History

import unittest
import torch
from gluefactory.eval.utils import eval_matches_homography
from gluefactory.geometry.homography import warp_points_torch
class TestEvalUtils(unittest.TestCase):
@staticmethod
def default_pts():
return torch.tensor(
[
[10.0, 10.0],
[10.0, 20.0],
[20.0, 20.0],
[20.0, 10.0],
]
)
@staticmethod
def default_pred(kps0, kps1):
return {
"keypoints0": kps0,
"keypoints1": kps1,
"matches0": torch.arange(len(kps0)),
"matching_scores0": torch.ones(len(kps1)),
}
def test_eval_matches_homography_trivial(self):
data = {"H_0to1": torch.eye(3)}
kps = self.default_pts()
pred = self.default_pred(kps, kps)
results = eval_matches_homography(data, pred)
self.assertEqual(results["prec@1px"], 1)
self.assertEqual(results["prec@3px"], 1)
self.assertEqual(results["num_matches"], 4)
self.assertEqual(results["num_keypoints"], 4)
def test_eval_matches_homography_real(self):
data = {"H_0to1": torch.tensor([[1.5, 0.2, 21], [-0.3, 1.6, 33], [0, 0, 1.0]])}
kps0 = self.default_pts()
kps1 = warp_points_torch(kps0, data["H_0to1"], inverse=False)
pred = self.default_pred(kps0, kps1)
results = eval_matches_homography(data, pred)
self.assertEqual(results["prec@1px"], 1)
self.assertEqual(results["prec@3px"], 1)
def test_eval_matches_homography_real_outliers(self):
data = {"H_0to1": torch.tensor([[1.5, 0.2, 21], [-0.3, 1.6, 33], [0, 0, 1.0]])}
kps0 = self.default_pts()
kps0 = torch.cat([kps0, torch.tensor([[5.0, 5.0]])])
kps1 = warp_points_torch(kps0, data["H_0to1"], inverse=False)
# Move one keypoint 1.5 pixels away in x and y
kps1[-1] += 1.5
pred = self.default_pred(kps0, kps1)
results = eval_matches_homography(data, pred)
self.assertAlmostEqual(results["prec@1px"], 0.8)
self.assertAlmostEqual(results["prec@3px"], 1.0)
def test_eval_matches_homography_batched(self):
H0 = torch.tensor([[1.5, 0.2, 21], [-0.3, 1.6, 33], [0, 0, 1.0]])
H1 = torch.tensor([[0.7, 0.1, -5], [-0.1, 0.65, 13], [0, 0, 1.0]])
data = {"H_0to1": torch.stack([H0, H1])}
kps0 = torch.stack([self.default_pts(), self.default_pts().flip(0)])
kps1 = warp_points_torch(kps0, data["H_0to1"], inverse=False)
# In the first element of the batch there is one outlier
kps1[0, -1] += 5
matches0 = torch.stack([torch.arange(4), torch.arange(4)])
# In the second element of the batch there is only 2 matches
matches0[1, :2] = -1
pred = {
"keypoints0": kps0,
"keypoints1": kps1,
"matches0": matches0,
"matching_scores0": torch.ones_like(matches0),
}
results = eval_matches_homography(data, pred)
self.assertAlmostEqual(results["prec@1px"][0], 0.75)
self.assertAlmostEqual(results["prec@1px"][1], 1.0)
self.assertAlmostEqual(results["num_matches"][0], 4)
self.assertAlmostEqual(results["num_matches"][1], 2)