import inspect import sys import warnings import matplotlib.pyplot as plt import torch from matplotlib.backend_tools import ToolToggleBase from matplotlib.widgets import RadioButtons, Slider from ..geometry.epipolar import T_to_F, generalized_epi_dist from ..geometry.homography import sym_homography_error from ..visualization.viz2d import ( cm_ranking, cm_RdGn, draw_epipolar_line, get_line, plot_color_line_matches, plot_heatmaps, plot_keypoints, plot_lines, plot_matches, ) with warnings.catch_warnings(): warnings.simplefilter("ignore") plt.rcParams["toolbar"] = "toolmanager" class RadioHideTool(ToolToggleBase): """Show lines with a given gid.""" default_keymap = "R" description = "Show by gid" default_toggled = False radio_group = "default" def __init__( self, *args, options=[], active=None, callback_fn=None, keymap="R", **kwargs ): super().__init__(*args, **kwargs) self.f = 1.0 self.options = options self.callback_fn = callback_fn self.active = self.options.index(active) if active else 0 self.default_keymap = keymap self.enabled = self.default_toggled def build_radios(self): w = 0.2 self.radios_ax = self.figure.add_axes([1.0 - w, 0.7, w, 0.2], zorder=1) # self.radios_ax = self.figure.add_axes([0.5-w/2, 1.0-0.2, w, 0.2], zorder=1) self.radios = RadioButtons(self.radios_ax, self.options, active=self.active) self.radios.on_clicked(self.on_radio_clicked) def enable(self, *args): size = self.figure.get_size_inches() size[0] *= self.f self.build_radios() self.figure.canvas.draw_idle() self.enabled = True def disable(self, *args): size = self.figure.get_size_inches() size[0] /= self.f self.radios_ax.remove() self.radios = None self.figure.canvas.draw_idle() self.enabled = False def on_radio_clicked(self, value): self.active = self.options.index(value) enabled = self.enabled if enabled: self.disable() if self.callback_fn is not None: self.callback_fn(value) if enabled: self.enable() class ToggleTool(ToolToggleBase): """Show lines with a given gid.""" default_keymap = "t" description = "Show by gid" def __init__(self, *args, callback_fn=None, keymap="t", **kwargs): super().__init__(*args, **kwargs) self.f = 1.0 self.callback_fn = callback_fn self.default_keymap = keymap self.enabled = self.default_toggled def enable(self, *args): self.callback_fn(True) def disable(self, *args): self.callback_fn(False) def add_whitespace_left(fig, factor): w, h = fig.get_size_inches() left = fig.subplotpars.left fig.set_size_inches([w * (1 + factor), h]) fig.subplots_adjust(left=(factor + left) / (1 + factor)) def add_whitespace_bottom(fig, factor): w, h = fig.get_size_inches() b = fig.subplotpars.bottom fig.set_size_inches([w, h * (1 + factor)]) fig.subplots_adjust(bottom=(factor + b) / (1 + factor)) fig.canvas.draw_idle() class KeypointPlot: plot_name = "keypoints" required_keys = ["keypoints0", "keypoints1"] def __init__(self, fig, axes, data, preds): for i, name in enumerate(preds): pred = preds[name] plot_keypoints([pred["keypoints0"][0], pred["keypoints1"][0]], axes=axes[i]) class LinePlot: plot_name = "lines" required_keys = ["lines0", "lines1"] def __init__(self, fig, axes, data, preds): for i, name in enumerate(preds): pred = preds[name] plot_lines([pred["lines0"][0], pred["lines1"][0]]) class KeypointRankingPlot: plot_name = "keypoint_ranking" required_keys = ["keypoints0", "keypoints1", "keypoint_scores0", "keypoint_scores1"] def __init__(self, fig, axes, data, preds): for i, name in enumerate(preds): pred = preds[name] kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0] sc0, sc1 = pred["keypoint_scores0"][0], pred["keypoint_scores1"][0] plot_keypoints( [kp0, kp1], axes=axes[i], colors=[cm_ranking(sc0), cm_ranking(sc1)] ) class KeypointScoresPlot: plot_name = "keypoint_scores" required_keys = ["keypoints0", "keypoints1", "keypoint_scores0", "keypoint_scores1"] def __init__(self, fig, axes, data, preds): for i, name in enumerate(preds): pred = preds[name] kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0] sc0, sc1 = pred["keypoint_scores0"][0], pred["keypoint_scores1"][0] plot_keypoints( [kp0, kp1], axes=axes[i], colors=[cm_RdGn(sc0), cm_RdGn(sc1)] ) class HeatmapPlot: plot_name = "heatmaps" required_keys = ["heatmap0", "heatmap1"] def __init__(self, fig, axes, data, preds): self.artists = [] for i, name in enumerate(preds): pred = preds[name] heatmaps = [pred["heatmap0"][0, 0], pred["heatmap1"][0, 0]] heatmaps = [torch.sigmoid(h) if h.min() < 0.0 else h for h in heatmaps] self.artists += plot_heatmaps(heatmaps, axes=axes[i], cmap="rainbow") def clear(self): for x in self.artists: x.remove() class ImagePlot: plot_name = "images" required_keys = ["view0", "view1"] def __init__(self, fig, axes, data, preds): pass class MatchesPlot: plot_name = "matches" required_keys = ["keypoints0", "keypoints1", "matches0", "matching_scores0"] def __init__(self, fig, axes, data, preds): self.fig = fig self.sbpars = { k: v for k, v in vars(fig.subplotpars).items() if k in ["left", "right", "top", "bottom"] } for i, name in enumerate(preds): pred = preds[name] plot_keypoints( [pred["keypoints0"][0], pred["keypoints1"][0]], axes=axes[i], colors="blue", ) kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0] m0 = pred["matches0"][0] valid = m0 > -1 kpm0 = kp0[valid] kpm1 = kp1[m0[valid]] mscores = pred["matching_scores0"][0][valid] plot_matches( kpm0, kpm1, color=cm_RdGn(mscores).tolist(), axes=axes[i], labels=mscores, lw=0.5, ) class LineMatchesPlot: plot_name = "line_matches" required_keys = ["lines0", "lines1", "line_matches0"] def __init__(self, fig, axes, data, preds): self.fig = fig self.sbpars = { k: v for k, v in vars(fig.subplotpars).items() if k in ["left", "right", "top", "bottom"] } for i, name in enumerate(preds): pred = preds[name] lines0, lines1 = pred["lines0"][0], pred["lines1"][0] m0 = pred["line_matches0"][0] valid = m0 > -1 m_lines0 = lines0[valid] m_lines1 = lines1[m0[valid]] plot_color_line_matches([m_lines0, m_lines1]) class GtMatchesPlot: plot_name = "gt_matches" required_keys = ["keypoints0", "keypoints1", "matches0", "gt_matches0"] def __init__(self, fig, axes, data, preds): self.fig = fig self.sbpars = { k: v for k, v in vars(fig.subplotpars).items() if k in ["left", "right", "top", "bottom"] } for i, name in enumerate(preds): pred = preds[name] plot_keypoints( [pred["keypoints0"][0], pred["keypoints1"][0]], axes=axes[i], colors="blue", ) kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0] m0 = pred["matches0"][0] gtm0 = pred["gt_matches0"][0] valid = (m0 > -1) & (gtm0 >= -1) kpm0 = kp0[valid] kpm1 = kp1[m0[valid]] correct = gtm0[valid] == m0[valid] plot_matches( kpm0, kpm1, color=cm_RdGn(correct).tolist(), axes=axes[i], labels=correct, lw=0.5, ) class GtLineMatchesPlot: plot_name = "gt_line_matches" required_keys = ["lines0", "lines1", "line_matches0", "line_gt_matches0"] def __init__(self, fig, axes, data, preds): self.fig = fig self.sbpars = { k: v for k, v in vars(fig.subplotpars).items() if k in ["left", "right", "top", "bottom"] } for i, name in enumerate(preds): pred = preds[name] lines0, lines1 = pred["lines0"][0], pred["lines1"][0] m0 = pred["line_matches0"][0] gtm0 = pred["gt_line_matches0"][0] valid = (m0 > -1) & (gtm0 >= -1) m_lines0 = lines0[valid] m_lines1 = lines1[m0[valid]] plot_color_line_matches([m_lines0, m_lines1]) class HomographyMatchesPlot: plot_name = "homography" required_keys = ["keypoints0", "keypoints1", "matches0", "H_0to1"] def __init__(self, fig, axes, data, preds): self.fig = fig self.sbpars = { k: v for k, v in vars(fig.subplotpars).items() if k in ["left", "right", "top", "bottom"] } add_whitespace_bottom(fig, 0.1) self.range_ax = fig.add_axes([0.3, 0.02, 0.4, 0.06]) self.range = Slider( self.range_ax, label="Homography Error", valmin=0, valmax=5, valinit=3.0, valstep=1.0, ) self.range.on_changed(self.color_matches) for i, name in enumerate(preds): pred = preds[name] plot_keypoints( [pred["keypoints0"][0], pred["keypoints1"][0]], axes=axes[i], colors="blue", ) kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0] m0 = pred["matches0"][0] valid = m0 > -1 kpm0 = kp0[valid] kpm1 = kp1[m0[valid]] errors = sym_homography_error(kpm0, kpm1, data["H_0to1"][0]) plot_matches( kpm0, kpm1, color=cm_RdGn(errors < self.range.val).tolist(), axes=axes[i], labels=errors.numpy(), lw=0.5, ) def clear(self): w, h = self.fig.get_size_inches() self.fig.set_size_inches(w, h / 1.1) self.fig.subplots_adjust(**self.sbpars) self.range_ax.remove() def color_matches(self, args): for line in self.fig.artists: label = line.get_label() line.set_color(cm_RdGn([float(label) < args])[0]) class EpipolarMatchesPlot: plot_name = "epipolar_matches" required_keys = ["keypoints0", "keypoints1", "matches0", "T_0to1", "view0", "view1"] def __init__(self, fig, axes, data, preds): self.fig = fig self.axes = axes self.sbpars = { k: v for k, v in vars(fig.subplotpars).items() if k in ["left", "right", "top", "bottom"] } add_whitespace_bottom(fig, 0.1) self.range_ax = fig.add_axes([0.3, 0.02, 0.4, 0.06]) self.range = Slider( self.range_ax, label="Epipolar Error [px]", valmin=0, valmax=5, valinit=3.0, valstep=1.0, ) self.range.on_changed(self.color_matches) camera0 = data["view0"]["camera"][0] camera1 = data["view1"]["camera"][0] T_0to1 = data["T_0to1"][0] for i, name in enumerate(preds): pred = preds[name] plot_keypoints( [pred["keypoints0"][0], pred["keypoints1"][0]], axes=axes[i], colors="blue", ) kp0, kp1 = pred["keypoints0"][0], pred["keypoints1"][0] m0 = pred["matches0"][0] valid = m0 > -1 kpm0 = kp0[valid] kpm1 = kp1[m0[valid]] errors = generalized_epi_dist( kpm0, kpm1, camera0, camera1, T_0to1, all=False, essential=False, ) plot_matches( kpm0, kpm1, color=cm_RdGn(errors < self.range.val).tolist(), axes=axes[i], labels=errors.numpy(), lw=0.5, ) self.F = T_to_F(camera0, camera1, T_0to1) def clear(self): w, h = self.fig.get_size_inches() self.fig.set_size_inches(w, h / 1.1) self.fig.subplots_adjust(**self.sbpars) self.range_ax.remove() def color_matches(self, args): for art in self.fig.artists: label = art.get_label() if label is not None: art.set_color(cm_RdGn([float(label) < args])[0]) def click_artist(self, event): art = event.artist if art.get_label() is not None: if hasattr(art, "epilines"): [ x.set_visible(not x.get_visible()) for x in art.epilines if x is not None ] else: xy1 = art.xy1 xy2 = art.xy2 line0 = get_line(self.F.transpose(0, 1), xy2)[:, 0] line1 = get_line(self.F, xy1)[:, 0] art.epilines = [ draw_epipolar_line(line0, art.axesA), draw_epipolar_line(line1, art.axesB), ] __plot_dict__ = { obj.plot_name: obj for _, obj in inspect.getmembers(sys.modules[__name__], predicate=inspect.isclass) if hasattr(obj, "plot_name") }