288 lines
9.4 KiB
Python
288 lines
9.4 KiB
Python
|
import traceback
|
||
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
from omegaconf import OmegaConf
|
||
|
from matplotlib.widgets import Button
|
||
|
from copy import deepcopy
|
||
|
import functools
|
||
|
|
||
|
# from ..eval.export_predictions import load_predictions
|
||
|
from ..models.cache_loader import CacheLoader
|
||
|
from ..datasets.base_dataset import collate
|
||
|
from .tools import RadioHideTool
|
||
|
|
||
|
|
||
|
class GlobalFrame:
|
||
|
default_conf = {
|
||
|
"x": "???",
|
||
|
"y": "???",
|
||
|
"diff": False,
|
||
|
"child": {},
|
||
|
"remove_outliers": False,
|
||
|
}
|
||
|
|
||
|
child_frame = None # MatchFrame
|
||
|
|
||
|
childs = []
|
||
|
|
||
|
lines = []
|
||
|
|
||
|
scatters = {}
|
||
|
|
||
|
def __init__(
|
||
|
self, conf, results, loader, predictions, title=None, child_frame=None
|
||
|
):
|
||
|
self.child_frame = child_frame
|
||
|
if self.child_frame is not None:
|
||
|
# We do NOT merge inside the child frame to keep settings across figs
|
||
|
self.default_conf["child"] = self.child_frame.default_conf
|
||
|
|
||
|
self.conf = OmegaConf.merge(self.default_conf, conf)
|
||
|
self.results = results
|
||
|
self.loader = loader
|
||
|
self.predictions = predictions
|
||
|
self.metrics = set()
|
||
|
for k, v in results.items():
|
||
|
self.metrics.update(v.keys())
|
||
|
self.metrics = sorted(list(self.metrics))
|
||
|
|
||
|
self.conf.x = conf["x"] if conf["x"] else self.metrics[0]
|
||
|
self.conf.y = conf["y"] if conf["y"] else self.metrics[1]
|
||
|
|
||
|
assert self.conf.x in self.metrics
|
||
|
assert self.conf.y in self.metrics
|
||
|
|
||
|
self.names = list(results)
|
||
|
self.fig, self.axes = self.init_frame()
|
||
|
if title is not None:
|
||
|
self.fig.canvas.manager.set_window_title(title)
|
||
|
|
||
|
self.xradios = self.fig.canvas.manager.toolmanager.add_tool(
|
||
|
"x",
|
||
|
RadioHideTool,
|
||
|
options=self.metrics,
|
||
|
callback_fn=self.update_x,
|
||
|
active=self.conf.x,
|
||
|
keymap="x",
|
||
|
)
|
||
|
|
||
|
self.yradios = self.fig.canvas.manager.toolmanager.add_tool(
|
||
|
"y",
|
||
|
RadioHideTool,
|
||
|
options=self.metrics,
|
||
|
callback_fn=self.update_y,
|
||
|
active=self.conf.y,
|
||
|
keymap="y",
|
||
|
)
|
||
|
if self.fig.canvas.manager.toolbar is not None:
|
||
|
self.fig.canvas.manager.toolbar.add_tool("x", "navigation")
|
||
|
self.fig.canvas.manager.toolbar.add_tool("y", "navigation")
|
||
|
|
||
|
def init_frame(self):
|
||
|
"""initialize frame"""
|
||
|
fig, ax = plt.subplots()
|
||
|
ax.set_title("click on points")
|
||
|
diffb_ax = fig.add_axes([0.01, 0.02, 0.12, 0.06])
|
||
|
self.diffb = Button(diffb_ax, label="diff_only")
|
||
|
self.diffb.on_clicked(self.diff_clicked)
|
||
|
fig.canvas.mpl_connect("pick_event", self.on_scatter_pick)
|
||
|
fig.canvas.mpl_connect("motion_notify_event", self.hover)
|
||
|
return fig, ax
|
||
|
|
||
|
def draw(self):
|
||
|
"""redraw content in frame"""
|
||
|
self.scatters = {}
|
||
|
self.axes.clear()
|
||
|
self.axes.set_xlabel(self.conf.x)
|
||
|
self.axes.set_ylabel(self.conf.y)
|
||
|
|
||
|
refx = 0.0
|
||
|
refy = 0.0
|
||
|
x_cat = isinstance(self.results[self.names[0]][self.conf.x][0], (bytes, str))
|
||
|
y_cat = isinstance(self.results[self.names[0]][self.conf.y][0], (bytes, str))
|
||
|
|
||
|
if self.conf.diff:
|
||
|
if not x_cat:
|
||
|
refx = np.array(self.results[self.names[0]][self.conf.x])
|
||
|
if not y_cat:
|
||
|
refy = np.array(self.results[self.names[0]][self.conf.y])
|
||
|
for name in list(self.results.keys()):
|
||
|
x = np.array(self.results[name][self.conf.x])
|
||
|
y = np.array(self.results[name][self.conf.y])
|
||
|
|
||
|
if x_cat and np.char.isdigit(x.astype(str)).all():
|
||
|
x = x.astype(int)
|
||
|
if y_cat and np.char.isdigit(y.astype(str)).all():
|
||
|
y = y.astype(int)
|
||
|
|
||
|
x = x if x_cat else x - refx
|
||
|
y = y if y_cat else y - refy
|
||
|
|
||
|
(s,) = self.axes.plot(
|
||
|
x, y, "o", markersize=3, label=name, picker=True, pickradius=5
|
||
|
)
|
||
|
self.scatters[name] = s
|
||
|
|
||
|
if x_cat and not y_cat:
|
||
|
xunique, ind, xinv, xbin = np.unique(
|
||
|
x, return_inverse=True, return_counts=True, return_index=True
|
||
|
)
|
||
|
ybin = np.bincount(xinv, weights=y)
|
||
|
sort_ax = np.argsort(ind)
|
||
|
self.axes.step(
|
||
|
xunique[sort_ax],
|
||
|
(ybin / xbin)[sort_ax],
|
||
|
where="mid",
|
||
|
color=s.get_color(),
|
||
|
)
|
||
|
|
||
|
if not x_cat:
|
||
|
xavg = np.nan_to_num(x).mean()
|
||
|
self.axes.axvline(xavg, c=s.get_color(), zorder=1, alpha=1.0)
|
||
|
xmed = np.median(x - refx)
|
||
|
self.axes.axvline(
|
||
|
xmed,
|
||
|
c=s.get_color(),
|
||
|
zorder=0,
|
||
|
alpha=0.5,
|
||
|
linestyle="dashed",
|
||
|
visible=False,
|
||
|
)
|
||
|
|
||
|
if not y_cat:
|
||
|
yavg = np.nan_to_num(y).mean()
|
||
|
self.axes.axhline(yavg, c=s.get_color(), zorder=1, alpha=0.5)
|
||
|
ymed = np.median(y - refy)
|
||
|
self.axes.axhline(
|
||
|
ymed,
|
||
|
c=s.get_color(),
|
||
|
zorder=0,
|
||
|
alpha=0.5,
|
||
|
linestyle="dashed",
|
||
|
visible=False,
|
||
|
)
|
||
|
if x_cat and x.dtype == object and xunique.shape[0] > 5:
|
||
|
self.axes.set_xticklabels(xunique[sort_ax], rotation=90)
|
||
|
self.axes.legend()
|
||
|
|
||
|
def on_scatter_pick(self, handle):
|
||
|
try:
|
||
|
art = handle.artist
|
||
|
try:
|
||
|
event = handle.mouseevent.button.value
|
||
|
except AttributeError:
|
||
|
return
|
||
|
name = art.get_label()
|
||
|
ind = handle.ind[0]
|
||
|
# draw lines
|
||
|
self.spawn_child(name, ind, event=event)
|
||
|
except Exception:
|
||
|
traceback.print_exc()
|
||
|
exit(0)
|
||
|
|
||
|
def spawn_child(self, model_name, ind, event=None):
|
||
|
[line.remove() for line in self.lines]
|
||
|
self.lines = []
|
||
|
|
||
|
x_source = self.scatters[model_name].get_xdata()[ind]
|
||
|
y_source = self.scatters[model_name].get_ydata()[ind]
|
||
|
for oname in self.names:
|
||
|
xn = self.scatters[oname].get_xdata()[ind]
|
||
|
yn = self.scatters[oname].get_ydata()[ind]
|
||
|
|
||
|
(ln,) = self.axes.plot([x_source, xn], [y_source, yn], "r")
|
||
|
self.lines.append(ln)
|
||
|
|
||
|
self.fig.canvas.draw_idle()
|
||
|
|
||
|
if self.child_frame is None:
|
||
|
return
|
||
|
|
||
|
data = collate([self.loader.dataset[ind]])
|
||
|
|
||
|
preds = {}
|
||
|
|
||
|
for name, pfile in self.predictions.items():
|
||
|
preds[name] = CacheLoader({"path": str(pfile), "add_data_path": False})(
|
||
|
data
|
||
|
)
|
||
|
summaries_i = {
|
||
|
name: {k: v[ind] for k, v in res.items() if k != "names"}
|
||
|
for name, res in self.results.items()
|
||
|
}
|
||
|
frame = self.child_frame(
|
||
|
self.conf.child,
|
||
|
deepcopy(data),
|
||
|
preds,
|
||
|
title=str(data["name"][0]),
|
||
|
event=event,
|
||
|
summaries=summaries_i,
|
||
|
)
|
||
|
|
||
|
frame.fig.canvas.mpl_connect(
|
||
|
"key_press_event",
|
||
|
functools.partial(
|
||
|
self.on_childframe_key_event, frame=frame, ind=ind, event=event
|
||
|
),
|
||
|
)
|
||
|
self.childs.append(frame)
|
||
|
# if plt.rcParams['backend'] == 'webagg':
|
||
|
# self.fig.canvas.manager_class.refresh_all()
|
||
|
self.childs[-1].fig.show()
|
||
|
|
||
|
def hover(self, event):
|
||
|
if event.inaxes == self.axes:
|
||
|
for _, s in self.scatters.items():
|
||
|
cont, ind = s.contains(event)
|
||
|
if cont:
|
||
|
ind = ind["ind"][0]
|
||
|
xdata, ydata = s.get_data()
|
||
|
[line.remove() for line in self.lines]
|
||
|
self.lines = []
|
||
|
|
||
|
for oname in self.names:
|
||
|
xn = self.scatters[oname].get_xdata()[ind]
|
||
|
yn = self.scatters[oname].get_ydata()[ind]
|
||
|
|
||
|
(ln,) = self.axes.plot(
|
||
|
[xdata[ind], xn],
|
||
|
[ydata[ind], yn],
|
||
|
"black",
|
||
|
zorder=0,
|
||
|
alpha=0.5,
|
||
|
)
|
||
|
self.lines.append(ln)
|
||
|
self.fig.canvas.draw_idle()
|
||
|
break
|
||
|
|
||
|
def diff_clicked(self, args):
|
||
|
self.conf.diff = not self.conf.diff
|
||
|
self.draw()
|
||
|
self.fig.canvas.draw_idle()
|
||
|
|
||
|
def update_x(self, x):
|
||
|
self.conf.x = x
|
||
|
self.draw()
|
||
|
|
||
|
def update_y(self, y):
|
||
|
self.conf.y = y
|
||
|
self.draw()
|
||
|
|
||
|
def on_childframe_key_event(self, key_event, frame, ind, event):
|
||
|
if key_event.key == "delete":
|
||
|
plt.close(frame.fig)
|
||
|
self.childs.remove(frame)
|
||
|
elif key_event.key in ["left", "right", "shift+left", "shift+right"]:
|
||
|
key = key_event.key
|
||
|
if key.startswith("shift+"):
|
||
|
key = key.replace("shift+", "")
|
||
|
else:
|
||
|
plt.close(frame.fig)
|
||
|
self.childs.remove(frame)
|
||
|
new_ind = ind + 1 if key_event.key == "right" else ind - 1
|
||
|
self.spawn_child(
|
||
|
self.names[0],
|
||
|
new_ind % len(self.loader),
|
||
|
event=event,
|
||
|
)
|