79 lines
2.5 KiB
Python
79 lines
2.5 KiB
Python
|
from omegaconf import OmegaConf
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
from ..base_model import BaseModel
|
||
|
from .. import get_model
|
||
|
|
||
|
# from ...geometry.depth import sample_fmap
|
||
|
|
||
|
to_ctr = OmegaConf.to_container # convert DictConfig to dict
|
||
|
|
||
|
|
||
|
class MixedExtractor(BaseModel):
|
||
|
default_conf = {
|
||
|
"detector": {"name": None},
|
||
|
"descriptor": {"name": None},
|
||
|
"interpolate_descriptors_from": None, # field name
|
||
|
}
|
||
|
|
||
|
required_data_keys = ["image"]
|
||
|
required_cache_keys = []
|
||
|
|
||
|
def _init(self, conf):
|
||
|
if conf.detector.name:
|
||
|
self.detector = get_model(conf.detector.name)(to_ctr(conf.detector))
|
||
|
else:
|
||
|
self.required_data_keys += ["cache"]
|
||
|
self.required_cache_keys += ["keypoints"]
|
||
|
|
||
|
if conf.descriptor.name:
|
||
|
self.descriptor = get_model(conf.descriptor.name)(to_ctr(conf.descriptor))
|
||
|
else:
|
||
|
self.required_data_keys += ["cache"]
|
||
|
self.required_cache_keys += ["descriptors"]
|
||
|
|
||
|
def _forward(self, data):
|
||
|
if self.conf.detector.name:
|
||
|
pred = self.detector(data)
|
||
|
else:
|
||
|
pred = data["cache"]
|
||
|
if self.conf.detector.name:
|
||
|
pred = {**pred, **self.descriptor({**pred, **data})}
|
||
|
|
||
|
if self.conf.interpolate_descriptors_from:
|
||
|
h, w = data["image"].shape[-2:]
|
||
|
kpts = pred["keypoints"]
|
||
|
pts = (kpts / kpts.new_tensor([[w, h]]) * 2 - 1)[:, None]
|
||
|
pred["descriptors"] = (
|
||
|
F.grid_sample(
|
||
|
pred[self.conf.interpolate_descriptors_from],
|
||
|
pts,
|
||
|
align_corners=False,
|
||
|
mode="bilinear",
|
||
|
)
|
||
|
.squeeze(-2)
|
||
|
.transpose(-2, -1)
|
||
|
.contiguous()
|
||
|
)
|
||
|
|
||
|
return pred
|
||
|
|
||
|
def loss(self, pred, data):
|
||
|
losses = {}
|
||
|
metrics = {}
|
||
|
total = 0
|
||
|
|
||
|
for k in ["detector", "descriptor"]:
|
||
|
apply = True
|
||
|
if "apply_loss" in self.conf[k].keys():
|
||
|
apply = self.conf[k].apply_loss
|
||
|
if self.conf[k].name and apply:
|
||
|
try:
|
||
|
losses_, metrics_ = getattr(self, k).loss(pred, {**pred, **data})
|
||
|
except NotImplementedError:
|
||
|
continue
|
||
|
losses = {**losses, **losses_}
|
||
|
metrics = {**metrics, **metrics_}
|
||
|
total = losses_["total"] + total
|
||
|
return {**losses, "total": total}, metrics
|