glue-factory-custom/gluefactory/models/extractors/disk_kornia.py

109 lines
3.8 KiB
Python
Raw Normal View History

import kornia
import torch
from ..base_model import BaseModel
from ..utils.misc import pad_and_stack
class DISK(BaseModel):
default_conf = {
"weights": "depth",
"dense_outputs": False,
"max_num_keypoints": None,
"desc_dim": 128,
"nms_window_size": 5,
"detection_threshold": 0.0,
"force_num_keypoints": False,
"pad_if_not_divisible": True,
"chunk": 4, # for reduced VRAM in training
}
required_data_keys = ["image"]
def _init(self, conf):
self.model = kornia.feature.DISK.from_pretrained(conf.weights)
self.set_initialized()
def _get_dense_outputs(self, images):
B = images.shape[0]
if self.conf.pad_if_not_divisible:
h, w = images.shape[2:]
pd_h = 16 - h % 16 if h % 16 > 0 else 0
pd_w = 16 - w % 16 if w % 16 > 0 else 0
images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0)
heatmaps, descriptors = self.model.heatmap_and_dense_descriptors(images)
if self.conf.pad_if_not_divisible:
heatmaps = heatmaps[..., :h, :w]
descriptors = descriptors[..., :h, :w]
keypoints = kornia.feature.disk.detector.heatmap_to_keypoints(
heatmaps,
n=self.conf.max_num_keypoints,
window_size=self.conf.nms_window_size,
score_threshold=self.conf.detection_threshold,
)
features = []
for i in range(B):
features.append(keypoints[i].merge_with_descriptors(descriptors[i]))
return features, descriptors
def _forward(self, data):
image = data["image"]
keypoints, scores, descriptors = [], [], []
if self.conf.dense_outputs:
dense_descriptors = []
chunk = self.conf.chunk
for i in range(0, image.shape[0], chunk):
if self.conf.dense_outputs:
features, d_descriptors = self._get_dense_outputs(
image[: min(image.shape[0], i + chunk)]
)
dense_descriptors.append(d_descriptors)
else:
features = self.model(
image[: min(image.shape[0], i + chunk)],
n=self.conf.max_num_keypoints,
window_size=self.conf.nms_window_size,
score_threshold=self.conf.detection_threshold,
pad_if_not_divisible=self.conf.pad_if_not_divisible,
)
keypoints += [f.keypoints for f in features]
scores += [f.detection_scores for f in features]
descriptors += [f.descriptors for f in features]
del features
if self.conf.force_num_keypoints:
# pad to target_length
target_length = self.conf.max_num_keypoints
keypoints = pad_and_stack(
keypoints,
target_length,
-2,
mode="random_c",
bounds=(
0,
data.get("image_size", torch.tensor(image.shape[-2:])).min().item(),
),
)
scores = pad_and_stack(scores, target_length, -1, mode="zeros")
descriptors = pad_and_stack(descriptors, target_length, -2, mode="zeros")
else:
keypoints = torch.stack(keypoints, 0)
scores = torch.stack(scores, 0)
descriptors = torch.stack(descriptors, 0)
pred = {
"keypoints": keypoints.to(image) + 0.5,
"keypoint_scores": scores.to(image),
"descriptors": descriptors.to(image),
}
if self.conf.dense_outputs:
pred["dense_descriptors"] = torch.cat(dense_descriptors, 0)
return pred
def loss(self, pred, data):
raise NotImplementedError