From 12640afd36dd1a11b7d5f72e9259e074691e3961 Mon Sep 17 00:00:00 2001 From: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com> Date: Mon, 9 Oct 2023 08:32:43 +0200 Subject: [PATCH] Auto sort imports (#6) * Add isort, merge check runs into one * Run isort * Ignor build in flake8 config * Remove jupyter as dev dependency --- .flake8 | 1 + .github/workflows/code-quality.yml | 21 ++++------ format.sh | 3 ++ gluefactory/__init__.py | 1 + gluefactory/datasets/__init__.py | 3 +- gluefactory/datasets/augmentations.py | 2 +- gluefactory/datasets/base_dataset.py | 5 ++- gluefactory/datasets/eth3d.py | 14 +++---- gluefactory/datasets/homographies.py | 14 +++---- gluefactory/datasets/hpatches.py | 5 ++- gluefactory/datasets/image_folder.py | 9 +++-- gluefactory/datasets/image_pairs.py | 11 +++--- gluefactory/datasets/megadepth.py | 20 ++++------ gluefactory/eval/__init__.py | 1 + gluefactory/eval/eth3d.py | 27 ++++++------- gluefactory/eval/eval_pipeline.py | 5 ++- gluefactory/eval/hpatches.py | 38 +++++++++---------- gluefactory/eval/inspect.py | 10 ++--- gluefactory/eval/io.py | 11 +++--- gluefactory/eval/megadepth1500.py | 37 +++++++++--------- gluefactory/eval/utils.py | 9 +++-- gluefactory/geometry/depth.py | 2 +- gluefactory/geometry/epipolar.py | 7 ++-- gluefactory/geometry/gt_generation.py | 4 +- gluefactory/geometry/homography.py | 5 ++- gluefactory/geometry/wrappers.py | 7 ++-- gluefactory/models/__init__.py | 3 +- gluefactory/models/base_model.py | 3 +- gluefactory/models/cache_loader.py | 9 +++-- gluefactory/models/extractors/aliked.py | 9 +++-- gluefactory/models/extractors/disk_kornia.py | 2 +- .../models/extractors/grid_extractor.py | 3 +- .../extractors/keynet_affnet_hardnet.py | 2 +- gluefactory/models/extractors/mixed.py | 6 +-- gluefactory/models/extractors/sift.py | 11 +++--- .../models/extractors/superpoint_open.py | 5 ++- gluefactory/models/lines/deeplsd.py | 4 +- gluefactory/models/lines/wireframe.py | 2 +- gluefactory/models/matchers/depth_matcher.py | 11 +++--- gluefactory/models/matchers/gluestick.py | 2 +- .../models/matchers/homography_matcher.py | 4 +- gluefactory/models/matchers/lightglue.py | 10 +++-- .../models/matchers/lightglue_pretrained.py | 3 +- .../matchers/nearest_neighbor_matcher.py | 3 +- gluefactory/models/triplet_pipeline.py | 3 +- gluefactory/models/two_view_pipeline.py | 4 +- gluefactory/models/utils/misc.py | 1 + gluefactory/robust_estimators/__init__.py | 1 + .../robust_estimators/base_estimator.py | 3 +- .../robust_estimators/homography/poselib.py | 2 +- .../robust_estimators/relative_pose/opencv.py | 4 +- .../relative_pose/poselib.py | 4 +- .../relative_pose/pycolmap.py | 4 +- gluefactory/scripts/export_local_features.py | 8 ++-- gluefactory/scripts/export_megadepth.py | 9 +++-- gluefactory/train.py | 38 +++++++++---------- gluefactory/utils/benchmark.py | 5 ++- gluefactory/utils/experiments.py | 11 +++--- gluefactory/utils/export_predictions.py | 7 ++-- gluefactory/utils/image.py | 13 ++++--- gluefactory/utils/stdout_capturing.py | 5 ++- gluefactory/utils/tensor.py | 3 +- gluefactory/utils/tools.py | 9 +++-- gluefactory/visualization/global_frame.py | 16 ++++---- gluefactory/visualization/tools.py | 24 ++++++------ gluefactory/visualization/two_view_frame.py | 9 ++--- gluefactory/visualization/visualize_batch.py | 8 +--- gluefactory/visualization/viz2d.py | 2 +- pyproject.toml | 6 ++- 69 files changed, 288 insertions(+), 275 deletions(-) create mode 100755 format.sh diff --git a/.flake8 b/.flake8 index 8dd399a..899119f 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,4 @@ [flake8] max-line-length = 88 extend-ignore = E203 +exclude = .git,__pycache__,build,.venv/ diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index 41be302..368b225 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -8,24 +8,17 @@ on: pull_request: types: [ assigned, opened, synchronize, reopened ] jobs: - formatting-check: - name: Formatting Check - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: psf/black@stable - with: - jupyter: true - linting-check: - name: Linting Check + check: + name: Format and Lint Checks runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: '3.10' cache: 'pip' - run: python -m pip install --upgrade pip - - run: python -m pip install . - - run: python -m pip install --upgrade flake8 - - run: python -m flake8 . --exclude build/ + - run: python -m pip install .[dev] + - run: python -m flake8 . + - run: python -m isort . --check-only --diff + - run: python -m black . --check --diff diff --git a/format.sh b/format.sh new file mode 100755 index 0000000..5756e5f --- /dev/null +++ b/format.sh @@ -0,0 +1,3 @@ +python -m flake8 . +python -m isort . +python -m black . diff --git a/gluefactory/__init__.py b/gluefactory/__init__.py index b3d0115..0d83f92 100644 --- a/gluefactory/__init__.py +++ b/gluefactory/__init__.py @@ -1,4 +1,5 @@ import logging + from .utils.experiments import load_experiment # noqa: F401 formatter = logging.Formatter( diff --git a/gluefactory/datasets/__init__.py b/gluefactory/datasets/__init__.py index 2941a4c..ce05e9a 100644 --- a/gluefactory/datasets/__init__.py +++ b/gluefactory/datasets/__init__.py @@ -1,6 +1,7 @@ import importlib.util -from .base_dataset import BaseDataset + from ..utils.tools import get_class +from .base_dataset import BaseDataset def get_dataset(name): diff --git a/gluefactory/datasets/augmentations.py b/gluefactory/datasets/augmentations.py index ea726a0..bd39129 100644 --- a/gluefactory/datasets/augmentations.py +++ b/gluefactory/datasets/augmentations.py @@ -1,11 +1,11 @@ from typing import Union import albumentations as A +import cv2 import numpy as np import torch from albumentations.pytorch.transforms import ToTensorV2 from omegaconf import OmegaConf -import cv2 class IdentityTransform(A.ImageOnlyTransform): diff --git a/gluefactory/datasets/base_dataset.py b/gluefactory/datasets/base_dataset.py index aeb316a..ef622cb 100644 --- a/gluefactory/datasets/base_dataset.py +++ b/gluefactory/datasets/base_dataset.py @@ -3,12 +3,13 @@ Base class for dataset. See mnist.py for an example of dataset. """ -from abc import ABCMeta, abstractmethod import collections import logging -from omegaconf import OmegaConf +from abc import ABCMeta, abstractmethod + import omegaconf import torch +from omegaconf import OmegaConf from torch.utils.data import DataLoader, Sampler, get_worker_info from torch.utils.data._utils.collate import ( default_collate_err_msg_format, diff --git a/gluefactory/datasets/eth3d.py b/gluefactory/datasets/eth3d.py index e0cdf14..ca5e264 100644 --- a/gluefactory/datasets/eth3d.py +++ b/gluefactory/datasets/eth3d.py @@ -4,18 +4,18 @@ ETH3D multi-view benchmark, used for line matching evaluation. import logging import os import shutil - -import numpy as np -import cv2 -import torch -from pathlib import Path import zipfile +from pathlib import Path + +import cv2 +import numpy as np +import torch -from .base_dataset import BaseDataset -from .utils import scale_intrinsics from ..geometry.wrappers import Camera, Pose from ..settings import DATA_PATH from ..utils.image import ImagePreprocessor, load_image +from .base_dataset import BaseDataset +from .utils import scale_intrinsics logger = logging.getLogger(__name__) diff --git a/gluefactory/datasets/homographies.py b/gluefactory/datasets/homographies.py index f5a2131..08f7563 100644 --- a/gluefactory/datasets/homographies.py +++ b/gluefactory/datasets/homographies.py @@ -11,25 +11,25 @@ import tarfile from pathlib import Path import cv2 +import matplotlib.pyplot as plt import numpy as np import omegaconf import torch -import matplotlib.pyplot as plt from omegaconf import OmegaConf from tqdm import tqdm -from .augmentations import IdentityAugmentation, augmentations -from .base_dataset import BaseDataset -from ..settings import DATA_PATH -from ..models.cache_loader import CacheLoader, pad_local_features -from ..utils.image import read_image from ..geometry.homography import ( - sample_homography_corners, compute_homography, + sample_homography_corners, warp_points, ) +from ..models.cache_loader import CacheLoader, pad_local_features +from ..settings import DATA_PATH +from ..utils.image import read_image from ..utils.tools import fork_rng from ..visualization.viz2d import plot_image_grid +from .augmentations import IdentityAugmentation, augmentations +from .base_dataset import BaseDataset logger = logging.getLogger(__name__) diff --git a/gluefactory/datasets/hpatches.py b/gluefactory/datasets/hpatches.py index d3054cd..baf4ac8 100644 --- a/gluefactory/datasets/hpatches.py +++ b/gluefactory/datasets/hpatches.py @@ -4,16 +4,17 @@ Simply load images from a folder or nested folders (does not have any split). import argparse import logging import tarfile + import matplotlib.pyplot as plt import numpy as np import torch from omegaconf import OmegaConf -from .base_dataset import BaseDataset from ..settings import DATA_PATH -from ..utils.image import load_image, ImagePreprocessor +from ..utils.image import ImagePreprocessor, load_image from ..utils.tools import fork_rng from ..visualization.viz2d import plot_image_grid +from .base_dataset import BaseDataset logger = logging.getLogger(__name__) diff --git a/gluefactory/datasets/image_folder.py b/gluefactory/datasets/image_folder.py index 474a6c1..ecbd3ab 100644 --- a/gluefactory/datasets/image_folder.py +++ b/gluefactory/datasets/image_folder.py @@ -2,13 +2,14 @@ Simply load images from a folder or nested folders (does not have any split). """ -from pathlib import Path -import torch import logging -import omegaconf +from pathlib import Path +import omegaconf +import torch + +from ..utils.image import ImagePreprocessor, load_image from .base_dataset import BaseDataset -from ..utils.image import load_image, ImagePreprocessor class ImageFolder(BaseDataset, torch.utils.data.Dataset): diff --git a/gluefactory/datasets/image_pairs.py b/gluefactory/datasets/image_pairs.py index da0706a..08bd760 100644 --- a/gluefactory/datasets/image_pairs.py +++ b/gluefactory/datasets/image_pairs.py @@ -3,13 +3,14 @@ Simply load images from a folder or nested folders (does not have any split). """ from pathlib import Path -import torch -import numpy as np -from .base_dataset import BaseDataset -from ..utils.image import load_image, ImagePreprocessor -from ..settings import DATA_PATH +import numpy as np +import torch + from ..geometry.wrappers import Camera, Pose +from ..settings import DATA_PATH +from ..utils.image import ImagePreprocessor, load_image +from .base_dataset import BaseDataset def names_to_pair(name0, name1, separator="/"): diff --git a/gluefactory/datasets/megadepth.py b/gluefactory/datasets/megadepth.py index d4b6002..19a7586 100644 --- a/gluefactory/datasets/megadepth.py +++ b/gluefactory/datasets/megadepth.py @@ -1,9 +1,9 @@ import argparse import logging -from pathlib import Path -from collections.abc import Iterable -import tarfile import shutil +import tarfile +from collections.abc import Iterable +from pathlib import Path import h5py import matplotlib.pyplot as plt @@ -12,18 +12,14 @@ import PIL.Image import torch from omegaconf import OmegaConf -from .base_dataset import BaseDataset -from .utils import ( - scale_intrinsics, - rotate_intrinsics, - rotate_pose_inplane, -) from ..geometry.wrappers import Camera, Pose from ..models.cache_loader import CacheLoader -from ..utils.tools import fork_rng -from ..utils.image import load_image, ImagePreprocessor from ..settings import DATA_PATH -from ..visualization.viz2d import plot_image_grid, plot_heatmaps +from ..utils.image import ImagePreprocessor, load_image +from ..utils.tools import fork_rng +from ..visualization.viz2d import plot_heatmaps, plot_image_grid +from .base_dataset import BaseDataset +from .utils import rotate_intrinsics, rotate_pose_inplane, scale_intrinsics logger = logging.getLogger(__name__) scene_lists_path = Path(__file__).parent / "megadepth_scene_lists" diff --git a/gluefactory/eval/__init__.py b/gluefactory/eval/__init__.py index e072cf9..0d451e0 100644 --- a/gluefactory/eval/__init__.py +++ b/gluefactory/eval/__init__.py @@ -1,4 +1,5 @@ import torch + from ..utils.tools import get_class from .eval_pipeline import EvalPipeline diff --git a/gluefactory/eval/eth3d.py b/gluefactory/eval/eth3d.py index ef2b3a7..d2fe3a5 100644 --- a/gluefactory/eval/eth3d.py +++ b/gluefactory/eval/eth3d.py @@ -1,23 +1,18 @@ -from pathlib import Path -from omegaconf import OmegaConf -import matplotlib.pyplot as plt from collections import defaultdict -from tqdm import tqdm +from pathlib import Path + +import matplotlib.pyplot as plt import numpy as np +from omegaconf import OmegaConf +from tqdm import tqdm -from .io import ( - parse_eval_args, - load_model, - get_eval_parser, -) - -from .eval_pipeline import EvalPipeline, load_eval - -from ..utils.export_predictions import export_predictions -from .utils import get_tp_fp_pts, aggregate_pr_results -from ..settings import EVAL_PATH -from ..models.cache_loader import CacheLoader from ..datasets import get_dataset +from ..models.cache_loader import CacheLoader +from ..settings import EVAL_PATH +from ..utils.export_predictions import export_predictions +from .eval_pipeline import EvalPipeline, load_eval +from .io import get_eval_parser, load_model, parse_eval_args +from .utils import aggregate_pr_results, get_tp_fp_pts def eval_dataset(loader, pred_file, suffix=""): diff --git a/gluefactory/eval/eval_pipeline.py b/gluefactory/eval/eval_pipeline.py index 750969a..ac56237 100644 --- a/gluefactory/eval/eval_pipeline.py +++ b/gluefactory/eval/eval_pipeline.py @@ -1,7 +1,8 @@ -from omegaconf import OmegaConf -import numpy as np import json + import h5py +import numpy as np +from omegaconf import OmegaConf def load_eval(dir): diff --git a/gluefactory/eval/hpatches.py b/gluefactory/eval/hpatches.py index c714bf1..8be7b70 100644 --- a/gluefactory/eval/hpatches.py +++ b/gluefactory/eval/hpatches.py @@ -1,31 +1,27 @@ -from pathlib import Path -from omegaconf import OmegaConf -from pprint import pprint -import matplotlib.pyplot as plt from collections import defaultdict from collections.abc import Iterable -from tqdm import tqdm +from pathlib import Path +from pprint import pprint + +import matplotlib.pyplot as plt import numpy as np -from ..visualization.viz2d import plot_cumulative +from omegaconf import OmegaConf +from tqdm import tqdm -from .io import ( - parse_eval_args, - load_model, - get_eval_parser, -) -from ..utils.export_predictions import export_predictions -from ..settings import EVAL_PATH -from ..models.cache_loader import CacheLoader from ..datasets import get_dataset -from .utils import ( - eval_homography_robust, - eval_poses, - eval_matches_homography, - eval_homography_dlt, -) +from ..models.cache_loader import CacheLoader +from ..settings import EVAL_PATH +from ..utils.export_predictions import export_predictions from ..utils.tools import AUCMetric - +from ..visualization.viz2d import plot_cumulative from .eval_pipeline import EvalPipeline +from .io import get_eval_parser, load_model, parse_eval_args +from .utils import ( + eval_homography_dlt, + eval_homography_robust, + eval_matches_homography, + eval_poses, +) class HPatchesPipeline(EvalPipeline): diff --git a/gluefactory/eval/inspect.py b/gluefactory/eval/inspect.py index 913371b..1b7a392 100644 --- a/gluefactory/eval/inspect.py +++ b/gluefactory/eval/inspect.py @@ -1,9 +1,10 @@ import argparse -from pathlib import Path -import matplotlib.pyplot as plt -import matplotlib -from pprint import pprint from collections import defaultdict +from pathlib import Path +from pprint import pprint + +import matplotlib +import matplotlib.pyplot as plt from ..settings import EVAL_PATH from ..visualization.global_frame import GlobalFrame @@ -11,7 +12,6 @@ from ..visualization.two_view_frame import TwoViewFrame from . import get_benchmark from .eval_pipeline import load_eval - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("benchmark", type=str) diff --git a/gluefactory/eval/io.py b/gluefactory/eval/io.py index 93b7259..067e845 100644 --- a/gluefactory/eval/io.py +++ b/gluefactory/eval/io.py @@ -1,13 +1,14 @@ -import pkg_resources -from pathlib import Path -from typing import Optional -from omegaconf import OmegaConf import argparse +from pathlib import Path from pprint import pprint +from typing import Optional + +import pkg_resources +from omegaconf import OmegaConf from ..models import get_model -from ..utils.experiments import load_experiment from ..settings import TRAINING_PATH +from ..utils.experiments import load_experiment def parse_config_path(name_or_path: Optional[str], defaults: str) -> Path: diff --git a/gluefactory/eval/megadepth1500.py b/gluefactory/eval/megadepth1500.py index d9eb337..e359361 100644 --- a/gluefactory/eval/megadepth1500.py +++ b/gluefactory/eval/megadepth1500.py @@ -1,26 +1,23 @@ -import torch -from pathlib import Path -from omegaconf import OmegaConf -from pprint import pprint -import matplotlib.pyplot as plt +import zipfile from collections import defaultdict from collections.abc import Iterable -from tqdm import tqdm -import zipfile -import numpy as np -from ..visualization.viz2d import plot_cumulative -from .io import ( - parse_eval_args, - load_model, - get_eval_parser, -) -from ..utils.export_predictions import export_predictions -from ..settings import EVAL_PATH, DATA_PATH -from ..models.cache_loader import CacheLoader -from ..datasets import get_dataset -from .eval_pipeline import EvalPipeline +from pathlib import Path +from pprint import pprint -from .utils import eval_relative_pose_robust, eval_poses, eval_matches_epipolar +import matplotlib.pyplot as plt +import numpy as np +import torch +from omegaconf import OmegaConf +from tqdm import tqdm + +from ..datasets import get_dataset +from ..models.cache_loader import CacheLoader +from ..settings import DATA_PATH, EVAL_PATH +from ..utils.export_predictions import export_predictions +from ..visualization.viz2d import plot_cumulative +from .eval_pipeline import EvalPipeline +from .io import get_eval_parser, load_model, parse_eval_args +from .utils import eval_matches_epipolar, eval_poses, eval_relative_pose_robust class MegaDepth1500Pipeline(EvalPipeline): diff --git a/gluefactory/eval/utils.py b/gluefactory/eval/utils.py index 77adb8d..c6e6f00 100644 --- a/gluefactory/eval/utils.py +++ b/gluefactory/eval/utils.py @@ -1,11 +1,12 @@ +import kornia import numpy as np import torch -import kornia -from ..geometry.epipolar import relative_pose_error, generalized_epi_dist -from ..geometry.homography import sym_homography_error, homography_corner_error + +from ..geometry.epipolar import generalized_epi_dist, relative_pose_error from ..geometry.gt_generation import IGNORE_FEATURE -from ..utils.tools import AUCMetric +from ..geometry.homography import homography_corner_error, sym_homography_error from ..robust_estimators import load_estimator +from ..utils.tools import AUCMetric def check_keys_recursive(d, pattern): diff --git a/gluefactory/geometry/depth.py b/gluefactory/geometry/depth.py index ea2da60..ca68bc5 100644 --- a/gluefactory/geometry/depth.py +++ b/gluefactory/geometry/depth.py @@ -1,5 +1,5 @@ -import torch import kornia +import torch from .utils import get_image_coords from .wrappers import Camera diff --git a/gluefactory/geometry/epipolar.py b/gluefactory/geometry/epipolar.py index d7c7129..7e1507c 100644 --- a/gluefactory/geometry/epipolar.py +++ b/gluefactory/geometry/epipolar.py @@ -1,7 +1,8 @@ -import torch -from .utils import skew_symmetric, to_homogeneous -from .wrappers import Pose, Camera import numpy as np +import torch + +from .utils import skew_symmetric, to_homogeneous +from .wrappers import Camera, Pose def T_to_E(T: Pose): diff --git a/gluefactory/geometry/gt_generation.py b/gluefactory/geometry/gt_generation.py index 52acc0f..21390cd 100644 --- a/gluefactory/geometry/gt_generation.py +++ b/gluefactory/geometry/gt_generation.py @@ -2,9 +2,9 @@ import numpy as np import torch from scipy.optimize import linear_sum_assignment -from .homography import warp_points_torch +from .depth import project, sample_depth from .epipolar import T_to_E, sym_epipolar_distance_all -from .depth import sample_depth, project +from .homography import warp_points_torch IGNORE_FEATURE = -2 UNMATCHED_FEATURE = -1 diff --git a/gluefactory/geometry/homography.py b/gluefactory/geometry/homography.py index 7679bf9..3acb930 100644 --- a/gluefactory/geometry/homography.py +++ b/gluefactory/geometry/homography.py @@ -1,9 +1,10 @@ -from typing import Tuple import math +from typing import Tuple + import numpy as np import torch -from .utils import to_homogeneous, from_homogeneous +from .utils import from_homogeneous, to_homogeneous def flat2mat(H): diff --git a/gluefactory/geometry/wrappers.py b/gluefactory/geometry/wrappers.py index 0886f58..9d4a1b1 100644 --- a/gluefactory/geometry/wrappers.py +++ b/gluefactory/geometry/wrappers.py @@ -6,13 +6,14 @@ Based on PyTorch tensors: differentiable, batched, with GPU support. import functools import inspect import math -from typing import Union, Tuple, List, Dict, NamedTuple, Optional -import torch +from typing import Dict, List, NamedTuple, Optional, Tuple, Union + import numpy as np +import torch from .utils import ( - distort_points, J_distort_points, + distort_points, skew_symmetric, so3exp_map, to_homogeneous, diff --git a/gluefactory/models/__init__.py b/gluefactory/models/__init__.py index 5d3f71b..a9d1a05 100644 --- a/gluefactory/models/__init__.py +++ b/gluefactory/models/__init__.py @@ -1,6 +1,7 @@ import importlib.util -from .base_model import BaseModel + from ..utils.tools import get_class +from .base_model import BaseModel def get_model(name): diff --git a/gluefactory/models/base_model.py b/gluefactory/models/base_model.py index ed4e107..7313d98 100644 --- a/gluefactory/models/base_model.py +++ b/gluefactory/models/base_model.py @@ -3,10 +3,11 @@ Base class for trainable models. """ from abc import ABCMeta, abstractmethod +from copy import copy + import omegaconf from omegaconf import OmegaConf from torch import nn -from copy import copy class MetaModel(ABCMeta): diff --git a/gluefactory/models/cache_loader.py b/gluefactory/models/cache_loader.py index 40cc55d..3fbf0f7 100644 --- a/gluefactory/models/cache_loader.py +++ b/gluefactory/models/cache_loader.py @@ -1,11 +1,12 @@ -import torch import string -import h5py -from .base_model import BaseModel -from ..settings import DATA_PATH +import h5py +import torch + from ..datasets.base_dataset import collate +from ..settings import DATA_PATH from ..utils.tensor import batch_to_device +from .base_model import BaseModel from .utils.misc import pad_to_length diff --git a/gluefactory/models/extractors/aliked.py b/gluefactory/models/extractors/aliked.py index 45bc46f..80cd348 100644 --- a/gluefactory/models/extractors/aliked.py +++ b/gluefactory/models/extractors/aliked.py @@ -1,10 +1,11 @@ +from typing import Callable, Optional + import torch -from torch import nn import torch.nn.functional as F -from torchvision.models import resnet -from typing import Optional, Callable -from torch.nn.modules.utils import _pair import torchvision +from torch import nn +from torch.nn.modules.utils import _pair +from torchvision.models import resnet from gluefactory.models.base_model import BaseModel diff --git a/gluefactory/models/extractors/disk_kornia.py b/gluefactory/models/extractors/disk_kornia.py index b403b04..4d60973 100644 --- a/gluefactory/models/extractors/disk_kornia.py +++ b/gluefactory/models/extractors/disk_kornia.py @@ -1,5 +1,5 @@ -import torch import kornia +import torch from ..base_model import BaseModel from ..utils.misc import pad_and_stack diff --git a/gluefactory/models/extractors/grid_extractor.py b/gluefactory/models/extractors/grid_extractor.py index 882a125..dd221d9 100644 --- a/gluefactory/models/extractors/grid_extractor.py +++ b/gluefactory/models/extractors/grid_extractor.py @@ -1,6 +1,7 @@ -import torch import math +import torch + from ..base_model import BaseModel diff --git a/gluefactory/models/extractors/keynet_affnet_hardnet.py b/gluefactory/models/extractors/keynet_affnet_hardnet.py index 15f1dca..b9091ea 100644 --- a/gluefactory/models/extractors/keynet_affnet_hardnet.py +++ b/gluefactory/models/extractors/keynet_affnet_hardnet.py @@ -1,5 +1,5 @@ -import torch import kornia +import torch from ..base_model import BaseModel from ..utils.misc import pad_to_length diff --git a/gluefactory/models/extractors/mixed.py b/gluefactory/models/extractors/mixed.py index 3bef2a4..5524cb6 100644 --- a/gluefactory/models/extractors/mixed.py +++ b/gluefactory/models/extractors/mixed.py @@ -1,10 +1,8 @@ -from omegaconf import OmegaConf import torch.nn.functional as F +from omegaconf import OmegaConf -from ..base_model import BaseModel from .. import get_model - -# from ...geometry.depth import sample_fmap +from ..base_model import BaseModel to_ctr = OmegaConf.to_container # convert DictConfig to dict diff --git a/gluefactory/models/extractors/sift.py b/gluefactory/models/extractors/sift.py index 24d7b7b..5eb0c95 100644 --- a/gluefactory/models/extractors/sift.py +++ b/gluefactory/models/extractors/sift.py @@ -1,12 +1,11 @@ -import numpy as np -import torch -import pycolmap -from scipy.spatial import KDTree -from omegaconf import OmegaConf import cv2 +import numpy as np +import pycolmap +import torch +from omegaconf import OmegaConf +from scipy.spatial import KDTree from ..base_model import BaseModel - from ..utils.misc import pad_to_length EPS = 1e-6 diff --git a/gluefactory/models/extractors/superpoint_open.py b/gluefactory/models/extractors/superpoint_open.py index 8da32a4..1f96040 100644 --- a/gluefactory/models/extractors/superpoint_open.py +++ b/gluefactory/models/extractors/superpoint_open.py @@ -5,11 +5,12 @@ The implementation of this model and its trained weights are made available under the MIT license. """ -import torch.nn as nn -import torch from collections import OrderedDict from types import SimpleNamespace +import torch +import torch.nn as nn + from ..base_model import BaseModel from ..utils.misc import pad_and_stack diff --git a/gluefactory/models/lines/deeplsd.py b/gluefactory/models/lines/deeplsd.py index 72fb532..c35aa01 100644 --- a/gluefactory/models/lines/deeplsd.py +++ b/gluefactory/models/lines/deeplsd.py @@ -1,9 +1,9 @@ +import deeplsd.models.deeplsd_inference as deeplsd_inference import numpy as np import torch -import deeplsd.models.deeplsd_inference as deeplsd_inference -from ..base_model import BaseModel from ...settings import DATA_PATH +from ..base_model import BaseModel class DeepLSD(BaseModel): diff --git a/gluefactory/models/lines/wireframe.py b/gluefactory/models/lines/wireframe.py index c2d086c..ac0d0b5 100644 --- a/gluefactory/models/lines/wireframe.py +++ b/gluefactory/models/lines/wireframe.py @@ -1,8 +1,8 @@ import torch from sklearn.cluster import DBSCAN -from ..base_model import BaseModel from .. import get_model +from ..base_model import BaseModel def sample_descriptors_corner_conv(keypoints, descriptors, s: int = 8): diff --git a/gluefactory/models/matchers/depth_matcher.py b/gluefactory/models/matchers/depth_matcher.py index 1d22365..125ded2 100644 --- a/gluefactory/models/matchers/depth_matcher.py +++ b/gluefactory/models/matchers/depth_matcher.py @@ -1,10 +1,11 @@ -from ..base_model import BaseModel -from ...geometry.gt_generation import ( - gt_matches_from_pose_depth, - gt_line_matches_from_pose_depth, -) import torch +from ...geometry.gt_generation import ( + gt_line_matches_from_pose_depth, + gt_matches_from_pose_depth, +) +from ..base_model import BaseModel + class DepthMatcher(BaseModel): default_conf = { diff --git a/gluefactory/models/matchers/gluestick.py b/gluefactory/models/matchers/gluestick.py index 1df19b5..0187e0c 100644 --- a/gluefactory/models/matchers/gluestick.py +++ b/gluefactory/models/matchers/gluestick.py @@ -7,9 +7,9 @@ import torch import torch.utils.checkpoint from torch import nn +from ...settings import DATA_PATH from ..base_model import BaseModel from ..utils.metrics import matcher_metrics -from ...settings import DATA_PATH warnings.filterwarnings("ignore", category=UserWarning) ETH_EPS = 1e-8 diff --git a/gluefactory/models/matchers/homography_matcher.py b/gluefactory/models/matchers/homography_matcher.py index 3ef346e..d3642fb 100644 --- a/gluefactory/models/matchers/homography_matcher.py +++ b/gluefactory/models/matchers/homography_matcher.py @@ -1,8 +1,8 @@ -from ..base_model import BaseModel from ...geometry.gt_generation import ( - gt_matches_from_homography, gt_line_matches_from_homography, + gt_matches_from_homography, ) +from ..base_model import BaseModel class HomographyMatcher(BaseModel): diff --git a/gluefactory/models/matchers/lightglue.py b/gluefactory/models/matchers/lightglue.py index 8589fa1..7671f60 100644 --- a/gluefactory/models/matchers/lightglue.py +++ b/gluefactory/models/matchers/lightglue.py @@ -1,15 +1,17 @@ import warnings +from pathlib import Path +from typing import Callable, List, Optional + import numpy as np import torch -from torch import nn import torch.nn.functional as F -from typing import Optional, List, Callable -from torch.utils.checkpoint import checkpoint from omegaconf import OmegaConf +from torch import nn +from torch.utils.checkpoint import checkpoint + from ...settings import DATA_PATH from ..utils.losses import NLLLoss from ..utils.metrics import matcher_metrics -from pathlib import Path FLASH_AVAILABLE = hasattr(F, "scaled_dot_product_attention") diff --git a/gluefactory/models/matchers/lightglue_pretrained.py b/gluefactory/models/matchers/lightglue_pretrained.py index 034684a..2e7c71b 100644 --- a/gluefactory/models/matchers/lightglue_pretrained.py +++ b/gluefactory/models/matchers/lightglue_pretrained.py @@ -1,7 +1,8 @@ -from ..base_model import BaseModel from lightglue import LightGlue as LightGlue_ from omegaconf import OmegaConf +from ..base_model import BaseModel + class LightGlue(BaseModel): default_conf = {"features": "superpoint", **LightGlue_.default_conf} diff --git a/gluefactory/models/matchers/nearest_neighbor_matcher.py b/gluefactory/models/matchers/nearest_neighbor_matcher.py index b3ad427..7bbc8ae 100644 --- a/gluefactory/models/matchers/nearest_neighbor_matcher.py +++ b/gluefactory/models/matchers/nearest_neighbor_matcher.py @@ -3,8 +3,9 @@ Nearest neighbor matcher for normalized descriptors. Optionally apply the mutual check and threshold the distance or ratio. """ -import torch import logging + +import torch import torch.nn.functional as F from ..base_model import BaseModel diff --git a/gluefactory/models/triplet_pipeline.py b/gluefactory/models/triplet_pipeline.py index 9bcc8da..2538517 100644 --- a/gluefactory/models/triplet_pipeline.py +++ b/gluefactory/models/triplet_pipeline.py @@ -9,9 +9,10 @@ Losses and metrics get accumulated accordingly. If no triplet is found, this falls back to two_view_pipeline.py """ -from .two_view_pipeline import TwoViewPipeline import torch + from ..utils.misc import get_twoview, stack_twoviews, unstack_twoviews +from .two_view_pipeline import TwoViewPipeline def has_triplet(data): diff --git a/gluefactory/models/two_view_pipeline.py b/gluefactory/models/two_view_pipeline.py index 2f521e9..9c517dc 100644 --- a/gluefactory/models/two_view_pipeline.py +++ b/gluefactory/models/two_view_pipeline.py @@ -11,9 +11,9 @@ that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched. """ from omegaconf import OmegaConf -from .base_model import BaseModel -from . import get_model +from . import get_model +from .base_model import BaseModel to_ctr = OmegaConf.to_container # convert DictConfig to dict diff --git a/gluefactory/models/utils/misc.py b/gluefactory/models/utils/misc.py index 2cb03d6..e86d1ad 100644 --- a/gluefactory/models/utils/misc.py +++ b/gluefactory/models/utils/misc.py @@ -1,5 +1,6 @@ import math from typing import List, Optional, Tuple + import torch diff --git a/gluefactory/robust_estimators/__init__.py b/gluefactory/robust_estimators/__init__.py index f5a85cd..a9d9c9b 100644 --- a/gluefactory/robust_estimators/__init__.py +++ b/gluefactory/robust_estimators/__init__.py @@ -1,4 +1,5 @@ import inspect + from .base_estimator import BaseEstimator diff --git a/gluefactory/robust_estimators/base_estimator.py b/gluefactory/robust_estimators/base_estimator.py index a94e35b..29f8dd4 100644 --- a/gluefactory/robust_estimators/base_estimator.py +++ b/gluefactory/robust_estimators/base_estimator.py @@ -1,6 +1,7 @@ -from omegaconf import OmegaConf from copy import copy +from omegaconf import OmegaConf + class BaseEstimator: base_default_conf = { diff --git a/gluefactory/robust_estimators/homography/poselib.py b/gluefactory/robust_estimators/homography/poselib.py index 0edfe10..e99e949 100644 --- a/gluefactory/robust_estimators/homography/poselib.py +++ b/gluefactory/robust_estimators/homography/poselib.py @@ -1,6 +1,6 @@ import poselib -from omegaconf import OmegaConf import torch +from omegaconf import OmegaConf from ..base_estimator import BaseEstimator diff --git a/gluefactory/robust_estimators/relative_pose/opencv.py b/gluefactory/robust_estimators/relative_pose/opencv.py index b212ea3..34442a0 100644 --- a/gluefactory/robust_estimators/relative_pose/opencv.py +++ b/gluefactory/robust_estimators/relative_pose/opencv.py @@ -1,9 +1,9 @@ import cv2 import numpy as np import torch -from ...geometry.wrappers import Pose -from ...geometry.utils import from_homogeneous +from ...geometry.utils import from_homogeneous +from ...geometry.wrappers import Pose from ..base_estimator import BaseEstimator diff --git a/gluefactory/robust_estimators/relative_pose/poselib.py b/gluefactory/robust_estimators/relative_pose/poselib.py index 35ab87c..6c736e4 100644 --- a/gluefactory/robust_estimators/relative_pose/poselib.py +++ b/gluefactory/robust_estimators/relative_pose/poselib.py @@ -1,8 +1,8 @@ import poselib -from omegaconf import OmegaConf import torch -from ...geometry.wrappers import Pose +from omegaconf import OmegaConf +from ...geometry.wrappers import Pose from ..base_estimator import BaseEstimator diff --git a/gluefactory/robust_estimators/relative_pose/pycolmap.py b/gluefactory/robust_estimators/relative_pose/pycolmap.py index c7d0946..21cb272 100644 --- a/gluefactory/robust_estimators/relative_pose/pycolmap.py +++ b/gluefactory/robust_estimators/relative_pose/pycolmap.py @@ -1,8 +1,8 @@ import pycolmap -from omegaconf import OmegaConf import torch -from ...geometry.wrappers import Pose +from omegaconf import OmegaConf +from ...geometry.wrappers import Pose from ..base_estimator import BaseEstimator diff --git a/gluefactory/scripts/export_local_features.py b/gluefactory/scripts/export_local_features.py index 892f333..7f3f0a9 100644 --- a/gluefactory/scripts/export_local_features.py +++ b/gluefactory/scripts/export_local_features.py @@ -1,14 +1,14 @@ +import argparse import logging from pathlib import Path -import argparse + import torch from omegaconf import OmegaConf +from ..datasets import get_dataset +from ..models import get_model from ..settings import DATA_PATH from ..utils.export_predictions import export_predictions -from ..models import get_model -from ..datasets import get_dataset - resize = 1600 diff --git a/gluefactory/scripts/export_megadepth.py b/gluefactory/scripts/export_megadepth.py index c94caec..95e89d8 100644 --- a/gluefactory/scripts/export_megadepth.py +++ b/gluefactory/scripts/export_megadepth.py @@ -1,14 +1,15 @@ +import argparse import logging from pathlib import Path -import argparse + import torch from omegaconf import OmegaConf -from ..settings import DATA_PATH -from ..utils.export_predictions import export_predictions -from ..models import get_model from ..datasets import get_dataset from ..geometry.depth import sample_depth +from ..models import get_model +from ..settings import DATA_PATH +from ..utils.export_predictions import export_predictions resize = 1024 n_kpts = 2048 diff --git a/gluefactory/train.py b/gluefactory/train.py index 2d5b639..08895d7 100644 --- a/gluefactory/train.py +++ b/gluefactory/train.py @@ -5,37 +5,37 @@ Author: Paul-Edouard Sarlin (skydes) """ import argparse -from pathlib import Path -import signal -import re import copy -from collections import defaultdict +import re import shutil -import numpy as np - -from omegaconf import OmegaConf -from tqdm import tqdm -import torch -from torch.utils.tensorboard import SummaryWriter -from torch.cuda.amp import GradScaler, autocast +import signal +from collections import defaultdict +from pathlib import Path from pydoc import locate -from .models import get_model +import numpy as np +import torch +from omegaconf import OmegaConf +from torch.cuda.amp import GradScaler, autocast +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +from . import __module_name__, logger from .datasets import get_dataset +from .eval import run_benchmark +from .models import get_model +from .settings import EVAL_PATH, TRAINING_PATH +from .utils.experiments import get_best_checkpoint, get_last_checkpoint, save_experiment from .utils.stdout_capturing import capture_outputs +from .utils.tensor import batch_to_device from .utils.tools import ( AverageMetric, MedianMetric, - RecallMetric, PRMetric, - set_seed, + RecallMetric, fork_rng, + set_seed, ) -from .utils.tensor import batch_to_device -from .utils.experiments import get_last_checkpoint, get_best_checkpoint, save_experiment -from .eval import run_benchmark -from .settings import TRAINING_PATH, EVAL_PATH -from . import __module_name__, logger # @TODO: Fix pbar pollution in logs # @TODO: add plotting during evaluation diff --git a/gluefactory/utils/benchmark.py b/gluefactory/utils/benchmark.py index 401578b..99b4f85 100644 --- a/gluefactory/utils/benchmark.py +++ b/gluefactory/utils/benchmark.py @@ -1,7 +1,8 @@ -import torch -import numpy as np import time +import numpy as np +import torch + def benchmark(model, data, device, r=100): timings = np.zeros((r, 1)) diff --git a/gluefactory/utils/experiments.py b/gluefactory/utils/experiments.py index 849d0bc..7723fce 100644 --- a/gluefactory/utils/experiments.py +++ b/gluefactory/utils/experiments.py @@ -4,16 +4,17 @@ A set of utilities to manage and load checkpoints of training experiments. Author: Paul-Edouard Sarlin (skydes) """ -from pathlib import Path import logging +import os import re import shutil -from omegaconf import OmegaConf -import torch -import os +from pathlib import Path + +import torch +from omegaconf import OmegaConf -from ..settings import TRAINING_PATH from ..models import get_model +from ..settings import TRAINING_PATH logger = logging.getLogger(__name__) diff --git a/gluefactory/utils/export_predictions.py b/gluefactory/utils/export_predictions.py index 084227f..1157a52 100644 --- a/gluefactory/utils/export_predictions.py +++ b/gluefactory/utils/export_predictions.py @@ -4,11 +4,12 @@ Use a standalone script with `python3 -m dsfm.scipts.export_predictions dir` or call from another script. """ -import torch -import numpy as np from pathlib import Path -from tqdm import tqdm + import h5py +import numpy as np +import torch +from tqdm import tqdm from .tensor import batch_to_device diff --git a/gluefactory/utils/image.py b/gluefactory/utils/image.py index 1e6a7e2..1a9b125 100644 --- a/gluefactory/utils/image.py +++ b/gluefactory/utils/image.py @@ -1,10 +1,11 @@ -from pathlib import Path -import torch -import kornia -import cv2 -import numpy as np -from typing import Tuple, Optional import collections.abc as collections +from pathlib import Path +from typing import Optional, Tuple + +import cv2 +import kornia +import numpy as np +import torch from omegaconf import OmegaConf diff --git a/gluefactory/utils/stdout_capturing.py b/gluefactory/utils/stdout_capturing.py index 9baef92..bfa2b83 100644 --- a/gluefactory/utils/stdout_capturing.py +++ b/gluefactory/utils/stdout_capturing.py @@ -6,11 +6,12 @@ Author: Paul-Edouard Sarlin (skydes) """ from __future__ import division, print_function, unicode_literals + import os -import sys import subprocess -from threading import Timer +import sys from contextlib import contextmanager +from threading import Timer def apply_backspaces_and_linefeeds(text): diff --git a/gluefactory/utils/tensor.py b/gluefactory/utils/tensor.py index a20c641..f31bb58 100644 --- a/gluefactory/utils/tensor.py +++ b/gluefactory/utils/tensor.py @@ -3,8 +3,9 @@ Author: Paul-Edouard Sarlin (skydes) """ import collections.abc as collections -import torch + import numpy as np +import torch string_classes = (str, bytes) diff --git a/gluefactory/utils/tools.py b/gluefactory/utils/tools.py index 21541e6..6a27f4a 100644 --- a/gluefactory/utils/tools.py +++ b/gluefactory/utils/tools.py @@ -4,13 +4,14 @@ Various handy Python and PyTorch utils. Author: Paul-Edouard Sarlin (skydes) """ -import time -import numpy as np import os -import torch import random -from contextlib import contextmanager +import time from collections.abc import Iterable +from contextlib import contextmanager + +import numpy as np +import torch class AverageMetric: diff --git a/gluefactory/visualization/global_frame.py b/gluefactory/visualization/global_frame.py index 41d33ec..a403c9c 100644 --- a/gluefactory/visualization/global_frame.py +++ b/gluefactory/visualization/global_frame.py @@ -1,14 +1,16 @@ -import traceback -import numpy as np -import matplotlib.pyplot as plt -from omegaconf import OmegaConf -from matplotlib.widgets import Button -from copy import deepcopy import functools +import traceback +from copy import deepcopy + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.widgets import Button +from omegaconf import OmegaConf + +from ..datasets.base_dataset import collate # from ..eval.export_predictions import load_predictions from ..models.cache_loader import CacheLoader -from ..datasets.base_dataset import collate from .tools import RadioHideTool diff --git a/gluefactory/visualization/tools.py b/gluefactory/visualization/tools.py index 1415807..a095d06 100644 --- a/gluefactory/visualization/tools.py +++ b/gluefactory/visualization/tools.py @@ -1,25 +1,25 @@ +import inspect +import sys +import warnings + import matplotlib.pyplot as plt +import torch from matplotlib.backend_tools import ToolToggleBase from matplotlib.widgets import RadioButtons, Slider -import warnings -import torch +from ..geometry.epipolar import T_to_F, generalized_epi_dist +from ..geometry.homography import sym_homography_error from ..visualization.viz2d import ( + cm_ranking, + cm_RdGn, + draw_epipolar_line, + get_line, + plot_color_line_matches, plot_heatmaps, plot_keypoints, plot_lines, plot_matches, - plot_color_line_matches, - cm_RdGn, - cm_ranking, - get_line, - draw_epipolar_line, ) -from ..geometry.homography import sym_homography_error -from ..geometry.epipolar import generalized_epi_dist, T_to_F - -import inspect -import sys with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/gluefactory/visualization/two_view_frame.py b/gluefactory/visualization/two_view_frame.py index fac2222..3461eb0 100644 --- a/gluefactory/visualization/two_view_frame.py +++ b/gluefactory/visualization/two_view_frame.py @@ -1,10 +1,9 @@ -import numpy as np import pprint + +import numpy as np + from . import viz2d - -from .tools import __plot_dict__ - -from .tools import RadioHideTool, ToggleTool +from .tools import RadioHideTool, ToggleTool, __plot_dict__ class FormatPrinter(pprint.PrettyPrinter): diff --git a/gluefactory/visualization/visualize_batch.py b/gluefactory/visualization/visualize_batch.py index 09bdcbf..3bd3f7b 100644 --- a/gluefactory/visualization/visualize_batch.py +++ b/gluefactory/visualization/visualize_batch.py @@ -1,13 +1,7 @@ import torch from ..utils.tensor import batch_to_device -from .viz2d import ( - plot_image_grid, - plot_keypoints, - plot_matches, - cm_RdGn, - plot_heatmaps, -) +from .viz2d import cm_RdGn, plot_heatmaps, plot_image_grid, plot_keypoints, plot_matches def make_match_figures(pred_, data_, n_pairs=2): diff --git a/gluefactory/visualization/viz2d.py b/gluefactory/visualization/viz2d.py index 4a3a636..42a000a 100644 --- a/gluefactory/visualization/viz2d.py +++ b/gluefactory/visualization/viz2d.py @@ -6,8 +6,8 @@ """ import matplotlib -import matplotlib.pyplot as plt import matplotlib.patheffects as path_effects +import matplotlib.pyplot as plt import numpy as np import seaborn as sns diff --git a/pyproject.toml b/pyproject.toml index b0cc6d7..5185a75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,10 +43,14 @@ extra = [ "deeplsd @ git+https://github.com/cvg/DeepLSD.git", "homography_est @ git+https://github.com/rpautrat/homography_est.git", ] -dev = ["black", "flake8", "jupyter"] +dev = ["black", "flake8", "isort"] [tool.setuptools.packages.find] include = ["gluefactory*"] [tool.setuptools.package-data] gluefactory = ["datasets/megadepth_scene_lists/*.txt", "configs/*.yaml"] + +[tool.isort] +profile = "black" +extend_skip = ["gluefactory_nonfree/"]