From f7b587e881d33d4e9c1904d28f19c5f82c8d275c Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com> Date: Mon, 9 Oct 2023 16:30:21 +0200 Subject: [PATCH] Create data folder before downloading MegaDepth1500 (#8) --- gluefactory/eval/megadepth1500.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/gluefactory/eval/megadepth1500.py b/gluefactory/eval/megadepth1500.py index e359361..a9cb10a 100644 --- a/gluefactory/eval/megadepth1500.py +++ b/gluefactory/eval/megadepth1500.py @@ -1,3 +1,4 @@ +import logging import zipfile from collections import defaultdict from collections.abc import Iterable @@ -19,6 +20,8 @@ from .eval_pipeline import EvalPipeline from .io import get_eval_parser, load_model, parse_eval_args from .utils import eval_matches_epipolar, eval_poses, eval_relative_pose_robust +logger = logging.getLogger(__name__) + class MegaDepth1500Pipeline(EvalPipeline): default_conf = { @@ -56,11 +59,13 @@ class MegaDepth1500Pipeline(EvalPipeline): def _init(self, conf): if not (DATA_PATH / "megadepth1500").exists(): + logger.info("Downloading the MegaDepth-1500 dataset.") url = "https://cvg-data.inf.ethz.ch/megadepth/megadepth1500.zip" zip_path = DATA_PATH / url.rsplit("/", 1)[-1] + zip_path.parent.mkdir(exist_ok=True, parents=True) torch.hub.download_url_to_file(url, zip_path) - with zipfile.ZipFile(zip_path) as zip: - zip.extractall(DATA_PATH) + with zipfile.ZipFile(zip_path) as fid: + fid.extractall(DATA_PATH) zip_path.unlink() @classmethod @@ -147,6 +152,8 @@ class MegaDepth1500Pipeline(EvalPipeline): if __name__ == "__main__": + from .. import logger # overwrite the logger + dataset_name = Path(__file__).stem parser = get_eval_parser() args = parser.parse_intermixed_args()