glue-factory-custom/gluefactory/models/two_view_pipeline.py

115 lines
3.9 KiB
Python
Raw Normal View History

"""
A two-view sparse feature matching pipeline.
This model contains sub-models for each step:
feature extraction, feature matching, outlier filtering, pose estimation.
Each step is optional, and the features or matches can be provided as input.
Default: SuperPoint with nearest neighbor matching.
Convention for the matches: m0[i] is the index of the keypoint in image 1
that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched.
"""
from omegaconf import OmegaConf
from .base_model import BaseModel
from . import get_model
to_ctr = OmegaConf.to_container # convert DictConfig to dict
class TwoViewPipeline(BaseModel):
default_conf = {
"extractor": {
"name": None,
"trainable": False,
},
"matcher": {"name": None},
"filter": {"name": None},
"solver": {"name": None},
"ground_truth": {"name": None},
"allow_no_extract": False,
"run_gt_in_forward": False,
}
required_data_keys = ["view0", "view1"]
strict_conf = False # need to pass new confs to children models
components = [
"extractor",
"matcher",
"filter",
"solver",
"ground_truth",
]
def _init(self, conf):
if conf.extractor.name:
self.extractor = get_model(conf.extractor.name)(to_ctr(conf.extractor))
if conf.matcher.name:
self.matcher = get_model(conf.matcher.name)(to_ctr(conf.matcher))
if conf.filter.name:
self.filter = get_model(conf.filter.name)(to_ctr(conf.filter))
if conf.solver.name:
self.solver = get_model(conf.solver.name)(to_ctr(conf.solver))
if conf.ground_truth.name:
self.ground_truth = get_model(conf.ground_truth.name)(
to_ctr(conf.ground_truth)
)
def extract_view(self, data, i):
data_i = data[f"view{i}"]
pred_i = data_i.get("cache", {})
skip_extract = len(pred_i) > 0 and self.conf.allow_no_extract
if self.conf.extractor.name and not skip_extract:
pred_i = {**pred_i, **self.extractor(data_i)}
elif self.conf.extractor.name and not self.conf.allow_no_extract:
pred_i = {**pred_i, **self.extractor({**data_i, **pred_i})}
return pred_i
def _forward(self, data):
pred0 = self.extract_view(data, "0")
pred1 = self.extract_view(data, "1")
pred = {
**{k + "0": v for k, v in pred0.items()},
**{k + "1": v for k, v in pred1.items()},
}
if self.conf.matcher.name:
pred = {**pred, **self.matcher({**data, **pred})}
if self.conf.filter.name:
pred = {**pred, **self.filter({**data, **pred})}
if self.conf.solver.name:
pred = {**pred, **self.solver({**data, **pred})}
if self.conf.ground_truth.name and self.conf.run_gt_in_forward:
gt_pred = self.ground_truth({**data, **pred})
pred.update({f"gt_{k}": v for k, v in gt_pred.items()})
return pred
def loss(self, pred, data):
losses = {}
metrics = {}
total = 0
# get labels
if self.conf.ground_truth.name and not self.conf.run_gt_in_forward:
gt_pred = self.ground_truth({**data, **pred})
pred.update({f"gt_{k}": v for k, v in gt_pred.items()})
for k in self.components:
apply = True
if "apply_loss" in self.conf[k].keys():
apply = self.conf[k].apply_loss
if self.conf[k].name and apply:
try:
losses_, metrics_ = getattr(self, k).loss(pred, {**pred, **data})
except NotImplementedError:
continue
losses = {**losses, **losses_}
metrics = {**metrics, **metrics_}
total = losses_["total"] + total
return {**losses, "total": total}, metrics