glue-factory-custom/tests/test_integration.py

133 lines
4.3 KiB
Python
Raw Permalink Normal View History

import unittest
from collections import namedtuple
from os.path import splitext
import cv2
import matplotlib.pyplot as plt
import torch.cuda
from kornia import image_to_tensor
from omegaconf import OmegaConf
from parameterized import parameterized
from torch import Tensor
from gluefactory import logger
from gluefactory.eval.utils import (
eval_homography_dlt,
eval_homography_robust,
eval_matches_homography,
)
from gluefactory.models.two_view_pipeline import TwoViewPipeline
from gluefactory.settings import root
from gluefactory.utils.image import ImagePreprocessor
from gluefactory.utils.tensor import map_tensor
from gluefactory.utils.tools import set_seed
from gluefactory.visualization.viz2d import (
plot_color_line_matches,
plot_images,
plot_matches,
)
def create_input_data(cv_img0, cv_img1, device):
img0 = image_to_tensor(cv_img0).float() / 255
img1 = image_to_tensor(cv_img1).float() / 255
ip = ImagePreprocessor({})
data = {"view0": ip(img0), "view1": ip(img1)}
data = map_tensor(
data,
lambda t: t[None].to(device)
if isinstance(t, Tensor)
else torch.from_numpy(t)[None].to(device),
)
return data
ExpectedResults = namedtuple("ExpectedResults", ("num_matches", "prec3px", "h_error"))
class TestIntegration(unittest.TestCase):
methods_to_test = [
("superpoint+NN.yaml", "poselib", ExpectedResults(1300, 0.8, 1.0)),
("superpoint-open+NN.yaml", "poselib", ExpectedResults(1300, 0.8, 1.0)),
(
"superpoint+lsd+gluestick.yaml",
"homography_est",
ExpectedResults(1300, 0.8, 1.0),
),
(
"superpoint+lightglue-official.yaml",
"poselib",
ExpectedResults(1300, 0.8, 1.0),
),
]
visualize = False
@parameterized.expand(methods_to_test)
@torch.no_grad()
def test_real_homography(self, conf_file, estimator, exp_results):
set_seed(0)
model_path = root / "gluefactory" / "configs" / conf_file
img_path0 = root / "assets" / "boat1.png"
img_path1 = root / "assets" / "boat2.png"
h_gt = torch.tensor(
[
[0.85799, 0.21669, 9.4839],
[-0.21177, 0.85855, 130.48],
[1.5015e-06, 9.2033e-07, 1],
]
)
device = "cuda" if torch.cuda.is_available() else "cpu"
gs = TwoViewPipeline(OmegaConf.load(model_path).model).to(device).eval()
cv_img0, cv_img1 = cv2.imread(str(img_path0)), cv2.imread(str(img_path1))
data = create_input_data(cv_img0, cv_img1, device)
pred = gs(data)
pred = map_tensor(
pred, lambda t: torch.squeeze(t, dim=0) if isinstance(t, Tensor) else t
)
data["H_0to1"] = h_gt.to(device)
data["H_1to0"] = torch.linalg.inv(h_gt).to(device)
results = eval_matches_homography(data, pred)
results = {**results, **eval_homography_dlt(data, pred)}
results = {
**results,
**eval_homography_robust(
data,
pred,
{"estimator": estimator},
),
}
logger.info(results)
self.assertGreater(results["num_matches"], exp_results.num_matches)
self.assertGreater(results["prec@3px"], exp_results.prec3px)
self.assertLess(results["H_error_ransac"], exp_results.h_error)
if self.visualize:
pred = map_tensor(
pred, lambda t: t.cpu().numpy() if isinstance(t, Tensor) else t
)
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
m0 = pred["matches0"]
valid0 = m0 != -1
kpm0, kpm1 = kp0[valid0], kp1[m0[valid0]]
plot_images([cv_img0, cv_img1])
plot_matches(kpm0, kpm1, a=0.0)
plt.savefig(f"{splitext(conf_file)[0]}_point_matches.svg")
if "lines0" in pred and "lines1" in pred:
lines0, lines1 = pred["lines0"], pred["lines1"]
lm0 = pred["line_matches0"]
lvalid0 = lm0 != -1
linem0, linem1 = lines0[lvalid0], lines1[lm0[lvalid0]]
plot_images([cv_img0, cv_img1])
plot_color_line_matches([linem0, linem1])
plt.savefig(f"{splitext(conf_file)[0]}_line_matches.svg")
plt.show()