glue-factory-custom/gluefactory/models/triplet_pipeline.py

100 lines
3.2 KiB
Python
Raw Normal View History

"""
A two-view sparse feature matching pipeline on triplets.
If a triplet is found, runs the extractor on three images and
then runs matcher/filter/solver for all three pairs.
Losses and metrics get accumulated accordingly.
If no triplet is found, this falls back to two_view_pipeline.py
"""
import torch
from ..utils.misc import get_twoview, stack_twoviews, unstack_twoviews
from .two_view_pipeline import TwoViewPipeline
def has_triplet(data):
# we already check for image0 and image1 in required_keys
return "view2" in data.keys()
class TripletPipeline(TwoViewPipeline):
default_conf = {"batch_triplets": True, **TwoViewPipeline.default_conf}
def _forward(self, data):
if not has_triplet(data):
return super()._forward(data)
# the two-view outputs are stored in
# pred['0to1'],pred['0to2'], pred['1to2']
assert not self.conf.run_gt_in_forward
pred0 = self.extract_view(data, "0")
pred1 = self.extract_view(data, "1")
pred2 = self.extract_view(data, "2")
pred = {}
pred = {
**{k + "0": v for k, v in pred0.items()},
**{k + "1": v for k, v in pred1.items()},
**{k + "2": v for k, v in pred2.items()},
}
def predict_twoview(pred, data):
# forward pass
if self.conf.matcher.name:
pred = {**pred, **self.matcher({**data, **pred})}
if self.conf.filter.name:
pred = {**pred, **self.filter({**m_data, **pred})}
if self.conf.solver.name:
pred = {**pred, **self.solver({**m_data, **pred})}
return pred
if self.conf.batch_triplets:
B = data["image1"].shape[0]
# stack on batch dimension
m_data = stack_twoviews(data)
m_pred = stack_twoviews(pred)
# forward pass
m_pred = predict_twoview(m_pred, m_data)
# unstack
pred = {**pred, **unstack_twoviews(m_pred, B)}
else:
for idx in ["0to1", "0to2", "1to2"]:
m_data = get_twoview(data, idx)
m_pred = get_twoview(pred, idx)
pred[idx] = predict_twoview(m_pred, m_data)
return pred
def loss(self, pred, data):
if not has_triplet(data):
return super().loss(pred, data)
if self.conf.batch_triplets:
m_data = stack_twoviews(data)
m_pred = stack_twoviews(pred)
losses, metrics = super().loss(m_pred, m_data)
else:
losses = {}
metrics = {}
for idx in ["0to1", "0to2", "1to2"]:
data_i = get_twoview(data, idx)
pred_i = pred[idx]
losses_i, metrics_i = super().loss(pred_i, data_i)
for k, v in losses_i.items():
if k in losses.keys():
losses[k] = losses[k] + v
else:
losses[k] = v
for k, v in metrics_i.items():
if k in metrics.keys():
metrics[k] = torch.cat([metrics[k], v], 0)
else:
metrics[k] = v
return losses, metrics