glue-factory-custom/gluefactory/models/extractors/sift_kornia.py

46 lines
1.3 KiB
Python

import kornia
import torch
from ..base_model import BaseModel
class KorniaSIFT(BaseModel):
default_conf = {
"has_detector": True,
"has_descriptor": True,
"max_num_keypoints": -1,
"detection_threshold": None,
"rootsift": True,
}
required_data_keys = ["image"]
def _init(self, conf):
self.sift = kornia.feature.SIFTFeature(
num_features=self.conf.max_num_keypoints, rootsift=self.conf.rootsift
)
def _forward(self, data):
lafs, scores, descriptors = self.sift(data["image"])
keypoints = kornia.feature.get_laf_center(lafs)
scales = kornia.feature.get_laf_scale(lafs)
oris = kornia.feature.get_laf_orientation(lafs)
pred = {
"keypoints": keypoints, # @TODO: confirm keypoints are in corner convention
"scales": scales,
"oris": oris,
"keypoint_scores": scores,
}
if self.conf.has_descriptor:
pred["descriptors"] = descriptors
pred = {k: pred[k].to(device=data["image"].device) for k in pred.keys()}
pred["scales"] = pred["scales"]
pred["oris"] = torch.deg2rad(pred["oris"])
return pred
def loss(self, pred, data):
raise NotImplementedError