glue-factory-custom/gluefactory/visualization/visualize_batch.py

58 lines
1.9 KiB
Python
Raw Normal View History

import torch
from ..utils.tensor import batch_to_device
from .viz2d import cm_RdGn, plot_heatmaps, plot_image_grid, plot_keypoints, plot_matches
def make_match_figures(pred_, data_, n_pairs=2):
# print first n pairs in batch
if "0to1" in pred_.keys():
pred_ = pred_["0to1"]
images, kpts, matches, mcolors = [], [], [], []
heatmaps = []
pred = batch_to_device(pred_, "cpu", non_blocking=False)
data = batch_to_device(data_, "cpu", non_blocking=False)
view0, view1 = data["view0"], data["view1"]
n_pairs = min(n_pairs, view0["image"].shape[0])
assert view0["image"].shape[0] >= n_pairs
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
m0 = pred["matches0"]
gtm0 = pred["gt_matches0"]
for i in range(n_pairs):
valid = (m0[i] > -1) & (gtm0[i] >= -1)
kpm0, kpm1 = kp0[i][valid].numpy(), kp1[i][m0[i][valid]].numpy()
images.append(
[view0["image"][i].permute(1, 2, 0), view1["image"][i].permute(1, 2, 0)]
)
kpts.append([kp0[i], kp1[i]])
matches.append((kpm0, kpm1))
correct = gtm0[i][valid] == m0[i][valid]
if "heatmap0" in pred.keys():
heatmaps.append(
[
torch.sigmoid(pred["heatmap0"][i, 0]),
torch.sigmoid(pred["heatmap1"][i, 0]),
]
)
elif "depth" in view0.keys() and view0["depth"] is not None:
heatmaps.append([view0["depth"][i], view1["depth"][i]])
mcolors.append(cm_RdGn(correct).tolist())
fig, axes = plot_image_grid(images, return_fig=True, set_lim=True)
if len(heatmaps) > 0:
[plot_heatmaps(heatmaps[i], axes=axes[i], a=1.0) for i in range(n_pairs)]
[plot_keypoints(kpts[i], axes=axes[i], colors="royalblue") for i in range(n_pairs)]
[
plot_matches(*matches[i], color=mcolors[i], axes=axes[i], a=0.5, lw=1.0, ps=0.0)
for i in range(n_pairs)
]
return {"matching": fig}