Auto sort imports (#6)

* Add isort, merge check runs into one

* Run isort

* Ignor build in flake8 config

* Remove jupyter as dev dependency
main
Paul-Edouard Sarlin 2023-10-09 08:32:43 +02:00 committed by GitHub
parent 1709021473
commit 12640afd36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
69 changed files with 288 additions and 275 deletions

View File

@ -1,3 +1,4 @@
[flake8] [flake8]
max-line-length = 88 max-line-length = 88
extend-ignore = E203 extend-ignore = E203
exclude = .git,__pycache__,build,.venv/

View File

@ -8,24 +8,17 @@ on:
pull_request: pull_request:
types: [ assigned, opened, synchronize, reopened ] types: [ assigned, opened, synchronize, reopened ]
jobs: jobs:
formatting-check: check:
name: Formatting Check name: Format and Lint Checks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: psf/black@stable
with:
jupyter: true
linting-check:
name: Linting Check
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-python@v4 - uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: '3.10'
cache: 'pip' cache: 'pip'
- run: python -m pip install --upgrade pip - run: python -m pip install --upgrade pip
- run: python -m pip install . - run: python -m pip install .[dev]
- run: python -m pip install --upgrade flake8 - run: python -m flake8 .
- run: python -m flake8 . --exclude build/ - run: python -m isort . --check-only --diff
- run: python -m black . --check --diff

3
format.sh Executable file
View File

@ -0,0 +1,3 @@
python -m flake8 .
python -m isort .
python -m black .

View File

@ -1,4 +1,5 @@
import logging import logging
from .utils.experiments import load_experiment # noqa: F401 from .utils.experiments import load_experiment # noqa: F401
formatter = logging.Formatter( formatter = logging.Formatter(

View File

@ -1,6 +1,7 @@
import importlib.util import importlib.util
from .base_dataset import BaseDataset
from ..utils.tools import get_class from ..utils.tools import get_class
from .base_dataset import BaseDataset
def get_dataset(name): def get_dataset(name):

View File

@ -1,11 +1,11 @@
from typing import Union from typing import Union
import albumentations as A import albumentations as A
import cv2
import numpy as np import numpy as np
import torch import torch
from albumentations.pytorch.transforms import ToTensorV2 from albumentations.pytorch.transforms import ToTensorV2
from omegaconf import OmegaConf from omegaconf import OmegaConf
import cv2
class IdentityTransform(A.ImageOnlyTransform): class IdentityTransform(A.ImageOnlyTransform):

View File

@ -3,12 +3,13 @@ Base class for dataset.
See mnist.py for an example of dataset. See mnist.py for an example of dataset.
""" """
from abc import ABCMeta, abstractmethod
import collections import collections
import logging import logging
from omegaconf import OmegaConf from abc import ABCMeta, abstractmethod
import omegaconf import omegaconf
import torch import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader, Sampler, get_worker_info from torch.utils.data import DataLoader, Sampler, get_worker_info
from torch.utils.data._utils.collate import ( from torch.utils.data._utils.collate import (
default_collate_err_msg_format, default_collate_err_msg_format,

View File

@ -4,18 +4,18 @@ ETH3D multi-view benchmark, used for line matching evaluation.
import logging import logging
import os import os
import shutil import shutil
import numpy as np
import cv2
import torch
from pathlib import Path
import zipfile import zipfile
from pathlib import Path
import cv2
import numpy as np
import torch
from .base_dataset import BaseDataset
from .utils import scale_intrinsics
from ..geometry.wrappers import Camera, Pose from ..geometry.wrappers import Camera, Pose
from ..settings import DATA_PATH from ..settings import DATA_PATH
from ..utils.image import ImagePreprocessor, load_image from ..utils.image import ImagePreprocessor, load_image
from .base_dataset import BaseDataset
from .utils import scale_intrinsics
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -11,25 +11,25 @@ import tarfile
from pathlib import Path from pathlib import Path
import cv2 import cv2
import matplotlib.pyplot as plt
import numpy as np import numpy as np
import omegaconf import omegaconf
import torch import torch
import matplotlib.pyplot as plt
from omegaconf import OmegaConf from omegaconf import OmegaConf
from tqdm import tqdm from tqdm import tqdm
from .augmentations import IdentityAugmentation, augmentations
from .base_dataset import BaseDataset
from ..settings import DATA_PATH
from ..models.cache_loader import CacheLoader, pad_local_features
from ..utils.image import read_image
from ..geometry.homography import ( from ..geometry.homography import (
sample_homography_corners,
compute_homography, compute_homography,
sample_homography_corners,
warp_points, warp_points,
) )
from ..models.cache_loader import CacheLoader, pad_local_features
from ..settings import DATA_PATH
from ..utils.image import read_image
from ..utils.tools import fork_rng from ..utils.tools import fork_rng
from ..visualization.viz2d import plot_image_grid from ..visualization.viz2d import plot_image_grid
from .augmentations import IdentityAugmentation, augmentations
from .base_dataset import BaseDataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -4,16 +4,17 @@ Simply load images from a folder or nested folders (does not have any split).
import argparse import argparse
import logging import logging
import tarfile import tarfile
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from .base_dataset import BaseDataset
from ..settings import DATA_PATH from ..settings import DATA_PATH
from ..utils.image import load_image, ImagePreprocessor from ..utils.image import ImagePreprocessor, load_image
from ..utils.tools import fork_rng from ..utils.tools import fork_rng
from ..visualization.viz2d import plot_image_grid from ..visualization.viz2d import plot_image_grid
from .base_dataset import BaseDataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -2,13 +2,14 @@
Simply load images from a folder or nested folders (does not have any split). Simply load images from a folder or nested folders (does not have any split).
""" """
from pathlib import Path
import torch
import logging import logging
import omegaconf from pathlib import Path
import omegaconf
import torch
from ..utils.image import ImagePreprocessor, load_image
from .base_dataset import BaseDataset from .base_dataset import BaseDataset
from ..utils.image import load_image, ImagePreprocessor
class ImageFolder(BaseDataset, torch.utils.data.Dataset): class ImageFolder(BaseDataset, torch.utils.data.Dataset):

View File

@ -3,13 +3,14 @@ Simply load images from a folder or nested folders (does not have any split).
""" """
from pathlib import Path from pathlib import Path
import torch
import numpy as np
from .base_dataset import BaseDataset
from ..utils.image import load_image, ImagePreprocessor
from ..settings import DATA_PATH import numpy as np
import torch
from ..geometry.wrappers import Camera, Pose from ..geometry.wrappers import Camera, Pose
from ..settings import DATA_PATH
from ..utils.image import ImagePreprocessor, load_image
from .base_dataset import BaseDataset
def names_to_pair(name0, name1, separator="/"): def names_to_pair(name0, name1, separator="/"):

View File

@ -1,9 +1,9 @@
import argparse import argparse
import logging import logging
from pathlib import Path
from collections.abc import Iterable
import tarfile
import shutil import shutil
import tarfile
from collections.abc import Iterable
from pathlib import Path
import h5py import h5py
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -12,18 +12,14 @@ import PIL.Image
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from .base_dataset import BaseDataset
from .utils import (
scale_intrinsics,
rotate_intrinsics,
rotate_pose_inplane,
)
from ..geometry.wrappers import Camera, Pose from ..geometry.wrappers import Camera, Pose
from ..models.cache_loader import CacheLoader from ..models.cache_loader import CacheLoader
from ..utils.tools import fork_rng
from ..utils.image import load_image, ImagePreprocessor
from ..settings import DATA_PATH from ..settings import DATA_PATH
from ..visualization.viz2d import plot_image_grid, plot_heatmaps from ..utils.image import ImagePreprocessor, load_image
from ..utils.tools import fork_rng
from ..visualization.viz2d import plot_heatmaps, plot_image_grid
from .base_dataset import BaseDataset
from .utils import rotate_intrinsics, rotate_pose_inplane, scale_intrinsics
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
scene_lists_path = Path(__file__).parent / "megadepth_scene_lists" scene_lists_path = Path(__file__).parent / "megadepth_scene_lists"

View File

@ -1,4 +1,5 @@
import torch import torch
from ..utils.tools import get_class from ..utils.tools import get_class
from .eval_pipeline import EvalPipeline from .eval_pipeline import EvalPipeline

View File

@ -1,23 +1,18 @@
from pathlib import Path
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
from collections import defaultdict from collections import defaultdict
from tqdm import tqdm from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np import numpy as np
from omegaconf import OmegaConf
from tqdm import tqdm
from .io import (
parse_eval_args,
load_model,
get_eval_parser,
)
from .eval_pipeline import EvalPipeline, load_eval
from ..utils.export_predictions import export_predictions
from .utils import get_tp_fp_pts, aggregate_pr_results
from ..settings import EVAL_PATH
from ..models.cache_loader import CacheLoader
from ..datasets import get_dataset from ..datasets import get_dataset
from ..models.cache_loader import CacheLoader
from ..settings import EVAL_PATH
from ..utils.export_predictions import export_predictions
from .eval_pipeline import EvalPipeline, load_eval
from .io import get_eval_parser, load_model, parse_eval_args
from .utils import aggregate_pr_results, get_tp_fp_pts
def eval_dataset(loader, pred_file, suffix=""): def eval_dataset(loader, pred_file, suffix=""):

View File

@ -1,7 +1,8 @@
from omegaconf import OmegaConf
import numpy as np
import json import json
import h5py import h5py
import numpy as np
from omegaconf import OmegaConf
def load_eval(dir): def load_eval(dir):

View File

@ -1,31 +1,27 @@
from pathlib import Path
from omegaconf import OmegaConf
from pprint import pprint
import matplotlib.pyplot as plt
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from tqdm import tqdm from pathlib import Path
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np import numpy as np
from ..visualization.viz2d import plot_cumulative from omegaconf import OmegaConf
from tqdm import tqdm
from .io import (
parse_eval_args,
load_model,
get_eval_parser,
)
from ..utils.export_predictions import export_predictions
from ..settings import EVAL_PATH
from ..models.cache_loader import CacheLoader
from ..datasets import get_dataset from ..datasets import get_dataset
from .utils import ( from ..models.cache_loader import CacheLoader
eval_homography_robust, from ..settings import EVAL_PATH
eval_poses, from ..utils.export_predictions import export_predictions
eval_matches_homography,
eval_homography_dlt,
)
from ..utils.tools import AUCMetric from ..utils.tools import AUCMetric
from ..visualization.viz2d import plot_cumulative
from .eval_pipeline import EvalPipeline from .eval_pipeline import EvalPipeline
from .io import get_eval_parser, load_model, parse_eval_args
from .utils import (
eval_homography_dlt,
eval_homography_robust,
eval_matches_homography,
eval_poses,
)
class HPatchesPipeline(EvalPipeline): class HPatchesPipeline(EvalPipeline):

View File

@ -1,9 +1,10 @@
import argparse import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib
from pprint import pprint
from collections import defaultdict from collections import defaultdict
from pathlib import Path
from pprint import pprint
import matplotlib
import matplotlib.pyplot as plt
from ..settings import EVAL_PATH from ..settings import EVAL_PATH
from ..visualization.global_frame import GlobalFrame from ..visualization.global_frame import GlobalFrame
@ -11,7 +12,6 @@ from ..visualization.two_view_frame import TwoViewFrame
from . import get_benchmark from . import get_benchmark
from .eval_pipeline import load_eval from .eval_pipeline import load_eval
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("benchmark", type=str) parser.add_argument("benchmark", type=str)

View File

@ -1,13 +1,14 @@
import pkg_resources
from pathlib import Path
from typing import Optional
from omegaconf import OmegaConf
import argparse import argparse
from pathlib import Path
from pprint import pprint from pprint import pprint
from typing import Optional
import pkg_resources
from omegaconf import OmegaConf
from ..models import get_model from ..models import get_model
from ..utils.experiments import load_experiment
from ..settings import TRAINING_PATH from ..settings import TRAINING_PATH
from ..utils.experiments import load_experiment
def parse_config_path(name_or_path: Optional[str], defaults: str) -> Path: def parse_config_path(name_or_path: Optional[str], defaults: str) -> Path:

View File

@ -1,26 +1,23 @@
import torch import zipfile
from pathlib import Path
from omegaconf import OmegaConf
from pprint import pprint
import matplotlib.pyplot as plt
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from tqdm import tqdm from pathlib import Path
import zipfile from pprint import pprint
import numpy as np
from ..visualization.viz2d import plot_cumulative
from .io import (
parse_eval_args,
load_model,
get_eval_parser,
)
from ..utils.export_predictions import export_predictions
from ..settings import EVAL_PATH, DATA_PATH
from ..models.cache_loader import CacheLoader
from ..datasets import get_dataset
from .eval_pipeline import EvalPipeline
from .utils import eval_relative_pose_robust, eval_poses, eval_matches_epipolar import matplotlib.pyplot as plt
import numpy as np
import torch
from omegaconf import OmegaConf
from tqdm import tqdm
from ..datasets import get_dataset
from ..models.cache_loader import CacheLoader
from ..settings import DATA_PATH, EVAL_PATH
from ..utils.export_predictions import export_predictions
from ..visualization.viz2d import plot_cumulative
from .eval_pipeline import EvalPipeline
from .io import get_eval_parser, load_model, parse_eval_args
from .utils import eval_matches_epipolar, eval_poses, eval_relative_pose_robust
class MegaDepth1500Pipeline(EvalPipeline): class MegaDepth1500Pipeline(EvalPipeline):

View File

@ -1,11 +1,12 @@
import kornia
import numpy as np import numpy as np
import torch import torch
import kornia
from ..geometry.epipolar import relative_pose_error, generalized_epi_dist from ..geometry.epipolar import generalized_epi_dist, relative_pose_error
from ..geometry.homography import sym_homography_error, homography_corner_error
from ..geometry.gt_generation import IGNORE_FEATURE from ..geometry.gt_generation import IGNORE_FEATURE
from ..utils.tools import AUCMetric from ..geometry.homography import homography_corner_error, sym_homography_error
from ..robust_estimators import load_estimator from ..robust_estimators import load_estimator
from ..utils.tools import AUCMetric
def check_keys_recursive(d, pattern): def check_keys_recursive(d, pattern):

View File

@ -1,5 +1,5 @@
import torch
import kornia import kornia
import torch
from .utils import get_image_coords from .utils import get_image_coords
from .wrappers import Camera from .wrappers import Camera

View File

@ -1,7 +1,8 @@
import torch
from .utils import skew_symmetric, to_homogeneous
from .wrappers import Pose, Camera
import numpy as np import numpy as np
import torch
from .utils import skew_symmetric, to_homogeneous
from .wrappers import Camera, Pose
def T_to_E(T: Pose): def T_to_E(T: Pose):

View File

@ -2,9 +2,9 @@ import numpy as np
import torch import torch
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
from .homography import warp_points_torch from .depth import project, sample_depth
from .epipolar import T_to_E, sym_epipolar_distance_all from .epipolar import T_to_E, sym_epipolar_distance_all
from .depth import sample_depth, project from .homography import warp_points_torch
IGNORE_FEATURE = -2 IGNORE_FEATURE = -2
UNMATCHED_FEATURE = -1 UNMATCHED_FEATURE = -1

View File

@ -1,9 +1,10 @@
from typing import Tuple
import math import math
from typing import Tuple
import numpy as np import numpy as np
import torch import torch
from .utils import to_homogeneous, from_homogeneous from .utils import from_homogeneous, to_homogeneous
def flat2mat(H): def flat2mat(H):

View File

@ -6,13 +6,14 @@ Based on PyTorch tensors: differentiable, batched, with GPU support.
import functools import functools
import inspect import inspect
import math import math
from typing import Union, Tuple, List, Dict, NamedTuple, Optional from typing import Dict, List, NamedTuple, Optional, Tuple, Union
import torch
import numpy as np import numpy as np
import torch
from .utils import ( from .utils import (
distort_points,
J_distort_points, J_distort_points,
distort_points,
skew_symmetric, skew_symmetric,
so3exp_map, so3exp_map,
to_homogeneous, to_homogeneous,

View File

@ -1,6 +1,7 @@
import importlib.util import importlib.util
from .base_model import BaseModel
from ..utils.tools import get_class from ..utils.tools import get_class
from .base_model import BaseModel
def get_model(name): def get_model(name):

View File

@ -3,10 +3,11 @@ Base class for trainable models.
""" """
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from copy import copy
import omegaconf import omegaconf
from omegaconf import OmegaConf from omegaconf import OmegaConf
from torch import nn from torch import nn
from copy import copy
class MetaModel(ABCMeta): class MetaModel(ABCMeta):

View File

@ -1,11 +1,12 @@
import torch
import string import string
import h5py
from .base_model import BaseModel import h5py
from ..settings import DATA_PATH import torch
from ..datasets.base_dataset import collate from ..datasets.base_dataset import collate
from ..settings import DATA_PATH
from ..utils.tensor import batch_to_device from ..utils.tensor import batch_to_device
from .base_model import BaseModel
from .utils.misc import pad_to_length from .utils.misc import pad_to_length

View File

@ -1,10 +1,11 @@
from typing import Callable, Optional
import torch import torch
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision.models import resnet
from typing import Optional, Callable
from torch.nn.modules.utils import _pair
import torchvision import torchvision
from torch import nn
from torch.nn.modules.utils import _pair
from torchvision.models import resnet
from gluefactory.models.base_model import BaseModel from gluefactory.models.base_model import BaseModel

View File

@ -1,5 +1,5 @@
import torch
import kornia import kornia
import torch
from ..base_model import BaseModel from ..base_model import BaseModel
from ..utils.misc import pad_and_stack from ..utils.misc import pad_and_stack

View File

@ -1,6 +1,7 @@
import torch
import math import math
import torch
from ..base_model import BaseModel from ..base_model import BaseModel

View File

@ -1,5 +1,5 @@
import torch
import kornia import kornia
import torch
from ..base_model import BaseModel from ..base_model import BaseModel
from ..utils.misc import pad_to_length from ..utils.misc import pad_to_length

View File

@ -1,10 +1,8 @@
from omegaconf import OmegaConf
import torch.nn.functional as F import torch.nn.functional as F
from omegaconf import OmegaConf
from ..base_model import BaseModel
from .. import get_model from .. import get_model
from ..base_model import BaseModel
# from ...geometry.depth import sample_fmap
to_ctr = OmegaConf.to_container # convert DictConfig to dict to_ctr = OmegaConf.to_container # convert DictConfig to dict

View File

@ -1,12 +1,11 @@
import numpy as np
import torch
import pycolmap
from scipy.spatial import KDTree
from omegaconf import OmegaConf
import cv2 import cv2
import numpy as np
import pycolmap
import torch
from omegaconf import OmegaConf
from scipy.spatial import KDTree
from ..base_model import BaseModel from ..base_model import BaseModel
from ..utils.misc import pad_to_length from ..utils.misc import pad_to_length
EPS = 1e-6 EPS = 1e-6

View File

@ -5,11 +5,12 @@
The implementation of this model and its trained weights are made The implementation of this model and its trained weights are made
available under the MIT license. available under the MIT license.
""" """
import torch.nn as nn
import torch
from collections import OrderedDict from collections import OrderedDict
from types import SimpleNamespace from types import SimpleNamespace
import torch
import torch.nn as nn
from ..base_model import BaseModel from ..base_model import BaseModel
from ..utils.misc import pad_and_stack from ..utils.misc import pad_and_stack

View File

@ -1,9 +1,9 @@
import deeplsd.models.deeplsd_inference as deeplsd_inference
import numpy as np import numpy as np
import torch import torch
import deeplsd.models.deeplsd_inference as deeplsd_inference
from ..base_model import BaseModel
from ...settings import DATA_PATH from ...settings import DATA_PATH
from ..base_model import BaseModel
class DeepLSD(BaseModel): class DeepLSD(BaseModel):

View File

@ -1,8 +1,8 @@
import torch import torch
from sklearn.cluster import DBSCAN from sklearn.cluster import DBSCAN
from ..base_model import BaseModel
from .. import get_model from .. import get_model
from ..base_model import BaseModel
def sample_descriptors_corner_conv(keypoints, descriptors, s: int = 8): def sample_descriptors_corner_conv(keypoints, descriptors, s: int = 8):

View File

@ -1,10 +1,11 @@
from ..base_model import BaseModel
from ...geometry.gt_generation import (
gt_matches_from_pose_depth,
gt_line_matches_from_pose_depth,
)
import torch import torch
from ...geometry.gt_generation import (
gt_line_matches_from_pose_depth,
gt_matches_from_pose_depth,
)
from ..base_model import BaseModel
class DepthMatcher(BaseModel): class DepthMatcher(BaseModel):
default_conf = { default_conf = {

View File

@ -7,9 +7,9 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from ...settings import DATA_PATH
from ..base_model import BaseModel from ..base_model import BaseModel
from ..utils.metrics import matcher_metrics from ..utils.metrics import matcher_metrics
from ...settings import DATA_PATH
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
ETH_EPS = 1e-8 ETH_EPS = 1e-8

View File

@ -1,8 +1,8 @@
from ..base_model import BaseModel
from ...geometry.gt_generation import ( from ...geometry.gt_generation import (
gt_matches_from_homography,
gt_line_matches_from_homography, gt_line_matches_from_homography,
gt_matches_from_homography,
) )
from ..base_model import BaseModel
class HomographyMatcher(BaseModel): class HomographyMatcher(BaseModel):

View File

@ -1,15 +1,17 @@
import warnings import warnings
from pathlib import Path
from typing import Callable, List, Optional
import numpy as np import numpy as np
import torch import torch
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, List, Callable
from torch.utils.checkpoint import checkpoint
from omegaconf import OmegaConf from omegaconf import OmegaConf
from torch import nn
from torch.utils.checkpoint import checkpoint
from ...settings import DATA_PATH from ...settings import DATA_PATH
from ..utils.losses import NLLLoss from ..utils.losses import NLLLoss
from ..utils.metrics import matcher_metrics from ..utils.metrics import matcher_metrics
from pathlib import Path
FLASH_AVAILABLE = hasattr(F, "scaled_dot_product_attention") FLASH_AVAILABLE = hasattr(F, "scaled_dot_product_attention")

View File

@ -1,7 +1,8 @@
from ..base_model import BaseModel
from lightglue import LightGlue as LightGlue_ from lightglue import LightGlue as LightGlue_
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ..base_model import BaseModel
class LightGlue(BaseModel): class LightGlue(BaseModel):
default_conf = {"features": "superpoint", **LightGlue_.default_conf} default_conf = {"features": "superpoint", **LightGlue_.default_conf}

View File

@ -3,8 +3,9 @@ Nearest neighbor matcher for normalized descriptors.
Optionally apply the mutual check and threshold the distance or ratio. Optionally apply the mutual check and threshold the distance or ratio.
""" """
import torch
import logging import logging
import torch
import torch.nn.functional as F import torch.nn.functional as F
from ..base_model import BaseModel from ..base_model import BaseModel

View File

@ -9,9 +9,10 @@ Losses and metrics get accumulated accordingly.
If no triplet is found, this falls back to two_view_pipeline.py If no triplet is found, this falls back to two_view_pipeline.py
""" """
from .two_view_pipeline import TwoViewPipeline
import torch import torch
from ..utils.misc import get_twoview, stack_twoviews, unstack_twoviews from ..utils.misc import get_twoview, stack_twoviews, unstack_twoviews
from .two_view_pipeline import TwoViewPipeline
def has_triplet(data): def has_triplet(data):

View File

@ -11,9 +11,9 @@ that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched.
""" """
from omegaconf import OmegaConf from omegaconf import OmegaConf
from .base_model import BaseModel
from . import get_model
from . import get_model
from .base_model import BaseModel
to_ctr = OmegaConf.to_container # convert DictConfig to dict to_ctr = OmegaConf.to_container # convert DictConfig to dict

View File

@ -1,5 +1,6 @@
import math import math
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch

View File

@ -1,4 +1,5 @@
import inspect import inspect
from .base_estimator import BaseEstimator from .base_estimator import BaseEstimator

View File

@ -1,6 +1,7 @@
from omegaconf import OmegaConf
from copy import copy from copy import copy
from omegaconf import OmegaConf
class BaseEstimator: class BaseEstimator:
base_default_conf = { base_default_conf = {

View File

@ -1,6 +1,6 @@
import poselib import poselib
from omegaconf import OmegaConf
import torch import torch
from omegaconf import OmegaConf
from ..base_estimator import BaseEstimator from ..base_estimator import BaseEstimator

View File

@ -1,9 +1,9 @@
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from ...geometry.wrappers import Pose
from ...geometry.utils import from_homogeneous
from ...geometry.utils import from_homogeneous
from ...geometry.wrappers import Pose
from ..base_estimator import BaseEstimator from ..base_estimator import BaseEstimator

View File

@ -1,8 +1,8 @@
import poselib import poselib
from omegaconf import OmegaConf
import torch import torch
from ...geometry.wrappers import Pose from omegaconf import OmegaConf
from ...geometry.wrappers import Pose
from ..base_estimator import BaseEstimator from ..base_estimator import BaseEstimator

View File

@ -1,8 +1,8 @@
import pycolmap import pycolmap
from omegaconf import OmegaConf
import torch import torch
from ...geometry.wrappers import Pose from omegaconf import OmegaConf
from ...geometry.wrappers import Pose
from ..base_estimator import BaseEstimator from ..base_estimator import BaseEstimator

View File

@ -1,14 +1,14 @@
import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import argparse
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ..datasets import get_dataset
from ..models import get_model
from ..settings import DATA_PATH from ..settings import DATA_PATH
from ..utils.export_predictions import export_predictions from ..utils.export_predictions import export_predictions
from ..models import get_model
from ..datasets import get_dataset
resize = 1600 resize = 1600

View File

@ -1,14 +1,15 @@
import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import argparse
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ..settings import DATA_PATH
from ..utils.export_predictions import export_predictions
from ..models import get_model
from ..datasets import get_dataset from ..datasets import get_dataset
from ..geometry.depth import sample_depth from ..geometry.depth import sample_depth
from ..models import get_model
from ..settings import DATA_PATH
from ..utils.export_predictions import export_predictions
resize = 1024 resize = 1024
n_kpts = 2048 n_kpts = 2048

View File

@ -5,37 +5,37 @@ Author: Paul-Edouard Sarlin (skydes)
""" """
import argparse import argparse
from pathlib import Path
import signal
import re
import copy import copy
from collections import defaultdict import re
import shutil import shutil
import numpy as np import signal
from collections import defaultdict
from omegaconf import OmegaConf from pathlib import Path
from tqdm import tqdm
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from pydoc import locate from pydoc import locate
from .models import get_model import numpy as np
import torch
from omegaconf import OmegaConf
from torch.cuda.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from . import __module_name__, logger
from .datasets import get_dataset from .datasets import get_dataset
from .eval import run_benchmark
from .models import get_model
from .settings import EVAL_PATH, TRAINING_PATH
from .utils.experiments import get_best_checkpoint, get_last_checkpoint, save_experiment
from .utils.stdout_capturing import capture_outputs from .utils.stdout_capturing import capture_outputs
from .utils.tensor import batch_to_device
from .utils.tools import ( from .utils.tools import (
AverageMetric, AverageMetric,
MedianMetric, MedianMetric,
RecallMetric,
PRMetric, PRMetric,
set_seed, RecallMetric,
fork_rng, fork_rng,
set_seed,
) )
from .utils.tensor import batch_to_device
from .utils.experiments import get_last_checkpoint, get_best_checkpoint, save_experiment
from .eval import run_benchmark
from .settings import TRAINING_PATH, EVAL_PATH
from . import __module_name__, logger
# @TODO: Fix pbar pollution in logs # @TODO: Fix pbar pollution in logs
# @TODO: add plotting during evaluation # @TODO: add plotting during evaluation

View File

@ -1,7 +1,8 @@
import torch
import numpy as np
import time import time
import numpy as np
import torch
def benchmark(model, data, device, r=100): def benchmark(model, data, device, r=100):
timings = np.zeros((r, 1)) timings = np.zeros((r, 1))

View File

@ -4,16 +4,17 @@ A set of utilities to manage and load checkpoints of training experiments.
Author: Paul-Edouard Sarlin (skydes) Author: Paul-Edouard Sarlin (skydes)
""" """
from pathlib import Path
import logging import logging
import os
import re import re
import shutil import shutil
from omegaconf import OmegaConf from pathlib import Path
import torch
import os import torch
from omegaconf import OmegaConf
from ..settings import TRAINING_PATH
from ..models import get_model from ..models import get_model
from ..settings import TRAINING_PATH
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -4,11 +4,12 @@ Use a standalone script with `python3 -m dsfm.scipts.export_predictions dir`
or call from another script. or call from another script.
""" """
import torch
import numpy as np
from pathlib import Path from pathlib import Path
from tqdm import tqdm
import h5py import h5py
import numpy as np
import torch
from tqdm import tqdm
from .tensor import batch_to_device from .tensor import batch_to_device

View File

@ -1,10 +1,11 @@
from pathlib import Path
import torch
import kornia
import cv2
import numpy as np
from typing import Tuple, Optional
import collections.abc as collections import collections.abc as collections
from pathlib import Path
from typing import Optional, Tuple
import cv2
import kornia
import numpy as np
import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf

View File

@ -6,11 +6,12 @@ Author: Paul-Edouard Sarlin (skydes)
""" """
from __future__ import division, print_function, unicode_literals from __future__ import division, print_function, unicode_literals
import os import os
import sys
import subprocess import subprocess
from threading import Timer import sys
from contextlib import contextmanager from contextlib import contextmanager
from threading import Timer
def apply_backspaces_and_linefeeds(text): def apply_backspaces_and_linefeeds(text):

View File

@ -3,8 +3,9 @@ Author: Paul-Edouard Sarlin (skydes)
""" """
import collections.abc as collections import collections.abc as collections
import torch
import numpy as np import numpy as np
import torch
string_classes = (str, bytes) string_classes = (str, bytes)

View File

@ -4,13 +4,14 @@ Various handy Python and PyTorch utils.
Author: Paul-Edouard Sarlin (skydes) Author: Paul-Edouard Sarlin (skydes)
""" """
import time
import numpy as np
import os import os
import torch
import random import random
from contextlib import contextmanager import time
from collections.abc import Iterable from collections.abc import Iterable
from contextlib import contextmanager
import numpy as np
import torch
class AverageMetric: class AverageMetric:

View File

@ -1,14 +1,16 @@
import traceback
import numpy as np
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from matplotlib.widgets import Button
from copy import deepcopy
import functools import functools
import traceback
from copy import deepcopy
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.widgets import Button
from omegaconf import OmegaConf
from ..datasets.base_dataset import collate
# from ..eval.export_predictions import load_predictions # from ..eval.export_predictions import load_predictions
from ..models.cache_loader import CacheLoader from ..models.cache_loader import CacheLoader
from ..datasets.base_dataset import collate
from .tools import RadioHideTool from .tools import RadioHideTool

View File

@ -1,25 +1,25 @@
import inspect
import sys
import warnings
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch
from matplotlib.backend_tools import ToolToggleBase from matplotlib.backend_tools import ToolToggleBase
from matplotlib.widgets import RadioButtons, Slider from matplotlib.widgets import RadioButtons, Slider
import warnings
import torch
from ..geometry.epipolar import T_to_F, generalized_epi_dist
from ..geometry.homography import sym_homography_error
from ..visualization.viz2d import ( from ..visualization.viz2d import (
cm_ranking,
cm_RdGn,
draw_epipolar_line,
get_line,
plot_color_line_matches,
plot_heatmaps, plot_heatmaps,
plot_keypoints, plot_keypoints,
plot_lines, plot_lines,
plot_matches, plot_matches,
plot_color_line_matches,
cm_RdGn,
cm_ranking,
get_line,
draw_epipolar_line,
) )
from ..geometry.homography import sym_homography_error
from ..geometry.epipolar import generalized_epi_dist, T_to_F
import inspect
import sys
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")

View File

@ -1,10 +1,9 @@
import numpy as np
import pprint import pprint
import numpy as np
from . import viz2d from . import viz2d
from .tools import RadioHideTool, ToggleTool, __plot_dict__
from .tools import __plot_dict__
from .tools import RadioHideTool, ToggleTool
class FormatPrinter(pprint.PrettyPrinter): class FormatPrinter(pprint.PrettyPrinter):

View File

@ -1,13 +1,7 @@
import torch import torch
from ..utils.tensor import batch_to_device from ..utils.tensor import batch_to_device
from .viz2d import ( from .viz2d import cm_RdGn, plot_heatmaps, plot_image_grid, plot_keypoints, plot_matches
plot_image_grid,
plot_keypoints,
plot_matches,
cm_RdGn,
plot_heatmaps,
)
def make_match_figures(pred_, data_, n_pairs=2): def make_match_figures(pred_, data_, n_pairs=2):

View File

@ -6,8 +6,8 @@
""" """
import matplotlib import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects import matplotlib.patheffects as path_effects
import matplotlib.pyplot as plt
import numpy as np import numpy as np
import seaborn as sns import seaborn as sns

View File

@ -43,10 +43,14 @@ extra = [
"deeplsd @ git+https://github.com/cvg/DeepLSD.git", "deeplsd @ git+https://github.com/cvg/DeepLSD.git",
"homography_est @ git+https://github.com/rpautrat/homography_est.git", "homography_est @ git+https://github.com/rpautrat/homography_est.git",
] ]
dev = ["black", "flake8", "jupyter"] dev = ["black", "flake8", "isort"]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
include = ["gluefactory*"] include = ["gluefactory*"]
[tool.setuptools.package-data] [tool.setuptools.package-data]
gluefactory = ["datasets/megadepth_scene_lists/*.txt", "configs/*.yaml"] gluefactory = ["datasets/megadepth_scene_lists/*.txt", "configs/*.yaml"]
[tool.isort]
profile = "black"
extend_skip = ["gluefactory_nonfree/"]