2023-10-05 16:53:51 +02:00
|
|
|
import kornia
|
2023-10-09 08:32:43 +02:00
|
|
|
import torch
|
2023-10-05 16:53:51 +02:00
|
|
|
|
|
|
|
from .utils import get_image_coords
|
|
|
|
from .wrappers import Camera
|
|
|
|
|
|
|
|
|
|
|
|
def sample_fmap(pts, fmap):
|
|
|
|
h, w = fmap.shape[-2:]
|
|
|
|
grid_sample = torch.nn.functional.grid_sample
|
|
|
|
pts = (pts / pts.new_tensor([[w, h]]) * 2 - 1)[:, None]
|
|
|
|
# @TODO: This might still be a source of noise --> bilinear interpolation dangerous
|
|
|
|
interp_lin = grid_sample(fmap, pts, align_corners=False, mode="bilinear")
|
|
|
|
interp_nn = grid_sample(fmap, pts, align_corners=False, mode="nearest")
|
|
|
|
return torch.where(torch.isnan(interp_lin), interp_nn, interp_lin)[:, :, 0].permute(
|
|
|
|
0, 2, 1
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def sample_depth(pts, depth_):
|
|
|
|
depth = torch.where(depth_ > 0, depth_, depth_.new_tensor(float("nan")))
|
|
|
|
depth = depth[:, None]
|
|
|
|
interp = sample_fmap(pts, depth).squeeze(-1)
|
|
|
|
valid = (~torch.isnan(interp)) & (interp > 0)
|
|
|
|
return interp, valid
|
|
|
|
|
|
|
|
|
|
|
|
def sample_normals_from_depth(pts, depth, K):
|
|
|
|
depth = depth[:, None]
|
|
|
|
normals = kornia.geometry.depth.depth_to_normals(depth, K)
|
|
|
|
normals = torch.where(depth > 0, normals, 0.0)
|
|
|
|
interp = sample_fmap(pts, normals)
|
|
|
|
valid = (~torch.isnan(interp)) & (interp > 0)
|
|
|
|
return interp, valid
|
|
|
|
|
|
|
|
|
|
|
|
def project(
|
|
|
|
kpi,
|
|
|
|
di,
|
|
|
|
depthj,
|
|
|
|
camera_i,
|
|
|
|
camera_j,
|
|
|
|
T_itoj,
|
|
|
|
validi,
|
|
|
|
ccth=None,
|
|
|
|
sample_depth_fun=sample_depth,
|
|
|
|
sample_depth_kwargs=None,
|
|
|
|
):
|
|
|
|
if sample_depth_kwargs is None:
|
|
|
|
sample_depth_kwargs = {}
|
|
|
|
|
|
|
|
kpi_3d_i = camera_i.image2cam(kpi)
|
|
|
|
kpi_3d_i = kpi_3d_i * di[..., None]
|
|
|
|
kpi_3d_j = T_itoj.transform(kpi_3d_i)
|
|
|
|
kpi_j, validj = camera_j.cam2image(kpi_3d_j)
|
|
|
|
# di_j = kpi_3d_j[..., -1]
|
|
|
|
validi = validi & validj
|
|
|
|
if depthj is None or ccth is None:
|
|
|
|
return kpi_j, validi & validj
|
|
|
|
else:
|
|
|
|
# circle consistency
|
|
|
|
dj, validj = sample_depth_fun(kpi_j, depthj, **sample_depth_kwargs)
|
|
|
|
kpi_j_3d_j = camera_j.image2cam(kpi_j) * dj[..., None]
|
|
|
|
kpi_j_i, validj_i = camera_i.cam2image(T_itoj.inv().transform(kpi_j_3d_j))
|
|
|
|
consistent = ((kpi - kpi_j_i) ** 2).sum(-1) < ccth
|
|
|
|
visible = validi & consistent & validj_i & validj
|
|
|
|
# visible = validi
|
|
|
|
return kpi_j, visible
|
|
|
|
|
|
|
|
|
|
|
|
def dense_warp_consistency(
|
|
|
|
depthi: torch.Tensor,
|
|
|
|
depthj: torch.Tensor,
|
|
|
|
T_itoj: torch.Tensor,
|
|
|
|
camerai: Camera,
|
|
|
|
cameraj: Camera,
|
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
kpi = get_image_coords(depthi).flatten(-3, -2)
|
|
|
|
di = depthi.flatten(
|
|
|
|
-2,
|
|
|
|
)
|
|
|
|
validi = di > 0
|
|
|
|
kpir, validir = project(kpi, di, depthj, camerai, cameraj, T_itoj, validi, **kwargs)
|
|
|
|
|
|
|
|
return kpir.unflatten(-2, depthi.shape[-2:]), validir.unflatten(
|
|
|
|
-1, (depthj.shape[-2:])
|
|
|
|
)
|