diff --git a/gluefactory/eval/megadepth1500.py b/gluefactory/eval/megadepth1500.py index e359361..a9cb10a 100644 --- a/gluefactory/eval/megadepth1500.py +++ b/gluefactory/eval/megadepth1500.py @@ -1,3 +1,4 @@ +import logging import zipfile from collections import defaultdict from collections.abc import Iterable @@ -19,6 +20,8 @@ from .eval_pipeline import EvalPipeline from .io import get_eval_parser, load_model, parse_eval_args from .utils import eval_matches_epipolar, eval_poses, eval_relative_pose_robust +logger = logging.getLogger(__name__) + class MegaDepth1500Pipeline(EvalPipeline): default_conf = { @@ -56,11 +59,13 @@ class MegaDepth1500Pipeline(EvalPipeline): def _init(self, conf): if not (DATA_PATH / "megadepth1500").exists(): + logger.info("Downloading the MegaDepth-1500 dataset.") url = "https://cvg-data.inf.ethz.ch/megadepth/megadepth1500.zip" zip_path = DATA_PATH / url.rsplit("/", 1)[-1] + zip_path.parent.mkdir(exist_ok=True, parents=True) torch.hub.download_url_to_file(url, zip_path) - with zipfile.ZipFile(zip_path) as zip: - zip.extractall(DATA_PATH) + with zipfile.ZipFile(zip_path) as fid: + fid.extractall(DATA_PATH) zip_path.unlink() @classmethod @@ -147,6 +152,8 @@ class MegaDepth1500Pipeline(EvalPipeline): if __name__ == "__main__": + from .. import logger # overwrite the logger + dataset_name = Path(__file__).stem parser = get_eval_parser() args = parser.parse_intermixed_args()