glue-factory-custom/gluefactory/models/extractors/grid_extractor.py

61 lines
1.4 KiB
Python

import math
import torch
from ..base_model import BaseModel
def to_sequence(map):
return map.flatten(-2).transpose(-1, -2)
def to_map(sequence):
n = sequence.shape[-2]
e = math.isqrt(n)
assert e * e == n
assert e * e == n
sequence.transpose(-1, -2).unflatten(-1, [e, e])
class GridExtractor(BaseModel):
default_conf = {"cell_size": 14}
required_data_keys = ["image"]
def _init(self, conf):
pass
def _forward(self, data):
b, c, h, w = data["image"].shape
cgrid = (
torch.stack(
torch.meshgrid(
torch.arange(
h // self.conf.cell_size,
dtype=torch.float32,
device=data["image"].device,
),
torch.arange(
w // self.conf.cell_size,
dtype=torch.float32,
device=data["image"].device,
),
indexing="ij",
)[::-1],
dim=0,
)
.unsqueeze(0)
.repeat([b, 1, 1, 1])
* self.conf.cell_size
+ self.conf.cell_size / 2
)
pred = {
"grid": cgrid + 0.5,
"keypoints": to_sequence(cgrid) + 0.5,
}
return pred
def loss(self, pred, data):
raise NotImplementedError