glue-factory-custom/gluefactory/models/extractors/sift.py

241 lines
8.3 KiB
Python

import numpy as np
import torch
import pycolmap
from scipy.spatial import KDTree
from omegaconf import OmegaConf
import cv2
from ..base_model import BaseModel
from ..utils.misc import pad_to_length
EPS = 1e-6
def sift_to_rootsift(x):
x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS)
x = np.sqrt(x.clip(min=EPS))
x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS)
return x
# from OpenGlue
def nms_keypoints(kpts: np.ndarray, responses: np.ndarray, radius: float) -> np.ndarray:
# TODO: add approximate tree
kd_tree = KDTree(kpts)
sorted_idx = np.argsort(-responses)
kpts_to_keep_idx = []
removed_idx = set()
for idx in sorted_idx:
# skip point if it was already removed
if idx in removed_idx:
continue
kpts_to_keep_idx.append(idx)
point = kpts[idx]
neighbors = kd_tree.query_ball_point(point, r=radius)
# Variable `neighbors` contains the `point` itself
removed_idx.update(neighbors)
mask = np.zeros((kpts.shape[0],), dtype=bool)
mask[kpts_to_keep_idx] = True
return mask
def detect_kpts_opencv(
features: cv2.Feature2D, image: np.ndarray, describe: bool = True
) -> np.ndarray:
"""
Detect keypoints using OpenCV Detector.
Optionally, perform NMS and filter top-response keypoints.
Optionally, perform description.
Args:
features: OpenCV based keypoints detector and descriptor
image: Grayscale image of uint8 data type
describe: flag indicating whether to simultaneously compute descriptors
Returns:
kpts: 1D array of detected cv2.KeyPoint
"""
if describe:
kpts, descriptors = features.detectAndCompute(image, None)
else:
kpts = features.detect(image, None)
kpts = np.array(kpts)
responses = np.array([k.response for k in kpts], dtype=np.float32)
# select all
top_score_idx = ...
pts = np.array([k.pt for k in kpts], dtype=np.float32)
scales = np.array([k.size for k in kpts], dtype=np.float32)
angles = np.array([k.angle for k in kpts], dtype=np.float32)
spts = np.concatenate([pts, scales[..., None], angles[..., None]], -1)
if describe:
return spts[top_score_idx], responses[top_score_idx], descriptors[top_score_idx]
else:
return spts[top_score_idx], responses[top_score_idx]
class SIFT(BaseModel):
default_conf = {
"has_detector": True,
"has_descriptor": True,
"descriptor_dim": 128,
"pycolmap_options": {
"first_octave": 0,
"peak_threshold": 0.005,
"edge_threshold": 10,
},
"rootsift": True,
"nms_radius": None,
"max_num_keypoints": -1,
"max_num_keypoints_val": None,
"force_num_keypoints": False,
"randomize_keypoints_training": False,
"detector": "pycolmap", # ['pycolmap', 'pycolmap_cpu', 'pycolmap_cuda', 'cv2']
"detection_threshold": None,
}
required_data_keys = ["image"]
def _init(self, conf):
self.sift = None # lazy loading
@torch.no_grad()
def extract_features(self, image):
image_np = image.cpu().numpy()[0]
assert image.shape[0] == 1
assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS
detector = str(self.conf.detector)
if self.sift is None and detector.startswith("pycolmap"):
options = OmegaConf.to_container(self.conf.pycolmap_options)
device = (
"auto" if detector == "pycolmap" else detector.replace("pycolmap_", "")
)
if self.conf.rootsift == "rootsift":
options["normalization"] = pycolmap.Normalization.L1_ROOT
else:
options["normalization"] = pycolmap.Normalization.L2
if self.conf.detection_threshold is not None:
options["peak_threshold"] = self.conf.detection_threshold
options["max_num_features"] = self.conf.max_num_keypoints
self.sift = pycolmap.Sift(options=options, device=device)
elif self.sift is None and self.conf.detector == "cv2":
self.sift = cv2.SIFT_create(contrastThreshold=self.conf.detection_threshold)
if detector.startswith("pycolmap"):
keypoints, scores, descriptors = self.sift.extract(image_np)
elif detector == "cv2":
# TODO: Check if opencv keypoints are already in corner convention
keypoints, scores, descriptors = detect_kpts_opencv(
self.sift, (image_np * 255.0).astype(np.uint8)
)
if self.conf.nms_radius is not None:
mask = nms_keypoints(keypoints[:, :2], scores, self.conf.nms_radius)
keypoints = keypoints[mask]
scores = scores[mask]
descriptors = descriptors[mask]
scales = keypoints[:, 2]
oris = np.rad2deg(keypoints[:, 3])
if self.conf.has_descriptor:
# We still renormalize because COLMAP does not normalize well,
# maybe due to numerical errors
if self.conf.rootsift:
descriptors = sift_to_rootsift(descriptors)
descriptors = torch.from_numpy(descriptors)
keypoints = torch.from_numpy(keypoints[:, :2]) # keep only x, y
scales = torch.from_numpy(scales)
oris = torch.from_numpy(oris)
scores = torch.from_numpy(scores)
# Keep the k keypoints with highest score
max_kps = self.conf.max_num_keypoints
# for val we allow different
if not self.training and self.conf.max_num_keypoints_val is not None:
max_kps = self.conf.max_num_keypoints_val
if max_kps is not None and max_kps > 0:
if self.conf.randomize_keypoints_training and self.training:
# instead of selecting top-k, sample k by score weights
raise NotImplementedError
elif max_kps < scores.shape[0]:
# TODO: check that the scores from PyCOLMAP are 100% correct,
# follow https://github.com/mihaidusmanu/pycolmap/issues/8
indices = torch.topk(scores, max_kps).indices
keypoints = keypoints[indices]
scales = scales[indices]
oris = oris[indices]
scores = scores[indices]
if self.conf.has_descriptor:
descriptors = descriptors[indices]
if self.conf.force_num_keypoints:
keypoints = pad_to_length(
keypoints,
max_kps,
-2,
mode="random_c",
bounds=(0, min(image.shape[1:])),
)
scores = pad_to_length(scores, max_kps, -1, mode="zeros")
scales = pad_to_length(scales, max_kps, -1, mode="zeros")
oris = pad_to_length(oris, max_kps, -1, mode="zeros")
if self.conf.has_descriptor:
descriptors = pad_to_length(descriptors, max_kps, -2, mode="zeros")
pred = {
"keypoints": keypoints,
"scales": scales,
"oris": oris,
"keypoint_scores": scores,
}
if self.conf.has_descriptor:
pred["descriptors"] = descriptors
return pred
@torch.no_grad()
def _forward(self, data):
pred = {
"keypoints": [],
"scales": [],
"oris": [],
"keypoint_scores": [],
"descriptors": [],
}
image = data["image"]
if image.shape[1] == 3: # RGB
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
image = (image * scale).sum(1, keepdim=True).cpu()
for k in range(image.shape[0]):
img = image[k]
if "image_size" in data.keys():
# avoid extracting points in padded areas
w, h = data["image_size"][k]
img = img[:, :h, :w]
p = self.extract_features(img)
for k, v in p.items():
pred[k].append(v)
if (image.shape[0] == 1) or self.conf.force_num_keypoints:
pred = {k: torch.stack(pred[k], 0) for k in pred.keys()}
pred = {k: pred[k].to(device=data["image"].device) for k in pred.keys()}
pred["oris"] = torch.deg2rad(pred["oris"])
return pred
def loss(self, pred, data):
raise NotImplementedError