Fix sparse depth export for megadepth (#13)
* Never export sparse depth * add sparse depth supportmain
parent
22154a60bc
commit
398c4b8c21
|
@ -29,6 +29,15 @@ def pad_local_features(pred: dict, seq_l: int):
|
|||
pred["scales"] = pad_to_length(pred["scales"], seq_l, -1, mode="zeros")
|
||||
if "oris" in pred.keys():
|
||||
pred["oris"] = pad_to_length(pred["oris"], seq_l, -1, mode="zeros")
|
||||
|
||||
if "depth_keypoints" in pred.keys():
|
||||
pred["depth_keypoints"] = pad_to_length(
|
||||
pred["depth_keypoints"], seq_l, -1, mode="zeros"
|
||||
)
|
||||
if "valid_depth_keypoints" in pred.keys():
|
||||
pred["valid_depth_keypoints"] = pad_to_length(
|
||||
pred["valid_depth_keypoints"], seq_l, -1, mode="zeros"
|
||||
)
|
||||
return pred
|
||||
|
||||
|
||||
|
|
|
@ -133,15 +133,18 @@ def run_export(feature_file, scene, args):
|
|||
|
||||
conf = OmegaConf.create(conf)
|
||||
|
||||
keys = configs[args.method]["keys"] + ["depth_keypoints", "valid_depth_keypoints"]
|
||||
keys = configs[args.method]["keys"]
|
||||
dataset = get_dataset(conf.data.name)(conf.data)
|
||||
loader = dataset.get_data_loader(conf.split or "test")
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = get_model(conf.model.name)(conf.model).eval().to(device)
|
||||
|
||||
callback_fn = None
|
||||
# callback_fn=get_kp_depth # use this to store the depth of each keypoint
|
||||
if args.export_sparse_depth:
|
||||
callback_fn = get_kp_depth # use this to store the depth of each keypoint
|
||||
keys = keys + ["depth_keypoints", "valid_depth_keypoints"]
|
||||
else:
|
||||
callback_fn = None
|
||||
export_predictions(
|
||||
loader, model, feature_file, as_half=True, keys=keys, callback_fn=callback_fn
|
||||
)
|
||||
|
@ -153,6 +156,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--method", type=str, default="sp")
|
||||
parser.add_argument("--scenes", type=str, default=None)
|
||||
parser.add_argument("--num_workers", type=int, default=0)
|
||||
parser.add_argument("--export_sparse_depth", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
export_name = configs[args.method]["name"]
|
||||
|
|
Loading…
Reference in New Issue