487 lines
14 KiB
Python
487 lines
14 KiB
Python
"""
|
|
2D visualization primitives based on Matplotlib.
|
|
1) Plot images with `plot_images`.
|
|
2) Call `plot_keypoints` or `plot_matches` any number of times.
|
|
3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
|
|
"""
|
|
|
|
import matplotlib
|
|
import matplotlib.patheffects as path_effects
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import seaborn as sns
|
|
|
|
|
|
def cm_ranking(sc, ths=[512, 1024, 2048, 4096]):
|
|
ls = sc.shape[0]
|
|
colors = ["red", "yellow", "lime", "cyan", "blue"]
|
|
out = ["gray"] * ls
|
|
for i in range(ls):
|
|
for c, th in zip(colors[: len(ths) + 1], ths + [ls]):
|
|
if i < th:
|
|
out[i] = c
|
|
break
|
|
sid = np.argsort(sc, axis=0).flip(0)
|
|
out = np.array(out)[sid]
|
|
return out
|
|
|
|
|
|
def cm_RdBl(x):
|
|
"""Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
|
|
x = np.clip(x, 0, 1)[..., None] * 2
|
|
c = x * np.array([[0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0]])
|
|
return np.clip(c, 0, 1)
|
|
|
|
|
|
def cm_RdGn(x):
|
|
"""Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
|
|
x = np.clip(x, 0, 1)[..., None] * 2
|
|
c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]])
|
|
return np.clip(c, 0, 1)
|
|
|
|
|
|
def cm_BlRdGn(x_):
|
|
"""Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
|
|
x = np.clip(x_, 0, 1)[..., None] * 2
|
|
c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]])
|
|
|
|
xn = -np.clip(x_, -1, 0)[..., None] * 2
|
|
cn = xn * np.array([[0, 1.0, 0, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]])
|
|
out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
|
|
return out
|
|
|
|
|
|
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True):
|
|
"""Plot a set of images horizontally.
|
|
Args:
|
|
imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
|
|
titles: a list of strings, as titles for each image.
|
|
cmaps: colormaps for monochrome images.
|
|
adaptive: whether the figure size should fit the image aspect ratios.
|
|
"""
|
|
n = len(imgs)
|
|
if not isinstance(cmaps, (list, tuple)):
|
|
cmaps = [cmaps] * n
|
|
|
|
if adaptive:
|
|
ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H
|
|
else:
|
|
ratios = [4 / 3] * n
|
|
figsize = [sum(ratios) * 4.5, 4.5]
|
|
fig, axs = plt.subplots(
|
|
1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
|
|
)
|
|
if n == 1:
|
|
axs = [axs]
|
|
for i, (img, ax) in enumerate(zip(imgs, axs)):
|
|
ax.imshow(img, cmap=plt.get_cmap(cmaps[i]))
|
|
ax.set_axis_off()
|
|
if titles:
|
|
ax.set_title(titles[i])
|
|
fig.tight_layout(pad=pad)
|
|
|
|
|
|
def plot_image_grid(
|
|
imgs,
|
|
titles=None,
|
|
cmaps="gray",
|
|
dpi=100,
|
|
pad=0.5,
|
|
fig=None,
|
|
adaptive=True,
|
|
figs=2.0,
|
|
return_fig=False,
|
|
set_lim=False,
|
|
):
|
|
"""Plot a grid of images.
|
|
Args:
|
|
imgs: a list of lists of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
|
|
titles: a list of strings, as titles for each image.
|
|
cmaps: colormaps for monochrome images.
|
|
adaptive: whether the figure size should fit the image aspect ratios.
|
|
"""
|
|
nr, n = len(imgs), len(imgs[0])
|
|
if not isinstance(cmaps, (list, tuple)):
|
|
cmaps = [cmaps] * n
|
|
|
|
if adaptive:
|
|
ratios = [i.shape[1] / i.shape[0] for i in imgs[0]] # W / H
|
|
else:
|
|
ratios = [4 / 3] * n
|
|
|
|
figsize = [sum(ratios) * figs, nr * figs]
|
|
if fig is None:
|
|
fig, axs = plt.subplots(
|
|
nr, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
|
|
)
|
|
else:
|
|
axs = fig.subplots(nr, n, gridspec_kw={"width_ratios": ratios})
|
|
fig.figure.set_size_inches(figsize)
|
|
if nr == 1:
|
|
axs = [axs]
|
|
|
|
for j in range(nr):
|
|
for i in range(n):
|
|
ax = axs[j][i]
|
|
ax.imshow(imgs[j][i], cmap=plt.get_cmap(cmaps[i]))
|
|
ax.set_axis_off()
|
|
if set_lim:
|
|
ax.set_xlim([0, imgs[j][i].shape[1]])
|
|
ax.set_ylim([imgs[j][i].shape[0], 0])
|
|
if titles:
|
|
ax.set_title(titles[j][i])
|
|
if isinstance(fig, plt.Figure):
|
|
fig.tight_layout(pad=pad)
|
|
if return_fig:
|
|
return fig, axs
|
|
else:
|
|
return axs
|
|
|
|
|
|
def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
|
|
"""Plot keypoints for existing images.
|
|
Args:
|
|
kpts: list of ndarrays of size (N, 2).
|
|
colors: string, or list of list of tuples (one for each keypoints).
|
|
ps: size of the keypoints as float.
|
|
"""
|
|
if not isinstance(colors, list):
|
|
colors = [colors] * len(kpts)
|
|
if not isinstance(a, list):
|
|
a = [a] * len(kpts)
|
|
if axes is None:
|
|
axes = plt.gcf().axes
|
|
for ax, k, c, alpha in zip(axes, kpts, colors, a):
|
|
ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha)
|
|
|
|
|
|
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None):
|
|
"""Plot matches for a pair of existing images.
|
|
Args:
|
|
kpts0, kpts1: corresponding keypoints of size (N, 2).
|
|
color: color of each match, string or RGB tuple. Random if not given.
|
|
lw: width of the lines.
|
|
ps: size of the end points (no endpoint if ps=0)
|
|
indices: indices of the images to draw the matches on.
|
|
a: alpha opacity of the match lines.
|
|
"""
|
|
fig = plt.gcf()
|
|
if axes is None:
|
|
ax = fig.axes
|
|
ax0, ax1 = ax[0], ax[1]
|
|
else:
|
|
ax0, ax1 = axes
|
|
|
|
assert len(kpts0) == len(kpts1)
|
|
if color is None:
|
|
color = sns.color_palette("husl", n_colors=len(kpts0))
|
|
elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
|
|
color = [color] * len(kpts0)
|
|
|
|
if lw > 0:
|
|
for i in range(len(kpts0)):
|
|
line = matplotlib.patches.ConnectionPatch(
|
|
xyA=(kpts0[i, 0], kpts0[i, 1]),
|
|
xyB=(kpts1[i, 0], kpts1[i, 1]),
|
|
coordsA=ax0.transData,
|
|
coordsB=ax1.transData,
|
|
axesA=ax0,
|
|
axesB=ax1,
|
|
zorder=1,
|
|
color=color[i],
|
|
linewidth=lw,
|
|
clip_on=True,
|
|
alpha=a,
|
|
label=None if labels is None else labels[i],
|
|
picker=5.0,
|
|
)
|
|
line.set_annotation_clip(True)
|
|
fig.add_artist(line)
|
|
|
|
# freeze the axes to prevent the transform to change
|
|
ax0.autoscale(enable=False)
|
|
ax1.autoscale(enable=False)
|
|
|
|
if ps > 0:
|
|
ax0.scatter(
|
|
kpts0[:, 0],
|
|
kpts0[:, 1],
|
|
c=color,
|
|
s=ps,
|
|
label=None if labels is None or len(labels) == 0 else labels[0],
|
|
)
|
|
ax1.scatter(
|
|
kpts1[:, 0],
|
|
kpts1[:, 1],
|
|
c=color,
|
|
s=ps,
|
|
label=None if labels is None or len(labels) == 0 else labels[1],
|
|
)
|
|
|
|
|
|
def add_text(
|
|
idx,
|
|
text,
|
|
pos=(0.01, 0.99),
|
|
fs=15,
|
|
color="w",
|
|
lcolor="k",
|
|
lwidth=2,
|
|
ha="left",
|
|
va="top",
|
|
axes=None,
|
|
**kwargs,
|
|
):
|
|
if axes is None:
|
|
axes = plt.gcf().axes
|
|
|
|
ax = axes[idx]
|
|
t = ax.text(
|
|
*pos,
|
|
text,
|
|
fontsize=fs,
|
|
ha=ha,
|
|
va=va,
|
|
color=color,
|
|
transform=ax.transAxes,
|
|
**kwargs,
|
|
)
|
|
if lcolor is not None:
|
|
t.set_path_effects(
|
|
[
|
|
path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
|
|
path_effects.Normal(),
|
|
]
|
|
)
|
|
return t
|
|
|
|
|
|
def draw_epipolar_line(
|
|
line, axis, imshape=None, color="b", label=None, alpha=1.0, visible=True
|
|
):
|
|
if imshape is not None:
|
|
h, w = imshape[:2]
|
|
else:
|
|
_, w = axis.get_xlim()
|
|
h, _ = axis.get_ylim()
|
|
imshape = (h + 0.5, w + 0.5)
|
|
# Intersect line with lines representing image borders.
|
|
X1 = np.cross(line, [1, 0, -1])
|
|
X1 = X1[:2] / X1[2]
|
|
X2 = np.cross(line, [1, 0, -w])
|
|
X2 = X2[:2] / X2[2]
|
|
X3 = np.cross(line, [0, 1, -1])
|
|
X3 = X3[:2] / X3[2]
|
|
X4 = np.cross(line, [0, 1, -h])
|
|
X4 = X4[:2] / X4[2]
|
|
|
|
# Find intersections which are not outside the image,
|
|
# which will therefore be on the image border.
|
|
Xs = [X1, X2, X3, X4]
|
|
Ps = []
|
|
for p in range(4):
|
|
X = Xs[p]
|
|
if (0 <= X[0] <= (w + 1e-6)) and (0 <= X[1] <= (h + 1e-6)):
|
|
Ps.append(X)
|
|
if len(Ps) == 2:
|
|
break
|
|
|
|
# Plot line, if it's visible in the image.
|
|
if len(Ps) == 2:
|
|
art = axis.plot(
|
|
[Ps[0][0], Ps[1][0]],
|
|
[Ps[0][1], Ps[1][1]],
|
|
color,
|
|
linestyle="dashed",
|
|
label=label,
|
|
alpha=alpha,
|
|
visible=visible,
|
|
)[0]
|
|
return art
|
|
else:
|
|
return None
|
|
|
|
|
|
def get_line(F, kp):
|
|
hom_kp = np.array([list(kp) + [1.0]]).transpose()
|
|
return np.dot(F, hom_kp)
|
|
|
|
|
|
def plot_epipolar_lines(
|
|
pts0, pts1, F, color="b", axes=None, labels=None, a=1.0, visible=True
|
|
):
|
|
if axes is None:
|
|
axes = plt.gcf().axes
|
|
assert len(axes) == 2
|
|
|
|
for ax, kps in zip(axes, [pts1, pts0]):
|
|
_, w = ax.get_xlim()
|
|
h, _ = ax.get_ylim()
|
|
|
|
imshape = (h + 0.5, w + 0.5)
|
|
for i in range(kps.shape[0]):
|
|
if ax == axes[0]:
|
|
line = get_line(F.transpose(0, 1), kps[i])[:, 0]
|
|
else:
|
|
line = get_line(F, kps[i])[:, 0]
|
|
draw_epipolar_line(
|
|
line,
|
|
ax,
|
|
imshape,
|
|
color=color,
|
|
label=None if labels is None else labels[i],
|
|
alpha=a,
|
|
visible=visible,
|
|
)
|
|
|
|
|
|
def plot_heatmaps(heatmaps, vmin=0.0, vmax=None, cmap="Spectral", a=0.5, axes=None):
|
|
if axes is None:
|
|
axes = plt.gcf().axes
|
|
artists = []
|
|
for i in range(len(axes)):
|
|
a_ = a if isinstance(a, float) else a[i]
|
|
art = axes[i].imshow(
|
|
heatmaps[i],
|
|
alpha=(heatmaps[i] > vmin).float() * a_,
|
|
vmin=vmin,
|
|
vmax=vmax,
|
|
cmap=cmap,
|
|
)
|
|
artists.append(art)
|
|
return artists
|
|
|
|
|
|
def plot_lines(
|
|
lines,
|
|
line_colors="orange",
|
|
point_colors="cyan",
|
|
ps=4,
|
|
lw=2,
|
|
alpha=1.0,
|
|
indices=(0, 1),
|
|
):
|
|
"""Plot lines and endpoints for existing images.
|
|
Args:
|
|
lines: list of ndarrays of size (N, 2, 2).
|
|
colors: string, or list of list of tuples (one for each keypoints).
|
|
ps: size of the keypoints as float pixels.
|
|
lw: line width as float pixels.
|
|
alpha: transparency of the points and lines.
|
|
indices: indices of the images to draw the matches on.
|
|
"""
|
|
if not isinstance(line_colors, list):
|
|
line_colors = [line_colors] * len(lines)
|
|
if not isinstance(point_colors, list):
|
|
point_colors = [point_colors] * len(lines)
|
|
|
|
fig = plt.gcf()
|
|
ax = fig.axes
|
|
assert len(ax) > max(indices)
|
|
axes = [ax[i] for i in indices]
|
|
|
|
# Plot the lines and junctions
|
|
for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
|
|
for i in range(len(l)):
|
|
line = matplotlib.lines.Line2D(
|
|
(l[i, 0, 0], l[i, 1, 0]),
|
|
(l[i, 0, 1], l[i, 1, 1]),
|
|
zorder=1,
|
|
c=lc,
|
|
linewidth=lw,
|
|
alpha=alpha,
|
|
)
|
|
a.add_line(line)
|
|
pts = l.reshape(-1, 2)
|
|
a.scatter(pts[:, 0], pts[:, 1], c=pc, s=ps, linewidths=0, zorder=2, alpha=alpha)
|
|
|
|
|
|
def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)):
|
|
"""Plot line matches for existing images with multiple colors.
|
|
Args:
|
|
lines: list of ndarrays of size (N, 2, 2).
|
|
correct_matches: bool array of size (N,) indicating correct matches.
|
|
lw: line width as float pixels.
|
|
indices: indices of the images to draw the matches on.
|
|
"""
|
|
n_lines = len(lines[0])
|
|
colors = sns.color_palette("husl", n_colors=n_lines)
|
|
np.random.shuffle(colors)
|
|
alphas = np.ones(n_lines)
|
|
# If correct_matches is not None, display wrong matches with a low alpha
|
|
if correct_matches is not None:
|
|
alphas[~np.array(correct_matches)] = 0.2
|
|
|
|
fig = plt.gcf()
|
|
ax = fig.axes
|
|
assert len(ax) > max(indices)
|
|
axes = [ax[i] for i in indices]
|
|
|
|
# Plot the lines
|
|
for a, img_lines in zip(axes, lines):
|
|
for i, line in enumerate(img_lines):
|
|
fig.add_artist(
|
|
matplotlib.patches.ConnectionPatch(
|
|
xyA=tuple(line[0]),
|
|
coordsA=a.transData,
|
|
xyB=tuple(line[1]),
|
|
coordsB=a.transData,
|
|
zorder=1,
|
|
color=colors[i],
|
|
linewidth=lw,
|
|
alpha=alphas[i],
|
|
)
|
|
)
|
|
|
|
|
|
def save_plot(path, **kw):
|
|
"""Save the current figure without any white margin."""
|
|
plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
|
|
|
|
|
|
def plot_cumulative(
|
|
errors: dict,
|
|
thresholds: list,
|
|
colors=None,
|
|
title="",
|
|
unit="-",
|
|
logx=False,
|
|
):
|
|
thresholds = np.linspace(min(thresholds), max(thresholds), 100)
|
|
|
|
plt.figure(figsize=[5, 8])
|
|
for method in errors:
|
|
recall = []
|
|
errs = np.array(errors[method])
|
|
for th in thresholds:
|
|
recall.append(np.mean(errs <= th))
|
|
plt.plot(
|
|
thresholds,
|
|
np.array(recall) * 100,
|
|
label=method,
|
|
c=colors[method] if colors else None,
|
|
linewidth=3,
|
|
)
|
|
|
|
plt.grid()
|
|
plt.xlabel(unit, fontsize=25)
|
|
if logx:
|
|
plt.semilogx()
|
|
plt.ylim([0, 100])
|
|
plt.yticks(ticks=[0, 20, 40, 60, 80, 100])
|
|
plt.ylabel(title + "Recall [%]", rotation=0, fontsize=25)
|
|
plt.gca().yaxis.set_label_coords(x=0.45, y=1.02)
|
|
plt.tick_params(axis="both", which="major", labelsize=20)
|
|
plt.yticks(rotation=0)
|
|
|
|
plt.legend(
|
|
bbox_to_anchor=(0.45, -0.12),
|
|
ncol=2,
|
|
loc="upper center",
|
|
fontsize=20,
|
|
handlelength=3,
|
|
)
|
|
plt.tight_layout()
|
|
|
|
return plt.gcf()
|