diff --git a/gluefactory/models/cache_loader.py b/gluefactory/models/cache_loader.py index 3fbf0f7..b345a99 100644 --- a/gluefactory/models/cache_loader.py +++ b/gluefactory/models/cache_loader.py @@ -29,6 +29,15 @@ def pad_local_features(pred: dict, seq_l: int): pred["scales"] = pad_to_length(pred["scales"], seq_l, -1, mode="zeros") if "oris" in pred.keys(): pred["oris"] = pad_to_length(pred["oris"], seq_l, -1, mode="zeros") + + if "depth_keypoints" in pred.keys(): + pred["depth_keypoints"] = pad_to_length( + pred["depth_keypoints"], seq_l, -1, mode="zeros" + ) + if "valid_depth_keypoints" in pred.keys(): + pred["valid_depth_keypoints"] = pad_to_length( + pred["valid_depth_keypoints"], seq_l, -1, mode="zeros" + ) return pred diff --git a/gluefactory/scripts/export_megadepth.py b/gluefactory/scripts/export_megadepth.py index 95e89d8..f332f2f 100644 --- a/gluefactory/scripts/export_megadepth.py +++ b/gluefactory/scripts/export_megadepth.py @@ -133,15 +133,18 @@ def run_export(feature_file, scene, args): conf = OmegaConf.create(conf) - keys = configs[args.method]["keys"] + ["depth_keypoints", "valid_depth_keypoints"] + keys = configs[args.method]["keys"] dataset = get_dataset(conf.data.name)(conf.data) loader = dataset.get_data_loader(conf.split or "test") device = "cuda" if torch.cuda.is_available() else "cpu" model = get_model(conf.model.name)(conf.model).eval().to(device) - callback_fn = None - # callback_fn=get_kp_depth # use this to store the depth of each keypoint + if args.export_sparse_depth: + callback_fn = get_kp_depth # use this to store the depth of each keypoint + keys = keys + ["depth_keypoints", "valid_depth_keypoints"] + else: + callback_fn = None export_predictions( loader, model, feature_file, as_half=True, keys=keys, callback_fn=callback_fn ) @@ -153,6 +156,7 @@ if __name__ == "__main__": parser.add_argument("--method", type=str, default="sp") parser.add_argument("--scenes", type=str, default=None) parser.add_argument("--num_workers", type=int, default=0) + parser.add_argument("--export_sparse_depth", action="store_true") args = parser.parse_args() export_name = configs[args.method]["name"]