glue-factory-custom/gluefactory/visualization/two_view_frame.py

159 lines
5.1 KiB
Python
Raw Permalink Normal View History

import pprint
import numpy as np
from . import viz2d
from .tools import RadioHideTool, ToggleTool, __plot_dict__
class FormatPrinter(pprint.PrettyPrinter):
def __init__(self, formats):
super(FormatPrinter, self).__init__()
self.formats = formats
def format(self, obj, ctx, maxlvl, lvl):
if type(obj) in self.formats:
return self.formats[type(obj)] % obj, 1, 0
return pprint.PrettyPrinter.format(self, obj, ctx, maxlvl, lvl)
class TwoViewFrame:
default_conf = {
"default": "matches",
"summary_visible": False,
}
plot_dict = __plot_dict__
childs = []
event_to_image = [None, "color", "depth", "color+depth"]
def __init__(self, conf, data, preds, title=None, event=1, summaries=None):
self.conf = conf
self.data = data
self.preds = preds
self.names = list(preds.keys())
self.plot = self.event_to_image[event]
self.summaries = summaries
self.fig, self.axes, self.summary_arts = self.init_frame()
if title is not None:
self.fig.canvas.manager.set_window_title(title)
keys = None
for _, pred in preds.items():
if keys is None:
keys = set(pred.keys())
else:
keys = keys.intersection(pred.keys())
keys = keys.union(data.keys())
self.options = [
k for k, v in self.plot_dict.items() if set(v.required_keys).issubset(keys)
]
self.handle = None
self.radios = self.fig.canvas.manager.toolmanager.add_tool(
"switch plot",
RadioHideTool,
options=self.options,
callback_fn=self.draw,
active=conf.default,
keymap="R",
)
self.toggle_summary = self.fig.canvas.manager.toolmanager.add_tool(
"toggle summary",
ToggleTool,
toggled=self.conf.summary_visible,
callback_fn=self.set_summary_visible,
keymap="t",
)
if self.fig.canvas.manager.toolbar is not None:
self.fig.canvas.manager.toolbar.add_tool("switch plot", "navigation")
self.draw(conf.default)
def init_frame(self):
"""initialize frame"""
view0, view1 = self.data["view0"], self.data["view1"]
if self.plot == "color" or self.plot == "color+depth":
imgs = [
view0["image"][0].permute(1, 2, 0),
view1["image"][0].permute(1, 2, 0),
]
elif self.plot == "depth":
imgs = [view0["depth"][0], view1["depth"][0]]
else:
raise ValueError(self.plot)
imgs = [imgs for _ in self.names] # repeat for each model
fig, axes = viz2d.plot_image_grid(imgs, return_fig=True, titles=None, figs=5)
[viz2d.add_text(0, n, axes=axes[i]) for i, n in enumerate(self.names)]
if (
self.plot == "color+depth"
and "depth" in view0.keys()
and view0["depth"] is not None
):
hmaps = [[view0["depth"][0], view1["depth"][0]] for _ in self.names]
[
viz2d.plot_heatmaps(hmaps[i], axes=axes[i], cmap="Spectral")
for i, _ in enumerate(hmaps)
]
fig.canvas.mpl_connect("pick_event", self.click_artist)
if self.summaries is not None:
formatter = FormatPrinter({np.float32: "%.4f", np.float64: "%.4f"})
toggle_artists = [
viz2d.add_text(
0,
formatter.pformat(self.summaries[n]),
axes=axes[i],
pos=(0.01, 0.01),
va="bottom",
backgroundcolor=(0, 0, 0, 0.5),
visible=self.conf.summary_visible,
)
for i, n in enumerate(self.names)
]
else:
toggle_artists = []
return fig, axes, toggle_artists
def draw(self, value):
"""redraw content in frame"""
self.clear()
self.conf.default = value
self.handle = self.plot_dict[value](self.fig, self.axes, self.data, self.preds)
return self.handle
def clear(self):
if self.handle is not None:
try:
self.handle.clear()
except AttributeError:
pass
self.handle = None
for row in self.axes:
for ax in row:
[li.remove() for li in ax.lines]
[c.remove() for c in ax.collections]
self.fig.artists.clear()
self.fig.canvas.draw_idle()
self.handle = None
def click_artist(self, event):
art = event.artist
select = art.get_arrowstyle().arrow == "-"
art.set_arrowstyle("<|-|>" if select else "-")
if select:
art.set_zorder(1)
if hasattr(self.handle, "click_artist"):
self.handle.click_artist(event)
self.fig.canvas.draw_idle()
def set_summary_visible(self, visible):
self.conf.summary_visible = visible
[s.set_visible(visible) for s in self.summary_arts]
self.fig.canvas.draw_idle()