2023-10-05 16:53:51 +02:00
|
|
|
"""
|
|
|
|
# %BANNER_BEGIN%
|
|
|
|
# ---------------------------------------------------------------------
|
|
|
|
# %COPYRIGHT_BEGIN%
|
|
|
|
#
|
|
|
|
# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
|
|
|
|
#
|
|
|
|
# Unpublished Copyright (c) 2020
|
|
|
|
# Magic Leap, Inc., All Rights Reserved.
|
|
|
|
#
|
|
|
|
# NOTICE: All information contained herein is, and remains the property
|
|
|
|
# of COMPANY. The intellectual and technical concepts contained herein
|
|
|
|
# are proprietary to COMPANY and may be covered by U.S. and Foreign
|
|
|
|
# Patents, patents in process, and are protected by trade secret or
|
|
|
|
# copyright law. Dissemination of this information or reproduction of
|
|
|
|
# this material is strictly forbidden unless prior written permission is
|
|
|
|
# obtained from COMPANY. Access to the source code contained herein is
|
|
|
|
# hereby forbidden to anyone except current COMPANY employees, managers
|
|
|
|
# or contractors who have executed Confidentiality and Non-disclosure
|
|
|
|
# agreements explicitly covering such access.
|
|
|
|
#
|
|
|
|
# The copyright notice above does not evidence any actual or intended
|
|
|
|
# publication or disclosure of this source code, which includes
|
|
|
|
# information that is confidential and/or proprietary, and is a trade
|
|
|
|
# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
|
|
|
|
# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
|
|
|
|
# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
|
|
|
|
# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
|
|
|
|
# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
|
|
|
|
# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
|
|
|
|
# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
|
|
|
|
# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
|
|
|
|
#
|
|
|
|
# %COPYRIGHT_END%
|
|
|
|
# ----------------------------------------------------------------------
|
|
|
|
# %AUTHORS_BEGIN%
|
|
|
|
#
|
|
|
|
# Originating Authors: Paul-Edouard Sarlin
|
|
|
|
#
|
|
|
|
# %AUTHORS_END%
|
|
|
|
# --------------------------------------------------------------------*/
|
|
|
|
# %BANNER_END%
|
|
|
|
|
|
|
|
Described in:
|
|
|
|
SuperPoint: Self-Supervised Interest Point Detection and Description,
|
|
|
|
Daniel DeTone, Tomasz Malisiewicz, Andrew Rabinovich, CVPRW 2018.
|
|
|
|
|
|
|
|
Original code: github.com/MagicLeapResearch/SuperPointPretrainedNetwork
|
|
|
|
|
|
|
|
Adapted by Philipp Lindenberger (Phil26AT)
|
|
|
|
"""
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
from gluefactory.models.base_model import BaseModel
|
|
|
|
from gluefactory.models.utils.misc import pad_and_stack
|
|
|
|
|
|
|
|
|
|
|
|
def simple_nms(scores, radius):
|
|
|
|
"""Perform non maximum suppression on the heatmap using max-pooling.
|
|
|
|
This method does not suppress contiguous points that have the same score.
|
|
|
|
Args:
|
|
|
|
scores: the score heatmap of size `(B, H, W)`.
|
|
|
|
radius: an integer scalar, the radius of the NMS window.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def max_pool(x):
|
|
|
|
return torch.nn.functional.max_pool2d(
|
|
|
|
x, kernel_size=radius * 2 + 1, stride=1, padding=radius
|
|
|
|
)
|
|
|
|
|
|
|
|
zeros = torch.zeros_like(scores)
|
|
|
|
max_mask = scores == max_pool(scores)
|
|
|
|
for _ in range(2):
|
|
|
|
supp_mask = max_pool(max_mask.float()) > 0
|
|
|
|
supp_scores = torch.where(supp_mask, zeros, scores)
|
|
|
|
new_max_mask = supp_scores == max_pool(supp_scores)
|
|
|
|
max_mask = max_mask | (new_max_mask & (~supp_mask))
|
|
|
|
return torch.where(max_mask, scores, zeros)
|
|
|
|
|
|
|
|
|
|
|
|
def top_k_keypoints(keypoints, scores, k):
|
|
|
|
if k >= len(keypoints):
|
|
|
|
return keypoints, scores
|
|
|
|
scores, indices = torch.topk(scores, k, dim=0, sorted=True)
|
|
|
|
return keypoints[indices], scores
|
|
|
|
|
|
|
|
|
|
|
|
def sample_k_keypoints(keypoints, scores, k):
|
|
|
|
if k >= len(keypoints):
|
|
|
|
return keypoints, scores
|
|
|
|
indices = torch.multinomial(scores, k, replacement=False)
|
|
|
|
return keypoints[indices], scores[indices]
|
|
|
|
|
|
|
|
|
|
|
|
def soft_argmax_refinement(keypoints, scores, radius: int):
|
|
|
|
width = 2 * radius + 1
|
|
|
|
sum_ = torch.nn.functional.avg_pool2d(
|
|
|
|
scores[:, None], width, 1, radius, divisor_override=1
|
|
|
|
)
|
|
|
|
ar = torch.arange(-radius, radius + 1).to(scores)
|
|
|
|
kernel_x = ar[None].expand(width, -1)[None, None]
|
|
|
|
dx = torch.nn.functional.conv2d(scores[:, None], kernel_x, padding=radius)
|
|
|
|
dy = torch.nn.functional.conv2d(
|
|
|
|
scores[:, None], kernel_x.transpose(2, 3), padding=radius
|
|
|
|
)
|
|
|
|
dydx = torch.stack([dy[:, 0], dx[:, 0]], -1) / sum_[:, 0, :, :, None]
|
|
|
|
refined_keypoints = []
|
|
|
|
for i, kpts in enumerate(keypoints):
|
|
|
|
delta = dydx[i][tuple(kpts.t())]
|
|
|
|
refined_keypoints.append(kpts.float() + delta)
|
|
|
|
return refined_keypoints
|
|
|
|
|
|
|
|
|
|
|
|
# Legacy (broken) sampling of the descriptors
|
|
|
|
def sample_descriptors(keypoints, descriptors, s):
|
|
|
|
b, c, h, w = descriptors.shape
|
|
|
|
keypoints = keypoints - s / 2 + 0.5
|
|
|
|
keypoints /= torch.tensor(
|
|
|
|
[(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
|
|
|
|
).to(
|
|
|
|
keypoints
|
|
|
|
)[None]
|
|
|
|
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
|
|
|
args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
|
|
|
|
descriptors = torch.nn.functional.grid_sample(
|
|
|
|
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
|
|
|
|
)
|
|
|
|
descriptors = torch.nn.functional.normalize(
|
|
|
|
descriptors.reshape(b, c, -1), p=2, dim=1
|
|
|
|
)
|
|
|
|
return descriptors
|
|
|
|
|
|
|
|
|
|
|
|
# The original keypoint sampling is incorrect. We patch it here but
|
|
|
|
# keep the original one above for legacy.
|
|
|
|
def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8):
|
|
|
|
"""Interpolate descriptors at keypoint locations"""
|
|
|
|
b, c, h, w = descriptors.shape
|
|
|
|
keypoints = keypoints / (keypoints.new_tensor([w, h]) * s)
|
|
|
|
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
|
|
|
descriptors = torch.nn.functional.grid_sample(
|
|
|
|
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False
|
|
|
|
)
|
|
|
|
descriptors = torch.nn.functional.normalize(
|
|
|
|
descriptors.reshape(b, c, -1), p=2, dim=1
|
|
|
|
)
|
|
|
|
return descriptors
|
|
|
|
|
|
|
|
|
|
|
|
class SuperPoint(BaseModel):
|
|
|
|
default_conf = {
|
|
|
|
"has_detector": True,
|
|
|
|
"has_descriptor": True,
|
|
|
|
"descriptor_dim": 256,
|
|
|
|
# Inference
|
|
|
|
"sparse_outputs": True,
|
|
|
|
"dense_outputs": False,
|
|
|
|
"nms_radius": 4,
|
|
|
|
"refinement_radius": 0,
|
|
|
|
"detection_threshold": 0.005,
|
|
|
|
"max_num_keypoints": -1,
|
|
|
|
"max_num_keypoints_val": None,
|
|
|
|
"force_num_keypoints": False,
|
|
|
|
"randomize_keypoints_training": False,
|
|
|
|
"remove_borders": 4,
|
|
|
|
"legacy_sampling": True, # True to use the old broken sampling
|
|
|
|
}
|
|
|
|
required_data_keys = ["image"]
|
|
|
|
|
|
|
|
checkpoint_url = "https://github.com/magicleap/SuperGluePretrainedNetwork/raw/master/models/weights/superpoint_v1.pth" # noqa: E501
|
|
|
|
|
|
|
|
def _init(self, conf):
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
|
|
c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
|
|
|
|
|
|
|
|
self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
|
|
|
|
|
|
|
|
if conf.has_detector:
|
|
|
|
self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
|
|
|
|
|
|
|
|
if conf.has_descriptor:
|
|
|
|
self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.convDb = nn.Conv2d(
|
|
|
|
c5, conf.descriptor_dim, kernel_size=1, stride=1, padding=0
|
|
|
|
)
|
|
|
|
|
|
|
|
self.load_state_dict(
|
|
|
|
torch.hub.load_state_dict_from_url(str(self.checkpoint_url)), strict=False
|
|
|
|
)
|
|
|
|
|
|
|
|
def _forward(self, data):
|
|
|
|
image = data["image"]
|
|
|
|
if image.shape[1] == 3: # RGB
|
|
|
|
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
|
|
|
|
image = (image * scale).sum(1, keepdim=True)
|
|
|
|
|
|
|
|
# Shared Encoder
|
|
|
|
x = self.relu(self.conv1a(image))
|
|
|
|
x = self.relu(self.conv1b(x))
|
|
|
|
x = self.pool(x)
|
|
|
|
x = self.relu(self.conv2a(x))
|
|
|
|
x = self.relu(self.conv2b(x))
|
|
|
|
x = self.pool(x)
|
|
|
|
x = self.relu(self.conv3a(x))
|
|
|
|
x = self.relu(self.conv3b(x))
|
|
|
|
x = self.pool(x)
|
|
|
|
x = self.relu(self.conv4a(x))
|
|
|
|
x = self.relu(self.conv4b(x))
|
|
|
|
|
|
|
|
pred = {}
|
|
|
|
if self.conf.has_detector:
|
|
|
|
# Compute the dense keypoint scores
|
|
|
|
cPa = self.relu(self.convPa(x))
|
|
|
|
scores = self.convPb(cPa)
|
|
|
|
scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
|
|
|
|
b, c, h, w = scores.shape
|
|
|
|
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
|
|
|
|
scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
|
|
|
|
pred["keypoint_scores"] = dense_scores = scores
|
|
|
|
if self.conf.has_descriptor:
|
|
|
|
# Compute the dense descriptors
|
|
|
|
cDa = self.relu(self.convDa(x))
|
|
|
|
dense_desc = self.convDb(cDa)
|
|
|
|
dense_desc = torch.nn.functional.normalize(dense_desc, p=2, dim=1)
|
|
|
|
pred["descriptors"] = dense_desc
|
|
|
|
|
|
|
|
if self.conf.sparse_outputs:
|
|
|
|
assert self.conf.has_detector and self.conf.has_descriptor
|
|
|
|
|
|
|
|
scores = simple_nms(scores, self.conf.nms_radius)
|
|
|
|
|
|
|
|
# Discard keypoints near the image borders
|
|
|
|
if self.conf.remove_borders:
|
|
|
|
scores[:, : self.conf.remove_borders] = -1
|
|
|
|
scores[:, :, : self.conf.remove_borders] = -1
|
|
|
|
if "image_size" in data:
|
|
|
|
for i in range(scores.shape[0]):
|
|
|
|
w, h = data["image_size"][i]
|
|
|
|
scores[i, int(h.item()) - self.conf.remove_borders :] = -1
|
|
|
|
scores[i, :, int(w.item()) - self.conf.remove_borders :] = -1
|
|
|
|
else:
|
|
|
|
scores[:, -self.conf.remove_borders :] = -1
|
|
|
|
scores[:, :, -self.conf.remove_borders :] = -1
|
|
|
|
|
|
|
|
# Extract keypoints
|
|
|
|
best_kp = torch.where(scores > self.conf.detection_threshold)
|
|
|
|
scores = scores[best_kp]
|
|
|
|
|
|
|
|
# Separate into batches
|
|
|
|
keypoints = [
|
|
|
|
torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b)
|
|
|
|
]
|
|
|
|
scores = [scores[best_kp[0] == i] for i in range(b)]
|
|
|
|
|
|
|
|
# Keep the k keypoints with highest score
|
|
|
|
max_kps = self.conf.max_num_keypoints
|
|
|
|
|
|
|
|
# for val we allow different
|
|
|
|
if not self.training and self.conf.max_num_keypoints_val is not None:
|
|
|
|
max_kps = self.conf.max_num_keypoints_val
|
|
|
|
|
|
|
|
# Keep the k keypoints with highest score
|
|
|
|
if max_kps > 0:
|
|
|
|
if self.conf.randomize_keypoints_training and self.training:
|
|
|
|
# instead of selecting top-k, sample k by score weights
|
|
|
|
keypoints, scores = list(
|
|
|
|
zip(
|
|
|
|
*[
|
|
|
|
sample_k_keypoints(k, s, max_kps)
|
|
|
|
for k, s in zip(keypoints, scores)
|
|
|
|
]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
keypoints, scores = list(
|
|
|
|
zip(
|
|
|
|
*[
|
|
|
|
top_k_keypoints(k, s, max_kps)
|
|
|
|
for k, s in zip(keypoints, scores)
|
|
|
|
]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
keypoints, scores = list(keypoints), list(scores)
|
|
|
|
|
|
|
|
if self.conf["refinement_radius"] > 0:
|
|
|
|
keypoints = soft_argmax_refinement(
|
|
|
|
keypoints, dense_scores, self.conf["refinement_radius"]
|
|
|
|
)
|
|
|
|
|
|
|
|
# Convert (h, w) to (x, y)
|
|
|
|
keypoints = [torch.flip(k, [1]).float() for k in keypoints]
|
|
|
|
|
|
|
|
if self.conf.force_num_keypoints:
|
|
|
|
keypoints = pad_and_stack(
|
|
|
|
keypoints,
|
|
|
|
max_kps,
|
|
|
|
-2,
|
|
|
|
mode="random_c",
|
|
|
|
bounds=(
|
|
|
|
0,
|
|
|
|
data.get("image_size", torch.tensor(image.shape[-2:]))
|
|
|
|
.min()
|
|
|
|
.item(),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
scores = pad_and_stack(scores, max_kps, -1, mode="zeros")
|
|
|
|
else:
|
|
|
|
keypoints = torch.stack(keypoints, 0)
|
|
|
|
scores = torch.stack(scores, 0)
|
|
|
|
|
|
|
|
# Extract descriptors
|
|
|
|
if (len(keypoints) == 1) or self.conf.force_num_keypoints:
|
|
|
|
# Batch sampling of the descriptors
|
|
|
|
if self.conf.legacy_sampling:
|
|
|
|
desc = sample_descriptors(keypoints, dense_desc, 8)
|
|
|
|
else:
|
|
|
|
desc = sample_descriptors_fix_sampling(keypoints, dense_desc, 8)
|
|
|
|
else:
|
|
|
|
if self.conf.legacy_sampling:
|
|
|
|
desc = [
|
|
|
|
sample_descriptors(k[None], d[None], 8)[0]
|
|
|
|
for k, d in zip(keypoints, dense_desc)
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
desc = [
|
|
|
|
sample_descriptors_fix_sampling(k[None], d[None], 8)[0]
|
|
|
|
for k, d in zip(keypoints, dense_desc)
|
|
|
|
]
|
|
|
|
|
2024-02-14 00:02:09 +01:00
|
|
|
if isinstance(desc, list):
|
|
|
|
if all(isinstance(d, torch.Tensor) for d in desc):
|
|
|
|
# If desc is a list of tensors
|
|
|
|
desc = torch.stack(desc)
|
|
|
|
else:
|
|
|
|
# If desc is a list of non-tensor elements
|
|
|
|
desc = torch.stack([torch.tensor(d) for d in desc])
|
|
|
|
|
2023-10-05 16:53:51 +02:00
|
|
|
pred = {
|
|
|
|
"keypoints": keypoints + 0.5,
|
|
|
|
"keypoint_scores": scores,
|
|
|
|
"descriptors": desc.transpose(-1, -2),
|
|
|
|
}
|
|
|
|
|
|
|
|
if self.conf.dense_outputs:
|
|
|
|
pred["dense_descriptors"] = dense_desc
|
|
|
|
|
|
|
|
return pred
|
|
|
|
|
|
|
|
def loss(self, pred, data):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def metrics(self, pred, data):
|
|
|
|
raise NotImplementedError
|