2023-10-05 16:53:51 +02:00
|
|
|
import kornia
|
2023-10-09 08:32:43 +02:00
|
|
|
import torch
|
2023-10-05 16:53:51 +02:00
|
|
|
|
|
|
|
from ..base_model import BaseModel
|
|
|
|
from ..utils.misc import pad_and_stack
|
|
|
|
|
|
|
|
|
|
|
|
class DISK(BaseModel):
|
|
|
|
default_conf = {
|
|
|
|
"weights": "depth",
|
|
|
|
"dense_outputs": False,
|
|
|
|
"max_num_keypoints": None,
|
|
|
|
"desc_dim": 128,
|
|
|
|
"nms_window_size": 5,
|
|
|
|
"detection_threshold": 0.0,
|
|
|
|
"force_num_keypoints": False,
|
|
|
|
"pad_if_not_divisible": True,
|
|
|
|
"chunk": 4, # for reduced VRAM in training
|
|
|
|
}
|
|
|
|
required_data_keys = ["image"]
|
|
|
|
|
|
|
|
def _init(self, conf):
|
|
|
|
self.model = kornia.feature.DISK.from_pretrained(conf.weights)
|
2023-10-10 19:46:33 +02:00
|
|
|
self.set_initialized()
|
2023-10-05 16:53:51 +02:00
|
|
|
|
|
|
|
def _get_dense_outputs(self, images):
|
|
|
|
B = images.shape[0]
|
|
|
|
if self.conf.pad_if_not_divisible:
|
|
|
|
h, w = images.shape[2:]
|
|
|
|
pd_h = 16 - h % 16 if h % 16 > 0 else 0
|
|
|
|
pd_w = 16 - w % 16 if w % 16 > 0 else 0
|
|
|
|
images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0)
|
|
|
|
|
|
|
|
heatmaps, descriptors = self.model.heatmap_and_dense_descriptors(images)
|
|
|
|
if self.conf.pad_if_not_divisible:
|
|
|
|
heatmaps = heatmaps[..., :h, :w]
|
|
|
|
descriptors = descriptors[..., :h, :w]
|
|
|
|
|
|
|
|
keypoints = kornia.feature.disk.detector.heatmap_to_keypoints(
|
|
|
|
heatmaps,
|
|
|
|
n=self.conf.max_num_keypoints,
|
|
|
|
window_size=self.conf.nms_window_size,
|
|
|
|
score_threshold=self.conf.detection_threshold,
|
|
|
|
)
|
|
|
|
|
|
|
|
features = []
|
|
|
|
for i in range(B):
|
|
|
|
features.append(keypoints[i].merge_with_descriptors(descriptors[i]))
|
|
|
|
|
|
|
|
return features, descriptors
|
|
|
|
|
|
|
|
def _forward(self, data):
|
|
|
|
image = data["image"]
|
|
|
|
|
|
|
|
keypoints, scores, descriptors = [], [], []
|
|
|
|
if self.conf.dense_outputs:
|
|
|
|
dense_descriptors = []
|
|
|
|
chunk = self.conf.chunk
|
|
|
|
for i in range(0, image.shape[0], chunk):
|
|
|
|
if self.conf.dense_outputs:
|
|
|
|
features, d_descriptors = self._get_dense_outputs(
|
|
|
|
image[: min(image.shape[0], i + chunk)]
|
|
|
|
)
|
|
|
|
dense_descriptors.append(d_descriptors)
|
|
|
|
else:
|
|
|
|
features = self.model(
|
|
|
|
image[: min(image.shape[0], i + chunk)],
|
|
|
|
n=self.conf.max_num_keypoints,
|
|
|
|
window_size=self.conf.nms_window_size,
|
|
|
|
score_threshold=self.conf.detection_threshold,
|
|
|
|
pad_if_not_divisible=self.conf.pad_if_not_divisible,
|
|
|
|
)
|
|
|
|
keypoints += [f.keypoints for f in features]
|
|
|
|
scores += [f.detection_scores for f in features]
|
|
|
|
descriptors += [f.descriptors for f in features]
|
|
|
|
del features
|
|
|
|
|
|
|
|
if self.conf.force_num_keypoints:
|
|
|
|
# pad to target_length
|
|
|
|
target_length = self.conf.max_num_keypoints
|
|
|
|
keypoints = pad_and_stack(
|
|
|
|
keypoints,
|
|
|
|
target_length,
|
|
|
|
-2,
|
|
|
|
mode="random_c",
|
|
|
|
bounds=(
|
|
|
|
0,
|
|
|
|
data.get("image_size", torch.tensor(image.shape[-2:])).min().item(),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
scores = pad_and_stack(scores, target_length, -1, mode="zeros")
|
|
|
|
descriptors = pad_and_stack(descriptors, target_length, -2, mode="zeros")
|
|
|
|
else:
|
|
|
|
keypoints = torch.stack(keypoints, 0)
|
|
|
|
scores = torch.stack(scores, 0)
|
|
|
|
descriptors = torch.stack(descriptors, 0)
|
|
|
|
|
|
|
|
pred = {
|
|
|
|
"keypoints": keypoints.to(image) + 0.5,
|
|
|
|
"keypoint_scores": scores.to(image),
|
|
|
|
"descriptors": descriptors.to(image),
|
|
|
|
}
|
|
|
|
if self.conf.dense_outputs:
|
|
|
|
pred["dense_descriptors"] = torch.cat(dense_descriptors, 0)
|
|
|
|
return pred
|
|
|
|
|
|
|
|
def loss(self, pred, data):
|
|
|
|
raise NotImplementedError
|