import torch import torch.nn as nn from omegaconf import OmegaConf def weight_loss(log_assignment, weights, gamma=0.0): b, m, n = log_assignment.shape m -= 1 n -= 1 loss_sc = log_assignment * weights num_neg0 = weights[:, :m, -1].sum(-1).clamp(min=1.0) num_neg1 = weights[:, -1, :n].sum(-1).clamp(min=1.0) num_pos = weights[:, :m, :n].sum((-1, -2)).clamp(min=1.0) nll_pos = -loss_sc[:, :m, :n].sum((-1, -2)) nll_pos /= num_pos.clamp(min=1.0) nll_neg0 = -loss_sc[:, :m, -1].sum(-1) nll_neg1 = -loss_sc[:, -1, :n].sum(-1) nll_neg = (nll_neg0 + nll_neg1) / (num_neg0 + num_neg1) return nll_pos, nll_neg, num_pos, (num_neg0 + num_neg1) / 2.0 class NLLLoss(nn.Module): default_conf = { "nll_balancing": 0.5, "gamma_f": 0.0, # focal loss } def __init__(self, conf): super().__init__() self.conf = OmegaConf.merge(self.default_conf, conf) self.loss_fn = self.nll_loss def forward(self, pred, data, weights=None): log_assignment = pred["log_assignment"] if weights is None: weights = self.loss_fn(log_assignment, data) nll_pos, nll_neg, num_pos, num_neg = weight_loss( log_assignment, weights, gamma=self.conf.gamma_f ) nll = ( self.conf.nll_balancing * nll_pos + (1 - self.conf.nll_balancing) * nll_neg ) return ( nll, weights, { "assignment_nll": nll, "nll_pos": nll_pos, "nll_neg": nll_neg, "num_matchable": num_pos, "num_unmatchable": num_neg, }, ) def nll_loss(self, log_assignment, data): m, n = data["gt_matches0"].size(-1), data["gt_matches1"].size(-1) positive = data["gt_assignment"].float() neg0 = (data["gt_matches0"] == -1).float() neg1 = (data["gt_matches1"] == -1).float() weights = torch.zeros_like(log_assignment) weights[:, :m, :n] = positive weights[:, :m, -1] = neg0 weights[:, -1, :m] = neg1 return weights