glue-factory-custom/gluefactory/visualization/viz2d.py

487 lines
14 KiB
Python

"""
2D visualization primitives based on Matplotlib.
1) Plot images with `plot_images`.
2) Call `plot_keypoints` or `plot_matches` any number of times.
3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
"""
import matplotlib
import matplotlib.patheffects as path_effects
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
def cm_ranking(sc, ths=[512, 1024, 2048, 4096]):
ls = sc.shape[0]
colors = ["red", "yellow", "lime", "cyan", "blue"]
out = ["gray"] * ls
for i in range(ls):
for c, th in zip(colors[: len(ths) + 1], ths + [ls]):
if i < th:
out[i] = c
break
sid = np.argsort(sc, axis=0).flip(0)
out = np.array(out)[sid]
return out
def cm_RdBl(x):
"""Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
x = np.clip(x, 0, 1)[..., None] * 2
c = x * np.array([[0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0]])
return np.clip(c, 0, 1)
def cm_RdGn(x):
"""Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
x = np.clip(x, 0, 1)[..., None] * 2
c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]])
return np.clip(c, 0, 1)
def cm_BlRdGn(x_):
"""Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
x = np.clip(x_, 0, 1)[..., None] * 2
c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]])
xn = -np.clip(x_, -1, 0)[..., None] * 2
cn = xn * np.array([[0, 1.0, 0, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]])
out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
return out
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True):
"""Plot a set of images horizontally.
Args:
imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
titles: a list of strings, as titles for each image.
cmaps: colormaps for monochrome images.
adaptive: whether the figure size should fit the image aspect ratios.
"""
n = len(imgs)
if not isinstance(cmaps, (list, tuple)):
cmaps = [cmaps] * n
if adaptive:
ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H
else:
ratios = [4 / 3] * n
figsize = [sum(ratios) * 4.5, 4.5]
fig, axs = plt.subplots(
1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
)
if n == 1:
axs = [axs]
for i, (img, ax) in enumerate(zip(imgs, axs)):
ax.imshow(img, cmap=plt.get_cmap(cmaps[i]))
ax.set_axis_off()
if titles:
ax.set_title(titles[i])
fig.tight_layout(pad=pad)
def plot_image_grid(
imgs,
titles=None,
cmaps="gray",
dpi=100,
pad=0.5,
fig=None,
adaptive=True,
figs=2.0,
return_fig=False,
set_lim=False,
):
"""Plot a grid of images.
Args:
imgs: a list of lists of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
titles: a list of strings, as titles for each image.
cmaps: colormaps for monochrome images.
adaptive: whether the figure size should fit the image aspect ratios.
"""
nr, n = len(imgs), len(imgs[0])
if not isinstance(cmaps, (list, tuple)):
cmaps = [cmaps] * n
if adaptive:
ratios = [i.shape[1] / i.shape[0] for i in imgs[0]] # W / H
else:
ratios = [4 / 3] * n
figsize = [sum(ratios) * figs, nr * figs]
if fig is None:
fig, axs = plt.subplots(
nr, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
)
else:
axs = fig.subplots(nr, n, gridspec_kw={"width_ratios": ratios})
fig.figure.set_size_inches(figsize)
if nr == 1:
axs = [axs]
for j in range(nr):
for i in range(n):
ax = axs[j][i]
ax.imshow(imgs[j][i], cmap=plt.get_cmap(cmaps[i]))
ax.set_axis_off()
if set_lim:
ax.set_xlim([0, imgs[j][i].shape[1]])
ax.set_ylim([imgs[j][i].shape[0], 0])
if titles:
ax.set_title(titles[j][i])
if isinstance(fig, plt.Figure):
fig.tight_layout(pad=pad)
if return_fig:
return fig, axs
else:
return axs
def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
"""Plot keypoints for existing images.
Args:
kpts: list of ndarrays of size (N, 2).
colors: string, or list of list of tuples (one for each keypoints).
ps: size of the keypoints as float.
"""
if not isinstance(colors, list):
colors = [colors] * len(kpts)
if not isinstance(a, list):
a = [a] * len(kpts)
if axes is None:
axes = plt.gcf().axes
for ax, k, c, alpha in zip(axes, kpts, colors, a):
ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha)
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None):
"""Plot matches for a pair of existing images.
Args:
kpts0, kpts1: corresponding keypoints of size (N, 2).
color: color of each match, string or RGB tuple. Random if not given.
lw: width of the lines.
ps: size of the end points (no endpoint if ps=0)
indices: indices of the images to draw the matches on.
a: alpha opacity of the match lines.
"""
fig = plt.gcf()
if axes is None:
ax = fig.axes
ax0, ax1 = ax[0], ax[1]
else:
ax0, ax1 = axes
assert len(kpts0) == len(kpts1)
if color is None:
color = sns.color_palette("husl", n_colors=len(kpts0))
elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
color = [color] * len(kpts0)
if lw > 0:
for i in range(len(kpts0)):
line = matplotlib.patches.ConnectionPatch(
xyA=(kpts0[i, 0], kpts0[i, 1]),
xyB=(kpts1[i, 0], kpts1[i, 1]),
coordsA=ax0.transData,
coordsB=ax1.transData,
axesA=ax0,
axesB=ax1,
zorder=1,
color=color[i],
linewidth=lw,
clip_on=True,
alpha=a,
label=None if labels is None else labels[i],
picker=5.0,
)
line.set_annotation_clip(True)
fig.add_artist(line)
# freeze the axes to prevent the transform to change
ax0.autoscale(enable=False)
ax1.autoscale(enable=False)
if ps > 0:
ax0.scatter(
kpts0[:, 0],
kpts0[:, 1],
c=color,
s=ps,
label=None if labels is None else labels[0],
)
ax1.scatter(
kpts1[:, 0],
kpts1[:, 1],
c=color,
s=ps,
label=None if labels is None else labels[1],
)
def add_text(
idx,
text,
pos=(0.01, 0.99),
fs=15,
color="w",
lcolor="k",
lwidth=2,
ha="left",
va="top",
axes=None,
**kwargs,
):
if axes is None:
axes = plt.gcf().axes
ax = axes[idx]
t = ax.text(
*pos,
text,
fontsize=fs,
ha=ha,
va=va,
color=color,
transform=ax.transAxes,
**kwargs,
)
if lcolor is not None:
t.set_path_effects(
[
path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
path_effects.Normal(),
]
)
return t
def draw_epipolar_line(
line, axis, imshape=None, color="b", label=None, alpha=1.0, visible=True
):
if imshape is not None:
h, w = imshape[:2]
else:
_, w = axis.get_xlim()
h, _ = axis.get_ylim()
imshape = (h + 0.5, w + 0.5)
# Intersect line with lines representing image borders.
X1 = np.cross(line, [1, 0, -1])
X1 = X1[:2] / X1[2]
X2 = np.cross(line, [1, 0, -w])
X2 = X2[:2] / X2[2]
X3 = np.cross(line, [0, 1, -1])
X3 = X3[:2] / X3[2]
X4 = np.cross(line, [0, 1, -h])
X4 = X4[:2] / X4[2]
# Find intersections which are not outside the image,
# which will therefore be on the image border.
Xs = [X1, X2, X3, X4]
Ps = []
for p in range(4):
X = Xs[p]
if (0 <= X[0] <= (w + 1e-6)) and (0 <= X[1] <= (h + 1e-6)):
Ps.append(X)
if len(Ps) == 2:
break
# Plot line, if it's visible in the image.
if len(Ps) == 2:
art = axis.plot(
[Ps[0][0], Ps[1][0]],
[Ps[0][1], Ps[1][1]],
color,
linestyle="dashed",
label=label,
alpha=alpha,
visible=visible,
)[0]
return art
else:
return None
def get_line(F, kp):
hom_kp = np.array([list(kp) + [1.0]]).transpose()
return np.dot(F, hom_kp)
def plot_epipolar_lines(
pts0, pts1, F, color="b", axes=None, labels=None, a=1.0, visible=True
):
if axes is None:
axes = plt.gcf().axes
assert len(axes) == 2
for ax, kps in zip(axes, [pts1, pts0]):
_, w = ax.get_xlim()
h, _ = ax.get_ylim()
imshape = (h + 0.5, w + 0.5)
for i in range(kps.shape[0]):
if ax == axes[0]:
line = get_line(F.transpose(0, 1), kps[i])[:, 0]
else:
line = get_line(F, kps[i])[:, 0]
draw_epipolar_line(
line,
ax,
imshape,
color=color,
label=None if labels is None else labels[i],
alpha=a,
visible=visible,
)
def plot_heatmaps(heatmaps, vmin=0.0, vmax=None, cmap="Spectral", a=0.5, axes=None):
if axes is None:
axes = plt.gcf().axes
artists = []
for i in range(len(axes)):
a_ = a if isinstance(a, float) else a[i]
art = axes[i].imshow(
heatmaps[i],
alpha=(heatmaps[i] > vmin).float() * a_,
vmin=vmin,
vmax=vmax,
cmap=cmap,
)
artists.append(art)
return artists
def plot_lines(
lines,
line_colors="orange",
point_colors="cyan",
ps=4,
lw=2,
alpha=1.0,
indices=(0, 1),
):
"""Plot lines and endpoints for existing images.
Args:
lines: list of ndarrays of size (N, 2, 2).
colors: string, or list of list of tuples (one for each keypoints).
ps: size of the keypoints as float pixels.
lw: line width as float pixels.
alpha: transparency of the points and lines.
indices: indices of the images to draw the matches on.
"""
if not isinstance(line_colors, list):
line_colors = [line_colors] * len(lines)
if not isinstance(point_colors, list):
point_colors = [point_colors] * len(lines)
fig = plt.gcf()
ax = fig.axes
assert len(ax) > max(indices)
axes = [ax[i] for i in indices]
# Plot the lines and junctions
for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
for i in range(len(l)):
line = matplotlib.lines.Line2D(
(l[i, 0, 0], l[i, 1, 0]),
(l[i, 0, 1], l[i, 1, 1]),
zorder=1,
c=lc,
linewidth=lw,
alpha=alpha,
)
a.add_line(line)
pts = l.reshape(-1, 2)
a.scatter(pts[:, 0], pts[:, 1], c=pc, s=ps, linewidths=0, zorder=2, alpha=alpha)
def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)):
"""Plot line matches for existing images with multiple colors.
Args:
lines: list of ndarrays of size (N, 2, 2).
correct_matches: bool array of size (N,) indicating correct matches.
lw: line width as float pixels.
indices: indices of the images to draw the matches on.
"""
n_lines = len(lines[0])
colors = sns.color_palette("husl", n_colors=n_lines)
np.random.shuffle(colors)
alphas = np.ones(n_lines)
# If correct_matches is not None, display wrong matches with a low alpha
if correct_matches is not None:
alphas[~np.array(correct_matches)] = 0.2
fig = plt.gcf()
ax = fig.axes
assert len(ax) > max(indices)
axes = [ax[i] for i in indices]
# Plot the lines
for a, img_lines in zip(axes, lines):
for i, line in enumerate(img_lines):
fig.add_artist(
matplotlib.patches.ConnectionPatch(
xyA=tuple(line[0]),
coordsA=a.transData,
xyB=tuple(line[1]),
coordsB=a.transData,
zorder=1,
color=colors[i],
linewidth=lw,
alpha=alphas[i],
)
)
def save_plot(path, **kw):
"""Save the current figure without any white margin."""
plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
def plot_cumulative(
errors: dict,
thresholds: list,
colors=None,
title="",
unit="-",
logx=False,
):
thresholds = np.linspace(min(thresholds), max(thresholds), 100)
plt.figure(figsize=[5, 8])
for method in errors:
recall = []
errs = np.array(errors[method])
for th in thresholds:
recall.append(np.mean(errs <= th))
plt.plot(
thresholds,
np.array(recall) * 100,
label=method,
c=colors[method] if colors else None,
linewidth=3,
)
plt.grid()
plt.xlabel(unit, fontsize=25)
if logx:
plt.semilogx()
plt.ylim([0, 100])
plt.yticks(ticks=[0, 20, 40, 60, 80, 100])
plt.ylabel(title + "Recall [%]", rotation=0, fontsize=25)
plt.gca().yaxis.set_label_coords(x=0.45, y=1.02)
plt.tick_params(axis="both", which="major", labelsize=20)
plt.yticks(rotation=0)
plt.legend(
bbox_to_anchor=(0.45, -0.12),
ncol=2,
loc="upper center",
fontsize=20,
handlelength=3,
)
plt.tight_layout()
return plt.gcf()