313 lines
12 KiB
Python
313 lines
12 KiB
Python
import torch
|
|
from sklearn.cluster import DBSCAN
|
|
|
|
from ..base_model import BaseModel
|
|
from .. import get_model
|
|
|
|
|
|
def sample_descriptors_corner_conv(keypoints, descriptors, s: int = 8):
|
|
"""Interpolate descriptors at keypoint locations"""
|
|
b, c, h, w = descriptors.shape
|
|
keypoints = keypoints / (keypoints.new_tensor([w, h]) * s)
|
|
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
|
descriptors = torch.nn.functional.grid_sample(
|
|
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False
|
|
)
|
|
descriptors = torch.nn.functional.normalize(
|
|
descriptors.reshape(b, c, -1), p=2, dim=1
|
|
)
|
|
return descriptors
|
|
|
|
|
|
def lines_to_wireframe(
|
|
lines, line_scores, all_descs, s, nms_radius, force_num_lines, max_num_lines
|
|
):
|
|
"""Given a set of lines, their score and dense descriptors,
|
|
merge close-by endpoints and compute a wireframe defined by
|
|
its junctions and connectivity.
|
|
Returns:
|
|
junctions: list of [num_junc, 2] tensors listing all wireframe junctions
|
|
junc_scores: list of [num_junc] tensors with the junction score
|
|
junc_descs: list of [dim, num_junc] tensors with the junction descriptors
|
|
connectivity: list of [num_junc, num_junc] bool arrays with True when 2
|
|
junctions are connected
|
|
new_lines: the new set of [b_size, num_lines, 2, 2] lines
|
|
lines_junc_idx: a [b_size, num_lines, 2] tensor with the indices of the
|
|
junctions of each endpoint
|
|
num_true_junctions: a list of the number of valid junctions for each image
|
|
in the batch, i.e. before filling with random ones
|
|
"""
|
|
b_size, _, h, w = all_descs.shape
|
|
device = lines.device
|
|
h, w = h * s, w * s
|
|
endpoints = lines.reshape(b_size, -1, 2)
|
|
|
|
(
|
|
junctions,
|
|
junc_scores,
|
|
connectivity,
|
|
new_lines,
|
|
lines_junc_idx,
|
|
num_true_junctions,
|
|
) = ([], [], [], [], [], [])
|
|
for bs in range(b_size):
|
|
# Cluster the junctions that are close-by
|
|
db = DBSCAN(eps=nms_radius, min_samples=1).fit(endpoints[bs].cpu().numpy())
|
|
clusters = db.labels_
|
|
n_clusters = len(set(clusters))
|
|
num_true_junctions.append(n_clusters)
|
|
|
|
# Compute the average junction and score for each cluster
|
|
clusters = torch.tensor(clusters, dtype=torch.long, device=device)
|
|
new_junc = torch.zeros(n_clusters, 2, dtype=torch.float, device=device)
|
|
new_junc.scatter_reduce_(
|
|
0,
|
|
clusters[:, None].repeat(1, 2),
|
|
endpoints[bs],
|
|
reduce="mean",
|
|
include_self=False,
|
|
)
|
|
junctions.append(new_junc)
|
|
new_scores = torch.zeros(n_clusters, dtype=torch.float, device=device)
|
|
new_scores.scatter_reduce_(
|
|
0,
|
|
clusters,
|
|
torch.repeat_interleave(line_scores[bs], 2),
|
|
reduce="mean",
|
|
include_self=False,
|
|
)
|
|
junc_scores.append(new_scores)
|
|
|
|
# Compute the new lines
|
|
new_lines.append(junctions[-1][clusters].reshape(-1, 2, 2))
|
|
lines_junc_idx.append(clusters.reshape(-1, 2))
|
|
|
|
if force_num_lines:
|
|
# Add random junctions (with no connectivity)
|
|
missing = max_num_lines * 2 - len(junctions[-1])
|
|
junctions[-1] = torch.cat(
|
|
[
|
|
junctions[-1],
|
|
torch.rand(missing, 2).to(lines)
|
|
* lines.new_tensor([[w - 1, h - 1]]),
|
|
],
|
|
dim=0,
|
|
)
|
|
junc_scores[-1] = torch.cat(
|
|
[junc_scores[-1], torch.zeros(missing).to(lines)], dim=0
|
|
)
|
|
|
|
junc_connect = torch.eye(max_num_lines * 2, dtype=torch.bool, device=device)
|
|
pairs = clusters.reshape(-1, 2) # these pairs are connected by a line
|
|
junc_connect[pairs[:, 0], pairs[:, 1]] = True
|
|
junc_connect[pairs[:, 1], pairs[:, 0]] = True
|
|
connectivity.append(junc_connect)
|
|
else:
|
|
# Compute the junction connectivity
|
|
junc_connect = torch.eye(n_clusters, dtype=torch.bool, device=device)
|
|
pairs = clusters.reshape(-1, 2) # these pairs are connected by a line
|
|
junc_connect[pairs[:, 0], pairs[:, 1]] = True
|
|
junc_connect[pairs[:, 1], pairs[:, 0]] = True
|
|
connectivity.append(junc_connect)
|
|
|
|
junctions = torch.stack(junctions, dim=0)
|
|
new_lines = torch.stack(new_lines, dim=0)
|
|
lines_junc_idx = torch.stack(lines_junc_idx, dim=0)
|
|
|
|
# Interpolate the new junction descriptors
|
|
junc_descs = sample_descriptors_corner_conv(junctions, all_descs, s).mT
|
|
|
|
return (
|
|
junctions,
|
|
junc_scores,
|
|
junc_descs,
|
|
connectivity,
|
|
new_lines,
|
|
lines_junc_idx,
|
|
num_true_junctions,
|
|
)
|
|
|
|
|
|
class WireframeExtractor(BaseModel):
|
|
default_conf = {
|
|
"point_extractor": {
|
|
"name": None,
|
|
"trainable": False,
|
|
"dense_outputs": True,
|
|
"max_num_keypoints": None,
|
|
"force_num_keypoints": False,
|
|
},
|
|
"line_extractor": {
|
|
"name": None,
|
|
"trainable": False,
|
|
"max_num_lines": None,
|
|
"force_num_lines": False,
|
|
"min_length": 15,
|
|
},
|
|
"wireframe_params": {
|
|
"merge_points": True,
|
|
"merge_line_endpoints": True,
|
|
"nms_radius": 3,
|
|
},
|
|
}
|
|
required_data_keys = ["image"]
|
|
|
|
def _init(self, conf):
|
|
self.point_extractor = get_model(self.conf.point_extractor.name)(
|
|
self.conf.point_extractor
|
|
)
|
|
self.line_extractor = get_model(self.conf.line_extractor.name)(
|
|
self.conf.line_extractor
|
|
)
|
|
|
|
def _forward(self, data):
|
|
b_size, _, h, w = data["image"].shape
|
|
device = data["image"].device
|
|
|
|
if (
|
|
not self.conf.point_extractor.force_num_keypoints
|
|
or not self.conf.line_extractor.force_num_lines
|
|
):
|
|
assert b_size == 1, "Only batch size of 1 accepted for non padded inputs"
|
|
|
|
# Line detection
|
|
pred = self.line_extractor(data)
|
|
if pred["line_scores"].shape[-1] != 0:
|
|
pred["line_scores"] /= pred["line_scores"].max(dim=1)[0][:, None] + 1e-8
|
|
|
|
# Keypoint prediction
|
|
pred = {**pred, **self.point_extractor(data)}
|
|
assert (
|
|
"dense_descriptors" in pred
|
|
), "The KP extractor should return dense descriptors"
|
|
s_desc = data["image"].shape[2] // pred["dense_descriptors"].shape[2]
|
|
|
|
# Remove keypoints that are too close to line endpoints
|
|
if self.conf.wireframe_params.merge_points:
|
|
line_endpts = pred["lines"].reshape(b_size, -1, 2)
|
|
dist_pt_lines = torch.norm(
|
|
pred["keypoints"][:, :, None] - line_endpts[:, None], dim=-1
|
|
)
|
|
# For each keypoint, mark it as valid or to remove
|
|
pts_to_remove = torch.any(
|
|
dist_pt_lines < self.conf.wireframe_params.nms_radius, dim=2
|
|
)
|
|
if self.conf.point_extractor.force_num_keypoints:
|
|
# Replace the points with random ones
|
|
num_to_remove = pts_to_remove.int().sum().item()
|
|
pred["keypoints"][pts_to_remove] = torch.rand(
|
|
num_to_remove, 2, device=device
|
|
) * pred["keypoints"].new_tensor([[w - 1, h - 1]])
|
|
pred["keypoint_scores"][pts_to_remove] = 0
|
|
for bs in range(b_size):
|
|
descrs = sample_descriptors_corner_conv(
|
|
pred["keypoints"][bs][pts_to_remove[bs]][None],
|
|
pred["dense_descriptors"][bs][None],
|
|
s_desc,
|
|
)
|
|
pred["descriptors"][bs][pts_to_remove[bs]] = descrs[0].T
|
|
else:
|
|
# Simply remove them (we assume batch_size = 1 here)
|
|
assert len(pred["keypoints"]) == 1
|
|
pred["keypoints"] = pred["keypoints"][0][~pts_to_remove[0]][None]
|
|
pred["keypoint_scores"] = pred["keypoint_scores"][0][~pts_to_remove[0]][
|
|
None
|
|
]
|
|
pred["descriptors"] = pred["descriptors"][0][~pts_to_remove[0]][None]
|
|
|
|
# Connect the lines together to form a wireframe
|
|
orig_lines = pred["lines"].clone()
|
|
if (
|
|
self.conf.wireframe_params.merge_line_endpoints
|
|
and len(pred["lines"][0]) > 0
|
|
):
|
|
# Merge first close-by endpoints to connect lines
|
|
(
|
|
line_points,
|
|
line_pts_scores,
|
|
line_descs,
|
|
line_association,
|
|
pred["lines"],
|
|
lines_junc_idx,
|
|
n_true_junctions,
|
|
) = lines_to_wireframe(
|
|
pred["lines"],
|
|
pred["line_scores"],
|
|
pred["dense_descriptors"],
|
|
s=s_desc,
|
|
nms_radius=self.conf.wireframe_params.nms_radius,
|
|
force_num_lines=self.conf.line_extractor.force_num_lines,
|
|
max_num_lines=self.conf.line_extractor.max_num_lines,
|
|
)
|
|
|
|
# Add the keypoints to the junctions and fill the rest with random keypoints
|
|
(all_points, all_scores, all_descs, pl_associativity) = [], [], [], []
|
|
for bs in range(b_size):
|
|
all_points.append(
|
|
torch.cat([line_points[bs], pred["keypoints"][bs]], dim=0)
|
|
)
|
|
all_scores.append(
|
|
torch.cat([line_pts_scores[bs], pred["keypoint_scores"][bs]], dim=0)
|
|
)
|
|
all_descs.append(
|
|
torch.cat([line_descs[bs], pred["descriptors"][bs]], dim=0)
|
|
)
|
|
|
|
associativity = torch.eye(
|
|
len(all_points[-1]), dtype=torch.bool, device=device
|
|
)
|
|
associativity[
|
|
: n_true_junctions[bs], : n_true_junctions[bs]
|
|
] = line_association[bs][: n_true_junctions[bs], : n_true_junctions[bs]]
|
|
pl_associativity.append(associativity)
|
|
|
|
all_points = torch.stack(all_points, dim=0)
|
|
all_scores = torch.stack(all_scores, dim=0)
|
|
all_descs = torch.stack(all_descs, dim=0)
|
|
pl_associativity = torch.stack(pl_associativity, dim=0)
|
|
else:
|
|
# Lines are independent
|
|
all_points = torch.cat(
|
|
[pred["lines"].reshape(b_size, -1, 2), pred["keypoints"]], dim=1
|
|
)
|
|
n_pts = all_points.shape[1]
|
|
num_lines = pred["lines"].shape[1]
|
|
n_true_junctions = [num_lines * 2] * b_size
|
|
all_scores = torch.cat(
|
|
[
|
|
torch.repeat_interleave(pred["line_scores"], 2, dim=1),
|
|
pred["keypoint_scores"],
|
|
],
|
|
dim=1,
|
|
)
|
|
line_descs = sample_descriptors_corner_conv(
|
|
pred["lines"].reshape(b_size, -1, 2), pred["dense_descriptors"], s_desc
|
|
).mT # [B, n_lines * 2, desc_dim]
|
|
all_descs = torch.cat([line_descs, pred["descriptors"]], dim=1)
|
|
pl_associativity = torch.eye(n_pts, dtype=torch.bool, device=device)[
|
|
None
|
|
].repeat(b_size, 1, 1)
|
|
lines_junc_idx = (
|
|
torch.arange(num_lines * 2, device=device)
|
|
.reshape(1, -1, 2)
|
|
.repeat(b_size, 1, 1)
|
|
)
|
|
|
|
del pred["dense_descriptors"] # Remove dense descriptors to save memory
|
|
torch.cuda.empty_cache()
|
|
|
|
pred["keypoints"] = all_points
|
|
pred["keypoint_scores"] = all_scores
|
|
pred["descriptors"] = all_descs
|
|
pred["pl_associativity"] = pl_associativity
|
|
pred["num_junctions"] = torch.tensor(n_true_junctions)
|
|
pred["orig_lines"] = orig_lines
|
|
pred["lines_junc_idx"] = lines_junc_idx
|
|
return pred
|
|
|
|
def loss(self, pred, data):
|
|
raise NotImplementedError
|
|
|
|
def metrics(self, _pred, _data):
|
|
return {}
|