Fix sparse depth export for megadepth (#13)

* Never export sparse depth

* add sparse depth support
main
Philipp Lindenberger 2023-10-13 13:19:19 +02:00 committed by GitHub
parent 22154a60bc
commit 398c4b8c21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 3 deletions

View File

@ -29,6 +29,15 @@ def pad_local_features(pred: dict, seq_l: int):
pred["scales"] = pad_to_length(pred["scales"], seq_l, -1, mode="zeros") pred["scales"] = pad_to_length(pred["scales"], seq_l, -1, mode="zeros")
if "oris" in pred.keys(): if "oris" in pred.keys():
pred["oris"] = pad_to_length(pred["oris"], seq_l, -1, mode="zeros") pred["oris"] = pad_to_length(pred["oris"], seq_l, -1, mode="zeros")
if "depth_keypoints" in pred.keys():
pred["depth_keypoints"] = pad_to_length(
pred["depth_keypoints"], seq_l, -1, mode="zeros"
)
if "valid_depth_keypoints" in pred.keys():
pred["valid_depth_keypoints"] = pad_to_length(
pred["valid_depth_keypoints"], seq_l, -1, mode="zeros"
)
return pred return pred

View File

@ -133,15 +133,18 @@ def run_export(feature_file, scene, args):
conf = OmegaConf.create(conf) conf = OmegaConf.create(conf)
keys = configs[args.method]["keys"] + ["depth_keypoints", "valid_depth_keypoints"] keys = configs[args.method]["keys"]
dataset = get_dataset(conf.data.name)(conf.data) dataset = get_dataset(conf.data.name)(conf.data)
loader = dataset.get_data_loader(conf.split or "test") loader = dataset.get_data_loader(conf.split or "test")
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
model = get_model(conf.model.name)(conf.model).eval().to(device) model = get_model(conf.model.name)(conf.model).eval().to(device)
if args.export_sparse_depth:
callback_fn = get_kp_depth # use this to store the depth of each keypoint
keys = keys + ["depth_keypoints", "valid_depth_keypoints"]
else:
callback_fn = None callback_fn = None
# callback_fn=get_kp_depth # use this to store the depth of each keypoint
export_predictions( export_predictions(
loader, model, feature_file, as_half=True, keys=keys, callback_fn=callback_fn loader, model, feature_file, as_half=True, keys=keys, callback_fn=callback_fn
) )
@ -153,6 +156,7 @@ if __name__ == "__main__":
parser.add_argument("--method", type=str, default="sp") parser.add_argument("--method", type=str, default="sp")
parser.add_argument("--scenes", type=str, default=None) parser.add_argument("--scenes", type=str, default=None)
parser.add_argument("--num_workers", type=int, default=0) parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--export_sparse_depth", action="store_true")
args = parser.parse_args() args = parser.parse_args()
export_name = configs[args.method]["name"] export_name = configs[args.method]["name"]