2023-10-05 16:53:51 +02:00
|
|
|
import pprint
|
|
|
|
|
2023-10-09 08:32:43 +02:00
|
|
|
import numpy as np
|
2023-10-05 16:53:51 +02:00
|
|
|
|
2023-10-09 08:32:43 +02:00
|
|
|
from . import viz2d
|
|
|
|
from .tools import RadioHideTool, ToggleTool, __plot_dict__
|
2023-10-05 16:53:51 +02:00
|
|
|
|
|
|
|
|
|
|
|
class FormatPrinter(pprint.PrettyPrinter):
|
|
|
|
def __init__(self, formats):
|
|
|
|
super(FormatPrinter, self).__init__()
|
|
|
|
self.formats = formats
|
|
|
|
|
|
|
|
def format(self, obj, ctx, maxlvl, lvl):
|
|
|
|
if type(obj) in self.formats:
|
|
|
|
return self.formats[type(obj)] % obj, 1, 0
|
|
|
|
return pprint.PrettyPrinter.format(self, obj, ctx, maxlvl, lvl)
|
|
|
|
|
|
|
|
|
|
|
|
class TwoViewFrame:
|
|
|
|
default_conf = {
|
|
|
|
"default": "matches",
|
|
|
|
"summary_visible": False,
|
|
|
|
}
|
|
|
|
|
|
|
|
plot_dict = __plot_dict__
|
|
|
|
|
|
|
|
childs = []
|
|
|
|
|
|
|
|
event_to_image = [None, "color", "depth", "color+depth"]
|
|
|
|
|
|
|
|
def __init__(self, conf, data, preds, title=None, event=1, summaries=None):
|
|
|
|
self.conf = conf
|
|
|
|
self.data = data
|
|
|
|
self.preds = preds
|
|
|
|
self.names = list(preds.keys())
|
|
|
|
self.plot = self.event_to_image[event]
|
|
|
|
self.summaries = summaries
|
|
|
|
self.fig, self.axes, self.summary_arts = self.init_frame()
|
|
|
|
if title is not None:
|
|
|
|
self.fig.canvas.manager.set_window_title(title)
|
|
|
|
|
|
|
|
keys = None
|
|
|
|
for _, pred in preds.items():
|
|
|
|
if keys is None:
|
|
|
|
keys = set(pred.keys())
|
|
|
|
else:
|
|
|
|
keys = keys.intersection(pred.keys())
|
|
|
|
keys = keys.union(data.keys())
|
|
|
|
|
|
|
|
self.options = [
|
|
|
|
k for k, v in self.plot_dict.items() if set(v.required_keys).issubset(keys)
|
|
|
|
]
|
|
|
|
self.handle = None
|
|
|
|
self.radios = self.fig.canvas.manager.toolmanager.add_tool(
|
|
|
|
"switch plot",
|
|
|
|
RadioHideTool,
|
|
|
|
options=self.options,
|
|
|
|
callback_fn=self.draw,
|
|
|
|
active=conf.default,
|
|
|
|
keymap="R",
|
|
|
|
)
|
|
|
|
|
|
|
|
self.toggle_summary = self.fig.canvas.manager.toolmanager.add_tool(
|
|
|
|
"toggle summary",
|
|
|
|
ToggleTool,
|
|
|
|
toggled=self.conf.summary_visible,
|
|
|
|
callback_fn=self.set_summary_visible,
|
|
|
|
keymap="t",
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.fig.canvas.manager.toolbar is not None:
|
|
|
|
self.fig.canvas.manager.toolbar.add_tool("switch plot", "navigation")
|
|
|
|
self.draw(conf.default)
|
|
|
|
|
|
|
|
def init_frame(self):
|
|
|
|
"""initialize frame"""
|
|
|
|
view0, view1 = self.data["view0"], self.data["view1"]
|
|
|
|
if self.plot == "color" or self.plot == "color+depth":
|
|
|
|
imgs = [
|
|
|
|
view0["image"][0].permute(1, 2, 0),
|
|
|
|
view1["image"][0].permute(1, 2, 0),
|
|
|
|
]
|
|
|
|
elif self.plot == "depth":
|
|
|
|
imgs = [view0["depth"][0], view1["depth"][0]]
|
|
|
|
else:
|
|
|
|
raise ValueError(self.plot)
|
|
|
|
imgs = [imgs for _ in self.names] # repeat for each model
|
|
|
|
|
|
|
|
fig, axes = viz2d.plot_image_grid(imgs, return_fig=True, titles=None, figs=5)
|
|
|
|
[viz2d.add_text(0, n, axes=axes[i]) for i, n in enumerate(self.names)]
|
|
|
|
|
|
|
|
if (
|
|
|
|
self.plot == "color+depth"
|
|
|
|
and "depth" in view0.keys()
|
|
|
|
and view0["depth"] is not None
|
|
|
|
):
|
|
|
|
hmaps = [[view0["depth"][0], view1["depth"][0]] for _ in self.names]
|
|
|
|
[
|
|
|
|
viz2d.plot_heatmaps(hmaps[i], axes=axes[i], cmap="Spectral")
|
|
|
|
for i, _ in enumerate(hmaps)
|
|
|
|
]
|
|
|
|
|
|
|
|
fig.canvas.mpl_connect("pick_event", self.click_artist)
|
|
|
|
if self.summaries is not None:
|
|
|
|
formatter = FormatPrinter({np.float32: "%.4f", np.float64: "%.4f"})
|
|
|
|
toggle_artists = [
|
|
|
|
viz2d.add_text(
|
|
|
|
0,
|
|
|
|
formatter.pformat(self.summaries[n]),
|
|
|
|
axes=axes[i],
|
|
|
|
pos=(0.01, 0.01),
|
|
|
|
va="bottom",
|
|
|
|
backgroundcolor=(0, 0, 0, 0.5),
|
|
|
|
visible=self.conf.summary_visible,
|
|
|
|
)
|
|
|
|
for i, n in enumerate(self.names)
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
toggle_artists = []
|
|
|
|
return fig, axes, toggle_artists
|
|
|
|
|
|
|
|
def draw(self, value):
|
|
|
|
"""redraw content in frame"""
|
|
|
|
self.clear()
|
|
|
|
self.conf.default = value
|
|
|
|
self.handle = self.plot_dict[value](self.fig, self.axes, self.data, self.preds)
|
|
|
|
return self.handle
|
|
|
|
|
|
|
|
def clear(self):
|
|
|
|
if self.handle is not None:
|
|
|
|
try:
|
|
|
|
self.handle.clear()
|
|
|
|
except AttributeError:
|
|
|
|
pass
|
|
|
|
self.handle = None
|
|
|
|
for row in self.axes:
|
|
|
|
for ax in row:
|
|
|
|
[li.remove() for li in ax.lines]
|
|
|
|
[c.remove() for c in ax.collections]
|
|
|
|
self.fig.artists.clear()
|
|
|
|
self.fig.canvas.draw_idle()
|
|
|
|
self.handle = None
|
|
|
|
|
|
|
|
def click_artist(self, event):
|
|
|
|
art = event.artist
|
|
|
|
select = art.get_arrowstyle().arrow == "-"
|
|
|
|
art.set_arrowstyle("<|-|>" if select else "-")
|
|
|
|
if select:
|
|
|
|
art.set_zorder(1)
|
|
|
|
if hasattr(self.handle, "click_artist"):
|
|
|
|
self.handle.click_artist(event)
|
|
|
|
self.fig.canvas.draw_idle()
|
|
|
|
|
|
|
|
def set_summary_visible(self, visible):
|
|
|
|
self.conf.summary_visible = visible
|
|
|
|
[s.set_visible(visible) for s in self.summary_arts]
|
|
|
|
self.fig.canvas.draw_idle()
|