2023-10-05 16:53:51 +02:00
|
|
|
import kornia
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from ..base_model import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
class KorniaSIFT(BaseModel):
|
|
|
|
default_conf = {
|
|
|
|
"has_detector": True,
|
|
|
|
"has_descriptor": True,
|
|
|
|
"max_num_keypoints": -1,
|
|
|
|
"detection_threshold": None,
|
|
|
|
"rootsift": True,
|
|
|
|
}
|
|
|
|
|
|
|
|
required_data_keys = ["image"]
|
|
|
|
|
|
|
|
def _init(self, conf):
|
|
|
|
self.sift = kornia.feature.SIFTFeature(
|
|
|
|
num_features=self.conf.max_num_keypoints, rootsift=self.conf.rootsift
|
|
|
|
)
|
2023-10-10 19:46:33 +02:00
|
|
|
self.set_initialized()
|
2023-10-05 16:53:51 +02:00
|
|
|
|
|
|
|
def _forward(self, data):
|
|
|
|
lafs, scores, descriptors = self.sift(data["image"])
|
|
|
|
keypoints = kornia.feature.get_laf_center(lafs)
|
2023-10-19 19:06:06 +02:00
|
|
|
scales = kornia.feature.get_laf_scale(lafs).squeeze(-1).squeeze(-1)
|
|
|
|
oris = kornia.feature.get_laf_orientation(lafs).squeeze(-1)
|
2023-10-05 16:53:51 +02:00
|
|
|
pred = {
|
|
|
|
"keypoints": keypoints, # @TODO: confirm keypoints are in corner convention
|
|
|
|
"scales": scales,
|
|
|
|
"oris": oris,
|
|
|
|
"keypoint_scores": scores,
|
|
|
|
}
|
|
|
|
|
|
|
|
if self.conf.has_descriptor:
|
|
|
|
pred["descriptors"] = descriptors
|
|
|
|
|
|
|
|
pred = {k: pred[k].to(device=data["image"].device) for k in pred.keys()}
|
|
|
|
|
|
|
|
pred["scales"] = pred["scales"]
|
|
|
|
pred["oris"] = torch.deg2rad(pred["oris"])
|
|
|
|
return pred
|
|
|
|
|
|
|
|
def loss(self, pred, data):
|
|
|
|
raise NotImplementedError
|