74 lines
2.1 KiB
Python
74 lines
2.1 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from omegaconf import OmegaConf
|
||
|
|
||
|
|
||
|
def weight_loss(log_assignment, weights, gamma=0.0):
|
||
|
b, m, n = log_assignment.shape
|
||
|
m -= 1
|
||
|
n -= 1
|
||
|
|
||
|
loss_sc = log_assignment * weights
|
||
|
|
||
|
num_neg0 = weights[:, :m, -1].sum(-1).clamp(min=1.0)
|
||
|
num_neg1 = weights[:, -1, :n].sum(-1).clamp(min=1.0)
|
||
|
num_pos = weights[:, :m, :n].sum((-1, -2)).clamp(min=1.0)
|
||
|
|
||
|
nll_pos = -loss_sc[:, :m, :n].sum((-1, -2))
|
||
|
nll_pos /= num_pos.clamp(min=1.0)
|
||
|
|
||
|
nll_neg0 = -loss_sc[:, :m, -1].sum(-1)
|
||
|
nll_neg1 = -loss_sc[:, -1, :n].sum(-1)
|
||
|
|
||
|
nll_neg = (nll_neg0 + nll_neg1) / (num_neg0 + num_neg1)
|
||
|
|
||
|
return nll_pos, nll_neg, num_pos, (num_neg0 + num_neg1) / 2.0
|
||
|
|
||
|
|
||
|
class NLLLoss(nn.Module):
|
||
|
default_conf = {
|
||
|
"nll_balancing": 0.5,
|
||
|
"gamma_f": 0.0, # focal loss
|
||
|
}
|
||
|
|
||
|
def __init__(self, conf):
|
||
|
super().__init__()
|
||
|
self.conf = OmegaConf.merge(self.default_conf, conf)
|
||
|
self.loss_fn = self.nll_loss
|
||
|
|
||
|
def forward(self, pred, data, weights=None):
|
||
|
log_assignment = pred["log_assignment"]
|
||
|
if weights is None:
|
||
|
weights = self.loss_fn(log_assignment, data)
|
||
|
nll_pos, nll_neg, num_pos, num_neg = weight_loss(
|
||
|
log_assignment, weights, gamma=self.conf.gamma_f
|
||
|
)
|
||
|
nll = (
|
||
|
self.conf.nll_balancing * nll_pos + (1 - self.conf.nll_balancing) * nll_neg
|
||
|
)
|
||
|
|
||
|
return (
|
||
|
nll,
|
||
|
weights,
|
||
|
{
|
||
|
"assignment_nll": nll,
|
||
|
"nll_pos": nll_pos,
|
||
|
"nll_neg": nll_neg,
|
||
|
"num_matchable": num_pos,
|
||
|
"num_unmatchable": num_neg,
|
||
|
},
|
||
|
)
|
||
|
|
||
|
def nll_loss(self, log_assignment, data):
|
||
|
m, n = data["gt_matches0"].size(-1), data["gt_matches1"].size(-1)
|
||
|
positive = data["gt_assignment"].float()
|
||
|
neg0 = (data["gt_matches0"] == -1).float()
|
||
|
neg1 = (data["gt_matches1"] == -1).float()
|
||
|
|
||
|
weights = torch.zeros_like(log_assignment)
|
||
|
weights[:, :m, :n] = positive
|
||
|
|
||
|
weights[:, :m, -1] = neg0
|
||
|
weights[:, -1, :m] = neg1
|
||
|
return weights
|