glue-factory-custom/gluefactory/models/matchers/kornia_loftr.py

67 lines
2.1 KiB
Python

import kornia
import torch
from ...models import BaseModel
class LoFTRModule(BaseModel):
default_conf = {
"topk": None,
"zero_pad": False,
}
required_data_keys = ["view0", "view1"]
def _init(self, conf):
self.net = kornia.feature.LoFTR(pretrained="outdoor")
self.set_initialized()
def _forward(self, data):
image0 = data["view0"]["image"]
image1 = data["view1"]["image"]
if self.conf.zero_pad:
image0, mask0 = self.zero_pad(image0)
image1, mask1 = self.zero_pad(image1)
res = self.net(
{"image0": image0, "image1": image1, "mask0": mask0, "mask1": mask1}
)
res = self.net({"image0": image0, "image1": image1})
else:
res = self.net({"image0": image0, "image1": image1})
topk = self.conf.topk
if topk is not None and res["confidence"].shape[-1] > topk:
_, top = torch.topk(res["confidence"], topk, -1)
m_kpts0 = res["keypoints0"][None][:, top]
m_kpts1 = res["keypoints1"][None][:, top]
scores = res["confidence"][None][:, top]
else:
m_kpts0 = res["keypoints0"][None]
m_kpts1 = res["keypoints1"][None]
scores = res["confidence"][None]
m0 = torch.arange(0, scores.shape[-1]).to(scores.device)[None]
m1 = torch.arange(0, scores.shape[-1]).to(scores.device)[None]
return {
"matches0": m0,
"matches1": m1,
"matching_scores0": scores,
"keypoints0": m_kpts0,
"keypoints1": m_kpts1,
"keypoint_scores0": scores,
"keypoint_scores1": scores,
"matching_scores1": scores,
}
def zero_pad(self, img):
b, c, h, w = img.shape
if h == w:
return img
s = max(h, w)
image = torch.zeros((b, c, s, s)).to(img)
image[:, :, :h, :w] = img
mask = torch.zeros_like(image)
mask[:, :, :h, :w] = 1.0
return image, mask.squeeze(0).float()
def loss(self, pred, data):
return NotImplementedError