""" A two-view sparse feature matching pipeline on triplets. If a triplet is found, runs the extractor on three images and then runs matcher/filter/solver for all three pairs. Losses and metrics get accumulated accordingly. If no triplet is found, this falls back to two_view_pipeline.py """ import torch from ..utils.misc import get_twoview, stack_twoviews, unstack_twoviews from .two_view_pipeline import TwoViewPipeline def has_triplet(data): # we already check for image0 and image1 in required_keys return "view2" in data.keys() class TripletPipeline(TwoViewPipeline): default_conf = {"batch_triplets": True, **TwoViewPipeline.default_conf} def _forward(self, data): if not has_triplet(data): return super()._forward(data) # the two-view outputs are stored in # pred['0to1'],pred['0to2'], pred['1to2'] assert not self.conf.run_gt_in_forward pred0 = self.extract_view(data, "0") pred1 = self.extract_view(data, "1") pred2 = self.extract_view(data, "2") pred = {} pred = { **{k + "0": v for k, v in pred0.items()}, **{k + "1": v for k, v in pred1.items()}, **{k + "2": v for k, v in pred2.items()}, } def predict_twoview(pred, data): # forward pass if self.conf.matcher.name: pred = {**pred, **self.matcher({**data, **pred})} if self.conf.filter.name: pred = {**pred, **self.filter({**m_data, **pred})} if self.conf.solver.name: pred = {**pred, **self.solver({**m_data, **pred})} return pred if self.conf.batch_triplets: B = data["image1"].shape[0] # stack on batch dimension m_data = stack_twoviews(data) m_pred = stack_twoviews(pred) # forward pass m_pred = predict_twoview(m_pred, m_data) # unstack pred = {**pred, **unstack_twoviews(m_pred, B)} else: for idx in ["0to1", "0to2", "1to2"]: m_data = get_twoview(data, idx) m_pred = get_twoview(pred, idx) pred[idx] = predict_twoview(m_pred, m_data) return pred def loss(self, pred, data): if not has_triplet(data): return super().loss(pred, data) if self.conf.batch_triplets: m_data = stack_twoviews(data) m_pred = stack_twoviews(pred) losses, metrics = super().loss(m_pred, m_data) else: losses = {} metrics = {} for idx in ["0to1", "0to2", "1to2"]: data_i = get_twoview(data, idx) pred_i = pred[idx] losses_i, metrics_i = super().loss(pred_i, data_i) for k, v in losses_i.items(): if k in losses.keys(): losses[k] = losses[k] + v else: losses[k] = v for k, v in metrics_i.items(): if k in metrics.keys(): metrics[k] = torch.cat([metrics[k], v], 0) else: metrics[k] = v return losses, metrics