glue-factory-custom/gluefactory/visualization/tools.py

466 lines
14 KiB
Python
Raw Permalink Normal View History

import inspect
import sys
import warnings
import matplotlib.pyplot as plt
import torch
from matplotlib.backend_tools import ToolToggleBase
from matplotlib.widgets import RadioButtons, Slider
from ..geometry.epipolar import T_to_F, generalized_epi_dist
from ..geometry.homography import sym_homography_error
from ..visualization.viz2d import (
cm_ranking,
cm_RdGn,
draw_epipolar_line,
get_line,
plot_color_line_matches,
plot_heatmaps,
plot_keypoints,
plot_lines,
plot_matches,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
plt.rcParams["toolbar"] = "toolmanager"
class RadioHideTool(ToolToggleBase):
"""Show lines with a given gid."""
default_keymap = "R"
description = "Show by gid"
default_toggled = False
radio_group = "default"
def __init__(
self, *args, options=[], active=None, callback_fn=None, keymap="R", **kwargs
):
super().__init__(*args, **kwargs)
self.f = 1.0
self.options = options
self.callback_fn = callback_fn
self.active = self.options.index(active) if active else 0
self.default_keymap = keymap
self.enabled = self.default_toggled
def build_radios(self):
w = 0.2
self.radios_ax = self.figure.add_axes([1.0 - w, 0.7, w, 0.2], zorder=1)
# self.radios_ax = self.figure.add_axes([0.5-w/2, 1.0-0.2, w, 0.2], zorder=1)
self.radios = RadioButtons(self.radios_ax, self.options, active=self.active)
self.radios.on_clicked(self.on_radio_clicked)
def enable(self, *args):
size = self.figure.get_size_inches()
size[0] *= self.f
self.build_radios()
self.figure.canvas.draw_idle()
self.enabled = True
def disable(self, *args):
size = self.figure.get_size_inches()
size[0] /= self.f
self.radios_ax.remove()
self.radios = None
self.figure.canvas.draw_idle()
self.enabled = False
def on_radio_clicked(self, value):
self.active = self.options.index(value)
enabled = self.enabled
if enabled:
self.disable()
if self.callback_fn is not None:
self.callback_fn(value)
if enabled:
self.enable()
class ToggleTool(ToolToggleBase):
"""Show lines with a given gid."""
default_keymap = "t"
description = "Show by gid"
def __init__(self, *args, callback_fn=None, keymap="t", **kwargs):
super().__init__(*args, **kwargs)
self.f = 1.0
self.callback_fn = callback_fn
self.default_keymap = keymap
self.enabled = self.default_toggled
def enable(self, *args):
self.callback_fn(True)
def disable(self, *args):
self.callback_fn(False)
def add_whitespace_left(fig, factor):
w, h = fig.get_size_inches()
left = fig.subplotpars.left
fig.set_size_inches([w * (1 + factor), h])
fig.subplots_adjust(left=(factor + left) / (1 + factor))
def add_whitespace_bottom(fig, factor):
w, h = fig.get_size_inches()
b = fig.subplotpars.bottom
fig.set_size_inches([w, h * (1 + factor)])
fig.subplots_adjust(bottom=(factor + b) / (1 + factor))
fig.canvas.draw_idle()
class KeypointPlot:
plot_name = "keypoints"
required_keys = ["keypoints0", "keypoints1"]
def __init__(self, fig, axes, data, preds):
for i, name in enumerate(preds):
pred = preds[name]
plot_keypoints([pred["keypoints0"][0], pred["keypoints1"][0]], axes=axes[i])
class LinePlot:
plot_name = "lines"
required_keys = ["lines0", "lines1"]
def __init__(self, fig, axes, data, preds):
for i, name in enumerate(preds):
pred = preds[name]
plot_lines([pred["lines0"][0], pred["lines1"][0]])
class KeypointRankingPlot:
plot_name = "keypoint_ranking"
required_keys = ["keypoints0", "keypoints1", "keypoint_scores0", "keypoint_scores1"]
def __init__(self, fig, axes, data, preds):
for i, name in enumerate(preds):
pred = preds[name]
kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0]
sc0, sc1 = pred["keypoint_scores0"][0], pred["keypoint_scores1"][0]
plot_keypoints(
[kp0, kp1], axes=axes[i], colors=[cm_ranking(sc0), cm_ranking(sc1)]
)
class KeypointScoresPlot:
plot_name = "keypoint_scores"
required_keys = ["keypoints0", "keypoints1", "keypoint_scores0", "keypoint_scores1"]
def __init__(self, fig, axes, data, preds):
for i, name in enumerate(preds):
pred = preds[name]
kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0]
sc0, sc1 = pred["keypoint_scores0"][0], pred["keypoint_scores1"][0]
plot_keypoints(
[kp0, kp1], axes=axes[i], colors=[cm_RdGn(sc0), cm_RdGn(sc1)]
)
class HeatmapPlot:
plot_name = "heatmaps"
required_keys = ["heatmap0", "heatmap1"]
def __init__(self, fig, axes, data, preds):
self.artists = []
for i, name in enumerate(preds):
pred = preds[name]
heatmaps = [pred["heatmap0"][0, 0], pred["heatmap1"][0, 0]]
heatmaps = [torch.sigmoid(h) if h.min() < 0.0 else h for h in heatmaps]
self.artists += plot_heatmaps(heatmaps, axes=axes[i], cmap="rainbow")
def clear(self):
for x in self.artists:
x.remove()
class ImagePlot:
plot_name = "images"
required_keys = ["view0", "view1"]
def __init__(self, fig, axes, data, preds):
pass
class MatchesPlot:
plot_name = "matches"
required_keys = ["keypoints0", "keypoints1", "matches0", "matching_scores0"]
def __init__(self, fig, axes, data, preds):
self.fig = fig
self.sbpars = {
k: v
for k, v in vars(fig.subplotpars).items()
if k in ["left", "right", "top", "bottom"]
}
for i, name in enumerate(preds):
pred = preds[name]
plot_keypoints(
[pred["keypoints0"][0], pred["keypoints1"][0]],
axes=axes[i],
colors="blue",
)
kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0]
m0 = pred["matches0"][0]
valid = m0 > -1
kpm0 = kp0[valid]
kpm1 = kp1[m0[valid]]
mscores = pred["matching_scores0"][0][valid]
plot_matches(
kpm0,
kpm1,
color=cm_RdGn(mscores).tolist(),
axes=axes[i],
labels=mscores,
lw=0.5,
)
class LineMatchesPlot:
plot_name = "line_matches"
required_keys = ["lines0", "lines1", "line_matches0"]
def __init__(self, fig, axes, data, preds):
self.fig = fig
self.sbpars = {
k: v
for k, v in vars(fig.subplotpars).items()
if k in ["left", "right", "top", "bottom"]
}
for i, name in enumerate(preds):
pred = preds[name]
lines0, lines1 = pred["lines0"][0], pred["lines1"][0]
m0 = pred["line_matches0"][0]
valid = m0 > -1
m_lines0 = lines0[valid]
m_lines1 = lines1[m0[valid]]
plot_color_line_matches([m_lines0, m_lines1])
class GtMatchesPlot:
plot_name = "gt_matches"
required_keys = ["keypoints0", "keypoints1", "matches0", "gt_matches0"]
def __init__(self, fig, axes, data, preds):
self.fig = fig
self.sbpars = {
k: v
for k, v in vars(fig.subplotpars).items()
if k in ["left", "right", "top", "bottom"]
}
for i, name in enumerate(preds):
pred = preds[name]
plot_keypoints(
[pred["keypoints0"][0], pred["keypoints1"][0]],
axes=axes[i],
colors="blue",
)
kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0]
m0 = pred["matches0"][0]
gtm0 = pred["gt_matches0"][0]
valid = (m0 > -1) & (gtm0 >= -1)
kpm0 = kp0[valid]
kpm1 = kp1[m0[valid]]
correct = gtm0[valid] == m0[valid]
plot_matches(
kpm0,
kpm1,
color=cm_RdGn(correct).tolist(),
axes=axes[i],
labels=correct,
lw=0.5,
)
class GtLineMatchesPlot:
plot_name = "gt_line_matches"
required_keys = ["lines0", "lines1", "line_matches0", "line_gt_matches0"]
def __init__(self, fig, axes, data, preds):
self.fig = fig
self.sbpars = {
k: v
for k, v in vars(fig.subplotpars).items()
if k in ["left", "right", "top", "bottom"]
}
for i, name in enumerate(preds):
pred = preds[name]
lines0, lines1 = pred["lines0"][0], pred["lines1"][0]
m0 = pred["line_matches0"][0]
gtm0 = pred["gt_line_matches0"][0]
valid = (m0 > -1) & (gtm0 >= -1)
m_lines0 = lines0[valid]
m_lines1 = lines1[m0[valid]]
plot_color_line_matches([m_lines0, m_lines1])
class HomographyMatchesPlot:
plot_name = "homography"
required_keys = ["keypoints0", "keypoints1", "matches0", "H_0to1"]
def __init__(self, fig, axes, data, preds):
self.fig = fig
self.sbpars = {
k: v
for k, v in vars(fig.subplotpars).items()
if k in ["left", "right", "top", "bottom"]
}
add_whitespace_bottom(fig, 0.1)
self.range_ax = fig.add_axes([0.3, 0.02, 0.4, 0.06])
self.range = Slider(
self.range_ax,
label="Homography Error",
valmin=0,
valmax=5,
valinit=3.0,
valstep=1.0,
)
self.range.on_changed(self.color_matches)
for i, name in enumerate(preds):
pred = preds[name]
plot_keypoints(
[pred["keypoints0"][0], pred["keypoints1"][0]],
axes=axes[i],
colors="blue",
)
kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0]
m0 = pred["matches0"][0]
valid = m0 > -1
kpm0 = kp0[valid]
kpm1 = kp1[m0[valid]]
errors = sym_homography_error(kpm0, kpm1, data["H_0to1"][0])
plot_matches(
kpm0,
kpm1,
color=cm_RdGn(errors < self.range.val).tolist(),
axes=axes[i],
labels=errors.numpy(),
lw=0.5,
)
def clear(self):
w, h = self.fig.get_size_inches()
self.fig.set_size_inches(w, h / 1.1)
self.fig.subplots_adjust(**self.sbpars)
self.range_ax.remove()
def color_matches(self, args):
for line in self.fig.artists:
label = line.get_label()
line.set_color(cm_RdGn([float(label) < args])[0])
class EpipolarMatchesPlot:
plot_name = "epipolar_matches"
required_keys = ["keypoints0", "keypoints1", "matches0", "T_0to1", "view0", "view1"]
def __init__(self, fig, axes, data, preds):
self.fig = fig
self.axes = axes
self.sbpars = {
k: v
for k, v in vars(fig.subplotpars).items()
if k in ["left", "right", "top", "bottom"]
}
add_whitespace_bottom(fig, 0.1)
self.range_ax = fig.add_axes([0.3, 0.02, 0.4, 0.06])
self.range = Slider(
self.range_ax,
label="Epipolar Error [px]",
valmin=0,
valmax=5,
valinit=3.0,
valstep=1.0,
)
self.range.on_changed(self.color_matches)
camera0 = data["view0"]["camera"][0]
camera1 = data["view1"]["camera"][0]
T_0to1 = data["T_0to1"][0]
for i, name in enumerate(preds):
pred = preds[name]
plot_keypoints(
[pred["keypoints0"][0], pred["keypoints1"][0]],
axes=axes[i],
colors="blue",
)
kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0]
m0 = pred["matches0"][0]
valid = m0 > -1
kpm0 = kp0[valid]
kpm1 = kp1[m0[valid]]
errors = generalized_epi_dist(
kpm0,
kpm1,
camera0,
camera1,
T_0to1,
all=False,
essential=False,
)
plot_matches(
kpm0,
kpm1,
color=cm_RdGn(errors < self.range.val).tolist(),
axes=axes[i],
labels=errors.numpy(),
lw=0.5,
)
self.F = T_to_F(camera0, camera1, T_0to1)
def clear(self):
w, h = self.fig.get_size_inches()
self.fig.set_size_inches(w, h / 1.1)
self.fig.subplots_adjust(**self.sbpars)
self.range_ax.remove()
def color_matches(self, args):
for art in self.fig.artists:
label = art.get_label()
if label is not None:
art.set_color(cm_RdGn([float(label) < args])[0])
def click_artist(self, event):
art = event.artist
if art.get_label() is not None:
if hasattr(art, "epilines"):
[
x.set_visible(not x.get_visible())
for x in art.epilines
if x is not None
]
else:
xy1 = art.xy1
xy2 = art.xy2
line0 = get_line(self.F.transpose(0, 1), xy2)[:, 0]
line1 = get_line(self.F, xy1)[:, 0]
art.epilines = [
draw_epipolar_line(line0, art.axesA),
draw_epipolar_line(line1, art.axesB),
]
__plot_dict__ = {
obj.plot_name: obj
for _, obj in inspect.getmembers(sys.modules[__name__], predicate=inspect.isclass)
if hasattr(obj, "plot_name")
}