2023-10-05 16:53:51 +02:00
|
|
|
"""PyTorch implementation of the SuperPoint model,
|
|
|
|
derived from the TensorFlow re-implementation (2018).
|
|
|
|
Authors: Rémi Pautrat, Paul-Edouard Sarlin
|
|
|
|
https://github.com/rpautrat/SuperPoint
|
|
|
|
The implementation of this model and its trained weights are made
|
|
|
|
available under the MIT license.
|
|
|
|
"""
|
|
|
|
from collections import OrderedDict
|
|
|
|
from types import SimpleNamespace
|
|
|
|
|
2023-10-09 08:32:43 +02:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
2023-10-05 16:53:51 +02:00
|
|
|
from ..base_model import BaseModel
|
|
|
|
from ..utils.misc import pad_and_stack
|
|
|
|
|
|
|
|
|
|
|
|
def sample_descriptors(keypoints, descriptors, s: int = 8):
|
|
|
|
"""Interpolate descriptors at keypoint locations"""
|
|
|
|
b, c, h, w = descriptors.shape
|
|
|
|
keypoints = (keypoints + 0.5) / (keypoints.new_tensor([w, h]) * s)
|
|
|
|
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
|
|
|
descriptors = torch.nn.functional.grid_sample(
|
|
|
|
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False
|
|
|
|
)
|
|
|
|
descriptors = torch.nn.functional.normalize(
|
|
|
|
descriptors.reshape(b, c, -1), p=2, dim=1
|
|
|
|
)
|
|
|
|
return descriptors
|
|
|
|
|
|
|
|
|
|
|
|
def batched_nms(scores, nms_radius: int):
|
|
|
|
assert nms_radius >= 0
|
|
|
|
|
|
|
|
def max_pool(x):
|
|
|
|
return torch.nn.functional.max_pool2d(
|
|
|
|
x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
|
|
|
)
|
|
|
|
|
|
|
|
zeros = torch.zeros_like(scores)
|
|
|
|
max_mask = scores == max_pool(scores)
|
|
|
|
for _ in range(2):
|
|
|
|
supp_mask = max_pool(max_mask.float()) > 0
|
|
|
|
supp_scores = torch.where(supp_mask, zeros, scores)
|
|
|
|
new_max_mask = supp_scores == max_pool(supp_scores)
|
|
|
|
max_mask = max_mask | (new_max_mask & (~supp_mask))
|
|
|
|
return torch.where(max_mask, scores, zeros)
|
|
|
|
|
|
|
|
|
|
|
|
def select_top_k_keypoints(keypoints, scores, k):
|
|
|
|
if k >= len(keypoints):
|
|
|
|
return keypoints, scores
|
|
|
|
scores, indices = torch.topk(scores, k, dim=0, sorted=True)
|
|
|
|
return keypoints[indices], scores
|
|
|
|
|
|
|
|
|
|
|
|
class VGGBlock(nn.Sequential):
|
|
|
|
def __init__(self, c_in, c_out, kernel_size, relu=True):
|
|
|
|
padding = (kernel_size - 1) // 2
|
|
|
|
conv = nn.Conv2d(
|
|
|
|
c_in, c_out, kernel_size=kernel_size, stride=1, padding=padding
|
|
|
|
)
|
|
|
|
activation = nn.ReLU(inplace=True) if relu else nn.Identity()
|
|
|
|
bn = nn.BatchNorm2d(c_out, eps=0.001)
|
|
|
|
super().__init__(
|
|
|
|
OrderedDict(
|
|
|
|
[
|
|
|
|
("conv", conv),
|
|
|
|
("activation", activation),
|
|
|
|
("bn", bn),
|
|
|
|
]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class SuperPoint(BaseModel):
|
|
|
|
default_conf = {
|
|
|
|
"descriptor_dim": 256,
|
|
|
|
"nms_radius": 4,
|
|
|
|
"max_num_keypoints": None,
|
|
|
|
"force_num_keypoints": False,
|
|
|
|
"detection_threshold": 0.005,
|
|
|
|
"remove_borders": 4,
|
|
|
|
"descriptor_dim": 256,
|
|
|
|
"channels": [64, 64, 128, 128, 256],
|
|
|
|
"dense_outputs": None,
|
|
|
|
}
|
|
|
|
|
|
|
|
checkpoint_url = "https://github.com/rpautrat/SuperPoint/raw/master/weights/superpoint_v6_from_tf.pth" # noqa: E501
|
|
|
|
|
|
|
|
def _init(self, conf):
|
|
|
|
self.conf = SimpleNamespace(**conf)
|
|
|
|
self.stride = 2 ** (len(self.conf.channels) - 2)
|
|
|
|
channels = [1, *self.conf.channels[:-1]]
|
|
|
|
|
|
|
|
backbone = []
|
|
|
|
for i, c in enumerate(channels[1:], 1):
|
|
|
|
layers = [VGGBlock(channels[i - 1], c, 3), VGGBlock(c, c, 3)]
|
|
|
|
if i < len(channels) - 1:
|
|
|
|
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
|
|
|
|
backbone.append(nn.Sequential(*layers))
|
|
|
|
self.backbone = nn.Sequential(*backbone)
|
|
|
|
|
|
|
|
c = self.conf.channels[-1]
|
|
|
|
self.detector = nn.Sequential(
|
|
|
|
VGGBlock(channels[-1], c, 3),
|
|
|
|
VGGBlock(c, self.stride**2 + 1, 1, relu=False),
|
|
|
|
)
|
|
|
|
self.descriptor = nn.Sequential(
|
|
|
|
VGGBlock(channels[-1], c, 3),
|
|
|
|
VGGBlock(c, self.conf.descriptor_dim, 1, relu=False),
|
|
|
|
)
|
|
|
|
|
|
|
|
state_dict = torch.hub.load_state_dict_from_url(self.checkpoint_url)
|
|
|
|
self.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
def _forward(self, data):
|
|
|
|
image = data["image"]
|
|
|
|
if image.shape[1] == 3: # RGB
|
|
|
|
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
|
|
|
|
image = (image * scale).sum(1, keepdim=True)
|
|
|
|
features = self.backbone(image)
|
|
|
|
descriptors_dense = torch.nn.functional.normalize(
|
|
|
|
self.descriptor(features), p=2, dim=1
|
|
|
|
)
|
|
|
|
|
|
|
|
# Decode the detection scores
|
|
|
|
scores = self.detector(features)
|
|
|
|
scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
|
|
|
|
b, _, h, w = scores.shape
|
|
|
|
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, self.stride, self.stride)
|
|
|
|
scores = scores.permute(0, 1, 3, 2, 4).reshape(
|
|
|
|
b, h * self.stride, w * self.stride
|
|
|
|
)
|
|
|
|
scores = batched_nms(scores, self.conf.nms_radius)
|
|
|
|
|
|
|
|
# Discard keypoints near the image borders
|
|
|
|
if self.conf.remove_borders:
|
|
|
|
pad = self.conf.remove_borders
|
|
|
|
scores[:, :pad] = -1
|
|
|
|
scores[:, :, :pad] = -1
|
|
|
|
scores[:, -pad:] = -1
|
|
|
|
scores[:, :, -pad:] = -1
|
|
|
|
|
|
|
|
# Extract keypoints
|
|
|
|
if b > 1:
|
|
|
|
idxs = torch.where(scores > self.conf.detection_threshold)
|
|
|
|
mask = idxs[0] == torch.arange(b, device=scores.device)[:, None]
|
|
|
|
else: # Faster shortcut
|
|
|
|
scores = scores.squeeze(0)
|
|
|
|
idxs = torch.where(scores > self.conf.detection_threshold)
|
|
|
|
|
|
|
|
# Convert (i, j) to (x, y)
|
|
|
|
keypoints_all = torch.stack(idxs[-2:], dim=-1).flip(1).float()
|
|
|
|
scores_all = scores[idxs]
|
|
|
|
|
|
|
|
keypoints = []
|
|
|
|
scores = []
|
|
|
|
for i in range(b):
|
|
|
|
if b > 1:
|
|
|
|
k = keypoints_all[mask[i]]
|
|
|
|
s = scores_all[mask[i]]
|
|
|
|
else:
|
|
|
|
k = keypoints_all
|
|
|
|
s = scores_all
|
|
|
|
if self.conf.max_num_keypoints is not None:
|
|
|
|
k, s = select_top_k_keypoints(k, s, self.conf.max_num_keypoints)
|
|
|
|
|
|
|
|
keypoints.append(k)
|
|
|
|
scores.append(s)
|
|
|
|
|
|
|
|
if self.conf.force_num_keypoints:
|
|
|
|
keypoints = pad_and_stack(
|
|
|
|
keypoints,
|
|
|
|
self.conf.max_num_keypoints,
|
|
|
|
-2,
|
|
|
|
mode="random_c",
|
|
|
|
bounds=(
|
|
|
|
0,
|
|
|
|
data.get("image_size", torch.tensor(image.shape[-2:])).min().item(),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
scores = pad_and_stack(
|
|
|
|
scores, self.conf.max_num_keypoints, -1, mode="zeros"
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
keypoints = torch.stack(keypoints, 0)
|
|
|
|
scores = torch.stack(scores, 0)
|
|
|
|
|
|
|
|
if len(keypoints) == 1 or self.conf.force_num_keypoints:
|
|
|
|
# Batch sampling of the descriptors
|
|
|
|
desc = sample_descriptors(keypoints, descriptors_dense, self.stride)
|
|
|
|
else:
|
|
|
|
desc = [
|
|
|
|
sample_descriptors(k[None], d[None], self.stride)[0]
|
|
|
|
for k, d in zip(keypoints, descriptors_dense)
|
|
|
|
]
|
|
|
|
|
|
|
|
pred = {
|
|
|
|
"keypoints": keypoints + 0.5,
|
|
|
|
"keypoint_scores": scores,
|
|
|
|
"descriptors": desc.transpose(-1, -2),
|
|
|
|
}
|
|
|
|
if self.conf.dense_outputs:
|
|
|
|
pred["dense_descriptors"] = descriptors_dense
|
|
|
|
|
|
|
|
return pred
|
|
|
|
|
|
|
|
def loss(self, pred, data):
|
|
|
|
raise NotImplementedError
|