Initial commit

Co-authored-by: Philipp Lindenberger <phil.lindenberger@gmail.com>
Co-authored-by: Iago Suárez <iago_h92@hotmail.com>
Co-authored-by: Paul-Edouard Sarlin <paul.edouard.sarlin@gmail.com>
main
Rémi Pautrat 2023-10-05 16:53:51 +02:00
commit 55c4fbd454
124 changed files with 22069 additions and 0 deletions

3
.flake8 Normal file
View File

@ -0,0 +1,3 @@
[flake8]
max-line-length = 88
extend-ignore = E203

31
.github/workflows/code-quality.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: Format and Lint Checks
on:
push:
branches:
- main
paths:
- '*.py'
pull_request:
types: [ assigned, opened, synchronize, reopened ]
jobs:
formatting-check:
name: Formatting Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: psf/black@stable
with:
jupyter: true
linting-check:
name: Linting Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.10"
cache: 'pip'
- run: python -m pip install --upgrade pip
- run: python -m pip install .
- run: python -m pip install --upgrade flake8
- run: python -m flake8 . --exclude build/

9
.gitignore vendored Normal file
View File

@ -0,0 +1,9 @@
.venv
/build/
*.egg-info
*.pyc
/.idea/
/venv/
/data/
/outputs/
__pycache__

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

335
README.md Normal file
View File

@ -0,0 +1,335 @@
# Glue Factory
Glue Factory is CVG's library for training and evaluating deep neural network that extract and match local visual feature. It enables you to:
- Reproduce the training of state-of-the-art models for point and line matching, like [LightGlue](https://github.com/cvg/LightGlue) and [GlueStick](https://github.com/cvg/GlueStick) (ICCV 2023)
- Train these models on multiple datasets using your own local features or lines
- Evaluate feature extractors or matchers on standard benchmarks like HPatches or MegaDepth-1500
<p align="center">
<a href="https://github.com/cvg/LightGlue"><img src="docs/lightglue_matches.svg" width="60%"/></a>
<a href="https://github.com/cvg/GlueStick"><img src="docs/gluestick_img.svg" width="60%"/></a>
<br /><em>Point and line matching with LightGlue and GlueStick.</em>
</p>
## Installation
Glue Factory runs with Python 3 and [PyTorch](https://pytorch.org/). The following installs the library and its basic dependencies:
```bash
git clone https://github.com/cvg/glue-factory
cd glue-factory
python3 -m pip install -e . # editable mode
```
Some advanced features might require installing the full set of dependencies:
```bash
python3 -m pip install -e .[extra]
```
All models and datasets in gluefactory have auto-downloaders, so you can get started right away!
## License
The code and trained models in Glue Factory are released with an Apache-2.0 license. This includes LightGlue trained with an [open version of SuperPoint](https://github.com/rpautrat/SuperPoint). Third-party models that are not compatible with this license, such as SuperPoint (original) and SuperGlue, are provided in `gluefactory_nonfree`, where each model might follow its own, restrictive license.
## Evaluation
#### HPatches
Running the evaluation commands automatically downloads the dataset, by default to the directory `data/`. You will need about 1.8 GB of free disk space.
<details>
<summary>[Evaluating LightGlue]</summary>
To evaluate the pre-trained SuperPoint+LightGlue model on HPatches, run:
```bash
python -m gluefactory.eval.hpatches --conf superpoint+lightglue-official --overwrite
```
You should expect the following results
```
{'H_error_dlt@1px': 0.3515,
'H_error_dlt@3px': 0.6723,
'H_error_dlt@5px': 0.7756,
'H_error_ransac@1px': 0.3428,
'H_error_ransac@3px': 0.5763,
'H_error_ransac@5px': 0.6943,
'mnum_keypoints': 1024.0,
'mnum_matches': 560.756,
'mprec@1px': 0.337,
'mprec@3px': 0.89,
'mransac_inl': 130.081,
'mransac_inl%': 0.217,
'ransac_mAA': 0.5378}
```
The default robust estimator is `opencv`, but we strongly recommend to use `poselib` instead:
```bash
python -m gluefactory.eval.hpatches --conf superpoint+lightglue-official --overwrite \
eval.estimator=poselib eval.ransac_th=-1
```
Setting `eval.ransac_th=-1` auto-tunes the RANSAC inlier threshold by running the evaluation with a range of thresholds and reports results for the optimal value.
Here are the results as Area Under the Curve (AUC) of the homography error at 1/3/5 pixels:
| Methods | DLT | [OpenCV](../gluefactory/robust_estimators/homography/opencv.py) | [PoseLib](../gluefactory/robust_estimators/homography/poselib.py) |
| ------------------------------------------------------------ | ------------------ | ------------------ | ------------------ |
| [SuperPoint + SuperGlue](../gluefactory/configs/superpoint+superglue.yaml) | 32.1 / 65.0 / 75.7 | 32.9 / 55.7 / 68.0 | 37.0 / 68.2 / 78.7 |
| [SuperPoint + LightGlue](../gluefactory/configs/superpoint+lightglue.yaml) | 35.1 / 67.2 / 77.6 | 34.2 / 57.9 / 69.9 | 37.1 / 67.4 / 77.8 |
</details>
<details>
<summary>[Evaluating GlueStick]</summary>
To evaluate GlueStick on HPatches, run:
```bash
python -m gluefactory.eval.hpatches --conf gluefactory/configs/superpoint+lsd+gluestick.yaml --overwrite
```
You should expect the following results
```
{"mprec@1px": 0.245,
"mprec@3px": 0.838,
"mnum_matches": 1290.5,
"mnum_keypoints": 2287.5,
"mH_error_dlt": null,
"H_error_dlt@1px": 0.3355,
"H_error_dlt@3px": 0.6637,
"H_error_dlt@5px": 0.7713,
"H_error_ransac@1px": 0.3915,
"H_error_ransac@3px": 0.6972,
"H_error_ransac@5px": 0.7955,
"H_error_ransac_mAA": 0.62806,
"mH_error_ransac": null}
```
Since we use points and lines to solve for the homography, we use a different robust estimator here: [Hest](https://github.com/rpautrat/homography_est/). Here are the results as Area Under the Curve (AUC) of the homography error at 1/3/5 pixels:
| Methods | DLT | [Hest](gluefactory/robust_estimators/homography/homography_est.py) |
| ------------------------------------------------------------ | ------------------ | ------------------ |
| [SP + LSD + GlueStick](gluefactory/configs/superpoint+lsd+gluestick.yaml) | 33.6 / 66.4 / 77.1 | 39.2 / 69.7 / 79.6 |
</details>
#### MegaDepth-1500
Running the evaluation commands automatically downloads the dataset, which takes about 1.5 GB of disk space.
<details>
<summary>[Evaluating LightGlue]</summary>
To evaluate the pre-trained SuperPoint+LightGlue model on MegaDepth-1500, run:
```bash
python -m gluefactory.eval.megadepth1500 --conf superpoint+lightglue-official
# or the adaptive variant
python -m gluefactory.eval.megadepth1500 --conf superpoint+lightglue-official \
model.matcher.{depth_confidence=0.95,width_confidence=0.95}
```
The first command should print the following results
```
{'mepi_prec@1e-3': 0.795,
'mepi_prec@1e-4': 0.15,
'mepi_prec@5e-4': 0.567,
'mnum_keypoints': 2048.0,
'mnum_matches': 613.287,
'mransac_inl': 280.518,
'mransac_inl%': 0.442,
'rel_pose_error@10°': 0.681,
'rel_pose_error@20°': 0.8065,
'rel_pose_error@5°': 0.5102,
'ransac_mAA': 0.6659}
```
To use the PoseLib estimator:
```bash
python -m gluefactory.eval.megadepth1500 --conf superpoint+lightglue-official \
eval.estimator=poselib eval.ransac_th=2.0
```
</details>
<details>
<summary>[Evaluating GlueStick]</summary>
To evaluate the pre-trained SuperPoint+GlueStick model on MegaDepth-1500, run:
```bash
python -m gluefactory.eval.megadepth1500 --conf gluefactory/configs/superpoint+lsd+gluestick.yaml
```
</details>
<details>
Here are the results as Area Under the Curve (AUC) of the pose error at 5/10/20 degrees:
| Methods | [pycolmap](../gluefactory/robust_estimators/relative_pose/pycolmap.py) | [OpenCV](../gluefactory/robust_estimators/relative_pose/opencv.py) | [PoseLib](../gluefactory/robust_estimators/relative_pose/poselib.py) |
| ------------------------------------------------------------ | ------------------ | ------------------ | ------------------ |
| [SuperPoint + SuperGlue](../gluefactory/configs/superpoint+superglue.yaml) | 54.4 / 70.4 / 82.4 | 48.7 / 65.6 / 79.0 | 64.8 / 77.9 / 87.0 |
| [SuperPoint + LightGlue](../gluefactory/configs/superpoint+lightglue.yaml) | 56.7 / 72.4 / 83.7 | 51.0 / 68.1 / 80.7 | 66.8 / 79.3 / 87.9 |
| [SuperPoint + GlueStick](../gluefactory/configs/superpoint+lsd+gluestick.yaml) | 53.2 / 69.8 / 81.9 | 46.3 / 64.2 / 78.1 | 64.4 / 77.5 / 86.5 |
</details>
#### ETH3D
The dataset will be auto-downloaded if it is not found on disk, and will need about 6 GB of free disk space.
<details>
<summary>[Evaluating GlueStick]</summary>
To evaluate GlueStick on ETH3D, run:
```bash
python -m gluefactory.eval.eth3d --conf gluefactory/configs/superpoint+lsd+gluestick.yaml
```
You should expect the following results
```
AP: 77.92
AP_lines: 69.22
```
</details>
#### Image Matching Challenge 2021
Coming soon!
#### Image Matching Challenge 2023
Coming soon!
#### Visual inspection
<details>
To inspect the evaluation visually, you can run:
```bash
python -m gluefactory.eval.inspect hpatches superpoint+lightglue-official
```
Click on a point to visualize matches on this pair.
To compare multiple methods on a dataset:
```bash
python -m gluefactory.eval.inspect hpatches superpoint+lightglue-official superpoint+superglue-official
```
All current benchmarks are supported by the viewer.
</details>
Detailed evaluation instructions can be found [here](./docs/evaluation.md).
## Training
We generally follow a two-stage training:
1. Pre-train on a large dataset of synthetic homographies applied to internet images. We use the 1M-image distractor set of the Oxford-Paris retrieval dataset. It requires about 450 GB of disk space.
2. Fine-tune on the MegaDepth dataset, which is based on PhotoTourism pictures of popular landmarks around the world. It exhibits more complex and realistic appearance and viewpoint changes. It requires about 420 GB of disk space.
All training commands automatically download the datasets.
<details>
<summary>[Training LightGlue]</summary>
We show how to train LightGlue with [SuperPoint open](https://github.com/rpautrat/SuperPoint).
We first pre-train LightGlue on the homography dataset:
```bash
python -m gluefactory.train sp+lg_homography \ # experiment name
--conf gluefactory/configs/superpoint-open+lightglue_homography.yaml
```
Feel free to use any other experiment name. By default the checkpoints are written to `outputs/training/`. The default batch size of 128 corresponds to the results reported in the paper and requires 2x 3090 GPUs with 24GB of VRAM each as well as PyTorch >= 2.0 (FlashAttention).
Configurations are managed by [OmegaConf](https://omegaconf.readthedocs.io/) so any entry can be overridden from the command line.
If you have PyTorch < 2.0 or weaker GPUs, you may thus need to reduce the batch size via:
```bash
python -m gluefactory.train sp+lg_homography \
--conf gluefactory/configs/superpoint-open+lightglue_homography.yaml \
data.batch_size=32 # for 1x 1080 GPU
```
Be aware that this can impact the overall performance. You might need to adjust the learning rate accordingly.
We then fine-tune the model on the MegaDepth dataset:
```bash
python -m gluefactory.train sp+lg_megadepth \
--conf gluefactory/configs/superpoint-open+lightglue_megadepth.yaml \
train.load_experiment=sp+lg_homography
```
Here the default batch size is 32. To speed up training on MegaDepth, we suggest to cache the local features before training (requires around 150 GB of disk space):
```bash
# extract features
python -m gluefactory.scripts.export_megadepth --method sp_open --num_workers 8
# run training with cached features
python -m gluefactory.train sp+lg_megadepth \
--conf gluefactory/configs/superpoint-open+lightglue_megadepth.yaml \
train.load_experiment=sp+lg_homography \
data.load_features.do=True
```
The model can then be evaluated using its experiment name:
```bash
python -m gluefactory.eval.megadepth1500 --checkpoint sp+lg_megadepth
```
You can also run all benchmarks after each training epoch with the option `--run_benchmarks`.
</details>
<details>
<summary>[Training GlueStick]</summary>
We first pre-train GlueStick on the homography dataset:
```bash
python -m gluefactory.train gluestick_H --conf gluefactory/configs/superpoint+lsd+gluestick-homography.yaml --distributed
```
Feel free to use any other experiment name. Configurations are managed by [OmegaConf](https://omegaconf.readthedocs.io/) so any entry can be overridden from the command line.
We then fine-tune the model on the MegaDepth dataset:
```bash
python -m gluefactory.train gluestick_MD --conf gluefactory/configs/superpoint+lsd+gluestick-megadepth.yaml --distributed
```
Note that we used the training splits `train_scenes.txt` and `valid_scenes.txt` to train the original model, which contains some overlap with the IMC challenge. The new default splits are now `train_scenes_clean.txt` and `valid_scenes_clean.txt`, without this overlap.
</details>
### Available models
Glue Factory supports training and evaluating the following deep matchers:
| Model | Training? | Evaluation? |
| --------- | --------- | ----------- |
| [LightGlue](https://github.com/cvg/LightGlue) | ✅ | ✅ |
| [GlueStick](https://github.com/cvg/GlueStick) | ✅ | ✅ |
| [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork) | ✅ | ✅ |
| [LoFTR](https://github.com/zju3dv/LoFTR) | ❌ | ✅ |
Using the following local feature extractors:
| Model | LightGlue config |
| --------- | --------- |
| [SuperPoint (open)](https://github.com/rpautrat/SuperPoint) | `superpoint-open+lightglue_{homography,megadepth}.yaml` |
| [SuperPoint (official)](https://github.com/magicleap/SuperPointPretrainedNetwork) | ❌ TODO |
| SIFT (via [pycolmap](https://github.com/colmap/pycolmap)) | `sift+lightglue_{homography,megadepth}.yaml` |
| [ALIKED](https://github.com/Shiaoming/ALIKED) | `aliked+lightglue_{homography,megadepth}.yaml` |
| [DISK](https://github.com/cvlab-epfl/disk) | ❌ TODO |
| Key.Net + HardNet | ❌ TODO |
## Coming soon
- [ ] More baselines (LoFTR, ASpanFormer, MatchFormer, SGMNet, DKM, RoMa)
- [ ] Training deep detectors and descriptors like SuperPoint
- [ ] IMC evaluations
- [ ] Better documentation
## BibTeX Citation
Please consider citing the following papers if you found this library useful:
```bibtex
@InProceedings{lindenberger_2023_lightglue,
title = {{LightGlue: Local Feature Matching at Light Speed}},
author = {Philipp Lindenberger and
Paul-Edouard Sarlin and
Marc Pollefeys},
booktitle = {International Conference on Computer Vision (ICCV)},
year = {2023}
}
```
```bibtex
@InProceedings{pautrat_suarez_2023_gluestick,
title = {{GlueStick: Robust Image Matching by Sticking Points and Lines Together}},
author = {R{\'e}mi Pautrat* and
Iago Su{\'a}rez* and
Yifan Yu and
Marc Pollefeys and
Viktor Larsson},
booktitle = {International Conference on Computer Vision (ICCV)},
year = {2023}
}
```

121
docs/evaluation.md Normal file
View File

@ -0,0 +1,121 @@
# Evaluation
Glue Factory is designed for simple and tight integration between training and evaluation.
All benchmarks are designed around one principle: only evaluate on cached results.
This enforces reproducible baselines.
Therefore, we first export model predictions for each dataset (`export`), and evaluate the cached results in a second pass (`evaluation`).
### Running an evaluation
We currently provide evaluation scripts for [MegaDepth-1500](../gluefactory/eval/megadepth1500.py), [HPatches](../gluefactory/eval/hpatches.py), and [ETH3D](../gluefactory/eval/eth3d.py).
You can run them with:
```bash
python -m gluefactory.eval.<benchmark_name> --conf "a name in gluefactory/configs/ or path" --checkpoint "and/or a checkpoint name"
```
Each evaluation run is assigned a `tag`, which can (optionally) be customized from the command line with `--tag <your_tag>`.
To overwrite an experiment, add `--overwrite`. To only overwrite the results of the evaluation loop, add `--overwrite_eval`. We perform config checks to warn the user about non-conforming configurations between runs.
The following files are written to `outputs/results/<benchmark_name>/<tag>`:
```yaml
conf.yaml # the config which was used
predictions.h5 # cached predictions
results.h5 # Results for each data point in eval, in the format <metric_name>: List[float]
summaries.json # Aggregated results for the entire dataset <agg_metric_name>: float
<plots> # some benchmarks add plots as png files here
```
Some datasets further output plots (add `--plot` to the command line).
<details>
<summary>[Configuration]</summary>
Each evaluation has 3 main configurations:
```yaml
data:
... # How to load the data. The user can overwrite this only during "export". The defaults are used in "evaluation".
model:
... # model configuration: this is only required for "export".
eval:
... # configuration for the "evaluation" loop, e.g. pose estimators and ransac thresholds.
```
The default configurations can be found in the respective evaluation scripts, e.g. [MegaDepth1500](../gluefactory/eval/megadepth1500.py).
To run an evaluation with a custom config, we expect them to be in the following format ([example](../gluefactory/configs/superpoint+lightglue.yaml)):
```yaml
model:
... # <your model configs>
benchmarks:
<benchmark_name1>:
data:
... # <your data configs for "export">
model:
... # <your benchmark-specific model configs>
eval:
... # <your evaluation configs, e.g. pose estimators>
<benchmark_name2>:
... # <same structure as above>
```
The configs are then merged in the following order (taking megadepth1500 as an example):
```yaml
data:
default < custom.benchmarks.megadepth1500.data
model:
default < custom.model < custom.benchmarks.megadepth1500.model
eval:
default < custom.benchmarks.megadepth1500.eval
```
You can then use the command line to further customize this configuration.
</details>
### Robust estimators
Gluefactory offers a flexible interface to state-of-the-art [robust estimators](../gluefactory/robust_estimators/) for points and lines.
You can configure the estimator in the benchmarks with the following config structure:
```yaml
eval:
estimator: <estimator_name> # poselib, opencv, pycolmap, ...
ransac_th: 0.5 # run evaluation on fixed threshold
#or
ransac_th: [0.5, 1.0, 1.5] # test on multiple thresholds, autoselect best
<extra configs for the estimator, e.g. max iters, ...>
```
For convenience, most benchmarks convert `eval.ransac_th=-1` to a default range of thresholds.
> [!NOTE]
> Gluefactory follows the corner convention of COLMAP, i.e. the top-left corner of the top-left pixel is (0, 0).
### Visualization
We provide a powerful, interactive visualization tool for our benchmarks, based on matplotlib.
You can run the visualization (after running the evaluations) with:
```bash
python -m gluefactory.eval.inspect <benchmark_name> <experiment_name1> <experiment_name2> ...
```
This prints the summaries of each experiment on the respective benchmark and visualizes the data as a scatter plot, where each point is the result of from a experiment on a specific data point in the dataset.
<details>
- Clicking on one of the data points opens a new frame showing the prediction on this specific data point for all experiments listed.
- You can customize the x / y axis from the navigation bar or by clicking `x` or `y`.
- Hiting `diff_only` computes the difference between `<experiment_name1>` and all other experiments.
- Hovering over a point shows lines to the results of other experiments on the same data.
- You can switch the visualization (matches, keypoints, ...) from the navigation bar or by clicking `shift+r`.
- Clicking `t` prints a summary of the eval on this data point.
- Hitting the `left` or `right` arrows circles between data points. `shift+left` opens an extra window.
When working on a remote machine (e.g. over ssh), the plots can be forwarded to the browser with the option `--backend webagg`. Note that you need to refresh the page everytime you load a new figure (e.g. when clicking on a scatter point). This part requires some more work, and we would highly appreciate any contributions!
</details>

2042
docs/gluestick_img.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 375 KiB

2777
docs/lightglue_matches.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 603 KiB

16
gluefactory/__init__.py Normal file
View File

@ -0,0 +1,16 @@
import logging
from .utils.experiments import load_experiment # noqa: F401
formatter = logging.Formatter(
fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
handler.setLevel(logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(handler)
logger.propagate = False
__module_name__ = __name__

View File

@ -0,0 +1,24 @@
model:
name: two_view_pipeline
extractor:
name: extractors.aliked
max_num_keypoints: 2048
detection_threshold: 0.0
matcher:
name: matchers.nearest_neighbor_matcher
benchmarks:
megadepth1500:
data:
preprocessing:
side: long
resize: 1600
eval:
estimator: opencv
ransac_th: 0.5
hpatches:
eval:
estimator: opencv
ransac_th: 0.5
model:
extractor:
max_num_keypoints: 1024 # overwrite config above

View File

@ -0,0 +1,50 @@
data:
name: homographies
data_dir: revisitop1m
train_size: 150000
val_size: 2000
batch_size: 128
num_workers: 14
homography:
difficulty: 0.7
max_angle: 45
photometric:
name: lg
model:
name: two_view_pipeline
extractor:
name: extractors.aliked
max_num_keypoints: 512
detection_threshold: 0.0
trainable: False
detector:
name: null
descriptor:
name: null
ground_truth:
name: matchers.homography_matcher
th_positive: 3
th_negative: 3
matcher:
name: matchers.lightglue
filter_threshold: 0.1
flash: false
checkpointed: true
input_dim: 128
train:
seed: 0
epochs: 40
log_every_iter: 100
eval_every_iter: 500
lr: 1e-4
lr_schedule:
start: 20
type: exp
on_epoch: true
exp_div_10: 10
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
benchmarks:
hpatches:
eval:
estimator: opencv
ransac_th: 0.5

View File

@ -0,0 +1,70 @@
data:
name: megadepth
preprocessing:
resize: 1024
side: long
square_pad: True
train_split: train_scenes_clean.txt
train_num_per_scene: 300
val_split: valid_scenes_clean.txt
val_pairs: valid_pairs.txt
min_overlap: 0.1
max_overlap: 0.7
num_overlap_bins: 3
read_depth: true
read_image: true
batch_size: 32
num_workers: 14
load_features:
do: false # enable this if you have cached predictions
path: exports/megadepth-undist-depth-r1024_ALIKED-k2048-n16/{scene}.h5
padding_length: 2048
padding_fn: pad_local_features
model:
name: two_view_pipeline
extractor:
name: extractors.aliked
max_num_keypoints: 2048
detection_threshold: 0.0
trainable: False
matcher:
name: matchers.lightglue
filter_threshold: 0.1
flash: false
checkpointed: true
input_dim: 128
ground_truth:
name: matchers.depth_matcher
th_positive: 3
th_negative: 5
th_epi: 5
allow_no_extract: True
train:
seed: 0
epochs: 50
log_every_iter: 100
eval_every_iter: 1000
lr: 1e-4
lr_schedule:
start: 30
type: exp
on_epoch: true
exp_div_10: 10
dataset_callback_fn: sample_new_items
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
benchmarks:
megadepth1500:
data:
preprocessing:
side: long
resize: 1600
eval:
estimator: opencv
ransac_th: 0.5
hpatches:
eval:
estimator: opencv
ransac_th: 0.5
model:
extractor:
max_num_keypoints: 1024

View File

@ -0,0 +1,24 @@
model:
name: two_view_pipeline
extractor:
name: extractors.disk_kornia
max_num_keypoints: 2048
detection_threshold: 0.0
matcher:
name: matchers.nearest_neighbor_matcher
benchmarks:
megadepth1500:
data:
preprocessing:
side: long
resize: 1600
eval:
estimator: opencv
ransac_th: 0.5
hpatches:
eval:
estimator: opencv
ransac_th: 0.5
model:
extractor:
max_num_keypoints: 1024 # overwrite config above

View File

@ -0,0 +1,28 @@
model:
name: two_view_pipeline
extractor:
name: extractors.disk_kornia
max_num_keypoints: 2048
detection_threshold: 0.0
matcher:
name: matchers.lightglue_pretrained
features: disk
depth_confidence: -1
width_confidence: -1
filter_threshold: 0.1
benchmarks:
megadepth1500:
data:
preprocessing:
side: long
resize: 1600
eval:
estimator: opencv
ransac_th: 0.5
hpatches:
eval:
estimator: opencv
ransac_th: 0.5
model:
extractor:
max_num_keypoints: 1024 # overwrite config above

View File

@ -0,0 +1,28 @@
model:
name: two_view_pipeline
extractor:
name: extractors.sift
detector: pycolmap_cuda
max_num_keypoints: 2048
detection_threshold: 0.00666666
nms_radius: -1
pycolmap_options:
first_octave: -1
matcher:
name: matchers.nearest_neighbor_matcher
benchmarks:
megadepth1500:
data:
preprocessing:
side: long
resize: 1600
eval:
estimator: opencv
ransac_th: 0.5
hpatches:
eval:
estimator: opencv
ransac_th: 0.5
model:
extractor:
max_num_keypoints: 1024 # overwrite config above

View File

@ -0,0 +1,48 @@
data:
name: homographies
data_dir: revisitop1m
train_size: 150000
val_size: 2000
batch_size: 64
num_workers: 14
homography:
difficulty: 0.7
max_angle: 45
photometric:
name: lg
model:
name: two_view_pipeline
extractor:
name: extractors.sift
detector: pycolmap_cuda
max_num_keypoints: 1024
force_num_keypoints: True
detection_threshold: 0.0001
trainable: False
ground_truth:
name: matchers.homography_matcher
th_positive: 3
th_negative: 3
matcher:
name: matchers.lightglue
filter_threshold: 0.1
flash: false
checkpointed: true
input_dim: 128
train:
seed: 0
epochs: 40
log_every_iter: 100
eval_every_iter: 500
lr: 1e-4
lr_schedule:
start: 20
type: exp
on_epoch: true
exp_div_10: 10
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
benchmarks:
hpatches:
eval:
estimator: opencv
ransac_th: 0.5

View File

@ -0,0 +1,74 @@
data:
name: megadepth
preprocessing:
resize: 1024
side: long
square_pad: True
train_split: train_scenes_clean.txt
train_num_per_scene: 300
val_split: valid_scenes_clean.txt
val_pairs: valid_pairs.txt
min_overlap: 0.1
max_overlap: 0.7
num_overlap_bins: 3
read_depth: true
read_image: true
batch_size: 32
num_workers: 14
load_features:
do: false # enable this if you have cached predictions
path: exports/megadepth-undist-depth-r1024_pycolmap_SIFTGPU-nms3-fixed-k2048/{scene}.h5
padding_length: 2048
padding_fn: pad_local_features
data_keys: ["keypoints", "keypoint_scores", "descriptors", "oris", "scales"]
model:
name: two_view_pipeline
extractor:
name: extractors.sift
detector: pycolmap_cuda
max_num_keypoints: 2048
force_num_keypoints: True
detection_threshold: 0.0001
trainable: False
matcher:
name: matchers.lightglue
filter_threshold: 0.1
flash: false
checkpointed: true
add_scale_ori: true
input_dim: 128
ground_truth:
name: matchers.depth_matcher
th_positive: 3
th_negative: 5
th_epi: 5
allow_no_extract: True
train:
seed: 0
epochs: 50
log_every_iter: 100
eval_every_iter: 1000
lr: 1e-4
lr_schedule:
start: 30
type: exp
on_epoch: true
exp_div_10: 10
dataset_callback_fn: sample_new_items
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
benchmarks:
megadepth1500:
data:
preprocessing:
side: long
resize: 1600
eval:
estimator: opencv
ransac_th: 0.5
hpatches:
eval:
estimator: opencv
ransac_th: 0.5
model:
extractor:
max_num_keypoints: 1024

View File

@ -0,0 +1,25 @@
model:
name: two_view_pipeline
extractor:
name: gluefactory_nonfree.superpoint
max_num_keypoints: 2048
detection_threshold: 0.0
nms_radius: 3
matcher:
name: matchers.nearest_neighbor_matcher
benchmarks:
megadepth1500:
data:
preprocessing:
side: long
resize: 1600
eval:
estimator: opencv
ransac_th: 1.0
hpatches:
eval:
estimator: opencv
ransac_th: 0.5
model:
extractor:
max_num_keypoints: 1024 # overwrite config above

View File

@ -0,0 +1,29 @@
model:
name: two_view_pipeline
extractor:
name: gluefactory_nonfree.superpoint
max_num_keypoints: 2048
detection_threshold: 0.0
nms_radius: 3
matcher:
name: matchers.lightglue_pretrained
features: superpoint
depth_confidence: -1
width_confidence: -1
filter_threshold: 0.1
benchmarks:
megadepth1500:
data:
preprocessing:
side: long
resize: 1600
eval:
estimator: opencv
ransac_th: 0.5
hpatches:
eval:
estimator: opencv
ransac_th: 0.5
model:
extractor:
max_num_keypoints: 1024 # overwrite config above

View File

@ -0,0 +1,73 @@
data:
name: homographies
homography:
difficulty: 0.5
max_angle: 30
patch_shape: [640, 480]
photometric:
p: 0.75
train_size: 900000
val_size: 1000
batch_size: 80 # 20 per 10GB of GPU mem (12 for triplet)
num_workers: 15
model:
name: gluefactory.models.two_view_pipeline
extractor:
name: gluefactory.models.lines.wireframe
trainable: False
point_extractor:
name: gluefactory.models.extractors.superpoint_open
# name: disk
# chunk: 10
max_num_keypoints: 1000
force_num_keypoints: true
trainable: False
line_extractor:
name: gluefactory.models.lines.lsd
max_num_lines: 250
force_num_lines: True
min_length: 15
trainable: False
wireframe_params:
merge_points: True
merge_line_endpoints: True
nms_radius: 4
detector:
name: null
descriptor:
name: null
ground_truth:
name: gluefactory.models.matchers.homography_matcher
trainable: False
use_points: True
use_lines: True
th_positive: 3
th_negative: 5
matcher:
name: gluefactory.models.matchers.gluestick
input_dim: 256 # 128 for DISK
descriptor_dim: 256 # 128 for DISK
inter_supervision: [2, 5]
GNN_layers: [
self, cross, self, cross, self, cross,
self, cross, self, cross, self, cross,
self, cross, self, cross, self, cross,
]
checkpointed: true
train:
seed: 0
epochs: 200
log_every_iter: 400
eval_every_iter: 700
save_every_iter: 1400
lr: 1e-4
lr_schedule:
type: exp # exp or multi_step
start: 200e3
exp_div_10: 200e3
gamma: 0.5
step: 50e3
n_steps: 4
submodules: []
# clip_grad: 10 # Use only with mixed precision
# load_experiment:

View File

@ -0,0 +1,69 @@
data:
name: gluefactory.datasets.megadepth
views: 2
preprocessing:
resize: 640
square_pad: True
batch_size: 60
num_workers: 15
model:
name: gluefactory.models.two_view_pipeline
extractor:
name: gluefactory.models.lines.wireframe
trainable: False
point_extractor:
name: gluefactory.models.extractors.superpoint_open
# name: disk
# chunk: 10
max_num_keypoints: 1000
force_num_keypoints: true
trainable: False
line_extractor:
name: gluefactory.models.lines.lsd
max_num_lines: 250
force_num_lines: True
min_length: 15
trainable: False
wireframe_params:
merge_points: True
merge_line_endpoints: True
nms_radius: 4
detector:
name: null
descriptor:
name: null
ground_truth:
name: gluefactory.models.matchers.depth_matcher
trainable: False
use_points: True
use_lines: True
th_positive: 3
th_negative: 5
matcher:
name: gluefactory.models.matchers.gluestick
input_dim: 256 # 128 for DISK
descriptor_dim: 256 # 128 for DISK
inter_supervision: null
GNN_layers: [
self, cross, self, cross, self, cross,
self, cross, self, cross, self, cross,
self, cross, self, cross, self, cross,
]
checkpointed: true
train:
seed: 0
epochs: 200
log_every_iter: 10
eval_every_iter: 100
save_every_iter: 500
lr: 1e-4
lr_schedule:
type: exp # exp or multi_step
start: 200e3
exp_div_10: 200e3
gamma: 0.5
step: 50e3
n_steps: 4
submodules: []
# clip_grad: 10 # Use only with mixed precision
load_experiment: gluestick_H

View File

@ -0,0 +1,49 @@
model:
name: gluefactory.models.two_view_pipeline
extractor:
name: gluefactory.models.lines.wireframe
point_extractor:
name: gluefactory_nonfree.superpoint
trainable: False
dense_outputs: True
max_num_keypoints: 2048
force_num_keypoints: False
detection_threshold: 0
line_extractor:
name: gluefactory.models.lines.lsd
trainable: False
max_num_lines: 512
force_num_lines: False
min_length: 15
wireframe_params:
merge_points: True
merge_line_endpoints: True
nms_radius: 3
matcher:
name: gluefactory.models.matchers.gluestick
weights: checkpoint_GlueStick_MD # This will download weights from internet
# ground_truth: # for ETH3D, comment otherwise
# name: gluefactory.models.matchers.depth_matcher
# use_lines: True
benchmarks:
hpatches:
eval:
estimator: homography_est
ransac_th: -1 # [1., 1.5, 2., 2.5, 3.]
megadepth1500:
data:
preprocessing:
side: long
resize: 1600
eval:
estimator: poselib
ransac_th: -1
eth3d:
ground_truth:
name: gluefactory.models.matchers.depth_matcher
use_lines: True
eval:
plot_methods: [ ] # ['sp+NN', 'sp+sg', 'superpoint+lsd+gluestick']
plot_line_methods: [ ] # ['superpoint+lsd+gluestick', 'sp+deeplsd+gs']

View File

@ -0,0 +1,26 @@
model:
name: two_view_pipeline
extractor:
name: gluefactory_nonfree.superpoint
max_num_keypoints: 2048
detection_threshold: 0.0
nms_radius: 3
matcher:
name: gluefactory_nonfree.superglue
benchmarks:
megadepth1500:
data:
preprocessing:
side: long
resize: 1600
eval:
estimator: opencv
ransac_th: 0.5
hpatches:
eval:
estimator: opencv
ransac_th: 0.5
model:
extractor:
max_num_keypoints: 1024 # overwrite config above

View File

@ -0,0 +1,25 @@
model:
name: two_view_pipeline
extractor:
name: extractors.superpoint_open
max_num_keypoints: 2048
detection_threshold: 0.0
nms_radius: 3
matcher:
name: matchers.nearest_neighbor_matcher
benchmarks:
megadepth1500:
data:
preprocessing:
side: long
resize: 1600
eval:
estimator: opencv
ransac_th: 1.0
hpatches:
eval:
estimator: opencv
ransac_th: 0.5
model:
extractor:
max_num_keypoints: 1024 # overwrite config above

View File

@ -0,0 +1,47 @@
data:
name: homographies
data_dir: revisitop1m
train_size: 150000
val_size: 2000
batch_size: 128
num_workers: 14
homography:
difficulty: 0.7
max_angle: 45
photometric:
name: lg
model:
name: two_view_pipeline
extractor:
name: extractors.superpoint_open
max_num_keypoints: 512
force_num_keypoints: True
detection_threshold: -1
nms_radius: 3
trainable: False
ground_truth:
name: matchers.homography_matcher
th_positive: 3
th_negative: 3
matcher:
name: matchers.lightglue
filter_threshold: 0.1
flash: false
checkpointed: true
train:
seed: 0
epochs: 40
log_every_iter: 100
eval_every_iter: 500
lr: 1e-4
lr_schedule:
start: 20
type: exp
on_epoch: true
exp_div_10: 10
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
benchmarks:
hpatches:
eval:
estimator: opencv
ransac_th: 0.5

View File

@ -0,0 +1,71 @@
data:
name: megadepth
preprocessing:
resize: 1024
side: long
square_pad: True
train_split: train_scenes_clean.txt
train_num_per_scene: 300
val_split: valid_scenes_clean.txt
val_pairs: valid_pairs.txt
min_overlap: 0.1
max_overlap: 0.7
num_overlap_bins: 3
read_depth: true
read_image: true
batch_size: 32
num_workers: 14
load_features:
do: false # enable this if you have cached predictions
path: exports/megadepth-undist-depth-r1024_SP-open-k2048-nms3/{scene}.h5
padding_length: 2048
padding_fn: pad_local_features
model:
name: two_view_pipeline
extractor:
name: extractors.superpoint_open
max_num_keypoints: 2048
force_num_keypoints: True
detection_threshold: -1
nms_radius: 3
trainable: False
matcher:
name: matchers.lightglue
filter_threshold: 0.1
flash: false
checkpointed: true
ground_truth:
name: matchers.depth_matcher
th_positive: 3
th_negative: 5
th_epi: 5
allow_no_extract: True
train:
seed: 0
epochs: 50
log_every_iter: 100
eval_every_iter: 1000
lr: 1e-4
lr_schedule:
start: 30
type: exp
on_epoch: true
exp_div_10: 10
dataset_callback_fn: sample_new_items
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
benchmarks:
megadepth1500:
data:
preprocessing:
side: long
resize: 1600
eval:
estimator: opencv
ransac_th: 0.5
hpatches:
eval:
estimator: opencv
ransac_th: 0.5
model:
extractor:
max_num_keypoints: 1024

View File

@ -0,0 +1,24 @@
import importlib.util
from .base_dataset import BaseDataset
from ..utils.tools import get_class
def get_dataset(name):
import_paths = [name, f"{__name__}.{name}"]
for path in import_paths:
try:
spec = importlib.util.find_spec(path)
except ModuleNotFoundError:
spec = None
if spec is not None:
try:
return get_class(path, BaseDataset)
except AssertionError:
mod = __import__(path, fromlist=[""])
try:
return mod.__main_dataset__
except AttributeError as exc:
print(exc)
continue
raise RuntimeError(f'Dataset {name} not found in any of [{" ".join(import_paths)}]')

View File

@ -0,0 +1,244 @@
from typing import Union
import albumentations as A
import numpy as np
import torch
from albumentations.pytorch.transforms import ToTensorV2
from omegaconf import OmegaConf
import cv2
class IdentityTransform(A.ImageOnlyTransform):
def apply(self, img, **params):
return img
def get_transform_init_args_names(self):
return ()
class RandomAdditiveShade(A.ImageOnlyTransform):
def __init__(
self,
nb_ellipses=10,
transparency_limit=[-0.5, 0.8],
kernel_size_limit=[150, 350],
always_apply=False,
p=0.5,
):
super().__init__(always_apply, p)
self.nb_ellipses = nb_ellipses
self.transparency_limit = transparency_limit
self.kernel_size_limit = kernel_size_limit
def apply(self, img, **params):
if img.dtype == np.float32:
shaded = self._py_additive_shade(img * 255.0)
shaded /= 255.0
elif img.dtype == np.uint8:
shaded = self._py_additive_shade(img.astype(np.float32))
shaded = shaded.astype(np.uint8)
else:
raise NotImplementedError(
f"Data augmentation not available for type: {img.dtype}"
)
return shaded
def _py_additive_shade(self, img):
grayscale = len(img.shape) == 2
if grayscale:
img = img[None]
min_dim = min(img.shape[:2]) / 4
mask = np.zeros(img.shape[:2], img.dtype)
for i in range(self.nb_ellipses):
ax = int(max(np.random.rand() * min_dim, min_dim / 5))
ay = int(max(np.random.rand() * min_dim, min_dim / 5))
max_rad = max(ax, ay)
x = np.random.randint(max_rad, img.shape[1] - max_rad) # center
y = np.random.randint(max_rad, img.shape[0] - max_rad)
angle = np.random.rand() * 90
cv2.ellipse(mask, (x, y), (ax, ay), angle, 0, 360, 255, -1)
transparency = np.random.uniform(*self.transparency_limit)
ks = np.random.randint(*self.kernel_size_limit)
if (ks % 2) == 0: # kernel_size has to be odd
ks += 1
mask = cv2.GaussianBlur(mask.astype(np.float32), (ks, ks), 0)
shaded = img * (1 - transparency * mask[..., np.newaxis] / 255.0)
out = np.clip(shaded, 0, 255)
if grayscale:
out = out.squeeze(0)
return out
def get_transform_init_args_names(self):
return "transparency_limit", "kernel_size_limit", "nb_ellipses"
def kw(entry: Union[float, dict], n=None, **default):
if not isinstance(entry, dict):
entry = {"p": entry}
entry = OmegaConf.create(entry)
if n is not None:
entry = default.get(n, entry)
return OmegaConf.merge(default, entry)
def kwi(entry: Union[float, dict], n=None, **default):
conf = kw(entry, n=n, **default)
return {k: conf[k] for k in set(default.keys()).union(set(["p"]))}
def replay_str(transforms, s="Replay:\n", log_inactive=True):
for t in transforms:
if "transforms" in t.keys():
s = replay_str(t["transforms"], s=s)
elif t["applied"] or log_inactive:
s += t["__class_fullname__"] + " " + str(t["applied"]) + "\n"
return s
class BaseAugmentation(object):
base_default_conf = {
"name": "???",
"shuffle": False,
"p": 1.0,
"verbose": False,
"dtype": "uint8", # (byte, float)
}
default_conf = {}
def __init__(self, conf={}):
"""Perform some logic and call the _init method of the child model."""
default_conf = OmegaConf.merge(
OmegaConf.create(self.base_default_conf),
OmegaConf.create(self.default_conf),
)
OmegaConf.set_struct(default_conf, True)
if isinstance(conf, dict):
conf = OmegaConf.create(conf)
self.conf = OmegaConf.merge(default_conf, conf)
OmegaConf.set_readonly(self.conf, True)
self._init(self.conf)
self.conf = OmegaConf.merge(self.conf, conf)
if self.conf.verbose:
self.compose = A.ReplayCompose
else:
self.compose = A.Compose
if self.conf.dtype == "uint8":
self.dtype = np.uint8
self.preprocess = A.FromFloat(always_apply=True, dtype="uint8")
self.postprocess = A.ToFloat(always_apply=True)
elif self.conf.dtype == "float32":
self.dtype = np.float32
self.preprocess = A.ToFloat(always_apply=True)
self.postprocess = IdentityTransform()
else:
raise ValueError(f"Unsupported dtype {self.conf.dtype}")
self.to_tensor = ToTensorV2()
def _init(self, conf):
"""Child class overwrites this, setting up a list of transforms"""
self.transforms = []
def __call__(self, image, return_tensor=False):
"""image as HW or HWC"""
if isinstance(image, torch.Tensor):
image = image.cpu().detach().numpy()
data = {"image": image}
if image.dtype != self.dtype:
data = self.preprocess(**data)
transforms = self.transforms
if self.conf.shuffle:
order = [i for i, _ in enumerate(transforms)]
np.random.shuffle(order)
transforms = [transforms[i] for i in order]
transformed = self.compose(transforms, p=self.conf.p)(**data)
if self.conf.verbose:
print(replay_str(transformed["replay"]["transforms"]))
transformed = self.postprocess(**transformed)
if return_tensor:
return self.to_tensor(**transformed)["image"]
else:
return transformed["image"]
class IdentityAugmentation(BaseAugmentation):
default_conf = {}
def _init(self, conf):
self.transforms = [IdentityTransform(p=1.0)]
class DarkAugmentation(BaseAugmentation):
default_conf = {"p": 0.75}
def _init(self, conf):
bright_contr = 0.5
blur = 0.1
random_gamma = 0.1
hue = 0.1
self.transforms = [
A.RandomRain(p=0.2),
A.RandomBrightnessContrast(
**kw(
bright_contr,
brightness_limit=(-0.4, 0.0),
contrast_limit=(-0.3, 0.0),
)
),
A.OneOf(
[
A.Blur(**kwi(blur, p=0.1, blur_limit=(3, 9), n="blur")),
A.MotionBlur(
**kwi(blur, p=0.2, blur_limit=(3, 25), n="motion_blur")
),
A.ISONoise(),
A.ImageCompression(),
],
**kwi(blur, p=0.1),
),
A.RandomGamma(**kw(random_gamma, gamma_limit=(15, 65))),
A.OneOf(
[
A.Equalize(),
A.CLAHE(p=0.2),
A.ToGray(),
A.ToSepia(p=0.1),
A.HueSaturationValue(**kw(hue, val_shift_limit=(-100, -40))),
],
p=0.5,
),
]
class LGAugmentation(BaseAugmentation):
default_conf = {"p": 0.95}
def _init(self, conf):
self.transforms = [
A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)),
A.OneOf(
[
A.Blur(blur_limit=(3, 9)),
A.MotionBlur(blur_limit=(3, 25)),
A.ISONoise(),
A.ImageCompression(),
],
p=0.1,
),
A.Blur(p=0.1, blur_limit=(3, 9)),
A.MotionBlur(p=0.1, blur_limit=(3, 25)),
A.RandomBrightnessContrast(
p=0.5, brightness_limit=(-0.4, 0.0), contrast_limit=(-0.3, 0.0)
),
A.CLAHE(p=0.2),
]
augmentations = {
"dark": DarkAugmentation,
"lg": LGAugmentation,
"identity": IdentityAugmentation,
}

View File

@ -0,0 +1,205 @@
"""
Base class for dataset.
See mnist.py for an example of dataset.
"""
from abc import ABCMeta, abstractmethod
import collections
import logging
from omegaconf import OmegaConf
import omegaconf
import torch
from torch.utils.data import DataLoader, Sampler, get_worker_info
from torch.utils.data._utils.collate import (
default_collate_err_msg_format,
np_str_obj_array_pattern,
)
from ..utils.tensor import string_classes
from ..utils.tools import set_num_threads, set_seed
logger = logging.getLogger(__name__)
class LoopSampler(Sampler):
def __init__(self, loop_size, total_size=None):
self.loop_size = loop_size
self.total_size = total_size - (total_size % loop_size)
def __iter__(self):
return (i % self.loop_size for i in range(self.total_size))
def __len__(self):
return self.total_size
def worker_init_fn(i):
info = get_worker_info()
if hasattr(info.dataset, "conf"):
conf = info.dataset.conf
set_seed(info.id + conf.seed)
set_num_threads(conf.num_threads)
else:
set_num_threads(1)
def collate(batch):
"""Difference with PyTorch default_collate: it can stack of other objects."""
if not isinstance(batch, list): # no batching
return batch
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
try:
storage = elem.untyped_storage()._new_shared(numel) # noqa: F841
except AttributeError:
storage = elem.storage()._new_shared(numel) # noqa: F841
return torch.stack(batch, dim=0)
elif (
elem_type.__module__ == "numpy"
and elem_type.__name__ != "str_"
and elem_type.__name__ != "string_"
):
if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
return elem_type(*(collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError("each element in list of batch should be of equal size")
transposed = zip(*batch)
return [collate(samples) for samples in transposed]
elif elem is None:
return elem
else:
# try to stack anyway in case the object implements stacking.
return torch.stack(batch, 0)
class BaseDataset(metaclass=ABCMeta):
"""
What the dataset model is expect to declare:
default_conf: dictionary of the default configuration of the dataset.
It overwrites base_default_conf in BaseModel, and it is overwritten by
the user-provided configuration passed to __init__.
Configurations can be nested.
_init(self, conf): initialization method, where conf is the final
configuration object (also accessible with `self.conf`). Accessing
unknown configuration entries will raise an error.
get_dataset(self, split): method that returns an instance of
torch.utils.data.Dataset corresponding to the requested split string,
which can be `'train'`, `'val'`, or `'test'`.
"""
base_default_conf = {
"name": "???",
"num_workers": "???",
"train_batch_size": "???",
"val_batch_size": "???",
"test_batch_size": "???",
"shuffle_training": True,
"batch_size": 1,
"num_threads": 1,
"seed": 0,
"prefetch_factor": 2,
}
default_conf = {}
def __init__(self, conf):
"""Perform some logic and call the _init method of the child model."""
default_conf = OmegaConf.merge(
OmegaConf.create(self.base_default_conf),
OmegaConf.create(self.default_conf),
)
OmegaConf.set_struct(default_conf, True)
if isinstance(conf, dict):
conf = OmegaConf.create(conf)
self.conf = OmegaConf.merge(default_conf, conf)
OmegaConf.set_readonly(self.conf, True)
logger.info(f"Creating dataset {self.__class__.__name__}")
self._init(self.conf)
@abstractmethod
def _init(self, conf):
"""To be implemented by the child class."""
raise NotImplementedError
@abstractmethod
def get_dataset(self, split):
"""To be implemented by the child class."""
raise NotImplementedError
def get_data_loader(self, split, shuffle=None, pinned=False, distributed=False):
"""Return a data loader for a given split."""
assert split in ["train", "val", "test"]
dataset = self.get_dataset(split)
try:
batch_size = self.conf[split + "_batch_size"]
except omegaconf.MissingMandatoryValue:
batch_size = self.conf.batch_size
num_workers = self.conf.get("num_workers", batch_size)
if distributed:
shuffle = False
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
sampler = None
if shuffle is None:
shuffle = split == "train" and self.conf.shuffle_training
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
pin_memory=pinned,
collate_fn=collate,
num_workers=num_workers,
worker_init_fn=worker_init_fn,
prefetch_factor=self.conf.prefetch_factor,
drop_last=True if split == "train" else False,
)
def get_overfit_loader(self, split):
"""Return an overfit data loader.
The training set is composed of a single duplicated batch, while
the validation and test sets contain a single copy of this same batch.
This is useful to debug a model and make sure that losses and metrics
correlate well.
"""
assert split in ["train", "val", "test"]
dataset = self.get_dataset("train")
sampler = LoopSampler(
self.conf.batch_size,
len(dataset) if split == "train" else self.conf.batch_size,
)
num_workers = self.conf.get("num_workers", self.conf.batch_size)
return DataLoader(
dataset,
batch_size=self.conf.batch_size,
pin_memory=True,
num_workers=num_workers,
sampler=sampler,
worker_init_fn=worker_init_fn,
collate_fn=collate,
)

View File

@ -0,0 +1,254 @@
"""
ETH3D multi-view benchmark, used for line matching evaluation.
"""
import logging
import os
import shutil
import numpy as np
import cv2
import torch
from pathlib import Path
import zipfile
from .base_dataset import BaseDataset
from .utils import scale_intrinsics
from ..geometry.wrappers import Camera, Pose
from ..settings import DATA_PATH
from ..utils.image import ImagePreprocessor, load_image
logger = logging.getLogger(__name__)
def read_cameras(camera_file, scale_factor=None):
"""Read the camera intrinsics from a file in COLMAP format."""
with open(camera_file, "r") as f:
raw_cameras = f.read().rstrip().split("\n")
raw_cameras = raw_cameras[3:]
cameras = []
for c in raw_cameras:
data = c.split(" ")
fx, fy, cx, cy = np.array(list(map(float, data[4:])))
K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=np.float32)
if scale_factor is not None:
K = scale_intrinsics(K, np.array([scale_factor, scale_factor]))
cameras.append(Camera.from_calibration_matrix(K).float())
return cameras
def qvec2rotmat(qvec):
"""Convert from quaternions to rotation matrix."""
return np.array(
[
[
1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
],
[
2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
],
[
2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
],
]
)
class ETH3DDataset(BaseDataset):
default_conf = {
"data_dir": "ETH3D_undistorted",
"grayscale": True,
"downsize_factor": 8,
"min_covisibility": 500,
"batch_size": 1,
"two_view": True,
"min_overlap": 0.5,
"max_overlap": 1.0,
"sort_by_overlap": False,
"seed": 0,
}
def _init(self, conf):
self.grayscale = conf.grayscale
self.downsize_factor = conf.downsize_factor
# Set random seeds
np.random.seed(conf.seed)
torch.manual_seed(conf.seed)
# Auto-download the dataset
if not (DATA_PATH / conf.data_dir).exists():
logger.info("Downloading the ETH3D dataset...")
self.download_eth3d()
# Form pairs of images from the multiview dataset
self.img_dir = DATA_PATH / conf.data_dir
self.data = []
for folder in self.img_dir.iterdir():
img_folder = Path(folder, "images", "dslr_images_undistorted")
depth_folder = Path(folder, "ground_truth_depth/undistorted_depth")
depth_ext = ".png"
names = [img.name for img in img_folder.iterdir()]
names.sort()
# Read intrinsics and extrinsics data
cameras = read_cameras(
str(Path(folder, "dslr_calibration_undistorted", "cameras.txt")),
1 / self.downsize_factor,
)
name_to_cam_idx = {name: {} for name in names}
with open(
str(Path(folder, "dslr_calibration_jpg", "images.txt")), "r"
) as f:
raw_data = f.read().rstrip().split("\n")[4::2]
for raw_line in raw_data:
line = raw_line.split(" ")
img_name = os.path.basename(line[-1])
name_to_cam_idx[img_name]["dist_camera_idx"] = int(line[-2])
T_world_to_camera = {}
image_visible_points3D = {}
with open(
str(Path(folder, "dslr_calibration_undistorted", "images.txt")), "r"
) as f:
lines = f.readlines()[4:] # Skip the header
raw_poses = [line.strip("\n").split(" ") for line in lines[::2]]
raw_points = [line.strip("\n").split(" ") for line in lines[1::2]]
for raw_pose, raw_pts in zip(raw_poses, raw_points):
img_name = os.path.basename(raw_pose[-1])
# Extract the transform from world to camera
target_extrinsics = list(map(float, raw_pose[1:8]))
pose = np.eye(4, dtype=np.float32)
pose[:3, :3] = qvec2rotmat(target_extrinsics[:4])
pose[:3, 3] = target_extrinsics[4:]
T_world_to_camera[img_name] = pose
name_to_cam_idx[img_name]["undist_camera_idx"] = int(raw_pose[-2])
# Extract the visible 3D points
point3D_ids = [id for id in map(int, raw_pts[2::3]) if id != -1]
image_visible_points3D[img_name] = set(point3D_ids)
# Extract the covisibility of each image
num_imgs = len(names)
n_covisible_points = np.zeros((num_imgs, num_imgs))
for i in range(num_imgs - 1):
for j in range(i + 1, num_imgs):
visible_points3D1 = image_visible_points3D[names[i]]
visible_points3D2 = image_visible_points3D[names[j]]
n_covisible_points[i, j] = len(
visible_points3D1 & visible_points3D2
)
# Keep only the pairs with enough covisibility
valid_pairs = np.where(n_covisible_points >= conf.min_covisibility)
valid_pairs = np.stack(valid_pairs, axis=1)
self.data += [
{
"view0": {
"name": names[i][:-4],
"img_path": str(Path(img_folder, names[i])),
"depth_path": str(Path(depth_folder, names[i][:-4]))
+ depth_ext,
"camera": cameras[name_to_cam_idx[names[i]]["dist_camera_idx"]],
"T_w2cam": Pose.from_4x4mat(T_world_to_camera[names[i]]),
},
"view1": {
"name": names[j][:-4],
"img_path": str(Path(img_folder, names[j])),
"depth_path": str(Path(depth_folder, names[j][:-4]))
+ depth_ext,
"camera": cameras[name_to_cam_idx[names[j]]["dist_camera_idx"]],
"T_w2cam": Pose.from_4x4mat(T_world_to_camera[names[j]]),
},
"T_world_to_ref": Pose.from_4x4mat(T_world_to_camera[names[i]]),
"T_world_to_target": Pose.from_4x4mat(T_world_to_camera[names[j]]),
"T_0to1": Pose.from_4x4mat(
np.float32(
T_world_to_camera[names[j]]
@ np.linalg.inv(T_world_to_camera[names[i]])
)
),
"T_1to0": Pose.from_4x4mat(
np.float32(
T_world_to_camera[names[i]]
@ np.linalg.inv(T_world_to_camera[names[j]])
)
),
"n_covisible_points": n_covisible_points[i, j],
}
for (i, j) in valid_pairs
]
# Print some info
print("[Info] Successfully initialized dataset")
print("\t Name: ETH3D")
print("----------------------------------------")
def download_eth3d(self):
data_dir = DATA_PATH / self.conf.data_dir
tmp_dir = data_dir.parent / "ETH3D_tmp"
if tmp_dir.exists():
shutil.rmtree(tmp_dir)
tmp_dir.mkdir(exist_ok=True, parents=True)
url_base = "https://cvg-data.inf.ethz.ch/ETH3D_undistorted/"
zip_name = "ETH3D_undistorted.zip"
zip_path = tmp_dir / zip_name
torch.hub.download_url_to_file(url_base + zip_name, zip_path)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(tmp_dir)
shutil.move(tmp_dir / zip_name.split(".")[0], data_dir)
def get_dataset(self, split):
return ETH3DDataset(self.conf)
def _read_image(self, img_path):
img = load_image(img_path, grayscale=self.grayscale)
shape = img.shape[-2:]
# instead of INTER_AREA this does bilinear interpolation with antialiasing
img_data = ImagePreprocessor({"resize": max(shape) // self.downsize_factor})(
img
)
return img_data
def read_depth(self, depth_path):
if self.downsize_factor != 8:
raise ValueError(
"Undistorted depth only available for low res"
+ " images(downsize_factor = 8)."
)
depth_img = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
depth_img = depth_img.astype(np.float32) / 256
return depth_img
def __getitem__(self, idx):
"""Returns the data associated to a pair of images (reference, target)
that are co-visible."""
data = self.data[idx]
# Load the images
view0 = data.pop("view0")
view1 = data.pop("view1")
view0 = {**view0, **self._read_image(view0["img_path"])}
view1 = {**view1, **self._read_image(view1["img_path"])}
view0["scales"] = np.array([1.0, 1]).astype(np.float32)
view1["scales"] = np.array([1.0, 1]).astype(np.float32)
# Load the depths
view0["depth"] = self.read_depth(view0["depth_path"])
view1["depth"] = self.read_depth(view1["depth_path"])
outputs = {
**data,
"view0": view0,
"view1": view1,
"name": f"{view0['name']}_{view1['name']}",
}
return outputs
def __len__(self):
return len(self.data)

View File

@ -0,0 +1,311 @@
"""
Simply load images from a folder or nested folders (does not have any split),
and apply homographic adaptations to it. Yields an image pair without border
artifacts.
"""
import argparse
import logging
import shutil
import tarfile
from pathlib import Path
import cv2
import numpy as np
import omegaconf
import torch
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from tqdm import tqdm
from .augmentations import IdentityAugmentation, augmentations
from .base_dataset import BaseDataset
from ..settings import DATA_PATH
from ..models.cache_loader import CacheLoader, pad_local_features
from ..utils.image import read_image
from ..geometry.homography import (
sample_homography_corners,
compute_homography,
warp_points,
)
from ..utils.tools import fork_rng
from ..visualization.viz2d import plot_image_grid
logger = logging.getLogger(__name__)
def sample_homography(img, conf: dict, size: list):
data = {}
H, _, coords, _ = sample_homography_corners(img.shape[:2][::-1], **conf)
data["image"] = cv2.warpPerspective(img, H, tuple(size))
data["H_"] = H.astype(np.float32)
data["coords"] = coords.astype(np.float32)
data["image_size"] = np.array(size, dtype=np.float32)
return data
class HomographyDataset(BaseDataset):
default_conf = {
# image search
"data_dir": "revisitop1m", # the top-level directory
"image_dir": "jpg/", # the subdirectory with the images
"image_list": "revisitop1m.txt", # optional: list or filename of list
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
# splits
"train_size": 100,
"val_size": 10,
"shuffle_seed": 0, # or None to skip
# image loading
"grayscale": False,
"triplet": False,
"right_only": False, # image0 is orig (rescaled), image1 is right
"reseed": False,
"homography": {
"difficulty": 0.8,
"translation": 1.0,
"max_angle": 60,
"n_angles": 10,
"patch_shape": [640, 480],
"min_convexity": 0.05,
},
"photometric": {
"name": "dark",
"p": 0.75,
# 'difficulty': 1.0, # currently unused
},
# feature loading
"load_features": {
"do": False,
**CacheLoader.default_conf,
"collate": False,
"thresh": 0.0,
"max_num_keypoints": -1,
"force_num_keypoints": False,
},
}
def _init(self, conf):
data_dir = DATA_PATH / conf.data_dir
if not data_dir.exists():
if conf.data_dir == "revisitop1m":
logger.info("Downloading the revisitop1m dataset.")
self.download_revisitop1m(data_dir)
else:
raise FileNotFoundError(data_dir)
image_dir = data_dir / conf.image_dir
images = []
if conf.image_list is None:
glob = [conf.glob] if isinstance(conf.glob, str) else conf.glob
for g in glob:
images += list(image_dir.glob("**/" + g))
if len(images) == 0:
raise ValueError(f"Cannot find any image in folder: {image_dir}.")
images = [i.relative_to(image_dir).as_posix() for i in images]
images = sorted(images) # for deterministic behavior
logger.info("Found %d images in folder.", len(images))
elif isinstance(conf.image_list, (str, Path)):
image_list = data_dir / conf.image_list
if not image_list.exists():
raise FileNotFoundError(f"Cannot find image list {image_list}.")
images = image_list.read_text().rstrip("\n").split("\n")
for image in images:
if not (image_dir / image).exists():
raise FileNotFoundError(image_dir / image)
logger.info("Found %d images in list file.", len(images))
elif isinstance(conf.image_list, omegaconf.listconfig.ListConfig):
images = conf.image_list.to_container()
for image in images:
if not (image_dir / image).exists():
raise FileNotFoundError(image_dir / image)
else:
raise ValueError(conf.image_list)
if conf.shuffle_seed is not None:
np.random.RandomState(conf.shuffle_seed).shuffle(images)
train_images = images[: conf.train_size]
val_images = images[conf.train_size : conf.train_size + conf.val_size]
self.images = {"train": train_images, "val": val_images}
def download_revisitop1m(self):
data_dir = DATA_PATH / self.conf.data_dir
tmp_dir = data_dir.parent / "revisitop1m_tmp"
if tmp_dir.exists(): # The previous download failed.
shutil.rmtree(tmp_dir)
image_dir = tmp_dir / self.conf.image_dir
image_dir.mkdir(exist_ok=True, parents=True)
num_files = 100
url_base = "http://ptak.felk.cvut.cz/revisitop/revisitop1m/"
list_name = "revisitop1m.txt"
torch.hub.download_url_to_file(url_base + list_name, tmp_dir / list_name)
for n in tqdm(range(num_files), position=1):
tar_name = "revisitop1m.{}.tar.gz".format(n + 1)
tar_path = image_dir / tar_name
torch.hub.download_url_to_file(url_base + "jpg/" + tar_name, tar_path)
with tarfile.open(tar_path) as tar:
tar.extractall(path=image_dir)
tar_path.unlink()
shutil.move(tmp_dir, data_dir)
def get_dataset(self, split):
return _Dataset(self.conf, self.images[split], split)
class _Dataset(torch.utils.data.Dataset):
def __init__(self, conf, image_names, split):
self.conf = conf
self.split = split
self.image_names = np.array(image_names)
self.image_dir = DATA_PATH / conf.data_dir / conf.image_dir
aug_conf = conf.photometric
aug_name = aug_conf.name
assert (
aug_name in augmentations.keys()
), f'{aug_name} not in {" ".join(augmentations.keys())}'
self.photo_augment = augmentations[aug_name](aug_conf)
self.left_augment = (
IdentityAugmentation() if conf.right_only else self.photo_augment
)
self.img_to_tensor = IdentityAugmentation()
if conf.load_features.do:
self.feature_loader = CacheLoader(conf.load_features)
def _transform_keypoints(self, features, data):
"""Transform keypoints by a homography, threshold them,
and potentially keep only the best ones."""
# Warp points
features["keypoints"] = warp_points(
features["keypoints"], data["H_"], inverse=False
)
h, w = data["image"].shape[1:3]
valid = (
(features["keypoints"][:, 0] >= 0)
& (features["keypoints"][:, 0] <= w - 1)
& (features["keypoints"][:, 1] >= 0)
& (features["keypoints"][:, 1] <= h - 1)
)
features["keypoints"] = features["keypoints"][valid]
# Threshold
if self.conf.load_features.thresh > 0:
valid = features["keypoint_scores"] >= self.conf.load_features.thresh
features = {k: v[valid] for k, v in features.items()}
# Get the top keypoints and pad
n = self.conf.load_features.max_num_keypoints
if n > -1:
inds = np.argsort(-features["keypoint_scores"])
features = {k: v[inds[:n]] for k, v in features.items()}
if self.conf.load_features.force_num_keypoints:
features = pad_local_features(
features, self.conf.load_features.max_num_keypoints
)
return features
def __getitem__(self, idx):
if self.conf.reseed:
with fork_rng(self.conf.seed + idx, False):
return self.getitem(idx)
else:
return self.getitem(idx)
def _read_view(self, img, H_conf, ps, left=False):
data = sample_homography(img, H_conf, ps)
if left:
data["image"] = self.left_augment(data["image"], return_tensor=True)
else:
data["image"] = self.photo_augment(data["image"], return_tensor=True)
gs = data["image"].new_tensor([0.299, 0.587, 0.114]).view(3, 1, 1)
if self.conf.grayscale:
data["image"] = (data["image"] * gs).sum(0, keepdim=True)
if self.conf.load_features.do:
features = self.feature_loader({k: [v] for k, v in data.items()})
features = self._transform_keypoints(features, data)
data["cache"] = features
return data
def getitem(self, idx):
name = self.image_names[idx]
img = read_image(self.image_dir / name, False)
if img is None:
logging.warning("Image %s could not be read.", name)
img = np.zeros((1024, 1024) + (() if self.conf.grayscale else (3,)))
img = img.astype(np.float32) / 255.0
size = img.shape[:2][::-1]
ps = self.conf.homography.patch_shape
left_conf = omegaconf.OmegaConf.to_container(self.conf.homography)
if self.conf.right_only:
left_conf["difficulty"] = 0.0
data0 = self._read_view(img, left_conf, ps, left=True)
data1 = self._read_view(img, self.conf.homography, ps, left=False)
H = compute_homography(data0["coords"], data1["coords"], [1, 1])
data = {
"name": name,
"original_image_size": np.array(size),
"H_0to1": H.astype(np.float32),
"idx": idx,
"view0": data0,
"view1": data1,
}
if self.conf.triplet:
# Generate third image
data2 = self._read_view(img, self.conf.homography, ps, left=False)
H02 = compute_homography(data0["coords"], data2["coords"], [1, 1])
H12 = compute_homography(data1["coords"], data2["coords"], [1, 1])
data = {
"H_0to2": H02.astype(np.float32),
"H_1to2": H12.astype(np.float32),
"view2": data2,
**data,
}
return data
def __len__(self):
return len(self.image_names)
def visualize(args):
conf = {
"batch_size": 1,
"num_workers": 1,
"prefetch_factor": 1,
}
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
dataset = HomographyDataset(conf)
loader = dataset.get_data_loader("train")
logger.info("The dataset has %d elements.", len(loader))
with fork_rng(seed=dataset.conf.seed):
images = []
for _, data in zip(range(args.num_items), loader):
images.append(
(data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2))
)
plot_image_grid(images, dpi=args.dpi)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
from .. import logger # overwrite the logger
parser = argparse.ArgumentParser()
parser.add_argument("--num_items", type=int, default=8)
parser.add_argument("--dpi", type=int, default=100)
parser.add_argument("dotlist", nargs="*")
args = parser.parse_intermixed_args()
visualize(args)

View File

@ -0,0 +1,144 @@
"""
Simply load images from a folder or nested folders (does not have any split).
"""
import argparse
import logging
import tarfile
import matplotlib.pyplot as plt
import numpy as np
import torch
from omegaconf import OmegaConf
from .base_dataset import BaseDataset
from ..settings import DATA_PATH
from ..utils.image import load_image, ImagePreprocessor
from ..utils.tools import fork_rng
from ..visualization.viz2d import plot_image_grid
logger = logging.getLogger(__name__)
def read_homography(path):
with open(path) as f:
result = []
for line in f.readlines():
while " " in line: # Remove double spaces
line = line.replace(" ", " ")
line = line.replace(" \n", "").replace("\n", "")
# Split and discard empty strings
elements = list(filter(lambda s: s, line.split(" ")))
if elements:
result.append(elements)
return np.array(result).astype(float)
class HPatches(BaseDataset, torch.utils.data.Dataset):
default_conf = {
"preprocessing": ImagePreprocessor.default_conf,
"data_dir": "hpatches-sequences-release",
"subset": None,
"ignore_large_images": True,
"grayscale": False,
}
# Large images that were ignored in previous papers
ignored_scenes = (
"i_contruction",
"i_crownnight",
"i_dc",
"i_pencils",
"i_whitebuilding",
"v_artisans",
"v_astronautis",
"v_talent",
)
url = "http://icvl.ee.ic.ac.uk/vbalnt/hpatches/hpatches-sequences-release.tar.gz"
def _init(self, conf):
assert conf.batch_size == 1
self.preprocessor = ImagePreprocessor(conf.preprocessing)
self.root = DATA_PATH / conf.data_dir
if not self.root.exists():
logger.info("Downloading the HPatches dataset.")
self.download()
self.sequences = sorted([x.name for x in self.root.iterdir()])
if not self.sequences:
raise ValueError("No image found!")
self.items = [] # (seq, q_idx, is_illu)
for seq in self.sequences:
if conf.ignore_large_images and seq in self.ignored_scenes:
continue
if conf.subset is not None and conf.subset != seq[0]:
continue
for i in range(2, 7):
self.items.append((seq, i, seq[0] == "i"))
def download(self):
data_dir = self.root.parent
data_dir.mkdir(exist_ok=True, parents=True)
tar_path = data_dir / self.url.rsplit("/", 1)[-1]
torch.hub.download_url_to_file(self.url, tar_path)
with tarfile.open(tar_path) as tar:
tar.extractall(data_dir)
tar_path.unlink()
def get_dataset(self, split):
assert split in ["val", "test"]
return self
def _read_image(self, seq: str, idx: int) -> dict:
img = load_image(self.root / seq / f"{idx}.ppm", self.conf.grayscale)
return self.preprocessor(img)
def __getitem__(self, idx):
seq, q_idx, is_illu = self.items[idx]
data0 = self._read_image(seq, 1)
data1 = self._read_image(seq, q_idx)
H = read_homography(self.root / seq / f"H_1_{q_idx}")
H = data1["transform"] @ H @ np.linalg.inv(data0["transform"])
return {
"H_0to1": H.astype(np.float32),
"scene": seq,
"idx": idx,
"is_illu": is_illu,
"name": f"{seq}/{idx}.ppm",
"view0": data0,
"view1": data1,
}
def __len__(self):
return len(self.items)
def visualize(args):
conf = {
"batch_size": 1,
"num_workers": 8,
"prefetch_factor": 1,
}
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
dataset = HPatches(conf)
loader = dataset.get_data_loader("test")
logger.info("The dataset has %d elements.", len(loader))
with fork_rng(seed=dataset.conf.seed):
images = []
for _, data in zip(range(args.num_items), loader):
images.append(
(data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2))
)
plot_image_grid(images, dpi=args.dpi)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
from .. import logger # overwrite the logger
parser = argparse.ArgumentParser()
parser.add_argument("--num_items", type=int, default=8)
parser.add_argument("--dpi", type=int, default=100)
parser.add_argument("dotlist", nargs="*")
args = parser.parse_intermixed_args()
visualize(args)

View File

@ -0,0 +1,58 @@
"""
Simply load images from a folder or nested folders (does not have any split).
"""
from pathlib import Path
import torch
import logging
import omegaconf
from .base_dataset import BaseDataset
from ..utils.image import load_image, ImagePreprocessor
class ImageFolder(BaseDataset, torch.utils.data.Dataset):
default_conf = {
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
"images": "???",
"root_folder": "/",
"preprocessing": ImagePreprocessor.default_conf,
}
def _init(self, conf):
self.root = conf.root_folder
if isinstance(conf.images, str):
if not Path(conf.images).is_dir():
with open(conf.images, "r") as f:
self.images = f.read().rstrip("\n").split("\n")
logging.info(f"Found {len(self.images)} images in list file.")
else:
self.images = []
glob = [conf.glob] if isinstance(conf.glob, str) else conf.glob
for g in glob:
self.images += list(Path(conf.images).glob("**/" + g))
if len(self.images) == 0:
raise ValueError(
f"Could not find any image in folder: {conf.images}."
)
self.images = [i.relative_to(conf.images) for i in self.images]
self.root = conf.images
logging.info(f"Found {len(self.images)} images in folder.")
elif isinstance(conf.images, omegaconf.listconfig.ListConfig):
self.images = conf.images.to_container()
else:
raise ValueError(conf.images)
self.preprocessor = ImagePreprocessor(conf.preprocessing)
def get_dataset(self, split):
return self
def __getitem__(self, idx):
path = self.images[idx]
img = load_image(path)
data = {"name": str(path), **self.preprocessor(img)}
return data
def __len__(self):
return len(self.images)

View File

@ -0,0 +1,99 @@
"""
Simply load images from a folder or nested folders (does not have any split).
"""
from pathlib import Path
import torch
import numpy as np
from .base_dataset import BaseDataset
from ..utils.image import load_image, ImagePreprocessor
from ..settings import DATA_PATH
from ..geometry.wrappers import Camera, Pose
def names_to_pair(name0, name1, separator="/"):
return separator.join((name0.replace("/", "-"), name1.replace("/", "-")))
def parse_homography(homography_elems) -> Camera:
return (
np.array([float(x) for x in homography_elems[:9]])
.reshape(3, 3)
.astype(np.float32)
)
def parse_camera(calib_elems) -> Camera:
# assert len(calib_list) == 9
K = np.array([float(x) for x in calib_elems[:9]]).reshape(3, 3).astype(np.float32)
return Camera.from_calibration_matrix(K)
def parse_relative_pose(pose_elems) -> Pose:
# assert len(calib_list) == 9
R, t = pose_elems[:9], pose_elems[9:12]
R = np.array([float(x) for x in R]).reshape(3, 3).astype(np.float32)
t = np.array([float(x) for x in t]).astype(np.float32)
return Pose.from_Rt(R, t)
class ImagePairs(BaseDataset, torch.utils.data.Dataset):
default_conf = {
"pairs": "???", # ToDo: add image folder interface
"root": "???",
"preprocessing": ImagePreprocessor.default_conf,
"extra_data": None, # relative_pose, homography
}
def _init(self, conf):
pair_f = (
Path(conf.pairs) if Path(conf.pairs).exists() else DATA_PATH / conf.pairs
)
with open(str(pair_f), "r") as f:
self.items = [line.rstrip() for line in f]
self.preprocessor = ImagePreprocessor(conf.preprocessing)
def get_dataset(self, split):
return self
def _read_view(self, name):
path = DATA_PATH / self.conf.root / name
img = load_image(path)
return self.preprocessor(img)
def __getitem__(self, idx):
line = self.items[idx]
pair_data = line.split(" ")
name0, name1 = pair_data[:2]
data0 = self._read_view(name0)
data1 = self._read_view(name1)
data = {
"view0": data0,
"view1": data1,
}
if self.conf.extra_data == "relative_pose":
data["view0"]["camera"] = parse_camera(pair_data[2:11]).scale(
data0["scales"]
)
data["view1"]["camera"] = parse_camera(pair_data[11:20]).scale(
data1["scales"]
)
data["T_0to1"] = parse_relative_pose(pair_data[20:32])
elif self.conf.extra_data == "homography":
data["H_0to1"] = (
data1["transform"]
@ parse_homography(pair_data[2:11])
@ np.linalg.inv(data0["transform"])
)
else:
assert (
self.conf.extra_data is None
), f"Unknown extra data format {self.conf.extra_data}"
data["name"] = names_to_pair(name0, name1)
return data
def __len__(self):
return len(self.items)

View File

@ -0,0 +1,514 @@
import argparse
import logging
from pathlib import Path
from collections.abc import Iterable
import tarfile
import shutil
import h5py
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import torch
from omegaconf import OmegaConf
from .base_dataset import BaseDataset
from .utils import (
scale_intrinsics,
rotate_intrinsics,
rotate_pose_inplane,
)
from ..geometry.wrappers import Camera, Pose
from ..models.cache_loader import CacheLoader
from ..utils.tools import fork_rng
from ..utils.image import load_image, ImagePreprocessor
from ..settings import DATA_PATH
from ..visualization.viz2d import plot_image_grid, plot_heatmaps
logger = logging.getLogger(__name__)
scene_lists_path = Path(__file__).parent / "megadepth_scene_lists"
def sample_n(data, num, seed=None):
if len(data) > num:
selected = np.random.RandomState(seed).choice(len(data), num, replace=False)
return data[selected]
else:
return data
class MegaDepth(BaseDataset):
default_conf = {
# paths
"data_dir": "megadepth/",
"depth_subpath": "depth_undistorted/",
"image_subpath": "Undistorted_SfM/",
"info_dir": "scene_info/", # @TODO: intrinsics problem?
# Training
"train_split": "train_scenes_clean.txt",
"train_num_per_scene": 500,
# Validation
"val_split": "valid_scenes_clean.txt",
"val_num_per_scene": None,
"val_pairs": None,
# Test
"test_split": "test_scenes_clean.txt",
"test_num_per_scene": None,
"test_pairs": None,
# data sampling
"views": 2,
"min_overlap": 0.3, # only with D2-Net format
"max_overlap": 1.0, # only with D2-Net format
"num_overlap_bins": 1,
"sort_by_overlap": False,
"triplet_enforce_overlap": False, # only with views==3
# image options
"read_depth": True,
"read_image": True,
"grayscale": False,
"preprocessing": ImagePreprocessor.default_conf,
"p_rotate": 0.0, # probability to rotate image by +/- 90°
"reseed": False,
"seed": 0,
# features from cache
"load_features": {
"do": False,
**CacheLoader.default_conf,
"collate": False,
},
}
def _init(self, conf):
if not (DATA_PATH / conf.data_dir).exists():
logger.info("Downloading the MegaDepth dataset.")
self.download()
def download(self):
data_dir = DATA_PATH / self.conf.data_dir
tmp_dir = data_dir.parent / "megadepth_tmp"
if tmp_dir.exists(): # The previous download failed.
shutil.rmtree(tmp_dir)
tmp_dir.mkdir(exist_ok=True, parents=True)
url_base = "https://cvg-data.inf.ethz.ch/megadepth/"
for tar_name, out_name in (
("Undistorted_SfM.tar.gz", self.conf.image_subpath),
("depth_undistorted.tar.gz", self.conf.depth_subpath),
("scene_info.tar.gz", self.conf.info_dir),
):
tar_path = tmp_dir / tar_name
torch.hub.download_url_to_file(url_base + tar_name, tar_path)
with tarfile.open(tar_path) as tar:
tar.extractall(path=tmp_dir)
tar_path.unlink()
shutil.move(tmp_dir / tar_name.split(".")[0], tmp_dir / out_name)
shutil.move(tmp_dir, data_dir)
def get_dataset(self, split):
assert self.conf.views in [1, 2, 3]
if self.conf.views == 3:
return _TripletDataset(self.conf, split)
else:
return _PairDataset(self.conf, split)
class _PairDataset(torch.utils.data.Dataset):
def __init__(self, conf, split, load_sample=True):
self.root = DATA_PATH / conf.data_dir
assert self.root.exists(), self.root
self.split = split
self.conf = conf
split_conf = conf[split + "_split"]
if isinstance(split_conf, (str, Path)):
scenes_path = scene_lists_path / split_conf
scenes = scenes_path.read_text().rstrip("\n").split("\n")
elif isinstance(split_conf, Iterable):
scenes = list(split_conf)
else:
raise ValueError(f"Unknown split configuration: {split_conf}.")
scenes = sorted(set(scenes))
if conf.load_features.do:
self.feature_loader = CacheLoader(conf.load_features)
self.preprocessor = ImagePreprocessor(conf.preprocessing)
self.images = {}
self.depths = {}
self.poses = {}
self.intrinsics = {}
self.valid = {}
# load metadata
self.info_dir = self.root / self.conf.info_dir
self.scenes = []
for scene in scenes:
path = self.info_dir / (scene + ".npz")
try:
info = np.load(str(path), allow_pickle=True)
except Exception:
logger.warning(
"Cannot load scene info for scene %s at %s.", scene, path
)
continue
self.images[scene] = info["image_paths"]
self.depths[scene] = info["depth_paths"]
self.poses[scene] = info["poses"]
self.intrinsics[scene] = info["intrinsics"]
self.scenes.append(scene)
if load_sample:
self.sample_new_items(conf.seed)
assert len(self.items) > 0
def sample_new_items(self, seed):
logger.info("Sampling new %s data with seed %d.", self.split, seed)
self.items = []
split = self.split
num_per_scene = self.conf[self.split + "_num_per_scene"]
if isinstance(num_per_scene, Iterable):
num_pos, num_neg = num_per_scene
else:
num_pos = num_per_scene
num_neg = None
if split != "train" and self.conf[split + "_pairs"] is not None:
# Fixed validation or test pairs
assert num_pos is None
assert num_neg is None
assert self.conf.views == 2
pairs_path = scene_lists_path / self.conf[split + "_pairs"]
for line in pairs_path.read_text().rstrip("\n").split("\n"):
im0, im1 = line.split(" ")
scene = im0.split("/")[0]
assert im1.split("/")[0] == scene
im0, im1 = [self.conf.image_subpath + im for im in [im0, im1]]
assert im0 in self.images[scene]
assert im1 in self.images[scene]
idx0 = np.where(self.images[scene] == im0)[0][0]
idx1 = np.where(self.images[scene] == im1)[0][0]
self.items.append((scene, idx0, idx1, 1.0))
elif self.conf.views == 1:
for scene in self.scenes:
if scene not in self.images:
continue
valid = (self.images[scene] != None) | ( # noqa: E711
self.depths[scene] != None # noqa: E711
)
ids = np.where(valid)[0]
if num_pos and len(ids) > num_pos:
ids = np.random.RandomState(seed).choice(
ids, num_pos, replace=False
)
ids = [(scene, i) for i in ids]
self.items.extend(ids)
else:
for scene in self.scenes:
path = self.info_dir / (scene + ".npz")
assert path.exists(), path
info = np.load(str(path), allow_pickle=True)
valid = (self.images[scene] != None) & ( # noqa: E711
self.depths[scene] != None # noqa: E711
)
ind = np.where(valid)[0]
mat = info["overlap_matrix"][valid][:, valid]
if num_pos is not None:
# Sample a subset of pairs, binned by overlap.
num_bins = self.conf.num_overlap_bins
assert num_bins > 0
bin_width = (
self.conf.max_overlap - self.conf.min_overlap
) / num_bins
num_per_bin = num_pos // num_bins
pairs_all = []
for k in range(num_bins):
bin_min = self.conf.min_overlap + k * bin_width
bin_max = bin_min + bin_width
pairs_bin = (mat > bin_min) & (mat <= bin_max)
pairs_bin = np.stack(np.where(pairs_bin), -1)
pairs_all.append(pairs_bin)
# Skip bins with too few samples
has_enough_samples = [len(p) >= num_per_bin * 2 for p in pairs_all]
num_per_bin_2 = num_pos // max(1, sum(has_enough_samples))
pairs = []
for pairs_bin, keep in zip(pairs_all, has_enough_samples):
if keep:
pairs.append(sample_n(pairs_bin, num_per_bin_2, seed))
pairs = np.concatenate(pairs, 0)
else:
pairs = (mat > self.conf.min_overlap) & (
mat <= self.conf.max_overlap
)
pairs = np.stack(np.where(pairs), -1)
pairs = [(scene, ind[i], ind[j], mat[i, j]) for i, j in pairs]
if num_neg is not None:
neg_pairs = np.stack(np.where(mat <= 0.0), -1)
neg_pairs = sample_n(neg_pairs, num_neg, seed)
pairs += [(scene, ind[i], ind[j], mat[i, j]) for i, j in neg_pairs]
self.items.extend(pairs)
if self.conf.views == 2 and self.conf.sort_by_overlap:
self.items.sort(key=lambda i: i[-1], reverse=True)
else:
np.random.RandomState(seed).shuffle(self.items)
def _read_view(self, scene, idx):
path = self.root / self.images[scene][idx]
# read pose data
K = self.intrinsics[scene][idx].astype(np.float32, copy=False)
T = self.poses[scene][idx].astype(np.float32, copy=False)
# read image
if self.conf.read_image:
img = load_image(self.root / self.images[scene][idx], self.conf.grayscale)
else:
size = PIL.Image.open(path).size[::-1]
img = torch.zeros(
[3 - 2 * int(self.conf.grayscale), size[0], size[1]]
).float()
# read depth
if self.conf.read_depth:
depth_path = (
self.root / self.conf.depth_subpath / scene / (path.stem + ".h5")
)
with h5py.File(str(depth_path), "r") as f:
depth = f["/depth"].__array__().astype(np.float32, copy=False)
depth = torch.Tensor(depth)[None]
assert depth.shape[-2:] == img.shape[-2:]
else:
depth = None
# add random rotations
do_rotate = self.conf.p_rotate > 0.0 and self.split == "train"
if do_rotate:
p = self.conf.p_rotate
k = 0
if np.random.rand() < p:
k = np.random.choice(2, 1, replace=False)[0] * 2 - 1
img = np.rot90(img, k=-k, axes=(-2, -1))
if self.conf.read_depth:
depth = np.rot90(depth, k=-k, axes=(-2, -1)).copy()
K = rotate_intrinsics(K, img.shape, k + 2)
T = rotate_pose_inplane(T, k + 2)
name = path.name
data = self.preprocessor(img)
if depth is not None:
data["depth"] = self.preprocessor(depth, interpolation="nearest")["image"][
0
]
K = scale_intrinsics(K, data["scales"])
data = {
"name": name,
"scene": scene,
"T_w2cam": Pose.from_4x4mat(T),
"depth": depth,
"camera": Camera.from_calibration_matrix(K).float(),
**data,
}
if self.conf.load_features.do:
features = self.feature_loader({k: [v] for k, v in data.items()})
if do_rotate and k != 0:
# ang = np.deg2rad(k * 90.)
kpts = features["keypoints"].copy()
x, y = kpts[:, 0].copy(), kpts[:, 1].copy()
w, h = data["image_size"]
if k == 1:
kpts[:, 0] = w - y
kpts[:, 1] = x
elif k == -1:
kpts[:, 0] = y
kpts[:, 1] = h - x
else:
raise ValueError
features["keypoints"] = kpts
data = {"cache": features, **data}
return data
def __getitem__(self, idx):
if self.conf.reseed:
with fork_rng(self.conf.seed + idx, False):
return self.getitem(idx)
else:
return self.getitem(idx)
def getitem(self, idx):
if self.conf.views == 2:
if isinstance(idx, list):
scene, idx0, idx1, overlap = idx
else:
scene, idx0, idx1, overlap = self.items[idx]
data0 = self._read_view(scene, idx0)
data1 = self._read_view(scene, idx1)
data = {
"view0": data0,
"view1": data1,
}
data["T_0to1"] = data1["T_w2cam"] @ data0["T_w2cam"].inv()
data["T_1to0"] = data0["T_w2cam"] @ data1["T_w2cam"].inv()
data["overlap_0to1"] = overlap
data["name"] = f"{scene}/{data0['name']}_{data1['name']}"
else:
assert self.conf.views == 1
scene, idx0 = self.items[idx]
data = self._read_view(scene, idx0)
data["scene"] = scene
data["idx"] = idx
return data
def __len__(self):
return len(self.items)
class _TripletDataset(_PairDataset):
def sample_new_items(self, seed):
logging.info("Sampling new triplets with seed %d", seed)
self.items = []
split = self.split
num = self.conf[self.split + "_num_per_scene"]
if split != "train" and self.conf[split + "_pairs"] is not None:
if Path(self.conf[split + "_pairs"]).exists():
pairs_path = Path(self.conf[split + "_pairs"])
else:
pairs_path = DATA_PATH / "configs" / self.conf[split + "_pairs"]
for line in pairs_path.read_text().rstrip("\n").split("\n"):
im0, im1, im2 = line.split(" ")
assert im0[:4] == im1[:4]
scene = im1[:4]
idx0 = np.where(self.images[scene] == im0)
idx1 = np.where(self.images[scene] == im1)
idx2 = np.where(self.images[scene] == im2)
self.items.append((scene, idx0, idx1, idx2, 1.0, 1.0, 1.0))
else:
for scene in self.scenes:
path = self.info_dir / (scene + ".npz")
assert path.exists(), path
info = np.load(str(path), allow_pickle=True)
if self.conf.num_overlap_bins > 1:
raise NotImplementedError("TODO")
valid = (self.images[scene] != None) & ( # noqa: E711
self.depth[scene] != None # noqa: E711
)
ind = np.where(valid)[0]
mat = info["overlap_matrix"][valid][:, valid]
good = (mat > self.conf.min_overlap) & (mat <= self.conf.max_overlap)
triplets = []
if self.conf.triplet_enforce_overlap:
pairs = np.stack(np.where(good), -1)
for i0, i1 in pairs:
for i2 in pairs[pairs[:, 0] == i0, 1]:
if good[i1, i2]:
triplets.append((i0, i1, i2))
if len(triplets) > num:
selected = np.random.RandomState(seed).choice(
len(triplets), num, replace=False
)
selected = range(num)
triplets = np.array(triplets)[selected]
else:
# we first enforce that each row has >1 pairs
non_unique = good.sum(-1) > 1
ind_r = np.where(non_unique)[0]
good = good[non_unique]
pairs = np.stack(np.where(good), -1)
if len(pairs) > num:
selected = np.random.RandomState(seed).choice(
len(pairs), num, replace=False
)
pairs = pairs[selected]
for idx, (k, i) in enumerate(pairs):
# We now sample a j from row k s.t. i != j
possible_j = np.where(good[k])[0]
possible_j = possible_j[possible_j != i]
selected = np.random.RandomState(seed + idx).choice(
len(possible_j), 1, replace=False
)[0]
triplets.append((ind_r[k], i, possible_j[selected]))
triplets = [
(scene, ind[k], ind[i], ind[j], mat[k, i], mat[k, j], mat[i, j])
for k, i, j in triplets
]
self.items.extend(triplets)
np.random.RandomState(seed).shuffle(self.items)
def __getitem__(self, idx):
scene, idx0, idx1, idx2, overlap01, overlap02, overlap12 = self.items[idx]
data0 = self._read_view(scene, idx0)
data1 = self._read_view(scene, idx1)
data2 = self._read_view(scene, idx2)
data = {
"view0": data0,
"view1": data1,
"view2": data2,
}
data["T_0to1"] = data1["T_w2cam"] @ data0["T_w2cam"].inv()
data["T_0to2"] = data2["T_w2cam"] @ data0["T_w2cam"].inv()
data["T_1to2"] = data2["T_w2cam"] @ data1["T_w2cam"].inv()
data["T_1to0"] = data0["T_w2cam"] @ data1["T_w2cam"].inv()
data["T_2to0"] = data0["T_w2cam"] @ data2["T_w2cam"].inv()
data["T_2to1"] = data1["T_w2cam"] @ data2["T_w2cam"].inv()
data["overlap_0to1"] = overlap01
data["overlap_0to2"] = overlap02
data["overlap_1to2"] = overlap12
data["scene"] = scene
data["name"] = f"{scene}/{data0['name']}_{data1['name']}_{data2['name']}"
return data
def __len__(self):
return len(self.items)
def visualize(args):
conf = {
"min_overlap": 0.1,
"max_overlap": 0.7,
"num_overlap_bins": 3,
"sort_by_overlap": False,
"train_num_per_scene": 5,
"batch_size": 1,
"num_workers": 0,
"prefetch_factor": None,
"val_num_per_scene": None,
}
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
dataset = MegaDepth(conf)
loader = dataset.get_data_loader(args.split)
logger.info("The dataset has elements.", len(loader))
with fork_rng(seed=dataset.conf.seed):
images, depths = [], []
for _, data in zip(range(args.num_items), loader):
images.append(
[
data[f"view{i}"]["image"][0].permute(1, 2, 0)
for i in range(dataset.conf.views)
]
)
depths.append(
[data[f"view{i}"]["depth"][0] for i in range(dataset.conf.views)]
)
axes = plot_image_grid(images, dpi=args.dpi)
for i in range(len(images)):
plot_heatmaps(depths[i], axes=axes[i])
plt.show()
if __name__ == "__main__":
from .. import logger # overwrite the logger
parser = argparse.ArgumentParser()
parser.add_argument("--split", type=str, default="val")
parser.add_argument("--num_items", type=int, default=4)
parser.add_argument("--dpi", type=int, default=100)
parser.add_argument("dotlist", nargs="*")
args = parser.parse_intermixed_args()
visualize(args)

View File

@ -0,0 +1,8 @@
0008
0019
0021
0024
0025
0032
0063
1589

View File

@ -0,0 +1,118 @@
0000
0001
0002
0003
0004
0005
0007
0008
0011
0012
0013
0015
0017
0019
0020
0021
0022
0023
0024
0025
0026
0027
0032
0035
0036
0037
0039
0042
0043
0046
0048
0050
0056
0057
0060
0061
0063
0065
0070
0080
0083
0086
0087
0092
0095
0098
0100
0101
0103
0104
0105
0107
0115
0117
0122
0130
0137
0143
0147
0148
0149
0150
0156
0160
0176
0183
0189
0190
0200
0214
0224
0235
0237
0240
0243
0258
0265
0269
0299
0312
0326
0327
0331
0335
0341
0348
0366
0377
0380
0394
0407
0411
0430
0446
0455
0472
0474
0476
0478
0493
0494
0496
0505
0559
0733
0860
1017
1589
4541
5004
5005
5006
5007
5009
5010
5012
5013
5017

View File

@ -0,0 +1,153 @@
0001
0003
0004
0005
0007
0012
0013
0016
0017
0023
0026
0027
0034
0035
0036
0037
0039
0041
0042
0043
0044
0046
0047
0048
0049
0056
0057
0058
0060
0061
0062
0064
0065
0067
0070
0071
0076
0078
0080
0083
0086
0087
0090
0094
0095
0098
0099
0100
0101
0102
0104
0107
0115
0117
0122
0129
0130
0137
0141
0147
0148
0149
0150
0151
0156
0160
0162
0175
0181
0183
0185
0186
0189
0190
0197
0200
0204
0205
0212
0214
0217
0223
0224
0231
0235
0237
0238
0240
0243
0252
0257
0258
0269
0271
0275
0277
0281
0285
0286
0290
0294
0299
0303
0306
0307
0312
0323
0326
0327
0331
0335
0341
0348
0360
0377
0380
0387
0389
0394
0402
0406
0407
0411
0446
0455
0472
0476
0478
0482
0493
0496
0505
0559
0733
0768
1017
3346
5000
5001
5002
5003
5004
5005
5006
5007
5008
5009
5010
5011
5012
5013
5017
5018

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,77 @@
0016
0033
0034
0041
0044
0047
0049
0058
0062
0064
0067
0071
0076
0078
0090
0094
0099
0102
0121
0129
0133
0141
0151
0162
0168
0175
0177
0178
0181
0185
0186
0197
0204
0205
0209
0212
0217
0223
0229
0231
0238
0252
0257
0271
0275
0277
0281
0285
0286
0290
0294
0303
0306
0307
0323
0349
0360
0387
0389
0402
0406
0412
0443
0482
0768
1001
3346
5000
5001
5002
5003
5008
5011
5014
5015
5016
5018

View File

@ -0,0 +1,2 @@
0015
0022

View File

@ -0,0 +1,131 @@
import cv2
import numpy as np
import torch
def read_image(path, grayscale=False):
"""Read an image from path as RGB or grayscale"""
mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
image = cv2.imread(str(path), mode)
if image is None:
raise IOError(f"Could not read image at {path}.")
if not grayscale:
image = image[..., ::-1]
return image
def numpy_image_to_torch(image):
"""Normalize the image tensor and reorder the dimensions."""
if image.ndim == 3:
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
elif image.ndim == 2:
image = image[None] # add channel axis
else:
raise ValueError(f"Not an image: {image.shape}")
return torch.tensor(image / 255.0, dtype=torch.float)
def rotate_intrinsics(K, image_shape, rot):
"""image_shape is the shape of the image after rotation"""
assert rot <= 3
h, w = image_shape[:2][:: -1 if (rot % 2) else 1]
fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
rot = rot % 4
if rot == 1:
return np.array(
[[fy, 0.0, cy], [0.0, fx, w - cx], [0.0, 0.0, 1.0]], dtype=K.dtype
)
elif rot == 2:
return np.array(
[[fx, 0.0, w - cx], [0.0, fy, h - cy], [0.0, 0.0, 1.0]],
dtype=K.dtype,
)
else: # if rot == 3:
return np.array(
[[fy, 0.0, h - cy], [0.0, fx, cx], [0.0, 0.0, 1.0]], dtype=K.dtype
)
def rotate_pose_inplane(i_T_w, rot):
rotation_matrices = [
np.array(
[
[np.cos(r), -np.sin(r), 0.0, 0.0],
[np.sin(r), np.cos(r), 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
],
dtype=np.float32,
)
for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
]
return np.dot(rotation_matrices[rot], i_T_w)
def scale_intrinsics(K, scales):
"""Scale intrinsics after resizing the corresponding image."""
scales = np.diag(np.concatenate([scales, [1.0]]))
return np.dot(scales.astype(K.dtype, copy=False), K)
def get_divisible_wh(w, h, df=None):
if df is not None:
w_new, h_new = map(lambda x: int(x // df * df), [w, h])
else:
w_new, h_new = w, h
return w_new, h_new
def resize(image, size, fn=None, interp="linear", df=None):
"""Resize an image to a fixed size, or according to max or min edge."""
h, w = image.shape[:2]
if isinstance(size, int):
scale = size / fn(h, w)
h_new, w_new = int(round(h * scale)), int(round(w * scale))
w_new, h_new = get_divisible_wh(w_new, h_new, df)
scale = (w_new / w, h_new / h)
elif isinstance(size, (tuple, list)):
h_new, w_new = size
scale = (w_new / w, h_new / h)
else:
raise ValueError(f"Incorrect new size: {size}")
mode = {
"linear": cv2.INTER_LINEAR,
"cubic": cv2.INTER_CUBIC,
"nearest": cv2.INTER_NEAREST,
"area": cv2.INTER_AREA,
}[interp]
return cv2.resize(image, (w_new, h_new), interpolation=mode), scale
def crop(image, size, random=True, other=None, K=None, return_bbox=False):
"""Random or deterministic crop of an image, adjust depth and intrinsics."""
h, w = image.shape[:2]
h_new, w_new = (size, size) if isinstance(size, int) else size
top = np.random.randint(0, h - h_new + 1) if random else 0
left = np.random.randint(0, w - w_new + 1) if random else 0
image = image[top : top + h_new, left : left + w_new]
ret = [image]
if other is not None:
ret += [other[top : top + h_new, left : left + w_new]]
if K is not None:
K[0, 2] -= left
K[1, 2] -= top
ret += [K]
if return_bbox:
ret += [(top, top + h_new, left, left + w_new)]
return ret
def zero_pad(size, *images):
"""zero pad images to size x size"""
ret = []
for image in images:
if image is None:
ret.append(None)
continue
h, w = image.shape[:2]
padded = np.zeros((size, size) + image.shape[2:], dtype=image.dtype)
padded[:h, :w] = image
ret.append(padded)
return ret

View File

@ -0,0 +1,19 @@
import torch
from ..utils.tools import get_class
from .eval_pipeline import EvalPipeline
def get_benchmark(benchmark):
return get_class(f"{__name__}.{benchmark}", EvalPipeline)
@torch.no_grad()
def run_benchmark(benchmark, eval_conf, experiment_dir, model=None):
"""This overwrites existing benchmarks"""
experiment_dir.mkdir(exist_ok=True, parents=True)
bm = get_benchmark(benchmark)
pipeline = bm(eval_conf)
return pipeline.run(
experiment_dir, model=model, overwrite=True, overwrite_eval=True
)

215
gluefactory/eval/eth3d.py Normal file
View File

@ -0,0 +1,215 @@
import torch
from pathlib import Path
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import resource
from collections import defaultdict
from tqdm import tqdm
import numpy as np
from .io import (
parse_eval_args,
load_model,
get_eval_parser,
)
from .eval_pipeline import EvalPipeline, load_eval
from ..utils.export_predictions import export_predictions
from .utils import get_tp_fp_pts, aggregate_pr_results
from ..settings import EVAL_PATH
from ..models.cache_loader import CacheLoader
from ..datasets import get_dataset
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
torch.set_grad_enabled(False)
def eval_dataset(loader, pred_file, suffix=""):
results = defaultdict(list)
results["num_pos" + suffix] = 0
cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
for data in tqdm(loader):
pred = cache_loader(data)
if suffix == "":
scores = pred["matching_scores0"].numpy()
sort_indices = np.argsort(scores)[::-1]
gt_matches = pred["gt_matches0"].numpy()[sort_indices]
pred_matches = pred["matches0"].numpy()[sort_indices]
else:
scores = pred["line_matching_scores0"].numpy()
sort_indices = np.argsort(scores)[::-1]
gt_matches = pred["gt_line_matches0"].numpy()[sort_indices]
pred_matches = pred["line_matches0"].numpy()[sort_indices]
scores = scores[sort_indices]
tp, fp, scores, num_pos = get_tp_fp_pts(pred_matches, gt_matches, scores)
results["tp" + suffix].append(tp)
results["fp" + suffix].append(fp)
results["scores" + suffix].append(scores)
results["num_pos" + suffix] += num_pos
# Aggregate the results
return aggregate_pr_results(results, suffix=suffix)
class ETH3DPipeline(EvalPipeline):
default_conf = {
"data": {
"name": "eth3d",
"batch_size": 1,
"train_batch_size": 1,
"val_batch_size": 1,
"test_batch_size": 1,
"num_workers": 16,
},
"model": {
"name": "gluefactory.models.two_view_pipeline",
"ground_truth": {
"name": "gluefactory.models.matchers.depth_matcher",
"use_lines": False,
},
"run_gt_in_forward": True,
},
"eval": {"plot_methods": [], "plot_line_methods": [], "eval_lines": False},
}
export_keys = [
"gt_matches0",
"matches0",
"matching_scores0",
]
optional_export_keys = [
"gt_line_matches0",
"line_matches0",
"line_matching_scores0",
]
def get_dataloader(self, data_conf=None):
data_conf = data_conf if data_conf is not None else self.default_conf["data"]
dataset = get_dataset("eth3d")(data_conf)
return dataset.get_data_loader("test")
def get_predictions(self, experiment_dir, model=None, overwrite=False):
pred_file = experiment_dir / "predictions.h5"
if not pred_file.exists() or overwrite:
if model is None:
model = load_model(self.conf.model, self.conf.checkpoint)
export_predictions(
self.get_dataloader(self.conf.data),
model,
pred_file,
keys=self.export_keys,
optional_keys=self.optional_export_keys,
)
return pred_file
def run_eval(self, loader, pred_file):
eval_conf = self.conf.eval
r = eval_dataset(loader, pred_file)
if self.conf.eval.eval_lines:
r.update(eval_dataset(loader, pred_file, conf=eval_conf, suffix="_lines"))
s = {}
return s, {}, r
def plot_pr_curve(
models_name, results, dst_file="eth3d_pr_curve.pdf", title=None, suffix=""
):
plt.figure()
f_scores = np.linspace(0.2, 0.9, num=8)
for f_score in f_scores:
x = np.linspace(0.01, 1)
y = f_score * x / (2 * x - f_score)
plt.plot(x[y >= 0], y[y >= 0], color=[0, 0.5, 0], alpha=0.3)
plt.annotate(
"f={0:0.1}".format(f_score),
xy=(0.9, y[45] + 0.02),
alpha=0.4,
fontsize=14,
)
plt.rcParams.update({"font.size": 12})
# plt.rc('legend', fontsize=10)
plt.grid(True)
plt.axis([0.0, 1.0, 0.0, 1.0])
plt.xticks(np.arange(0, 1.05, step=0.1), fontsize=16)
plt.xlabel("Recall", fontsize=18)
plt.ylabel("Precision", fontsize=18)
plt.yticks(np.arange(0, 1.05, step=0.1), fontsize=16)
plt.ylim([0.3, 1.0])
prop_cycle = plt.rcParams["axes.prop_cycle"]
colors = prop_cycle.by_key()["color"]
for m, c in zip(models_name, colors):
sAP_string = f'{m}: {results[m]["AP" + suffix]:.1f}'
plt.plot(
results[m]["curve_recall" + suffix],
results[m]["curve_precision" + suffix],
label=sAP_string,
color=c,
)
plt.legend(fontsize=16, loc="lower right")
if title:
plt.title(title)
plt.tight_layout(pad=0.5)
print(f"Saving plot to: {dst_file}")
plt.savefig(dst_file)
plt.show()
if __name__ == "__main__":
dataset_name = Path(__file__).stem
parser = get_eval_parser()
args = parser.parse_intermixed_args()
default_conf = OmegaConf.create(ETH3DPipeline.default_conf)
# mingle paths
output_dir = Path(EVAL_PATH, dataset_name)
output_dir.mkdir(exist_ok=True, parents=True)
name, conf = parse_eval_args(
dataset_name,
args,
"configs/",
default_conf,
)
experiment_dir = output_dir / name
experiment_dir.mkdir(exist_ok=True)
pipeline = ETH3DPipeline(conf)
s, f, r = pipeline.run(
experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval
)
# print results
for k, v in r.items():
if k.startswith("AP"):
print(f"{k}: {v:.2f}")
if args.plot:
results = {}
for m in conf.eval.plot_methods:
exp_dir = output_dir / m
results[m] = load_eval(exp_dir)[1]
plot_pr_curve(conf.eval.plot_methods, results, dst_file="eth3d_pr_curve.pdf")
if conf.eval.eval_lines:
for m in conf.eval.plot_line_methods:
exp_dir = output_dir / m
results[m] = load_eval(exp_dir)[1]
plot_pr_curve(
conf.eval.plot_line_methods,
results,
dst_file="eth3d_pr_curve_lines.pdf",
suffix="_lines",
)

View File

@ -0,0 +1,108 @@
from omegaconf import OmegaConf
import numpy as np
import json
import h5py
def load_eval(dir):
summaries, results = {}, {}
with h5py.File(str(dir / "results.h5"), "r") as hfile:
for k in hfile.keys():
r = np.array(hfile[k])
if len(r.shape) < 3:
results[k] = r
for k, v in hfile.attrs.items():
summaries[k] = v
with open(dir / "summaries.json", "r") as f:
s = json.load(f)
summaries = {k: v if v is not None else np.nan for k, v in s.items()}
return summaries, results
def save_eval(dir, summaries, figures, results):
with h5py.File(str(dir / "results.h5"), "w") as hfile:
for k, v in results.items():
arr = np.array(v)
if not np.issubdtype(arr.dtype, np.number):
arr = arr.astype("object")
hfile.create_dataset(k, data=arr)
# just to be safe, not used in practice
for k, v in summaries.items():
hfile.attrs[k] = v
s = {
k: float(v) if np.isfinite(v) else None
for k, v in summaries.items()
if not isinstance(v, list)
}
s = {**s, **{k: v for k, v in summaries.items() if isinstance(v, list)}}
with open(dir / "summaries.json", "w") as f:
json.dump(s, f, indent=4)
for fig_name, fig in figures.items():
fig.savefig(dir / f"{fig_name}.png")
def exists_eval(dir):
return (dir / "results.h5").exists() and (dir / "summaries.json").exists()
class EvalPipeline:
default_conf = {}
export_keys = []
optional_export_keys = []
def __init__(self, conf):
"""Assumes"""
self.default_conf = OmegaConf.create(self.default_conf)
self.conf = OmegaConf.merge(self.default_conf, conf)
self._init(self.conf)
def _init(self, conf):
pass
@classmethod
def get_dataloader(self, data_conf=None):
"""Returns a data loader with samples for each eval datapoint"""
raise NotImplementedError
def get_predictions(self, experiment_dir, model=None, overwrite=False):
"""Export a prediction file for each eval datapoint"""
raise NotImplementedError
def run_eval(self, loader, pred_file):
"""Run the eval on cached predictions"""
raise NotImplementedError
def run(self, experiment_dir, model=None, overwrite=False, overwrite_eval=False):
"""Run export+eval loop"""
self.save_conf(
experiment_dir, overwrite=overwrite, overwrite_eval=overwrite_eval
)
pred_file = self.get_predictions(
experiment_dir, model=model, overwrite=overwrite
)
f = {}
if not exists_eval(experiment_dir) or overwrite_eval or overwrite:
s, f, r = self.run_eval(self.get_dataloader(), pred_file)
save_eval(experiment_dir, s, f, r)
s, r = load_eval(experiment_dir)
return s, f, r
def save_conf(self, experiment_dir, overwrite=False, overwrite_eval=False):
# store config
conf_output_path = experiment_dir / "conf.yaml"
if conf_output_path.exists():
saved_conf = OmegaConf.load(conf_output_path)
if (saved_conf.data != self.conf.data) or (
saved_conf.model != self.conf.model
):
assert (
overwrite
), "configs changed, add --overwrite to rerun experiment with new conf"
if saved_conf.eval != self.conf.eval:
assert (
overwrite or overwrite_eval
), "eval configs changed, add --overwrite_eval to rerun evaluation"
OmegaConf.save(self.conf, experiment_dir / "conf.yaml")

View File

@ -0,0 +1,211 @@
import torch
from pathlib import Path
from omegaconf import OmegaConf
from pprint import pprint
import matplotlib.pyplot as plt
import resource
from collections import defaultdict
from collections.abc import Iterable
from tqdm import tqdm
import numpy as np
from ..visualization.viz2d import plot_cumulative
from .io import (
parse_eval_args,
load_model,
get_eval_parser,
)
from ..utils.export_predictions import export_predictions
from ..settings import EVAL_PATH
from ..models.cache_loader import CacheLoader
from ..datasets import get_dataset
from .utils import (
eval_homography_robust,
eval_poses,
eval_matches_homography,
eval_homography_dlt,
)
from ..utils.tools import AUCMetric
from .eval_pipeline import EvalPipeline
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
torch.set_grad_enabled(False)
class HPatchesPipeline(EvalPipeline):
default_conf = {
"data": {
"batch_size": 1,
"name": "hpatches",
"num_workers": 16,
"preprocessing": {
"resize": 480, # we also resize during eval to have comparable metrics
"side": "short",
},
},
"model": {
"ground_truth": {
"name": None, # remove gt matches
}
},
"eval": {
"estimator": "poselib",
"ransac_th": 1.0, # -1 runs a bunch of thresholds and selects the best
},
}
export_keys = [
"keypoints0",
"keypoints1",
"keypoint_scores0",
"keypoint_scores1",
"matches0",
"matches1",
"matching_scores0",
"matching_scores1",
]
optional_export_keys = [
"lines0",
"lines1",
"orig_lines0",
"orig_lines1",
"line_matches0",
"line_matches1",
"line_matching_scores0",
"line_matching_scores1",
]
def _init(self, conf):
pass
@classmethod
def get_dataloader(self, data_conf=None):
data_conf = data_conf if data_conf else self.default_conf["data"]
dataset = get_dataset("hpatches")(data_conf)
return dataset.get_data_loader("test")
def get_predictions(self, experiment_dir, model=None, overwrite=False):
pred_file = experiment_dir / "predictions.h5"
if not pred_file.exists() or overwrite:
if model is None:
model = load_model(self.conf.model, self.conf.checkpoint)
export_predictions(
self.get_dataloader(self.conf.data),
model,
pred_file,
keys=self.export_keys,
optional_keys=self.optional_export_keys,
)
return pred_file
def run_eval(self, loader, pred_file):
assert pred_file.exists()
results = defaultdict(list)
conf = self.conf.eval
test_thresholds = (
([conf.ransac_th] if conf.ransac_th > 0 else [0.5, 1.0, 1.5, 2.0, 2.5, 3.0])
if not isinstance(conf.ransac_th, Iterable)
else conf.ransac_th
)
pose_results = defaultdict(lambda: defaultdict(list))
cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
for i, data in enumerate(tqdm(loader)):
pred = cache_loader(data)
# add custom evaluations here
if "keypoints0" in pred:
results_i = eval_matches_homography(data, pred, {})
results_i = {**results_i, **eval_homography_dlt(data, pred)}
else:
results_i = {}
for th in test_thresholds:
pose_results_i = eval_homography_robust(
data,
pred,
{"estimator": conf.estimator, "ransac_th": th},
)
[pose_results[th][k].append(v) for k, v in pose_results_i.items()]
# we also store the names for later reference
results_i["names"] = data["name"][0]
results_i["scenes"] = data["scene"][0]
for k, v in results_i.items():
results[k].append(v)
# summarize results as a dict[str, float]
# you can also add your custom evaluations here
summaries = {}
for k, v in results.items():
arr = np.array(v)
if not np.issubdtype(np.array(v).dtype, np.number):
continue
summaries[f"m{k}"] = round(np.median(arr), 3)
auc_ths = [1, 3, 5]
best_pose_results, best_th = eval_poses(
pose_results, auc_ths=auc_ths, key="H_error_ransac", unit="px"
)
if "H_error_dlt" in results.keys():
dlt_aucs = AUCMetric(auc_ths, results["H_error_dlt"]).compute()
for i, ath in enumerate(auc_ths):
summaries[f"H_error_dlt@{ath}px"] = dlt_aucs[i]
results = {**results, **pose_results[best_th]}
summaries = {
**summaries,
**best_pose_results,
}
figures = {
"homography_recall": plot_cumulative(
{
"DLT": results["H_error_dlt"],
self.conf.eval.estimator: results["H_error_ransac"],
},
[0, 10],
unit="px",
title="Homography ",
)
}
return summaries, figures, results
if __name__ == "__main__":
dataset_name = Path(__file__).stem
parser = get_eval_parser()
args = parser.parse_intermixed_args()
default_conf = OmegaConf.create(HPatchesPipeline.default_conf)
# mingle paths
output_dir = Path(EVAL_PATH, dataset_name)
output_dir.mkdir(exist_ok=True, parents=True)
name, conf = parse_eval_args(
dataset_name,
args,
"configs/",
default_conf,
)
experiment_dir = output_dir / name
experiment_dir.mkdir(exist_ok=True)
pipeline = HPatchesPipeline(conf)
s, f, r = pipeline.run(
experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval
)
# print results
pprint(s)
if args.plot:
for name, fig in f.items():
fig.canvas.manager.set_window_title(name)
plt.show()

View File

@ -0,0 +1,61 @@
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib
from pprint import pprint
from collections import defaultdict
from ..settings import EVAL_PATH
from ..visualization.global_frame import GlobalFrame
from ..visualization.two_view_frame import TwoViewFrame
from . import get_benchmark
from .eval_pipeline import load_eval
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("benchmark", type=str)
parser.add_argument("--x", type=str, default=None)
parser.add_argument("--y", type=str, default=None)
parser.add_argument("--backend", type=str, default=None)
parser.add_argument(
"--default_plot", type=str, default=TwoViewFrame.default_conf["default"]
)
parser.add_argument("dotlist", nargs="*")
args = parser.parse_intermixed_args()
output_dir = Path(EVAL_PATH, args.benchmark)
results = {}
summaries = defaultdict(dict)
predictions = {}
if args.backend:
matplotlib.use(args.backend)
bm = get_benchmark(args.benchmark)
loader = bm.get_dataloader()
for name in args.dotlist:
experiment_dir = output_dir / name
pred_file = experiment_dir / "predictions.h5"
s, results[name] = load_eval(experiment_dir)
predictions[name] = pred_file
for k, v in s.items():
summaries[k][name] = v
pprint(summaries)
plt.close("all")
frame = GlobalFrame(
{"child": {"default": args.default_plot}, **vars(args)},
results,
loader,
predictions,
child_frame=TwoViewFrame,
)
frame.draw()
plt.show()

103
gluefactory/eval/io.py Normal file
View File

@ -0,0 +1,103 @@
import pkg_resources
from pathlib import Path
from typing import Optional
from omegaconf import OmegaConf
import argparse
from pprint import pprint
from ..models import get_model
from ..utils.experiments import load_experiment
from ..settings import TRAINING_PATH
def parse_config_path(name_or_path: Optional[str], defaults: str) -> Path:
default_configs = {}
for c in pkg_resources.resource_listdir("gluefactory", str(defaults)):
if c.endswith(".yaml"):
default_configs[Path(c).stem] = Path(
pkg_resources.resource_filename("gluefactory", defaults + c)
)
if name_or_path is None:
return None
if name_or_path in default_configs:
return default_configs[name_or_path]
path = Path(name_or_path)
if not path.exists():
raise FileNotFoundError(
f"Cannot find the config file: {name_or_path}. "
f"Not in the default configs {list(default_configs.keys())} "
"and not an existing path."
)
return Path(path)
def extract_benchmark_conf(conf, benchmark):
mconf = OmegaConf.create(
{
"model": conf.get("model", {}),
}
)
if "benchmarks" in conf.keys():
return OmegaConf.merge(mconf, conf.benchmarks.get(benchmark, {}))
else:
return mconf
def parse_eval_args(benchmark, args, configs_path, default=None):
conf = {"data": {}, "model": {}, "eval": {}}
if args.conf:
conf_path = parse_config_path(args.conf, configs_path)
custom_conf = OmegaConf.load(conf_path)
conf = extract_benchmark_conf(OmegaConf.merge(conf, custom_conf), benchmark)
args.tag = (
args.tag if args.tag is not None else conf_path.name.replace(".yaml", "")
)
cli_conf = OmegaConf.from_cli(args.dotlist)
conf = OmegaConf.merge(conf, cli_conf)
conf.checkpoint = args.checkpoint if args.checkpoint else conf.get("checkpoint")
if conf.checkpoint and not conf.checkpoint.endswith(".tar"):
checkpoint_conf = OmegaConf.load(
TRAINING_PATH / conf.checkpoint / "config.yaml"
)
conf = OmegaConf.merge(extract_benchmark_conf(checkpoint_conf, benchmark), conf)
if default:
conf = OmegaConf.merge(default, conf)
if args.tag is not None:
name = args.tag
elif args.conf and conf.checkpoint:
name = f"{args.conf}_{conf.checkpoint}"
elif args.conf:
name = args.conf
elif conf.checkpoint:
name = conf.checkpoint
if len(args.dotlist) > 0 and not args.tag:
name = name + "_" + ":".join(args.dotlist)
print("Running benchmark:", benchmark)
print("Experiment tag:", name)
print("Config:")
pprint(OmegaConf.to_container(conf))
return name, conf
def load_model(model_conf, checkpoint):
if checkpoint:
model = load_experiment(checkpoint, conf=model_conf).eval()
else:
model = get_model("two_view_pipeline")(model_conf).eval()
return model
def get_eval_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--tag", type=str, default=None)
parser.add_argument("--checkpoint", type=str, default=None)
parser.add_argument("--conf", type=str, default=None)
parser.add_argument("--overwrite", action="store_true")
parser.add_argument("--overwrite_eval", action="store_true")
parser.add_argument("--plot", action="store_true")
parser.add_argument("dotlist", nargs="*")
return parser

View File

@ -0,0 +1,191 @@
import torch
from pathlib import Path
from omegaconf import OmegaConf
from pprint import pprint
import matplotlib.pyplot as plt
import resource
from collections import defaultdict
from collections.abc import Iterable
from tqdm import tqdm
import zipfile
import numpy as np
from ..visualization.viz2d import plot_cumulative
from .io import (
parse_eval_args,
load_model,
get_eval_parser,
)
from ..utils.export_predictions import export_predictions
from ..settings import EVAL_PATH, DATA_PATH
from ..models.cache_loader import CacheLoader
from ..datasets import get_dataset
from .eval_pipeline import EvalPipeline
from .utils import eval_relative_pose_robust, eval_poses, eval_matches_epipolar
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
torch.set_grad_enabled(False)
class MegaDepth1500Pipeline(EvalPipeline):
default_conf = {
"data": {
"name": "image_pairs",
"pairs": "megadepth1500/pairs_calibrated.txt",
"root": "megadepth1500/images/",
"extra_data": "relative_pose",
"preprocessing": {
"side": "long",
},
},
"model": {
"ground_truth": {
"name": None, # remove gt matches
}
},
"eval": {
"estimator": "poselib",
"ransac_th": 1.0, # -1 runs a bunch of thresholds and selects the best
},
}
export_keys = [
"keypoints0",
"keypoints1",
"keypoint_scores0",
"keypoint_scores1",
"matches0",
"matches1",
"matching_scores0",
"matching_scores1",
]
optional_export_keys = []
def _init(self, conf):
if not (DATA_PATH / "megadepth1500").exists():
url = "https://cvg-data.inf.ethz.ch/megadepth/megadepth1500.zip"
zip_path = DATA_PATH / url.rsplit("/", 1)[-1]
torch.hub.download_url_to_file(url, zip_path)
with zipfile.ZipFile(zip_path) as zip:
zip.extractall(DATA_PATH)
zip_path.unlink()
@classmethod
def get_dataloader(self, data_conf=None):
"""Returns a data loader with samples for each eval datapoint"""
data_conf = data_conf if data_conf else self.default_conf["data"]
dataset = get_dataset(data_conf["name"])(data_conf)
return dataset.get_data_loader("test")
def get_predictions(self, experiment_dir, model=None, overwrite=False):
"""Export a prediction file for each eval datapoint"""
pred_file = experiment_dir / "predictions.h5"
if not pred_file.exists() or overwrite:
if model is None:
model = load_model(self.conf.model, self.conf.checkpoint)
export_predictions(
self.get_dataloader(self.conf.data),
model,
pred_file,
keys=self.export_keys,
optional_keys=self.optional_export_keys,
)
return pred_file
def run_eval(self, loader, pred_file):
"""Run the eval on cached predictions"""
conf = self.conf.eval
results = defaultdict(list)
test_thresholds = (
([conf.ransac_th] if conf.ransac_th > 0 else [0.5, 1.0, 1.5, 2.0, 2.5, 3.0])
if not isinstance(conf.ransac_th, Iterable)
else conf.ransac_th
)
pose_results = defaultdict(lambda: defaultdict(list))
cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
for i, data in enumerate(tqdm(loader)):
pred = cache_loader(data)
# add custom evaluations here
results_i = eval_matches_epipolar(data, pred)
for th in test_thresholds:
pose_results_i = eval_relative_pose_robust(
data,
pred,
{"estimator": conf.estimator, "ransac_th": th},
)
[pose_results[th][k].append(v) for k, v in pose_results_i.items()]
# we also store the names for later reference
results_i["names"] = data["name"][0]
if "scene" in data.keys():
results_i["scenes"] = data["scene"][0]
for k, v in results_i.items():
results[k].append(v)
# summarize results as a dict[str, float]
# you can also add your custom evaluations here
summaries = {}
for k, v in results.items():
arr = np.array(v)
if not np.issubdtype(np.array(v).dtype, np.number):
continue
summaries[f"m{k}"] = round(np.mean(arr), 3)
best_pose_results, best_th = eval_poses(
pose_results, auc_ths=[5, 10, 20], key="rel_pose_error"
)
results = {**results, **pose_results[best_th]}
summaries = {
**summaries,
**best_pose_results,
}
figures = {
"pose_recall": plot_cumulative(
{self.conf.eval.estimator: results["rel_pose_error"]},
[0, 30],
unit="°",
title="Pose ",
)
}
return summaries, figures, results
if __name__ == "__main__":
dataset_name = Path(__file__).stem
parser = get_eval_parser()
args = parser.parse_intermixed_args()
default_conf = OmegaConf.create(MegaDepth1500Pipeline.default_conf)
# mingle paths
output_dir = Path(EVAL_PATH, dataset_name)
output_dir.mkdir(exist_ok=True, parents=True)
name, conf = parse_eval_args(
dataset_name,
args,
"configs/",
default_conf,
)
experiment_dir = output_dir / name
experiment_dir.mkdir(exist_ok=True)
pipeline = MegaDepth1500Pipeline(conf)
s, f, r = pipeline.run(
experiment_dir,
overwrite=args.overwrite,
overwrite_eval=args.overwrite_eval,
)
pprint(s)
if args.plot:
for name, fig in f.items():
fig.canvas.manager.set_window_title(name)
plt.show()

254
gluefactory/eval/utils.py Normal file
View File

@ -0,0 +1,254 @@
import numpy as np
import torch
import kornia
from ..geometry.epipolar import relative_pose_error, generalized_epi_dist
from ..geometry.homography import sym_homography_error, homography_corner_error
from ..geometry.gt_generation import IGNORE_FEATURE
from ..utils.tools import AUCMetric
from ..robust_estimators import load_estimator
def check_keys_recursive(d, pattern):
if isinstance(pattern, dict):
{check_keys_recursive(d[k], v) for k, v in pattern.items()}
else:
for k in pattern:
assert k in d.keys()
def get_matches_scores(kpts0, kpts1, matches0, mscores0):
m0 = matches0 > -1
m1 = matches0[m0]
pts0 = kpts0[m0]
pts1 = kpts1[m1]
scores = mscores0[m0]
return pts0, pts1, scores
def eval_matches_epipolar(data: dict, pred: dict) -> dict:
check_keys_recursive(data, ["view0", "view1", "T_0to1"])
check_keys_recursive(
pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"]
)
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
m0, scores0 = pred["matches0"], pred["matching_scores0"]
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
results = {}
# match metrics
n_epi_err = generalized_epi_dist(
pts0[None],
pts1[None],
data["view0"]["camera"],
data["view1"]["camera"],
data["T_0to1"],
False,
essential=True,
)[0]
results["epi_prec@1e-4"] = (n_epi_err < 1e-4).float().mean()
results["epi_prec@5e-4"] = (n_epi_err < 5e-4).float().mean()
results["epi_prec@1e-3"] = (n_epi_err < 1e-3).float().mean()
results["num_matches"] = pts0.shape[0]
results["num_keypoints"] = (kp0.shape[0] + kp1.shape[0]) / 2.0
return results
def eval_matches_homography(data: dict, pred: dict, conf) -> dict:
check_keys_recursive(data, ["H_0to1"])
check_keys_recursive(
pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"]
)
H_gt = data["H_0to1"]
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
m0, scores0 = pred["matches0"], pred["matching_scores0"]
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
err = sym_homography_error(pts0, pts1, H_gt[0])
results = {}
results["prec@1px"] = (err < 1).float().mean().nan_to_num().item()
results["prec@3px"] = (err < 3).float().mean().nan_to_num().item()
results["num_matches"] = pts0.shape[0]
results["num_keypoints"] = (kp0.shape[0] + kp1.shape[0]) / 2.0
return results
def eval_relative_pose_robust(data, pred, conf):
check_keys_recursive(data, ["view0", "view1", "T_0to1"])
check_keys_recursive(
pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"]
)
T_gt = data["T_0to1"][0]
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
m0, scores0 = pred["matches0"], pred["matching_scores0"]
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
results = {}
estimator = load_estimator("relative_pose", conf["estimator"])(conf)
data_ = {
"m_kpts0": pts0,
"m_kpts1": pts1,
"camera0": data["view0"]["camera"][0],
"camera1": data["view1"]["camera"][0],
}
est = estimator(data_)
if not est["success"]:
results["rel_pose_error"] = float("inf")
results["ransac_inl"] = 0
results["ransac_inl%"] = 0
else:
# R, t, inl = ret
M = est["M_0to1"]
R, t = M.numpy()
inl = est["inliers"].numpy()
r_error, t_error = relative_pose_error(T_gt, R, t)
results["rel_pose_error"] = max(r_error, t_error)
results["ransac_inl"] = np.sum(inl)
results["ransac_inl%"] = np.mean(inl)
return results
def eval_homography_robust(data, pred, conf):
H_gt = data["H_0to1"]
estimator = load_estimator("homography", conf["estimator"])(conf)
data_ = {}
if "keypoints0" in pred:
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
m0, scores0 = pred["matches0"], pred["matching_scores0"]
pts0, pts1, _ = get_matches_scores(kp0, kp1, m0, scores0)
data_["m_kpts0"] = pts0
data_["m_kpts1"] = pts1
if "lines0" in pred:
if "orig_lines0" in pred:
lines0 = pred["orig_lines0"]
lines1 = pred["orig_lines1"]
else:
lines0 = pred["lines0"]
lines1 = pred["lines1"]
m_lines0, m_lines1, _ = get_matches_scores(
lines0, lines1, pred["line_matches0"], pred["line_matching_scores0"]
)
data_["m_lines0"] = m_lines0
data_["m_lines1"] = m_lines1
est = estimator(data_)
if est["success"]:
M = est["M_0to1"]
error_r = homography_corner_error(M, H_gt, data["view0"]["image_size"]).item()
else:
error_r = float("inf")
results = {}
results["H_error_ransac"] = error_r
if "inliers" in est:
inl = est["inliers"]
results["ransac_inl"] = inl.float().sum().item()
results["ransac_inl%"] = inl.float().sum().item() / max(len(inl), 1)
return results
def eval_homography_dlt(data, pred, *args):
H_gt = data["H_0to1"]
H_inf = torch.ones_like(H_gt) * float("inf")
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
m0, scores0 = pred["matches0"], pred["matching_scores0"]
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
results = {}
try:
Hdlt = kornia.geometry.homography.find_homography_dlt(
pts0[None], pts1[None], scores[None].to(pts0)
)[0]
except AssertionError:
Hdlt = H_inf
error_dlt = homography_corner_error(Hdlt, H_gt, data["view0"]["image_size"])
results["H_error_dlt"] = error_dlt.item()
return results
def eval_poses(pose_results, auc_ths, key, unit="°"):
pose_aucs = {}
best_th = -1
for th, results_i in pose_results.items():
pose_aucs[th] = AUCMetric(auc_ths, results_i[key]).compute()
mAAs = {k: np.mean(v) for k, v in pose_aucs.items()}
best_th = max(mAAs, key=mAAs.get)
if len(pose_aucs) > -1:
print("Tested ransac setup with following results:")
print("AUC", pose_aucs)
print("mAA", mAAs)
print("best threshold =", best_th)
summaries = {}
for i, ath in enumerate(auc_ths):
summaries[f"{key}@{ath}{unit}"] = pose_aucs[best_th][i]
summaries[f"{key}_mAA"] = mAAs[best_th]
for k, v in pose_results[best_th].items():
arr = np.array(v)
if not np.issubdtype(np.array(v).dtype, np.number):
continue
summaries[f"m{k}"] = round(np.median(arr), 3)
return summaries, best_th
def get_tp_fp_pts(pred_matches, gt_matches, pred_scores):
"""
Computes the True Positives (TP), False positives (FP), the score associated
to each match and the number of positives for a set of matches.
"""
assert pred_matches.shape == pred_scores.shape
ignore_mask = gt_matches != IGNORE_FEATURE
pred_matches, gt_matches, pred_scores = (
pred_matches[ignore_mask],
gt_matches[ignore_mask],
pred_scores[ignore_mask],
)
num_pos = np.sum(gt_matches != -1)
pred_positives = pred_matches != -1
tp = pred_matches[pred_positives] == gt_matches[pred_positives]
fp = pred_matches[pred_positives] != gt_matches[pred_positives]
scores = pred_scores[pred_positives]
return tp, fp, scores, num_pos
def AP(tp, fp):
recall = tp
precision = tp / np.maximum(tp + fp, 1e-9)
recall = np.concatenate(([0.0], recall, [1.0]))
precision = np.concatenate(([0.0], precision, [0.0]))
for i in range(precision.size - 1, 0, -1):
precision[i - 1] = max(precision[i - 1], precision[i])
i = np.where(recall[1:] != recall[:-1])[0]
ap = np.sum((recall[i + 1] - recall[i]) * precision[i + 1])
return ap
def aggregate_pr_results(results, suffix=""):
tp_list = np.concatenate(results["tp" + suffix], axis=0)
fp_list = np.concatenate(results["fp" + suffix], axis=0)
scores_list = np.concatenate(results["scores" + suffix], axis=0)
n_gt = max(results["num_pos" + suffix], 1)
out = {}
idx = np.argsort(scores_list)[::-1]
tp_vals = np.cumsum(tp_list[idx]) / n_gt
fp_vals = np.cumsum(fp_list[idx]) / n_gt
out["curve_recall" + suffix] = tp_vals
out["curve_precision" + suffix] = tp_vals / np.maximum(tp_vals + fp_vals, 1e-9)
out["AP" + suffix] = AP(tp_vals, fp_vals) * 100
return out

View File

View File

@ -0,0 +1,88 @@
import torch
import kornia
from .utils import get_image_coords
from .wrappers import Camera
def sample_fmap(pts, fmap):
h, w = fmap.shape[-2:]
grid_sample = torch.nn.functional.grid_sample
pts = (pts / pts.new_tensor([[w, h]]) * 2 - 1)[:, None]
# @TODO: This might still be a source of noise --> bilinear interpolation dangerous
interp_lin = grid_sample(fmap, pts, align_corners=False, mode="bilinear")
interp_nn = grid_sample(fmap, pts, align_corners=False, mode="nearest")
return torch.where(torch.isnan(interp_lin), interp_nn, interp_lin)[:, :, 0].permute(
0, 2, 1
)
def sample_depth(pts, depth_):
depth = torch.where(depth_ > 0, depth_, depth_.new_tensor(float("nan")))
depth = depth[:, None]
interp = sample_fmap(pts, depth).squeeze(-1)
valid = (~torch.isnan(interp)) & (interp > 0)
return interp, valid
def sample_normals_from_depth(pts, depth, K):
depth = depth[:, None]
normals = kornia.geometry.depth.depth_to_normals(depth, K)
normals = torch.where(depth > 0, normals, 0.0)
interp = sample_fmap(pts, normals)
valid = (~torch.isnan(interp)) & (interp > 0)
return interp, valid
def project(
kpi,
di,
depthj,
camera_i,
camera_j,
T_itoj,
validi,
ccth=None,
sample_depth_fun=sample_depth,
sample_depth_kwargs=None,
):
if sample_depth_kwargs is None:
sample_depth_kwargs = {}
kpi_3d_i = camera_i.image2cam(kpi)
kpi_3d_i = kpi_3d_i * di[..., None]
kpi_3d_j = T_itoj.transform(kpi_3d_i)
kpi_j, validj = camera_j.cam2image(kpi_3d_j)
# di_j = kpi_3d_j[..., -1]
validi = validi & validj
if depthj is None or ccth is None:
return kpi_j, validi & validj
else:
# circle consistency
dj, validj = sample_depth_fun(kpi_j, depthj, **sample_depth_kwargs)
kpi_j_3d_j = camera_j.image2cam(kpi_j) * dj[..., None]
kpi_j_i, validj_i = camera_i.cam2image(T_itoj.inv().transform(kpi_j_3d_j))
consistent = ((kpi - kpi_j_i) ** 2).sum(-1) < ccth
visible = validi & consistent & validj_i & validj
# visible = validi
return kpi_j, visible
def dense_warp_consistency(
depthi: torch.Tensor,
depthj: torch.Tensor,
T_itoj: torch.Tensor,
camerai: Camera,
cameraj: Camera,
**kwargs,
):
kpi = get_image_coords(depthi).flatten(-3, -2)
di = depthi.flatten(
-2,
)
validi = di > 0
kpir, validir = project(kpi, di, depthj, camerai, cameraj, T_itoj, validi, **kwargs)
return kpir.unflatten(-2, depthi.shape[-2:]), validir.unflatten(
-1, (depthj.shape[-2:])
)

View File

@ -0,0 +1,161 @@
import torch
from .utils import skew_symmetric, to_homogeneous
from .wrappers import Pose, Camera
import numpy as np
def T_to_E(T: Pose):
"""Convert batched poses (..., 4, 4) to batched essential matrices."""
return skew_symmetric(T.t) @ T.R
def T_to_F(cam0: Camera, cam1: Camera, T_0to1: Pose):
return E_to_F(cam0, cam1, T_to_E(T_0to1))
def E_to_F(cam0: Camera, cam1: Camera, E: torch.Tensor):
assert cam0._data.shape[-1] == 6, "only pinhole cameras supported"
assert cam1._data.shape[-1] == 6, "only pinhole cameras supported"
K0 = cam0.calibration_matrix()
K1 = cam1.calibration_matrix()
return K1.inverse().transpose(-1, -2) @ E @ K0.inverse()
def F_to_E(cam0: Camera, cam1: Camera, F: torch.Tensor):
assert cam0._data.shape[-1] == 6, "only pinhole cameras supported"
assert cam1._data.shape[-1] == 6, "only pinhole cameras supported"
K0 = cam0.calibration_matrix()
K1 = cam1.calibration_matrix()
return K1.transpose(-1, -2) @ F @ K0
def sym_epipolar_distance(p0, p1, E, squared=True):
"""Compute batched symmetric epipolar distances.
Args:
p0, p1: batched tensors of N 2D points of size (..., N, 2).
E: essential matrices from camera 0 to camera 1, size (..., 3, 3).
Returns:
The symmetric epipolar distance of each point-pair: (..., N).
"""
assert p0.shape[-2] == p1.shape[-2]
if p0.shape[-2] == 0:
return torch.zeros(p0.shape[:-1]).to(p0)
if p0.shape[-1] != 3:
p0 = to_homogeneous(p0)
if p1.shape[-1] != 3:
p1 = to_homogeneous(p1)
p1_E_p0 = torch.einsum("...ni,...ij,...nj->...n", p1, E, p0)
E_p0 = torch.einsum("...ij,...nj->...ni", E, p0)
Et_p1 = torch.einsum("...ij,...ni->...nj", E, p1)
d0 = (E_p0[..., 0] ** 2 + E_p0[..., 1] ** 2).clamp(min=1e-6)
d1 = (Et_p1[..., 0] ** 2 + Et_p1[..., 1] ** 2).clamp(min=1e-6)
if squared:
d = p1_E_p0**2 * (1 / d0 + 1 / d1)
else:
d = p1_E_p0.abs() * (1 / d0.sqrt() + 1 / d1.sqrt()) / 2
return d
def sym_epipolar_distance_all(p0, p1, E, eps=1e-15):
if p0.shape[-1] != 3:
p0 = to_homogeneous(p0)
if p1.shape[-1] != 3:
p1 = to_homogeneous(p1)
p1_E_p0 = torch.einsum("...mi,...ij,...nj->...nm", p1, E, p0).abs()
E_p0 = torch.einsum("...ij,...nj->...ni", E, p0)
Et_p1 = torch.einsum("...ij,...mi->...mj", E, p1)
d0 = p1_E_p0 / (E_p0[..., None, 0] ** 2 + E_p0[..., None, 1] ** 2 + eps).sqrt()
d1 = (
p1_E_p0
/ (Et_p1[..., None, :, 0] ** 2 + Et_p1[..., None, :, 1] ** 2 + eps).sqrt()
)
return (d0 + d1) / 2
def generalized_epi_dist(
kpts0, kpts1, cam0: Camera, cam1: Camera, T_0to1: Pose, all=True, essential=True
):
if essential:
E = T_to_E(T_0to1)
p0 = cam0.image2cam(kpts0)
p1 = cam1.image2cam(kpts1)
if all:
return sym_epipolar_distance_all(p0, p1, E, agg="max")
else:
return sym_epipolar_distance(p0, p1, E, squared=False)
else:
assert cam0._data.shape[-1] == 6
assert cam1._data.shape[-1] == 6
K0, K1 = cam0.calibration_matrix(), cam1.calibration_matrix()
F = K1.inverse().transpose(-1, -2) @ T_to_E(T_0to1) @ K0.inverse()
if all:
return sym_epipolar_distance_all(kpts0, kpts1, F)
else:
return sym_epipolar_distance(kpts0, kpts1, F, squared=False)
def decompose_essential_matrix(E):
# decompose matrix by its singular values
U, _, V = torch.svd(E)
Vt = V.transpose(-2, -1)
mask = torch.ones_like(E)
mask[..., -1:] *= -1.0 # fill last column with negative values
maskt = mask.transpose(-2, -1)
# avoid singularities
U = torch.where((torch.det(U) < 0.0)[..., None, None], U * mask, U)
Vt = torch.where((torch.det(Vt) < 0.0)[..., None, None], Vt * maskt, Vt)
W = skew_symmetric(E.new_tensor([[0, 0, 1]]))
W[..., 2, 2] += 1.0
# reconstruct rotations and retrieve translation vector
U_W_Vt = U @ W @ Vt
U_Wt_Vt = U @ W.transpose(-2, -1) @ Vt
# return values
R1 = U_W_Vt
R2 = U_Wt_Vt
T = U[..., -1]
return R1, R2, T
# pose errors
# TODO: port to torch and batch
def angle_error_mat(R1, R2):
cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
return np.rad2deg(np.abs(np.arccos(cos)))
def angle_error_vec(v1, v2):
n = np.linalg.norm(v1) * np.linalg.norm(v2)
return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
def compute_pose_error(T_0to1, R, t):
R_gt = T_0to1[:3, :3]
t_gt = T_0to1[:3, 3]
error_t = angle_error_vec(t, t_gt)
error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
error_R = angle_error_mat(R, R_gt)
return error_t, error_R
def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
# angle error between 2 vectors
R_gt, t_gt = T_0to1.numpy()
n = np.linalg.norm(t) * np.linalg.norm(t_gt)
t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0)))
t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity
if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging
t_err = 0
# angle error between 2 rotation matrices
cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
cos = np.clip(cos, -1.0, 1.0) # handle numercial errors
R_err = np.rad2deg(np.abs(np.arccos(cos)))
return t_err, R_err

View File

@ -0,0 +1,558 @@
import numpy as np
import torch
from scipy.optimize import linear_sum_assignment
from .homography import warp_points_torch
from .epipolar import T_to_E, sym_epipolar_distance_all
from .depth import sample_depth, project
IGNORE_FEATURE = -2
UNMATCHED_FEATURE = -1
@torch.no_grad()
def gt_matches_from_pose_depth(
kp0, kp1, data, pos_th=3, neg_th=5, epi_th=None, cc_th=None, **kw
):
if kp0.shape[1] == 0 or kp1.shape[1] == 0:
b_size, n_kp0 = kp0.shape[:2]
n_kp1 = kp1.shape[1]
assignment = torch.zeros(
b_size, n_kp0, n_kp1, dtype=torch.bool, device=kp0.device
)
m0 = -torch.ones_like(kp0[:, :, 0]).long()
m1 = -torch.ones_like(kp1[:, :, 0]).long()
return assignment, m0, m1
camera0, camera1 = data["view0"]["camera"], data["view1"]["camera"]
T_0to1, T_1to0 = data["T_0to1"], data["T_1to0"]
depth0 = data["view0"].get("depth")
depth1 = data["view1"].get("depth")
if "depth_keypoints0" in kw and "depth_keypoints1" in kw:
d0, valid0 = kw["depth_keypoints0"], kw["valid_depth_keypoints0"]
d1, valid1 = kw["depth_keypoints1"], kw["valid_depth_keypoints1"]
else:
assert depth0 is not None
assert depth1 is not None
d0, valid0 = sample_depth(kp0, depth0)
d1, valid1 = sample_depth(kp1, depth1)
kp0_1, visible0 = project(
kp0, d0, depth1, camera0, camera1, T_0to1, valid0, ccth=cc_th
)
kp1_0, visible1 = project(
kp1, d1, depth0, camera1, camera0, T_1to0, valid1, ccth=cc_th
)
mask_visible = visible0.unsqueeze(-1) & visible1.unsqueeze(-2)
# build a distance matrix of size [... x M x N]
dist0 = torch.sum((kp0_1.unsqueeze(-2) - kp1.unsqueeze(-3)) ** 2, -1)
dist1 = torch.sum((kp0.unsqueeze(-2) - kp1_0.unsqueeze(-3)) ** 2, -1)
dist = torch.max(dist0, dist1)
inf = dist.new_tensor(float("inf"))
dist = torch.where(mask_visible, dist, inf)
min0 = dist.min(-1).indices
min1 = dist.min(-2).indices
ismin0 = torch.zeros(dist.shape, dtype=torch.bool, device=dist.device)
ismin1 = ismin0.clone()
ismin0.scatter_(-1, min0.unsqueeze(-1), value=1)
ismin1.scatter_(-2, min1.unsqueeze(-2), value=1)
positive = ismin0 & ismin1 & (dist < pos_th**2)
negative0 = (dist0.min(-1).values > neg_th**2) & valid0
negative1 = (dist1.min(-2).values > neg_th**2) & valid1
# pack the indices of positive matches
# if -1: unmatched point
# if -2: ignore point
unmatched = min0.new_tensor(UNMATCHED_FEATURE)
ignore = min0.new_tensor(IGNORE_FEATURE)
m0 = torch.where(positive.any(-1), min0, ignore)
m1 = torch.where(positive.any(-2), min1, ignore)
m0 = torch.where(negative0, unmatched, m0)
m1 = torch.where(negative1, unmatched, m1)
F = (
camera1.calibration_matrix().inverse().transpose(-1, -2)
@ T_to_E(T_0to1)
@ camera0.calibration_matrix().inverse()
)
epi_dist = sym_epipolar_distance_all(kp0, kp1, F)
# Add some more unmatched points using epipolar geometry
if epi_th is not None:
mask_ignore = (m0.unsqueeze(-1) == ignore) & (m1.unsqueeze(-2) == ignore)
epi_dist = torch.where(mask_ignore, epi_dist, inf)
exclude0 = epi_dist.min(-1).values > neg_th
exclude1 = epi_dist.min(-2).values > neg_th
m0 = torch.where((~valid0) & exclude0, ignore.new_tensor(-1), m0)
m1 = torch.where((~valid1) & exclude1, ignore.new_tensor(-1), m1)
return {
"assignment": positive,
"reward": (dist < pos_th**2).float() - (epi_dist > neg_th).float(),
"matches0": m0,
"matches1": m1,
"matching_scores0": (m0 > -1).float(),
"matching_scores1": (m1 > -1).float(),
"depth_keypoints0": d0,
"depth_keypoints1": d1,
"proj_0to1": kp0_1,
"proj_1to0": kp1_0,
"visible0": visible0,
"visible1": visible1,
}
@torch.no_grad()
def gt_matches_from_homography(kp0, kp1, H, pos_th=3, neg_th=6, **kw):
if kp0.shape[1] == 0 or kp1.shape[1] == 0:
b_size, n_kp0 = kp0.shape[:2]
n_kp1 = kp1.shape[1]
assignment = torch.zeros(
b_size, n_kp0, n_kp1, dtype=torch.bool, device=kp0.device
)
m0 = -torch.ones_like(kp0[:, :, 0]).long()
m1 = -torch.ones_like(kp1[:, :, 0]).long()
return assignment, m0, m1
kp0_1 = warp_points_torch(kp0, H, inverse=False)
kp1_0 = warp_points_torch(kp1, H, inverse=True)
# build a distance matrix of size [... x M x N]
dist0 = torch.sum((kp0_1.unsqueeze(-2) - kp1.unsqueeze(-3)) ** 2, -1)
dist1 = torch.sum((kp0.unsqueeze(-2) - kp1_0.unsqueeze(-3)) ** 2, -1)
dist = torch.max(dist0, dist1)
reward = (dist < pos_th**2).float() - (dist > neg_th**2).float()
min0 = dist.min(-1).indices
min1 = dist.min(-2).indices
ismin0 = torch.zeros(dist.shape, dtype=torch.bool, device=dist.device)
ismin1 = ismin0.clone()
ismin0.scatter_(-1, min0.unsqueeze(-1), value=1)
ismin1.scatter_(-2, min1.unsqueeze(-2), value=1)
positive = ismin0 & ismin1 & (dist < pos_th**2)
negative0 = dist0.min(-1).values > neg_th**2
negative1 = dist1.min(-2).values > neg_th**2
# pack the indices of positive matches
# if -1: unmatched point
# if -2: ignore point
unmatched = min0.new_tensor(UNMATCHED_FEATURE)
ignore = min0.new_tensor(IGNORE_FEATURE)
m0 = torch.where(positive.any(-1), min0, ignore)
m1 = torch.where(positive.any(-2), min1, ignore)
m0 = torch.where(negative0, unmatched, m0)
m1 = torch.where(negative1, unmatched, m1)
return {
"assignment": positive,
"reward": reward,
"matches0": m0,
"matches1": m1,
"matching_scores0": (m0 > -1).float(),
"matching_scores1": (m1 > -1).float(),
"proj_0to1": kp0_1,
"proj_1to0": kp1_0,
}
def sample_pts(lines, npts):
dir_vec = (lines[..., 2:4] - lines[..., :2]) / (npts - 1)
pts = lines[..., :2, np.newaxis] + dir_vec[..., np.newaxis].expand(
dir_vec.shape + (npts,)
) * torch.arange(npts).to(lines)
pts = torch.transpose(pts, -1, -2)
return pts
def torch_perp_dist(segs2d, points_2d):
# Check batch size and segments format
assert segs2d.shape[0] == points_2d.shape[0]
assert segs2d.shape[-1] == 4
dir = segs2d[..., 2:] - segs2d[..., :2]
sizes = torch.norm(dir, dim=-1).half()
norm_dir = dir / torch.unsqueeze(sizes, dim=-1)
# middle_ptn = 0.5 * (segs2d[..., 2:] + segs2d[..., :2])
# centered [batch, nsegs0, nsegs1, n_sampled_pts, 2]
centered = points_2d[:, None] - segs2d[..., None, None, 2:]
R = torch.cat(
[
norm_dir[..., 0, None],
norm_dir[..., 1, None],
-norm_dir[..., 1, None],
norm_dir[..., 0, None],
],
dim=2,
).reshape((len(segs2d), -1, 2, 2))
# Try to reduce the memory consumption by using float16 type
if centered.is_cuda:
centered, R = centered.half(), R.half()
# R: [batch, nsegs0, 2, 2] , centered: [batch, nsegs1, n_sampled_pts, 2]
# -> [batch, nsegs0, nsegs1, n_sampled_pts, 2]
rotated = torch.einsum("bdji,bdepi->bdepj", R, centered)
overlaping = (rotated[..., 0] <= 0) & (
torch.abs(rotated[..., 0]) <= sizes[..., None, None]
)
return torch.abs(rotated[..., 1]), overlaping
@torch.no_grad()
def gt_line_matches_from_pose_depth(
pred_lines0,
pred_lines1,
valid_lines0,
valid_lines1,
data,
npts=50,
dist_th=5,
overlap_th=0.2,
min_visibility_th=0.5,
):
"""Compute ground truth line matches and label the remaining the lines as:
- UNMATCHED: if reprojection is outside the image
or far away from any other line.
- IGNORE: if a line has not enough valid depth pixels along itself
or it is labeled as invalid."""
lines0 = pred_lines0.clone()
lines1 = pred_lines1.clone()
if pred_lines0.shape[1] == 0 or pred_lines1.shape[1] == 0:
bsize, nlines0, nlines1 = (
pred_lines0.shape[0],
pred_lines0.shape[1],
pred_lines1.shape[1],
)
positive = torch.zeros(
(bsize, nlines0, nlines1), dtype=torch.bool, device=pred_lines0.device
)
m0 = torch.full((bsize, nlines0), -1, device=pred_lines0.device)
m1 = torch.full((bsize, nlines1), -1, device=pred_lines0.device)
return positive, m0, m1
if lines0.shape[-2:] == (2, 2):
lines0 = torch.flatten(lines0, -2)
elif lines0.dim() == 4:
lines0 = torch.cat([lines0[:, :, 0], lines0[:, :, -1]], dim=2)
if lines1.shape[-2:] == (2, 2):
lines1 = torch.flatten(lines1, -2)
elif lines1.dim() == 4:
lines1 = torch.cat([lines1[:, :, 0], lines1[:, :, -1]], dim=2)
b_size, n_lines0, _ = lines0.shape
b_size, n_lines1, _ = lines1.shape
h0, w0 = data["view0"]["depth"][0].shape
h1, w1 = data["view1"]["depth"][0].shape
lines0 = torch.min(
torch.max(lines0, torch.zeros_like(lines0)),
lines0.new_tensor([w0 - 1, h0 - 1, w0 - 1, h0 - 1], dtype=torch.float),
)
lines1 = torch.min(
torch.max(lines1, torch.zeros_like(lines1)),
lines1.new_tensor([w1 - 1, h1 - 1, w1 - 1, h1 - 1], dtype=torch.float),
)
# Sample points along each line
pts0 = sample_pts(lines0, npts).reshape(b_size, n_lines0 * npts, 2)
pts1 = sample_pts(lines1, npts).reshape(b_size, n_lines1 * npts, 2)
# Sample depth and valid points
d0, valid0_pts0 = sample_depth(pts0, data["view0"]["depth"])
d1, valid1_pts1 = sample_depth(pts1, data["view1"]["depth"])
# Reproject to the other view
pts0_1, visible0 = project(
pts0,
d0,
data["view1"]["depth"],
data["view0"]["camera"],
data["view1"]["camera"],
data["T_0to1"],
valid0_pts0,
)
pts1_0, visible1 = project(
pts1,
d1,
data["view0"]["depth"],
data["view1"]["camera"],
data["view0"]["camera"],
data["T_1to0"],
valid1_pts1,
)
h0, w0 = data["view0"]["image"].shape[-2:]
h1, w1 = data["view1"]["image"].shape[-2:]
# If a line has less than min_visibility_th inside the image is considered OUTSIDE
pts_out_of0 = (pts1_0 < 0).any(-1) | (
pts1_0 >= torch.tensor([w0, h0]).to(pts1_0)
).any(-1)
pts_out_of0 = pts_out_of0.reshape(b_size, n_lines1, npts).float()
out_of0 = pts_out_of0.mean(dim=-1) >= (1 - min_visibility_th)
pts_out_of1 = (pts0_1 < 0).any(-1) | (
pts0_1 >= torch.tensor([w1, h1]).to(pts0_1)
).any(-1)
pts_out_of1 = pts_out_of1.reshape(b_size, n_lines0, npts).float()
out_of1 = pts_out_of1.mean(dim=-1) >= (1 - min_visibility_th)
# visible0 is [bs, nl0 * npts]
pts0_1 = pts0_1.reshape(b_size, n_lines0, npts, 2)
pts1_0 = pts1_0.reshape(b_size, n_lines1, npts, 2)
perp_dists0, overlaping0 = torch_perp_dist(lines0, pts1_0)
close_points0 = (perp_dists0 < dist_th) & overlaping0 # [bs, nl0, nl1, npts]
del perp_dists0, overlaping0
close_points0 = close_points0 * visible1.reshape(b_size, 1, n_lines1, npts)
perp_dists1, overlaping1 = torch_perp_dist(lines1, pts0_1)
close_points1 = (perp_dists1 < dist_th) & overlaping1 # [bs, nl1, nl0, npts]
del perp_dists1, overlaping1
close_points1 = close_points1 * visible0.reshape(b_size, 1, n_lines0, npts)
torch.cuda.empty_cache()
# For each segment detected in 0, how many sampled points from
# reprojected segments 1 are close
num_close_pts0 = close_points0.sum(dim=-1) # [bs, nl0, nl1]
# num_close_pts0_t = num_close_pts0.transpose(-1, -2)
# For each segment detected in 1, how many sampled points from
# reprojected segments 0 are close
num_close_pts1 = close_points1.sum(dim=-1)
num_close_pts1_t = num_close_pts1.transpose(-1, -2) # [bs, nl1, nl0]
num_close_pts = num_close_pts0 * num_close_pts1_t
mask_close = (
num_close_pts1_t
> visible0.reshape(b_size, n_lines0, npts).float().sum(-1)[:, :, None]
* overlap_th
) & (
num_close_pts0
> visible1.reshape(b_size, n_lines1, npts).float().sum(-1)[:, None] * overlap_th
)
# mask_close = (num_close_pts1_t > npts * overlap_th) & (
# num_close_pts0 > npts * overlap_th)
# Define the unmatched lines
unmatched0 = torch.all(~mask_close, dim=2) | out_of1
unmatched1 = torch.all(~mask_close, dim=1) | out_of0
# Define the lines to ignore
ignore0 = (
valid0_pts0.reshape(b_size, n_lines0, npts).float().mean(dim=-1)
< min_visibility_th
) | ~valid_lines0
ignore1 = (
valid1_pts1.reshape(b_size, n_lines1, npts).float().mean(dim=-1)
< min_visibility_th
) | ~valid_lines1
cost = -num_close_pts.clone()
# High score for unmatched and non-valid lines
cost[unmatched0] = 1e6
cost[ignore0] = 1e6
# TODO: Is it reasonable to forbid the matching with a segment because it
# has not GT depth?
cost = cost.transpose(1, 2)
cost[unmatched1] = 1e6
cost[ignore1] = 1e6
cost = cost.transpose(1, 2)
# For each row, returns the col of max number of points
assignation = np.array(
[linear_sum_assignment(C) for C in cost.detach().cpu().numpy()]
)
assignation = torch.tensor(assignation).to(num_close_pts)
# Set ignore and unmatched labels
unmatched = assignation.new_tensor(UNMATCHED_FEATURE)
ignore = assignation.new_tensor(IGNORE_FEATURE)
positive = num_close_pts.new_zeros(num_close_pts.shape, dtype=torch.bool)
all_in_batch = (
torch.arange(b_size)[:, None].repeat(1, assignation.shape[-1]).flatten()
)
positive[
all_in_batch, assignation[:, 0].flatten(), assignation[:, 1].flatten()
] = True
m0 = assignation.new_full((b_size, n_lines0), unmatched, dtype=torch.long)
m0.scatter_(-1, assignation[:, 0], assignation[:, 1])
m1 = assignation.new_full((b_size, n_lines1), unmatched, dtype=torch.long)
m1.scatter_(-1, assignation[:, 1], assignation[:, 0])
positive = positive & mask_close
# Remove values to be ignored or unmatched
positive[unmatched0] = False
positive[ignore0] = False
positive = positive.transpose(1, 2)
positive[unmatched1] = False
positive[ignore1] = False
positive = positive.transpose(1, 2)
m0[~positive.any(-1)] = unmatched
m0[unmatched0] = unmatched
m0[ignore0] = ignore
m1[~positive.any(-2)] = unmatched
m1[unmatched1] = unmatched
m1[ignore1] = ignore
if num_close_pts.numel() == 0:
no_matches = torch.zeros(positive.shape[0], 0).to(positive)
return positive, no_matches, no_matches
return positive, m0, m1
@torch.no_grad()
def gt_line_matches_from_homography(
pred_lines0,
pred_lines1,
valid_lines0,
valid_lines1,
shape0,
shape1,
H,
npts=50,
dist_th=5,
overlap_th=0.2,
min_visibility_th=0.2,
):
"""Compute ground truth line matches and label the remaining the lines as:
- UNMATCHED: if reprojection is outside the image or far away from any other line.
- IGNORE: if a line is labeled as invalid."""
h0, w0 = shape0[-2:]
h1, w1 = shape1[-2:]
lines0 = pred_lines0.clone()
lines1 = pred_lines1.clone()
if lines0.shape[-2:] == (2, 2):
lines0 = torch.flatten(lines0, -2)
elif lines0.dim() == 4:
lines0 = torch.cat([lines0[:, :, 0], lines0[:, :, -1]], dim=2)
if lines1.shape[-2:] == (2, 2):
lines1 = torch.flatten(lines1, -2)
elif lines1.dim() == 4:
lines1 = torch.cat([lines1[:, :, 0], lines1[:, :, -1]], dim=2)
b_size, n_lines0, _ = lines0.shape
b_size, n_lines1, _ = lines1.shape
lines0 = torch.min(
torch.max(lines0, torch.zeros_like(lines0)),
lines0.new_tensor([w0 - 1, h0 - 1, w0 - 1, h0 - 1], dtype=torch.float),
)
lines1 = torch.min(
torch.max(lines1, torch.zeros_like(lines1)),
lines1.new_tensor([w1 - 1, h1 - 1, w1 - 1, h1 - 1], dtype=torch.float),
)
# Sample points along each line
pts0 = sample_pts(lines0, npts).reshape(b_size, n_lines0 * npts, 2)
pts1 = sample_pts(lines1, npts).reshape(b_size, n_lines1 * npts, 2)
# Project the points to the other image
pts0_1 = warp_points_torch(pts0, H, inverse=False)
pts1_0 = warp_points_torch(pts1, H, inverse=True)
pts0_1 = pts0_1.reshape(b_size, n_lines0, npts, 2)
pts1_0 = pts1_0.reshape(b_size, n_lines1, npts, 2)
# If a line has less than min_visibility_th inside the image is considered OUTSIDE
pts_out_of0 = (pts1_0 < 0).any(-1) | (
pts1_0 >= torch.tensor([w0, h0]).to(pts1_0)
).any(-1)
pts_out_of0 = pts_out_of0.reshape(b_size, n_lines1, npts).float()
out_of0 = pts_out_of0.mean(dim=-1) >= (1 - min_visibility_th)
pts_out_of1 = (pts0_1 < 0).any(-1) | (
pts0_1 >= torch.tensor([w1, h1]).to(pts0_1)
).any(-1)
pts_out_of1 = pts_out_of1.reshape(b_size, n_lines0, npts).float()
out_of1 = pts_out_of1.mean(dim=-1) >= (1 - min_visibility_th)
perp_dists0, overlaping0 = torch_perp_dist(lines0, pts1_0)
close_points0 = (perp_dists0 < dist_th) & overlaping0 # [bs, nl0, nl1, npts]
del perp_dists0, overlaping0
perp_dists1, overlaping1 = torch_perp_dist(lines1, pts0_1)
close_points1 = (perp_dists1 < dist_th) & overlaping1 # [bs, nl1, nl0, npts]
del perp_dists1, overlaping1
torch.cuda.empty_cache()
# For each segment detected in 0,
# how many sampled points from reprojected segments 1 are close
num_close_pts0 = close_points0.sum(dim=-1) # [bs, nl0, nl1]
# num_close_pts0_t = num_close_pts0.transpose(-1, -2)
# For each segment detected in 1,
# how many sampled points from reprojected segments 0 are close
num_close_pts1 = close_points1.sum(dim=-1)
num_close_pts1_t = num_close_pts1.transpose(-1, -2) # [bs, nl1, nl0]
num_close_pts = num_close_pts0 * num_close_pts1_t
mask_close = (
(num_close_pts1_t > npts * overlap_th)
& (num_close_pts0 > npts * overlap_th)
& ~out_of0.unsqueeze(1)
& ~out_of1.unsqueeze(-1)
)
# Define the unmatched lines
unmatched0 = torch.all(~mask_close, dim=2) | out_of1
unmatched1 = torch.all(~mask_close, dim=1) | out_of0
# Define the lines to ignore
ignore0 = ~valid_lines0
ignore1 = ~valid_lines1
cost = -num_close_pts.clone()
# High score for unmatched and non-valid lines
cost[unmatched0] = 1e6
cost[ignore0] = 1e6
cost = cost.transpose(1, 2)
cost[unmatched1] = 1e6
cost[ignore1] = 1e6
cost = cost.transpose(1, 2)
# For each row, returns the col of max number of points
assignation = np.array(
[linear_sum_assignment(C) for C in cost.detach().cpu().numpy()]
)
assignation = torch.tensor(assignation).to(num_close_pts)
# Set unmatched labels
unmatched = assignation.new_tensor(UNMATCHED_FEATURE)
ignore = assignation.new_tensor(IGNORE_FEATURE)
positive = num_close_pts.new_zeros(num_close_pts.shape, dtype=torch.bool)
# TODO Do with a single and beautiful call
# for b in range(b_size):
# positive[b][assignation[b, 0], assignation[b, 1]] = True
positive[
torch.arange(b_size)[:, None].repeat(1, assignation.shape[-1]).flatten(),
assignation[:, 0].flatten(),
assignation[:, 1].flatten(),
] = True
m0 = assignation.new_full((b_size, n_lines0), unmatched, dtype=torch.long)
m0.scatter_(-1, assignation[:, 0], assignation[:, 1])
m1 = assignation.new_full((b_size, n_lines1), unmatched, dtype=torch.long)
m1.scatter_(-1, assignation[:, 1], assignation[:, 0])
positive = positive & mask_close
# Remove values to be ignored or unmatched
positive[unmatched0] = False
positive[ignore0] = False
positive = positive.transpose(1, 2)
positive[unmatched1] = False
positive[ignore1] = False
positive = positive.transpose(1, 2)
m0[~positive.any(-1)] = unmatched
m0[unmatched0] = unmatched
m0[ignore0] = ignore
m1[~positive.any(-2)] = unmatched
m1[unmatched1] = unmatched
m1[ignore1] = ignore
if num_close_pts.numel() == 0:
no_matches = torch.zeros(positive.shape[0], 0).to(positive)
return positive, no_matches, no_matches
return positive, m0, m1

View File

@ -0,0 +1,340 @@
from typing import Tuple
import math
import numpy as np
import torch
from .utils import to_homogeneous, from_homogeneous
def flat2mat(H):
return np.reshape(np.concatenate([H, np.ones_like(H[:, :1])], axis=1), [3, 3])
# Homography creation
def create_center_patch(shape, patch_shape=None):
if patch_shape is None:
patch_shape = shape
width, height = shape
pwidth, pheight = patch_shape
left = int((width - pwidth) / 2)
bottom = int((height - pheight) / 2)
right = int((width + pwidth) / 2)
top = int((height + pheight) / 2)
return np.array([[left, bottom], [left, top], [right, top], [right, bottom]])
def check_convex(patch, min_convexity=0.05):
"""Checks if given polygon vertices [N,2] form a convex shape"""
for i in range(patch.shape[0]):
x1, y1 = patch[(i - 1) % patch.shape[0]]
x2, y2 = patch[i]
x3, y3 = patch[(i + 1) % patch.shape[0]]
if (x2 - x1) * (y3 - y2) - (x3 - x2) * (y2 - y1) > -min_convexity:
return False
return True
def sample_homography_corners(
shape,
patch_shape,
difficulty=1.0,
translation=0.4,
n_angles=10,
max_angle=90,
min_convexity=0.05,
rng=np.random,
):
max_angle = max_angle / 180.0 * math.pi
width, height = shape
pwidth, pheight = width * (1 - difficulty), height * (1 - difficulty)
min_pts1 = create_center_patch(shape, (pwidth, pheight))
full = create_center_patch(shape)
pts2 = create_center_patch(patch_shape)
scale = min_pts1 - full
found_valid = False
cnt = -1
while not found_valid:
offsets = rng.uniform(0.0, 1.0, size=(4, 2)) * scale
pts1 = full + offsets
found_valid = check_convex(pts1 / np.array(shape), min_convexity)
cnt += 1
# re-center
pts1 = pts1 - np.mean(pts1, axis=0, keepdims=True)
pts1 = pts1 + np.mean(min_pts1, axis=0, keepdims=True)
# Rotation
if n_angles > 0 and difficulty > 0:
angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles)
rng.shuffle(angles)
rng.shuffle(angles)
angles = np.concatenate([[0.0], angles], axis=0)
center = np.mean(pts1, axis=0, keepdims=True)
rot_mat = np.reshape(
np.stack(
[np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)],
axis=1,
),
[-1, 2, 2],
)
rotated = (
np.matmul(
np.tile(np.expand_dims(pts1 - center, axis=0), [n_angles + 1, 1, 1]),
rot_mat,
)
+ center
)
for idx in range(1, n_angles):
warped_points = rotated[idx] / np.array(shape)
if np.all((warped_points >= 0.0) & (warped_points < 1.0)):
pts1 = rotated[idx]
break
# Translation
if translation > 0:
min_trans = -np.min(pts1, axis=0)
max_trans = shape - np.max(pts1, axis=0)
trans = rng.uniform(min_trans, max_trans)[None]
pts1 += trans * translation * difficulty
H = compute_homography(pts1, pts2, [1.0, 1.0])
warped = warp_points(full, H, inverse=False)
return H, full, warped, patch_shape
def compute_homography(pts1_, pts2_, shape):
"""Compute the homography matrix from 4 point correspondences"""
# Rescale to actual size
shape = np.array(shape[::-1], dtype=np.float32) # different convention [y, x]
pts1 = pts1_ * np.expand_dims(shape, axis=0)
pts2 = pts2_ * np.expand_dims(shape, axis=0)
def ax(p, q):
return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
def ay(p, q):
return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0)
p_mat = np.transpose(
np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0)
)
homography = np.transpose(np.linalg.solve(a_mat, p_mat))
return flat2mat(homography)
# Point warping utils
def warp_points(points, homography, inverse=True):
"""
Warp a list of points with the INVERSE of the given homography.
The inverse is used to be coherent with tf.contrib.image.transform
Arguments:
points: list of N points, shape (N, 2).
homography: batched or not (shapes (B, 3, 3) and (3, 3) respectively).
Returns: a Tensor of shape (N, 2) or (B, N, 2) (depending on whether the homography
is batched) containing the new coordinates of the warped points.
"""
H = homography[None] if len(homography.shape) == 2 else homography
# Get the points to the homogeneous format
num_points = points.shape[0]
# points = points.astype(np.float32)[:, ::-1]
points = np.concatenate([points, np.ones([num_points, 1], dtype=np.float32)], -1)
H_inv = np.transpose(np.linalg.inv(H) if inverse else H)
warped_points = np.tensordot(points, H_inv, axes=[[1], [0]])
warped_points = np.transpose(warped_points, [2, 0, 1])
warped_points[np.abs(warped_points[:, :, 2]) < 1e-8, 2] = 1e-8
warped_points = warped_points[:, :, :2] / warped_points[:, :, 2:]
return warped_points[0] if len(homography.shape) == 2 else warped_points
def warp_points_torch(points, H, inverse=True):
"""
Warp a list of points with the INVERSE of the given homography.
The inverse is used to be coherent with tf.contrib.image.transform
Arguments:
points: batched list of N points, shape (B, N, 2).
homography: batched or not (shapes (B, 3, 3) and (3, 3) respectively).
Returns: a Tensor of shape (B, N, 2) containing the new coordinates of the warps.
"""
# Get the points to the homogeneous format
points = to_homogeneous(points)
# Apply the homography
H_mat = (torch.inverse(H) if inverse else H).transpose(-2, -1)
warped_points = torch.einsum("...nj,...ji->...ni", points, H_mat)
warped_points = from_homogeneous(warped_points, eps=1e-5)
return warped_points
# Line warping utils
def seg_equation(segs):
# calculate list of start, end and midpoints points from both lists
start_points, end_points = to_homogeneous(segs[..., 0, :]), to_homogeneous(
segs[..., 1, :]
)
# Compute the line equations as ax + by + c = 0 , where x^2 + y^2 = 1
lines = torch.cross(start_points, end_points, dim=-1)
lines_norm = torch.sqrt(lines[..., 0] ** 2 + lines[..., 1] ** 2)[..., None]
assert torch.all(
lines_norm > 0
), "Error: trying to compute the equation of a line with a single point"
lines = lines / lines_norm
return lines
def is_inside_img(pts: torch.Tensor, img_shape: Tuple[int, int]):
h, w = img_shape
return (
(pts >= 0).all(dim=-1)
& (pts[..., 0] < w)
& (pts[..., 1] < h)
& (~torch.isinf(pts).any(dim=-1))
)
def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch.Tensor:
"""
Shrink an array of segments to fit inside the image.
:param segs: The tensor of segments with shape (N, 2, 2)
:param img_shape: The image shape in format (H, W)
"""
EPS = 1e-4
device = segs.device
w, h = img_shape[1], img_shape[0]
# Project the segments to the reference image
segs = segs.clone()
eqs = seg_equation(segs)
x0, y0 = torch.tensor([1.0, 0, 0.0], device=device), torch.tensor(
[0.0, 1, 0], device=device
)
x0 = x0.repeat(eqs.shape[:-1] + (1,))
y0 = y0.repeat(eqs.shape[:-1] + (1,))
pt_x0s = torch.cross(eqs, x0, dim=-1)
pt_x0s = pt_x0s[..., :-1] / pt_x0s[..., None, -1]
pt_x0s_valid = is_inside_img(pt_x0s, img_shape)
pt_y0s = torch.cross(eqs, y0, dim=-1)
pt_y0s = pt_y0s[..., :-1] / pt_y0s[..., None, -1]
pt_y0s_valid = is_inside_img(pt_y0s, img_shape)
xW = torch.tensor([1.0, 0, EPS - w], device=device)
yH = torch.tensor([0.0, 1, EPS - h], device=device)
xW = xW.repeat(eqs.shape[:-1] + (1,))
yH = yH.repeat(eqs.shape[:-1] + (1,))
pt_xWs = torch.cross(eqs, xW, dim=-1)
pt_xWs = pt_xWs[..., :-1] / pt_xWs[..., None, -1]
pt_xWs_valid = is_inside_img(pt_xWs, img_shape)
pt_yHs = torch.cross(eqs, yH, dim=-1)
pt_yHs = pt_yHs[..., :-1] / pt_yHs[..., None, -1]
pt_yHs_valid = is_inside_img(pt_yHs, img_shape)
# If the X coordinate of the first endpoint is out
mask = (segs[..., 0, 0] < 0) & pt_x0s_valid
segs[mask, 0, :] = pt_x0s[mask]
mask = (segs[..., 0, 0] > (w - 1)) & pt_xWs_valid
segs[mask, 0, :] = pt_xWs[mask]
# If the X coordinate of the second endpoint is out
mask = (segs[..., 1, 0] < 0) & pt_x0s_valid
segs[mask, 1, :] = pt_x0s[mask]
mask = (segs[:, 1, 0] > (w - 1)) & pt_xWs_valid
segs[mask, 1, :] = pt_xWs[mask]
# If the Y coordinate of the first endpoint is out
mask = (segs[..., 0, 1] < 0) & pt_y0s_valid
segs[mask, 0, :] = pt_y0s[mask]
mask = (segs[..., 0, 1] > (h - 1)) & pt_yHs_valid
segs[mask, 0, :] = pt_yHs[mask]
# If the Y coordinate of the second endpoint is out
mask = (segs[..., 1, 1] < 0) & pt_y0s_valid
segs[mask, 1, :] = pt_y0s[mask]
mask = (segs[..., 1, 1] > (h - 1)) & pt_yHs_valid
segs[mask, 1, :] = pt_yHs[mask]
assert (
torch.all(segs >= 0)
and torch.all(segs[..., 0] < w)
and torch.all(segs[..., 1] < h)
)
return segs
def warp_lines_torch(
lines, H, inverse=True, dst_shape: Tuple[int, int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param lines: A tensor of shape (B, N, 2, 2)
where B is the batch size, N the number of lines.
:param H: The homography used to convert the lines.
batched or not (shapes (B, 3, 3) and (3, 3) respectively).
:param inverse: Whether to apply H or the inverse of H
:param dst_shape:If provided, lines are trimmed to be inside the image
"""
device = lines.device
batch_size = len(lines)
lines = warp_points_torch(lines.reshape(batch_size, -1, 2), H, inverse).reshape(
lines.shape
)
if dst_shape is None:
return lines, torch.ones(lines.shape[:-2], dtype=torch.bool, device=device)
out_img = torch.any(
(lines < 0) | (lines >= torch.tensor(dst_shape[::-1], device=device)), -1
)
valid = ~out_img.all(-1)
any_out_of_img = out_img.any(-1)
lines_to_trim = valid & any_out_of_img
for b in range(batch_size):
lines_to_trim_mask_b = lines_to_trim[b]
lines_to_trim_b = lines[b][lines_to_trim_mask_b]
corrected_lines = shrink_segs_to_img(lines_to_trim_b, dst_shape)
lines[b][lines_to_trim_mask_b] = corrected_lines
return lines, valid
# Homography evaluation utils
def sym_homography_error(kpts0, kpts1, T_0to1):
kpts0_1 = from_homogeneous(to_homogeneous(kpts0) @ T_0to1.transpose(-1, -2))
dist0_1 = ((kpts0_1 - kpts1) ** 2).sum(-1).sqrt()
kpts1_0 = from_homogeneous(
to_homogeneous(kpts1) @ torch.pinverse(T_0to1.transpose(-1, -2))
)
dist1_0 = ((kpts1_0 - kpts0) ** 2).sum(-1).sqrt()
return (dist0_1 + dist1_0) / 2.0
def sym_homography_error_all(kpts0, kpts1, H):
kp0_1 = warp_points_torch(kpts0, H, inverse=False)
kp1_0 = warp_points_torch(kpts1, H, inverse=True)
# build a distance matrix of size [... x M x N]
dist0 = torch.sum((kp0_1.unsqueeze(-2) - kpts1.unsqueeze(-3)) ** 2, -1).sqrt()
dist1 = torch.sum((kpts0.unsqueeze(-2) - kp1_0.unsqueeze(-3)) ** 2, -1).sqrt()
return (dist0 + dist1) / 2.0
def homography_corner_error(T, T_gt, image_size):
W, H = image_size[:, 0], image_size[:, 1]
corners0 = torch.Tensor([[0, 0], [W, 0], [W, H], [0, H]]).float().to(T)
corners1_gt = from_homogeneous(to_homogeneous(corners0) @ T_gt.transpose(-1, -2))
corners1 = from_homogeneous(to_homogeneous(corners0) @ T.transpose(-1, -2))
d = torch.sqrt(((corners1 - corners1_gt) ** 2).sum(-1))
return d.mean(-1)

View File

@ -0,0 +1,166 @@
import numpy as np
import torch
def to_homogeneous(points):
"""Convert N-dimensional points to homogeneous coordinates.
Args:
points: torch.Tensor or numpy.ndarray with size (..., N).
Returns:
A torch.Tensor or numpy.ndarray with size (..., N+1).
"""
if isinstance(points, torch.Tensor):
pad = points.new_ones(points.shape[:-1] + (1,))
return torch.cat([points, pad], dim=-1)
elif isinstance(points, np.ndarray):
pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype)
return np.concatenate([points, pad], axis=-1)
else:
raise ValueError
def from_homogeneous(points, eps=0.0):
"""Remove the homogeneous dimension of N-dimensional points.
Args:
points: torch.Tensor or numpy.ndarray with size (..., N+1).
Returns:
A torch.Tensor or numpy ndarray with size (..., N).
"""
return points[..., :-1] / (points[..., -1:] + eps)
def batched_eye_like(x: torch.Tensor, n: int):
"""Create a batch of identity matrices.
Args:
x: a reference torch.Tensor whose batch dimension will be copied.
n: the size of each identity matrix.
Returns:
A torch.Tensor of size (B, n, n), with same dtype and device as x.
"""
return torch.eye(n).to(x)[None].repeat(len(x), 1, 1)
def skew_symmetric(v):
"""Create a skew-symmetric matrix from a (batched) vector of size (..., 3)."""
z = torch.zeros_like(v[..., 0])
M = torch.stack(
[
z,
-v[..., 2],
v[..., 1],
v[..., 2],
z,
-v[..., 0],
-v[..., 1],
v[..., 0],
z,
],
dim=-1,
).reshape(v.shape[:-1] + (3, 3))
return M
def transform_points(T, points):
return from_homogeneous(to_homogeneous(points) @ T.transpose(-1, -2))
def is_inside(pts, shape):
return (pts > 0).all(-1) & (pts < shape[:, None]).all(-1)
def so3exp_map(w, eps: float = 1e-7):
"""Compute rotation matrices from batched twists.
Args:
w: batched 3D axis-angle vectors of size (..., 3).
Returns:
A batch of rotation matrices of size (..., 3, 3).
"""
theta = w.norm(p=2, dim=-1, keepdim=True)
small = theta < eps
div = torch.where(small, torch.ones_like(theta), theta)
W = skew_symmetric(w / div)
theta = theta[..., None] # ... x 1 x 1
res = W * torch.sin(theta) + (W @ W) * (1 - torch.cos(theta))
res = torch.where(small[..., None], W, res) # first-order Taylor approx
return torch.eye(3).to(W) + res
@torch.jit.script
def distort_points(pts, dist):
"""Distort normalized 2D coordinates
and check for validity of the distortion model.
"""
dist = dist.unsqueeze(-2) # add point dimension
ndist = dist.shape[-1]
undist = pts
valid = torch.ones(pts.shape[:-1], device=pts.device, dtype=torch.bool)
if ndist > 0:
k1, k2 = dist[..., :2].split(1, -1)
r2 = torch.sum(pts**2, -1, keepdim=True)
radial = k1 * r2 + k2 * r2**2
undist = undist + pts * radial
# The distortion model is supposedly only valid within the image
# boundaries. Because of the negative radial distortion, points that
# are far outside of the boundaries might actually be mapped back
# within the image. To account for this, we discard points that are
# beyond the inflection point of the distortion model,
# e.g. such that d(r + k_1 r^3 + k2 r^5)/dr = 0
limited = ((k2 > 0) & ((9 * k1**2 - 20 * k2) > 0)) | ((k2 <= 0) & (k1 > 0))
limit = torch.abs(
torch.where(
k2 > 0,
(torch.sqrt(9 * k1**2 - 20 * k2) - 3 * k1) / (10 * k2),
1 / (3 * k1),
)
)
valid = valid & torch.squeeze(~limited | (r2 < limit), -1)
if ndist > 2:
p12 = dist[..., 2:]
p21 = p12.flip(-1)
uv = torch.prod(pts, -1, keepdim=True)
undist = undist + 2 * p12 * uv + p21 * (r2 + 2 * pts**2)
# TODO: handle tangential boundaries
return undist, valid
@torch.jit.script
def J_distort_points(pts, dist):
dist = dist.unsqueeze(-2) # add point dimension
ndist = dist.shape[-1]
J_diag = torch.ones_like(pts)
J_cross = torch.zeros_like(pts)
if ndist > 0:
k1, k2 = dist[..., :2].split(1, -1)
r2 = torch.sum(pts**2, -1, keepdim=True)
uv = torch.prod(pts, -1, keepdim=True)
radial = k1 * r2 + k2 * r2**2
d_radial = 2 * k1 + 4 * k2 * r2
J_diag += radial + (pts**2) * d_radial
J_cross += uv * d_radial
if ndist > 2:
p12 = dist[..., 2:]
p21 = p12.flip(-1)
J_diag += 2 * p12 * pts.flip(-1) + 6 * p21 * pts
J_cross += 2 * p12 * pts + 2 * p21 * pts.flip(-1)
J = torch.diag_embed(J_diag) + torch.diag_embed(J_cross).flip(-1)
return J
def get_image_coords(img):
h, w = img.shape[-2:]
return (
torch.stack(
torch.meshgrid(
torch.arange(h, dtype=torch.float32, device=img.device),
torch.arange(w, dtype=torch.float32, device=img.device),
indexing="ij",
)[::-1],
dim=0,
).permute(1, 2, 0)
)[None] + 0.5

View File

@ -0,0 +1,424 @@
"""
Convenience classes for an SE3 pose and a pinhole Camera with lens distortion.
Based on PyTorch tensors: differentiable, batched, with GPU support.
"""
import functools
import inspect
import math
from typing import Union, Tuple, List, Dict, NamedTuple, Optional
import torch
import numpy as np
from .utils import (
distort_points,
J_distort_points,
skew_symmetric,
so3exp_map,
to_homogeneous,
)
def autocast(func):
"""Cast the inputs of a TensorWrapper method to PyTorch tensors
if they are numpy arrays. Use the device and dtype of the wrapper.
"""
@functools.wraps(func)
def wrap(self, *args):
device = torch.device("cpu")
dtype = None
if isinstance(self, TensorWrapper):
if self._data is not None:
device = self.device
dtype = self.dtype
elif not inspect.isclass(self) or not issubclass(self, TensorWrapper):
raise ValueError(self)
cast_args = []
for arg in args:
if isinstance(arg, np.ndarray):
arg = torch.from_numpy(arg)
arg = arg.to(device=device, dtype=dtype)
cast_args.append(arg)
return func(self, *cast_args)
return wrap
class TensorWrapper:
_data = None
@autocast
def __init__(self, data: torch.Tensor):
self._data = data
@property
def shape(self):
return self._data.shape[:-1]
@property
def device(self):
return self._data.device
@property
def dtype(self):
return self._data.dtype
def __getitem__(self, index):
return self.__class__(self._data[index])
def __setitem__(self, index, item):
self._data[index] = item.data
def to(self, *args, **kwargs):
return self.__class__(self._data.to(*args, **kwargs))
def cpu(self):
return self.__class__(self._data.cpu())
def cuda(self):
return self.__class__(self._data.cuda())
def pin_memory(self):
return self.__class__(self._data.pin_memory())
def float(self):
return self.__class__(self._data.float())
def double(self):
return self.__class__(self._data.double())
def detach(self):
return self.__class__(self._data.detach())
@classmethod
def stack(cls, objects: List, dim=0, *, out=None):
data = torch.stack([obj._data for obj in objects], dim=dim, out=out)
return cls(data)
@classmethod
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.stack:
return self.stack(*args, **kwargs)
else:
return NotImplemented
class Pose(TensorWrapper):
def __init__(self, data: torch.Tensor):
assert data.shape[-1] == 12
super().__init__(data)
@classmethod
@autocast
def from_Rt(cls, R: torch.Tensor, t: torch.Tensor):
"""Pose from a rotation matrix and translation vector.
Accepts numpy arrays or PyTorch tensors.
Args:
R: rotation matrix with shape (..., 3, 3).
t: translation vector with shape (..., 3).
"""
assert R.shape[-2:] == (3, 3)
assert t.shape[-1] == 3
assert R.shape[:-2] == t.shape[:-1]
data = torch.cat([R.flatten(start_dim=-2), t], -1)
return cls(data)
@classmethod
@autocast
def from_aa(cls, aa: torch.Tensor, t: torch.Tensor):
"""Pose from an axis-angle rotation vector and translation vector.
Accepts numpy arrays or PyTorch tensors.
Args:
aa: axis-angle rotation vector with shape (..., 3).
t: translation vector with shape (..., 3).
"""
assert aa.shape[-1] == 3
assert t.shape[-1] == 3
assert aa.shape[:-1] == t.shape[:-1]
return cls.from_Rt(so3exp_map(aa), t)
@classmethod
def from_4x4mat(cls, T: torch.Tensor):
"""Pose from an SE(3) transformation matrix.
Args:
T: transformation matrix with shape (..., 4, 4).
"""
assert T.shape[-2:] == (4, 4)
R, t = T[..., :3, :3], T[..., :3, 3]
return cls.from_Rt(R, t)
@classmethod
def from_colmap(cls, image: NamedTuple):
"""Pose from a COLMAP Image."""
return cls.from_Rt(image.qvec2rotmat(), image.tvec)
@property
def R(self) -> torch.Tensor:
"""Underlying rotation matrix with shape (..., 3, 3)."""
rvec = self._data[..., :9]
return rvec.reshape(rvec.shape[:-1] + (3, 3))
@property
def t(self) -> torch.Tensor:
"""Underlying translation vector with shape (..., 3)."""
return self._data[..., -3:]
def inv(self) -> "Pose":
"""Invert an SE(3) pose."""
R = self.R.transpose(-1, -2)
t = -(R @ self.t.unsqueeze(-1)).squeeze(-1)
return self.__class__.from_Rt(R, t)
def compose(self, other: "Pose") -> "Pose":
"""Chain two SE(3) poses: T_B2C.compose(T_A2B) -> T_A2C."""
R = self.R @ other.R
t = self.t + (self.R @ other.t.unsqueeze(-1)).squeeze(-1)
return self.__class__.from_Rt(R, t)
@autocast
def transform(self, p3d: torch.Tensor) -> torch.Tensor:
"""Transform a set of 3D points.
Args:
p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3).
"""
assert p3d.shape[-1] == 3
# assert p3d.shape[:-2] == self.shape # allow broadcasting
return p3d @ self.R.transpose(-1, -2) + self.t.unsqueeze(-2)
def __mul__(self, p3D: torch.Tensor) -> torch.Tensor:
"""Transform a set of 3D points: T_A2B * p3D_A -> p3D_B."""
return self.transform(p3D)
def __matmul__(
self, other: Union["Pose", torch.Tensor]
) -> Union["Pose", torch.Tensor]:
"""Transform a set of 3D points: T_A2B * p3D_A -> p3D_B.
or chain two SE(3) poses: T_B2C @ T_A2B -> T_A2C."""
if isinstance(other, self.__class__):
return self.compose(other)
else:
return self.transform(other)
@autocast
def J_transform(self, p3d_out: torch.Tensor):
# [[1,0,0,0,-pz,py],
# [0,1,0,pz,0,-px],
# [0,0,1,-py,px,0]]
J_t = torch.diag_embed(torch.ones_like(p3d_out))
J_rot = -skew_symmetric(p3d_out)
J = torch.cat([J_t, J_rot], dim=-1)
return J # N x 3 x 6
def numpy(self) -> Tuple[np.ndarray]:
return self.R.numpy(), self.t.numpy()
def magnitude(self) -> Tuple[torch.Tensor]:
"""Magnitude of the SE(3) transformation.
Returns:
dr: rotation anngle in degrees.
dt: translation distance in meters.
"""
trace = torch.diagonal(self.R, dim1=-1, dim2=-2).sum(-1)
cos = torch.clamp((trace - 1) / 2, -1, 1)
dr = torch.acos(cos).abs() / math.pi * 180
dt = torch.norm(self.t, dim=-1)
return dr, dt
def __repr__(self):
return f"Pose: {self.shape} {self.dtype} {self.device}"
class Camera(TensorWrapper):
eps = 1e-4
def __init__(self, data: torch.Tensor):
assert data.shape[-1] in {6, 8, 10}
super().__init__(data)
@classmethod
def from_colmap(cls, camera: Union[Dict, NamedTuple]):
"""Camera from a COLMAP Camera tuple or dictionary.
We use the corner-convetion from COLMAP (center of top left pixel is (0.5, 0.5))
"""
if isinstance(camera, tuple):
camera = camera._asdict()
model = camera["model"]
params = camera["params"]
if model in ["OPENCV", "PINHOLE", "RADIAL"]:
(fx, fy, cx, cy), params = np.split(params, [4])
elif model in ["SIMPLE_PINHOLE", "SIMPLE_RADIAL"]:
(f, cx, cy), params = np.split(params, [3])
fx = fy = f
if model == "SIMPLE_RADIAL":
params = np.r_[params, 0.0]
else:
raise NotImplementedError(model)
data = np.r_[camera["width"], camera["height"], fx, fy, cx, cy, params]
return cls(data)
@classmethod
@autocast
def from_calibration_matrix(cls, K: torch.Tensor):
cx, cy = K[..., 0, 2], K[..., 1, 2]
fx, fy = K[..., 0, 0], K[..., 1, 1]
data = torch.stack([2 * cx, 2 * cy, fx, fy, cx, cy], -1)
return cls(data)
@autocast
def calibration_matrix(self):
K = torch.zeros(
*self._data.shape[:-1],
3,
3,
device=self._data.device,
dtype=self._data.dtype,
)
K[..., 0, 2] = self._data[..., 4]
K[..., 1, 2] = self._data[..., 5]
K[..., 0, 0] = self._data[..., 2]
K[..., 1, 1] = self._data[..., 3]
K[..., 2, 2] = 1.0
return K
@property
def size(self) -> torch.Tensor:
"""Size (width height) of the images, with shape (..., 2)."""
return self._data[..., :2]
@property
def f(self) -> torch.Tensor:
"""Focal lengths (fx, fy) with shape (..., 2)."""
return self._data[..., 2:4]
@property
def c(self) -> torch.Tensor:
"""Principal points (cx, cy) with shape (..., 2)."""
return self._data[..., 4:6]
@property
def dist(self) -> torch.Tensor:
"""Distortion parameters, with shape (..., {0, 2, 4})."""
return self._data[..., 6:]
@autocast
def scale(self, scales: torch.Tensor):
"""Update the camera parameters after resizing an image."""
s = scales
data = torch.cat([self.size * s, self.f * s, self.c * s, self.dist], -1)
return self.__class__(data)
def crop(self, left_top: Tuple[float], size: Tuple[int]):
"""Update the camera parameters after cropping an image."""
left_top = self._data.new_tensor(left_top)
size = self._data.new_tensor(size)
data = torch.cat([size, self.f, self.c - left_top, self.dist], -1)
return self.__class__(data)
@autocast
def in_image(self, p2d: torch.Tensor):
"""Check if 2D points are within the image boundaries."""
assert p2d.shape[-1] == 2
# assert p2d.shape[:-2] == self.shape # allow broadcasting
size = self.size.unsqueeze(-2)
valid = torch.all((p2d >= 0) & (p2d <= (size - 1)), -1)
return valid
@autocast
def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
"""Project 3D points into the camera plane and check for visibility."""
z = p3d[..., -1]
valid = z > self.eps
z = z.clamp(min=self.eps)
p2d = p3d[..., :-1] / z.unsqueeze(-1)
return p2d, valid
def J_project(self, p3d: torch.Tensor):
x, y, z = p3d[..., 0], p3d[..., 1], p3d[..., 2]
zero = torch.zeros_like(z)
z = z.clamp(min=self.eps)
J = torch.stack([1 / z, zero, -x / z**2, zero, 1 / z, -y / z**2], dim=-1)
J = J.reshape(p3d.shape[:-1] + (2, 3))
return J # N x 2 x 3
@autocast
def distort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
"""Distort normalized 2D coordinates
and check for validity of the distortion model.
"""
assert pts.shape[-1] == 2
# assert pts.shape[:-2] == self.shape # allow broadcasting
return distort_points(pts, self.dist)
def J_distort(self, pts: torch.Tensor):
return J_distort_points(pts, self.dist) # N x 2 x 2
@autocast
def denormalize(self, p2d: torch.Tensor) -> torch.Tensor:
"""Convert normalized 2D coordinates into pixel coordinates."""
return p2d * self.f.unsqueeze(-2) + self.c.unsqueeze(-2)
@autocast
def normalize(self, p2d: torch.Tensor) -> torch.Tensor:
"""Convert normalized 2D coordinates into pixel coordinates."""
return (p2d - self.c.unsqueeze(-2)) / self.f.unsqueeze(-2)
def J_denormalize(self):
return torch.diag_embed(self.f).unsqueeze(-3) # 1 x 2 x 2
@autocast
def cam2image(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
"""Transform 3D points into 2D pixel coordinates."""
p2d, visible = self.project(p3d)
p2d, mask = self.distort(p2d)
p2d = self.denormalize(p2d)
valid = visible & mask & self.in_image(p2d)
return p2d, valid
def J_world2image(self, p3d: torch.Tensor):
p2d_dist, valid = self.project(p3d)
J = self.J_denormalize() @ self.J_distort(p2d_dist) @ self.J_project(p3d)
return J, valid
@autocast
def image2cam(self, p2d: torch.Tensor) -> torch.Tensor:
"""Convert 2D pixel corrdinates to 3D points with z=1"""
assert self._data.shape
p2d = self.normalize(p2d)
# iterative undistortion
return to_homogeneous(p2d)
def to_cameradict(self, camera_model: Optional[str] = None) -> List[Dict]:
data = self._data.clone()
if data.dim() == 1:
data = data.unsqueeze(0)
assert data.dim() == 2
b, d = data.shape
if camera_model is None:
camera_model = {6: "PINHOLE", 8: "RADIAL", 10: "OPENCV"}[d]
cameras = []
for i in range(b):
if camera_model.startswith("SIMPLE_"):
params = [x.item() for x in data[i, 3 : min(d, 7)]]
else:
params = [x.item() for x in data[i, 2:]]
cameras.append(
{
"model": camera_model,
"width": int(data[i, 0].item()),
"height": int(data[i, 1].item()),
"params": params,
}
)
return cameras if self._data.dim() == 2 else cameras[0]
def __repr__(self):
return f"Camera {self.shape} {self.dtype} {self.device}"

View File

@ -0,0 +1,29 @@
import importlib.util
from .base_model import BaseModel
from ..utils.tools import get_class
def get_model(name):
import_paths = [
name,
f"{__name__}.{name}",
f"{__name__}.extractors.{name}", # backward compatibility
f"{__name__}.matchers.{name}", # backward compatibility
]
for path in import_paths:
try:
spec = importlib.util.find_spec(path)
except ModuleNotFoundError:
spec = None
if spec is not None:
try:
return get_class(path, BaseModel)
except AssertionError:
mod = __import__(path, fromlist=[""])
try:
return mod.__main_model__
except AttributeError as exc:
print(exc)
continue
raise RuntimeError(f'Model {name} not found in any of [{" ".join(import_paths)}]')

View File

View File

@ -0,0 +1,29 @@
import torch
import torch.nn.functional as F
from ..base_model import BaseModel
class DinoV2(BaseModel):
default_conf = {"weights": "dinov2_vits14", "allow_resize": False}
required_data_keys = ["image"]
def _init(self, conf):
self.net = torch.hub.load("facebookresearch/dinov2", conf.weights)
def _forward(self, data):
img = data["image"]
if self.conf.allow_resize:
img = F.upsample(img, [int(x // 14 * 14) for x in img.shape[-2:]])
desc, cls_token = self.net.get_intermediate_layers(
img, n=1, return_class_token=True, reshape=True
)[0]
return {
"features": desc,
"global_descriptor": cls_token,
"descriptors": desc.flatten(-2).transpose(-2, -1),
}
def loss(self, pred, data):
raise NotImplementedError

View File

@ -0,0 +1,126 @@
"""
Base class for trainable models.
"""
from abc import ABCMeta, abstractmethod
import omegaconf
from omegaconf import OmegaConf
from torch import nn
from copy import copy
class MetaModel(ABCMeta):
def __prepare__(name, bases, **kwds):
total_conf = OmegaConf.create()
for base in bases:
for key in ("base_default_conf", "default_conf"):
update = getattr(base, key, {})
if isinstance(update, dict):
update = OmegaConf.create(update)
total_conf = OmegaConf.merge(total_conf, update)
return dict(base_default_conf=total_conf)
class BaseModel(nn.Module, metaclass=MetaModel):
"""
What the child model is expect to declare:
default_conf: dictionary of the default configuration of the model.
It recursively updates the default_conf of all parent classes, and
it is updated by the user-provided configuration passed to __init__.
Configurations can be nested.
required_data_keys: list of expected keys in the input data dictionary.
strict_conf (optional): boolean. If false, BaseModel does not raise
an error when the user provides an unknown configuration entry.
_init(self, conf): initialization method, where conf is the final
configuration object (also accessible with `self.conf`). Accessing
unknown configuration entries will raise an error.
_forward(self, data): method that returns a dictionary of batched
prediction tensors based on a dictionary of batched input data tensors.
loss(self, pred, data): method that returns a dictionary of losses,
computed from model predictions and input data. Each loss is a batch
of scalars, i.e. a torch.Tensor of shape (B,).
The total loss to be optimized has the key `'total'`.
metrics(self, pred, data): method that returns a dictionary of metrics,
each as a batch of scalars.
"""
default_conf = {
"name": None,
"trainable": True, # if false: do not optimize this model parameters
"freeze_batch_normalization": False, # use test-time statistics
"timeit": False, # time forward pass
}
required_data_keys = []
strict_conf = False
def __init__(self, conf):
"""Perform some logic and call the _init method of the child model."""
super().__init__()
default_conf = OmegaConf.merge(
self.base_default_conf, OmegaConf.create(self.default_conf)
)
if self.strict_conf:
OmegaConf.set_struct(default_conf, True)
# fixme: backward compatibility
if "pad" in conf and "pad" not in default_conf: # backward compat.
with omegaconf.read_write(conf):
with omegaconf.open_dict(conf):
conf["interpolation"] = {"pad": conf.pop("pad")}
if isinstance(conf, dict):
conf = OmegaConf.create(conf)
self.conf = conf = OmegaConf.merge(default_conf, conf)
OmegaConf.set_readonly(conf, True)
OmegaConf.set_struct(conf, True)
self.required_data_keys = copy(self.required_data_keys)
self._init(conf)
if not conf.trainable:
for p in self.parameters():
p.requires_grad = False
def train(self, mode=True):
super().train(mode)
def freeze_bn(module):
if isinstance(module, nn.modules.batchnorm._BatchNorm):
module.eval()
if self.conf.freeze_batch_normalization:
self.apply(freeze_bn)
return self
def forward(self, data):
"""Check the data and call the _forward method of the child model."""
def recursive_key_check(expected, given):
for key in expected:
assert key in given, f"Missing key {key} in data"
if isinstance(expected, dict):
recursive_key_check(expected[key], given[key])
recursive_key_check(self.required_data_keys, data)
return self._forward(data)
@abstractmethod
def _init(self, conf):
"""To be implemented by the child class."""
raise NotImplementedError
@abstractmethod
def _forward(self, data):
"""To be implemented by the child class."""
raise NotImplementedError
@abstractmethod
def loss(self, pred, data):
"""To be implemented by the child class."""
raise NotImplementedError

View File

@ -0,0 +1,129 @@
import torch
import string
import h5py
from .base_model import BaseModel
from ..settings import DATA_PATH
from ..datasets.base_dataset import collate
from ..utils.tensor import batch_to_device
from .utils.misc import pad_to_length
def pad_local_features(pred: dict, seq_l: int):
pred["keypoints"] = pad_to_length(
pred["keypoints"],
seq_l,
-2,
mode="random_c",
)
if "keypoint_scores" in pred.keys():
pred["keypoint_scores"] = pad_to_length(
pred["keypoint_scores"], seq_l, -1, mode="zeros"
)
if "descriptors" in pred.keys():
pred["descriptors"] = pad_to_length(
pred["descriptors"], seq_l, -2, mode="random"
)
if "scales" in pred.keys():
pred["scales"] = pad_to_length(pred["scales"], seq_l, -1, mode="zeros")
if "oris" in pred.keys():
pred["oris"] = pad_to_length(pred["oris"], seq_l, -1, mode="zeros")
return pred
def pad_line_features(pred, seq_l: int = None):
raise NotImplementedError
def recursive_load(grp, pkeys):
return {
k: torch.from_numpy(grp[k].__array__())
if isinstance(grp[k], h5py.Dataset)
else recursive_load(grp[k], list(grp.keys()))
for k in pkeys
}
class CacheLoader(BaseModel):
default_conf = {
"path": "???", # can be a format string like exports/{scene}/
"data_keys": None, # load all keys
"device": None, # load to same device as data
"trainable": False,
"add_data_path": True,
"collate": True,
"scale": ["keypoints", "lines", "orig_lines"],
"padding_fn": None,
"padding_length": None, # required for batching!
"numeric_type": "float32", # [None, "float16", "float32", "float64"]
}
required_data_keys = ["name"] # we need an identifier
def _init(self, conf):
self.hfiles = {}
self.padding_fn = conf.padding_fn
if self.padding_fn is not None:
self.padding_fn = eval(self.padding_fn)
self.numeric_dtype = {
None: None,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
}[conf.numeric_type]
def _forward(self, data):
preds = []
device = self.conf.device
if not device:
devices = set(
[v.device for v in data.values() if isinstance(v, torch.Tensor)]
)
if len(devices) == 0:
device = "cpu"
else:
assert len(devices) == 1
device = devices.pop()
var_names = [x[1] for x in string.Formatter().parse(self.conf.path) if x[1]]
for i, name in enumerate(data["name"]):
fpath = self.conf.path.format(**{k: data[k][i] for k in var_names})
if self.conf.add_data_path:
fpath = DATA_PATH / fpath
hfile = h5py.File(str(fpath), "r")
grp = hfile[name]
pkeys = (
self.conf.data_keys if self.conf.data_keys is not None else grp.keys()
)
pred = recursive_load(grp, pkeys)
if self.numeric_dtype is not None:
pred = {
k: v
if not isinstance(v, torch.Tensor) or not torch.is_floating_point(v)
else v.to(dtype=self.numeric_dtype)
for k, v in pred.items()
}
pred = batch_to_device(pred, device)
for k, v in pred.items():
for pattern in self.conf.scale:
if k.startswith(pattern):
view_idx = k.replace(pattern, "")
scales = (
data["scales"]
if len(view_idx) == 0
else data[f"view{view_idx}"]["scales"]
)
pred[k] = pred[k] * scales[i]
# use this function to fix number of keypoints etc.
if self.padding_fn is not None:
pred = self.padding_fn(pred, self.conf.padding_length)
preds.append(pred)
hfile.close()
if self.conf.collate:
return batch_to_device(collate(preds), device)
else:
assert len(preds) == 1
return batch_to_device(preds[0], device)
def loss(self, pred, data):
raise NotImplementedError

View File

@ -0,0 +1,785 @@
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.models import resnet
from typing import Optional, Callable
from torch.nn.modules.utils import _pair
import torchvision
from gluefactory.models.base_model import BaseModel
# coordinates system
# ------------------------------> [ x: range=-1.0~1.0; w: range=0~W ]
# | -----------------------------
# | | |
# | | |
# | | |
# | | image |
# | | |
# | | |
# | | |
# | |---------------------------|
# v
# [ y: range=-1.0~1.0; h: range=0~H ]
def get_patches(
tensor: torch.Tensor, required_corners: torch.Tensor, ps: int
) -> torch.Tensor:
c, h, w = tensor.shape
corner = (required_corners - ps / 2 + 1).long()
corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps)
corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps)
offset = torch.arange(0, ps)
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
x, y = torch.meshgrid(offset, offset, **kw)
patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2)
patches = patches.to(corner) + corner[None, None]
pts = patches.reshape(-1, 2)
sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]]
sampled = sampled.reshape(ps, ps, -1, c)
assert sampled.shape[:3] == patches.shape[:3]
return sampled.permute(2, 3, 0, 1)
def simple_nms(scores: torch.Tensor, nms_radius: int):
"""Fast Non-maximum suppression to remove nearby points"""
zeros = torch.zeros_like(scores)
max_mask = scores == torch.nn.functional.max_pool2d(
scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
)
for _ in range(2):
supp_mask = (
torch.nn.functional.max_pool2d(
max_mask.float(),
kernel_size=nms_radius * 2 + 1,
stride=1,
padding=nms_radius,
)
> 0
)
supp_scores = torch.where(supp_mask, zeros, scores)
new_max_mask = supp_scores == torch.nn.functional.max_pool2d(
supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
)
max_mask = max_mask | (new_max_mask & (~supp_mask))
return torch.where(max_mask, scores, zeros)
class DKD(nn.Module):
def __init__(
self,
radius: int = 2,
top_k: int = 0,
scores_th: float = 0.2,
n_limit: int = 20000,
):
"""
Args:
radius: soft detection radius, kernel size is (2 * radius + 1)
top_k: top_k > 0: return top k keypoints
scores_th: top_k <= 0 threshold mode:
scores_th > 0: return keypoints with scores>scores_th
else: return keypoints with scores > scores.mean()
n_limit: max number of keypoint in threshold mode
"""
super().__init__()
self.radius = radius
self.top_k = top_k
self.scores_th = scores_th
self.n_limit = n_limit
self.kernel_size = 2 * self.radius + 1
self.temperature = 0.1 # tuned temperature
self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
# local xy grid
x = torch.linspace(-self.radius, self.radius, self.kernel_size)
# (kernel_size*kernel_size) x 2 : (w,h)
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
self.hw_grid = (
torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]]
)
def forward(
self,
scores_map: torch.Tensor,
sub_pixel: bool = True,
image_size: Optional[torch.Tensor] = None,
):
"""
:param scores_map: Bx1xHxW
:param descriptor_map: BxCxHxW
:param sub_pixel: whether to use sub-pixel keypoint detection
:return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1~1
"""
b, c, h, w = scores_map.shape
scores_nograd = scores_map.detach()
nms_scores = simple_nms(scores_nograd, self.radius)
# remove border
nms_scores[:, :, : self.radius, :] = 0
nms_scores[:, :, :, : self.radius] = 0
if image_size is not None:
for i in range(scores_map.shape[0]):
w, h = image_size[i].long()
nms_scores[i, :, h.item() - self.radius :, :] = 0
nms_scores[i, :, :, w.item() - self.radius :] = 0
else:
nms_scores[:, :, -self.radius :, :] = 0
nms_scores[:, :, :, -self.radius :] = 0
# detect keypoints without grad
if self.top_k > 0:
topk = torch.topk(nms_scores.view(b, -1), self.top_k)
indices_keypoints = [topk.indices[i] for i in range(b)] # B x top_k
else:
if self.scores_th > 0:
masks = nms_scores > self.scores_th
if masks.sum() == 0:
th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
masks = nms_scores > th.reshape(b, 1, 1, 1)
else:
th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
masks = nms_scores > th.reshape(b, 1, 1, 1)
masks = masks.reshape(b, -1)
indices_keypoints = [] # list, B x (any size)
scores_view = scores_nograd.reshape(b, -1)
for mask, scores in zip(masks, scores_view):
indices = mask.nonzero()[:, 0]
if len(indices) > self.n_limit:
kpts_sc = scores[indices]
sort_idx = kpts_sc.sort(descending=True)[1]
sel_idx = sort_idx[: self.n_limit]
indices = indices[sel_idx]
indices_keypoints.append(indices)
wh = torch.tensor([w - 1, h - 1], device=scores_nograd.device)
keypoints = []
scoredispersitys = []
kptscores = []
if sub_pixel:
# detect soft keypoints with grad backpropagation
patches = self.unfold(scores_map) # B x (kernel**2) x (H*W)
self.hw_grid = self.hw_grid.to(scores_map) # to device
for b_idx in range(b):
patch = patches[b_idx].t() # (H*W) x (kernel**2)
indices_kpt = indices_keypoints[
b_idx
] # one dimension vector, say its size is M
patch_scores = patch[indices_kpt] # M x (kernel**2)
keypoints_xy_nms = torch.stack(
[indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
dim=1,
) # Mx2
# max is detached to prevent undesired backprop loops in the graph
max_v = patch_scores.max(dim=1).values.detach()[:, None]
x_exp = (
(patch_scores - max_v) / self.temperature
).exp() # M * (kernel**2), in [0, 1]
# \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
xy_residual = (
x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
) # Soft-argmax, Mx2
hw_grid_dist2 = (
torch.norm(
(self.hw_grid[None, :, :] - xy_residual[:, None, :])
/ self.radius,
dim=-1,
)
** 2
)
scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
# compute result keypoints
keypoints_xy = keypoints_xy_nms + xy_residual
keypoints_xy = keypoints_xy / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
kptscore = torch.nn.functional.grid_sample(
scores_map[b_idx].unsqueeze(0),
keypoints_xy.view(1, 1, -1, 2),
mode="bilinear",
align_corners=True,
)[
0, 0, 0, :
] # CxN
keypoints.append(keypoints_xy)
scoredispersitys.append(scoredispersity)
kptscores.append(kptscore)
else:
for b_idx in range(b):
indices_kpt = indices_keypoints[
b_idx
] # one dimension vector, say its size is M
# To avoid warning: UserWarning: __floordiv__ is deprecated
keypoints_xy_nms = torch.stack(
[indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
dim=1,
) # Mx2
keypoints_xy = keypoints_xy_nms / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
kptscore = torch.nn.functional.grid_sample(
scores_map[b_idx].unsqueeze(0),
keypoints_xy.view(1, 1, -1, 2),
mode="bilinear",
align_corners=True,
)[
0, 0, 0, :
] # CxN
keypoints.append(keypoints_xy)
scoredispersitys.append(kptscore) # for jit.script compatability
kptscores.append(kptscore)
return keypoints, scoredispersitys, kptscores
class InputPadder(object):
"""Pads images such that dimensions are divisible by 8"""
def __init__(self, h: int, w: int, divis_by: int = 8):
self.ht = h
self.wd = w
pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
self._pad = [
pad_wd // 2,
pad_wd - pad_wd // 2,
pad_ht // 2,
pad_ht - pad_ht // 2,
]
def pad(self, x: torch.Tensor):
assert x.ndim == 4
return F.pad(x, self._pad, mode="replicate")
def unpad(self, x: torch.Tensor):
assert x.ndim == 4
ht = x.shape[-2]
wd = x.shape[-1]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]
class DeformableConv2d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
mask=False,
):
super(DeformableConv2d, self).__init__()
self.padding = padding
self.mask = mask
self.channel_num = (
3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size
)
self.offset_conv = nn.Conv2d(
in_channels,
self.channel_num,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True,
)
self.regular_conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=bias,
)
def forward(self, x):
h, w = x.shape[2:]
max_offset = max(h, w) / 4.0
out = self.offset_conv(x)
if self.mask:
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
else:
offset = out
mask = None
offset = offset.clamp(-max_offset, max_offset)
x = torchvision.ops.deform_conv2d(
input=x,
offset=offset,
weight=self.regular_conv.weight,
bias=self.regular_conv.bias,
padding=self.padding,
mask=mask,
)
return x
def get_conv(
inplanes,
planes,
kernel_size=3,
stride=1,
padding=1,
bias=False,
conv_type="conv",
mask=False,
):
if conv_type == "conv":
conv = nn.Conv2d(
inplanes,
planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
elif conv_type == "dcn":
conv = DeformableConv2d(
inplanes,
planes,
kernel_size=kernel_size,
stride=stride,
padding=_pair(padding),
bias=bias,
mask=mask,
)
else:
raise TypeError
return conv
class ConvBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
gate: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
conv_type: str = "conv",
mask: bool = False,
):
super().__init__()
if gate is None:
self.gate = nn.ReLU(inplace=True)
else:
self.gate = gate
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self.conv1 = get_conv(
in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
)
self.bn1 = norm_layer(out_channels)
self.conv2 = get_conv(
out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
)
self.bn2 = norm_layer(out_channels)
def forward(self, x):
x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W
x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W
return x
# modified based on torchvision\models\resnet.py#27->BasicBlock
class ResBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
gate: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
conv_type: str = "conv",
mask: bool = False,
) -> None:
super(ResBlock, self).__init__()
if gate is None:
self.gate = nn.ReLU(inplace=True)
else:
self.gate = gate
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError("ResBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in ResBlock")
# Both self.conv1 and self.downsample layers
# downsample the input when stride != 1
self.conv1 = get_conv(
inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask
)
self.bn1 = norm_layer(planes)
self.conv2 = get_conv(
planes, planes, kernel_size=3, conv_type=conv_type, mask=mask
)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.gate(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.gate(out)
return out
class SDDH(nn.Module):
def __init__(
self,
dims: int,
kernel_size: int = 3,
n_pos: int = 8,
gate=nn.ReLU(),
conv2D=False,
mask=False,
):
super(SDDH, self).__init__()
self.kernel_size = kernel_size
self.n_pos = n_pos
self.conv2D = conv2D
self.mask = mask
self.get_patches_func = get_patches
# estimate offsets
self.channel_num = 3 * n_pos if mask else 2 * n_pos
self.offset_conv = nn.Sequential(
nn.Conv2d(
dims,
self.channel_num,
kernel_size=kernel_size,
stride=1,
padding=0,
bias=True,
),
gate,
nn.Conv2d(
self.channel_num,
self.channel_num,
kernel_size=1,
stride=1,
padding=0,
bias=True,
),
)
# sampled feature conv
self.sf_conv = nn.Conv2d(
dims, dims, kernel_size=1, stride=1, padding=0, bias=False
)
# convM
if not conv2D:
# deformable desc weights
agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims))
self.register_parameter("agg_weights", agg_weights)
else:
self.convM = nn.Conv2d(
dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False
)
def forward(self, x, keypoints):
# x: [B,C,H,W]
# keypoints: list, [[N_kpts,2], ...] (w,h)
b, c, h, w = x.shape
wh = torch.tensor([[w - 1, h - 1]], device=x.device)
max_offset = max(h, w) / 4.0
offsets = []
descriptors = []
# get offsets for each keypoint
for ib in range(b):
xi, kptsi = x[ib], keypoints[ib]
kptsi_wh = (kptsi / 2 + 0.5) * wh
N_kpts = len(kptsi)
if self.kernel_size > 1:
patch = self.get_patches_func(
xi, kptsi_wh.long(), self.kernel_size
) # [N_kpts, C, K, K]
else:
kptsi_wh_long = kptsi_wh.long()
patch = (
xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]]
.permute(1, 0)
.reshape(N_kpts, c, 1, 1)
)
offset = self.offset_conv(patch).clamp(
-max_offset, max_offset
) # [N_kpts, 2*n_pos, 1, 1]
if self.mask:
offset = (
offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1)
) # [N_kpts, n_pos, 3]
offset = offset[:, :, :-1] # [N_kpts, n_pos, 2]
mask_weight = torch.sigmoid(offset[:, :, -1]) # [N_kpts, n_pos]
else:
offset = (
offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1)
) # [N_kpts, n_pos, 2]
offsets.append(offset) # for visualization
# get sample positions
pos = kptsi_wh.unsqueeze(1) + offset # [N_kpts, n_pos, 2]
pos = 2.0 * pos / wh[None] - 1
pos = pos.reshape(1, N_kpts * self.n_pos, 1, 2)
# sample features
features = F.grid_sample(
xi.unsqueeze(0), pos, mode="bilinear", align_corners=True
) # [1,C,(N_kpts*n_pos),1]
features = features.reshape(c, N_kpts, self.n_pos, 1).permute(
1, 0, 2, 3
) # [N_kpts, C, n_pos, 1]
if self.mask:
features = torch.einsum("ncpo,np->ncpo", features, mask_weight)
features = torch.selu_(self.sf_conv(features)).squeeze(
-1
) # [N_kpts, C, n_pos]
# convM
if not self.conv2D:
descs = torch.einsum(
"ncp,pcd->nd", features, self.agg_weights
) # [N_kpts, C]
else:
features = features.reshape(N_kpts, -1)[
:, :, None, None
] # [N_kpts, C*n_pos, 1, 1]
descs = self.convM(features).squeeze() # [N_kpts, C]
# normalize
descs = F.normalize(descs, p=2.0, dim=1)
descriptors.append(descs)
return descriptors, offsets
class ALIKED(BaseModel):
default_conf = {
"model_name": "aliked-n16",
"max_num_keypoints": -1,
"detection_threshold": 0.2,
"force_num_keypoints": False,
"pretrained": True,
"nms_radius": 2,
}
checkpoint_url = "https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth"
n_limit_max = 20000
cfgs = {
"aliked-t16": {
"c1": 8,
"c2": 16,
"c3": 32,
"c4": 64,
"dim": 64,
"K": 3,
"M": 16,
},
"aliked-n16": {
"c1": 16,
"c2": 32,
"c3": 64,
"c4": 128,
"dim": 128,
"K": 3,
"M": 16,
},
"aliked-n16rot": {
"c1": 16,
"c2": 32,
"c3": 64,
"c4": 128,
"dim": 128,
"K": 3,
"M": 16,
},
"aliked-n32": {
"c1": 16,
"c2": 32,
"c3": 64,
"c4": 128,
"dim": 128,
"K": 3,
"M": 32,
},
}
required_data_keys = ["image"]
def _init(self, conf):
if conf.force_num_keypoints:
assert conf.detection_threshold <= 0 and conf.max_num_keypoints > 0
# get configurations
c1, c2, c3, c4, dim, K, M = [v for _, v in self.cfgs[conf.model_name].items()]
conv_types = ["conv", "conv", "dcn", "dcn"]
conv2D = False
mask = False
# build model
self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4)
self.norm = nn.BatchNorm2d
self.gate = nn.SELU(inplace=True)
self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0])
self.block2 = ResBlock(
c1,
c2,
1,
nn.Conv2d(c1, c2, 1),
gate=self.gate,
norm_layer=self.norm,
conv_type=conv_types[1],
)
self.block3 = ResBlock(
c2,
c3,
1,
nn.Conv2d(c2, c3, 1),
gate=self.gate,
norm_layer=self.norm,
conv_type=conv_types[2],
mask=mask,
)
self.block4 = ResBlock(
c3,
c4,
1,
nn.Conv2d(c3, c4, 1),
gate=self.gate,
norm_layer=self.norm,
conv_type=conv_types[3],
mask=mask,
)
self.conv1 = resnet.conv1x1(c1, dim // 4)
self.conv2 = resnet.conv1x1(c2, dim // 4)
self.conv3 = resnet.conv1x1(c3, dim // 4)
self.conv4 = resnet.conv1x1(dim, dim // 4)
self.upsample2 = nn.Upsample(
scale_factor=2, mode="bilinear", align_corners=True
)
self.upsample4 = nn.Upsample(
scale_factor=4, mode="bilinear", align_corners=True
)
self.upsample8 = nn.Upsample(
scale_factor=8, mode="bilinear", align_corners=True
)
self.upsample32 = nn.Upsample(
scale_factor=32, mode="bilinear", align_corners=True
)
self.score_head = nn.Sequential(
resnet.conv1x1(dim, 8),
self.gate,
resnet.conv3x3(8, 4),
self.gate,
resnet.conv3x3(4, 4),
self.gate,
resnet.conv3x3(4, 1),
)
self.desc_head = SDDH(dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask)
self.dkd = DKD(
radius=conf.nms_radius,
top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
scores_th=conf.detection_threshold,
n_limit=conf.max_num_keypoints
if conf.max_num_keypoints > 0
else self.n_limit_max,
)
# load pretrained
if conf.pretrained:
state_dict = torch.hub.load_state_dict_from_url(
self.checkpoint_url.format(conf.model_name), map_location="cpu"
)
self.load_state_dict(state_dict, strict=True)
def extract_dense_map(self, image):
# Pads images such that dimensions are divisible by
div_by = 2**5
padder = InputPadder(image.shape[-2], image.shape[-1], div_by)
image = padder.pad(image)
# ================================== feature encoder
x1 = self.block1(image) # B x c1 x H x W
x2 = self.pool2(x1)
x2 = self.block2(x2) # B x c2 x H/2 x W/2
x3 = self.pool4(x2)
x3 = self.block3(x3) # B x c3 x H/8 x W/8
x4 = self.pool4(x3)
x4 = self.block4(x4) # B x dim x H/32 x W/32
# ================================== feature aggregation
x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W
x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2
x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8
x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32
x2_up = self.upsample2(x2) # B x dim//4 x H x W
x3_up = self.upsample8(x3) # B x dim//4 x H x W
x4_up = self.upsample32(x4) # B x dim//4 x H x W
x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1)
# ================================== score head
score_map = torch.sigmoid(self.score_head(x1234))
feature_map = torch.nn.functional.normalize(x1234, p=2, dim=1)
# Unpads images
feature_map = padder.unpad(feature_map)
score_map = padder.unpad(score_map)
return feature_map, score_map
def _forward(self, data):
image = data["image"]
feature_map, score_map = self.extract_dense_map(image)
keypoints, kptscores, scoredispersitys = self.dkd(
score_map, image_size=data.get("image_size")
)
descriptors, offsets = self.desc_head(feature_map, keypoints)
_, _, h, w = image.shape
wh = torch.tensor([w, h], device=image.device)
# no padding required,
# we can set detection_threshold=-1 and conf.max_num_keypoints
return {
"keypoints": wh * (torch.stack(keypoints) + 1) / 2.0, # B N 2
"descriptors": torch.stack(descriptors), # B N D
"keypoint_scores": torch.stack(kptscores), # B N
"score_dispersity": torch.stack(scoredispersitys),
"score_map": score_map, # Bx1xHxW
}
def loss(self, pred, data):
raise NotImplementedError

View File

@ -0,0 +1,107 @@
import torch
import kornia
from ..base_model import BaseModel
from ..utils.misc import pad_and_stack
class DISK(BaseModel):
default_conf = {
"weights": "depth",
"dense_outputs": False,
"max_num_keypoints": None,
"desc_dim": 128,
"nms_window_size": 5,
"detection_threshold": 0.0,
"force_num_keypoints": False,
"pad_if_not_divisible": True,
"chunk": 4, # for reduced VRAM in training
}
required_data_keys = ["image"]
def _init(self, conf):
self.model = kornia.feature.DISK.from_pretrained(conf.weights)
def _get_dense_outputs(self, images):
B = images.shape[0]
if self.conf.pad_if_not_divisible:
h, w = images.shape[2:]
pd_h = 16 - h % 16 if h % 16 > 0 else 0
pd_w = 16 - w % 16 if w % 16 > 0 else 0
images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0)
heatmaps, descriptors = self.model.heatmap_and_dense_descriptors(images)
if self.conf.pad_if_not_divisible:
heatmaps = heatmaps[..., :h, :w]
descriptors = descriptors[..., :h, :w]
keypoints = kornia.feature.disk.detector.heatmap_to_keypoints(
heatmaps,
n=self.conf.max_num_keypoints,
window_size=self.conf.nms_window_size,
score_threshold=self.conf.detection_threshold,
)
features = []
for i in range(B):
features.append(keypoints[i].merge_with_descriptors(descriptors[i]))
return features, descriptors
def _forward(self, data):
image = data["image"]
keypoints, scores, descriptors = [], [], []
if self.conf.dense_outputs:
dense_descriptors = []
chunk = self.conf.chunk
for i in range(0, image.shape[0], chunk):
if self.conf.dense_outputs:
features, d_descriptors = self._get_dense_outputs(
image[: min(image.shape[0], i + chunk)]
)
dense_descriptors.append(d_descriptors)
else:
features = self.model(
image[: min(image.shape[0], i + chunk)],
n=self.conf.max_num_keypoints,
window_size=self.conf.nms_window_size,
score_threshold=self.conf.detection_threshold,
pad_if_not_divisible=self.conf.pad_if_not_divisible,
)
keypoints += [f.keypoints for f in features]
scores += [f.detection_scores for f in features]
descriptors += [f.descriptors for f in features]
del features
if self.conf.force_num_keypoints:
# pad to target_length
target_length = self.conf.max_num_keypoints
keypoints = pad_and_stack(
keypoints,
target_length,
-2,
mode="random_c",
bounds=(
0,
data.get("image_size", torch.tensor(image.shape[-2:])).min().item(),
),
)
scores = pad_and_stack(scores, target_length, -1, mode="zeros")
descriptors = pad_and_stack(descriptors, target_length, -2, mode="zeros")
else:
keypoints = torch.stack(keypoints, 0)
scores = torch.stack(scores, 0)
descriptors = torch.stack(descriptors, 0)
pred = {
"keypoints": keypoints.to(image) + 0.5,
"keypoint_scores": scores.to(image),
"descriptors": descriptors.to(image),
}
if self.conf.dense_outputs:
pred["dense_descriptors"] = torch.cat(dense_descriptors, 0)
return pred
def loss(self, pred, data):
raise NotImplementedError

View File

@ -0,0 +1,59 @@
import torch
import math
from ..base_model import BaseModel
def to_sequence(map):
return map.flatten(-2).transpose(-1, -2)
def to_map(sequence):
n = sequence.shape[-2]
e = math.isqrt(n)
assert e * e == n
assert e * e == n
sequence.transpose(-1, -2).unflatten(-1, [e, e])
class GridExtractor(BaseModel):
default_conf = {"cell_size": 14}
required_data_keys = ["image"]
def _init(self, conf):
pass
def _forward(self, data):
b, c, h, w = data["image"].shape
cgrid = (
torch.stack(
torch.meshgrid(
torch.arange(
h // self.conf.cell_size,
dtype=torch.float32,
device=data["image"].device,
),
torch.arange(
w // self.conf.cell_size,
dtype=torch.float32,
device=data["image"].device,
),
indexing="ij",
)[::-1],
dim=0,
)
.unsqueeze(0)
.repeat([b, 1, 1, 1])
* self.conf.cell_size
+ self.conf.cell_size / 2
)
pred = {
"grid": cgrid + 0.5,
"keypoints": to_sequence(cgrid) + 0.5,
}
return pred
def loss(self, pred, data):
raise NotImplementedError

View File

@ -0,0 +1,73 @@
import torch
import kornia
from ..base_model import BaseModel
from ..utils.misc import pad_to_length
class KeyNetAffNetHardNet(BaseModel):
default_conf = {
"max_num_keypoints": None,
"desc_dim": 128,
"upright": False,
"scale_laf": 1.0,
"chunk": 4, # for reduced VRAM in training
}
required_data_keys = ["image"]
def _init(self, conf):
self.model = kornia.feature.KeyNetHardNet(
num_features=conf.max_num_keypoints,
upright=conf.upright,
scale_laf=conf.scale_laf,
)
def _forward(self, data):
image = data["image"]
if image.shape[1] == 3: # RGB
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
image = (image * scale).sum(1, keepdim=True)
lafs, scores, descs = [], [], []
im_size = data.get("image_size")
for i in range(image.shape[0]):
img_i = image[i : i + 1, :1]
if im_size is not None:
img_i = img_i[:, :, : im_size[i, 1], : im_size[i, 0]]
laf, score, desc = self.model(img_i)
xn = pad_to_length(
kornia.feature.get_laf_center(laf),
self.conf.max_num_keypoints,
pad_dim=-2,
mode="random_c",
bounds=(0, min(img_i.shape[-2:])),
)
laf = torch.cat(
[
laf,
kornia.feature.laf_from_center_scale_ori(xn[:, score.shape[-1] :]),
],
-3,
)
lafs.append(laf)
scores.append(pad_to_length(score, self.conf.max_num_keypoints, -1))
descs.append(pad_to_length(desc, self.conf.max_num_keypoints, -2))
lafs = torch.cat(lafs, 0)
scores = torch.cat(scores, 0)
descs = torch.cat(descs, 0)
keypoints = kornia.feature.get_laf_center(lafs)
scales = kornia.feature.get_laf_scale(lafs)[..., 0]
oris = kornia.feature.get_laf_orientation(lafs)
pred = {
"keypoints": keypoints,
"scales": scales.squeeze(-1),
"oris": oris.squeeze(-1),
"lafs": lafs,
"keypoint_scores": scores,
"descriptors": descs,
}
return pred
def loss(self, pred, data):
raise NotImplementedError

View File

@ -0,0 +1,78 @@
from omegaconf import OmegaConf
import torch.nn.functional as F
from ..base_model import BaseModel
from .. import get_model
# from ...geometry.depth import sample_fmap
to_ctr = OmegaConf.to_container # convert DictConfig to dict
class MixedExtractor(BaseModel):
default_conf = {
"detector": {"name": None},
"descriptor": {"name": None},
"interpolate_descriptors_from": None, # field name
}
required_data_keys = ["image"]
required_cache_keys = []
def _init(self, conf):
if conf.detector.name:
self.detector = get_model(conf.detector.name)(to_ctr(conf.detector))
else:
self.required_data_keys += ["cache"]
self.required_cache_keys += ["keypoints"]
if conf.descriptor.name:
self.descriptor = get_model(conf.descriptor.name)(to_ctr(conf.descriptor))
else:
self.required_data_keys += ["cache"]
self.required_cache_keys += ["descriptors"]
def _forward(self, data):
if self.conf.detector.name:
pred = self.detector(data)
else:
pred = data["cache"]
if self.conf.detector.name:
pred = {**pred, **self.descriptor({**pred, **data})}
if self.conf.interpolate_descriptors_from:
h, w = data["image"].shape[-2:]
kpts = pred["keypoints"]
pts = (kpts / kpts.new_tensor([[w, h]]) * 2 - 1)[:, None]
pred["descriptors"] = (
F.grid_sample(
pred[self.conf.interpolate_descriptors_from],
pts,
align_corners=False,
mode="bilinear",
)
.squeeze(-2)
.transpose(-2, -1)
.contiguous()
)
return pred
def loss(self, pred, data):
losses = {}
metrics = {}
total = 0
for k in ["detector", "descriptor"]:
apply = True
if "apply_loss" in self.conf[k].keys():
apply = self.conf[k].apply_loss
if self.conf[k].name and apply:
try:
losses_, metrics_ = getattr(self, k).loss(pred, {**pred, **data})
except NotImplementedError:
continue
losses = {**losses, **losses_}
metrics = {**metrics, **metrics_}
total = losses_["total"] + total
return {**losses, "total": total}, metrics

View File

@ -0,0 +1,240 @@
import numpy as np
import torch
import pycolmap
from scipy.spatial import KDTree
from omegaconf import OmegaConf
import cv2
from ..base_model import BaseModel
from ..utils.misc import pad_to_length
EPS = 1e-6
def sift_to_rootsift(x):
x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS)
x = np.sqrt(x.clip(min=EPS))
x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS)
return x
# from OpenGlue
def nms_keypoints(kpts: np.ndarray, responses: np.ndarray, radius: float) -> np.ndarray:
# TODO: add approximate tree
kd_tree = KDTree(kpts)
sorted_idx = np.argsort(-responses)
kpts_to_keep_idx = []
removed_idx = set()
for idx in sorted_idx:
# skip point if it was already removed
if idx in removed_idx:
continue
kpts_to_keep_idx.append(idx)
point = kpts[idx]
neighbors = kd_tree.query_ball_point(point, r=radius)
# Variable `neighbors` contains the `point` itself
removed_idx.update(neighbors)
mask = np.zeros((kpts.shape[0],), dtype=bool)
mask[kpts_to_keep_idx] = True
return mask
def detect_kpts_opencv(
features: cv2.Feature2D, image: np.ndarray, describe: bool = True
) -> np.ndarray:
"""
Detect keypoints using OpenCV Detector.
Optionally, perform NMS and filter top-response keypoints.
Optionally, perform description.
Args:
features: OpenCV based keypoints detector and descriptor
image: Grayscale image of uint8 data type
describe: flag indicating whether to simultaneously compute descriptors
Returns:
kpts: 1D array of detected cv2.KeyPoint
"""
if describe:
kpts, descriptors = features.detectAndCompute(image, None)
else:
kpts = features.detect(image, None)
kpts = np.array(kpts)
responses = np.array([k.response for k in kpts], dtype=np.float32)
# select all
top_score_idx = ...
pts = np.array([k.pt for k in kpts], dtype=np.float32)
scales = np.array([k.size for k in kpts], dtype=np.float32)
angles = np.array([k.angle for k in kpts], dtype=np.float32)
spts = np.concatenate([pts, scales[..., None], angles[..., None]], -1)
if describe:
return spts[top_score_idx], responses[top_score_idx], descriptors[top_score_idx]
else:
return spts[top_score_idx], responses[top_score_idx]
class SIFT(BaseModel):
default_conf = {
"has_detector": True,
"has_descriptor": True,
"descriptor_dim": 128,
"pycolmap_options": {
"first_octave": 0,
"peak_threshold": 0.005,
"edge_threshold": 10,
},
"rootsift": True,
"nms_radius": None,
"max_num_keypoints": -1,
"max_num_keypoints_val": None,
"force_num_keypoints": False,
"randomize_keypoints_training": False,
"detector": "pycolmap", # ['pycolmap', 'pycolmap_cpu', 'pycolmap_cuda', 'cv2']
"detection_threshold": None,
}
required_data_keys = ["image"]
def _init(self, conf):
self.sift = None # lazy loading
@torch.no_grad()
def extract_features(self, image):
image_np = image.cpu().numpy()[0]
assert image.shape[0] == 1
assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS
detector = str(self.conf.detector)
if self.sift is None and detector.startswith("pycolmap"):
options = OmegaConf.to_container(self.conf.pycolmap_options)
device = (
"auto" if detector == "pycolmap" else detector.replace("pycolmap_", "")
)
if self.conf.rootsift == "rootsift":
options["normalization"] = pycolmap.Normalization.L1_ROOT
else:
options["normalization"] = pycolmap.Normalization.L2
if self.conf.detection_threshold is not None:
options["peak_threshold"] = self.conf.detection_threshold
options["max_num_features"] = self.conf.max_num_keypoints
self.sift = pycolmap.Sift(options=options, device=device)
elif self.sift is None and self.conf.detector == "cv2":
self.sift = cv2.SIFT_create(contrastThreshold=self.conf.detection_threshold)
if detector.startswith("pycolmap"):
keypoints, scores, descriptors = self.sift.extract(image_np)
elif detector == "cv2":
# TODO: Check if opencv keypoints are already in corner convention
keypoints, scores, descriptors = detect_kpts_opencv(
self.sift, (image_np * 255.0).astype(np.uint8)
)
if self.conf.nms_radius is not None:
mask = nms_keypoints(keypoints[:, :2], scores, self.conf.nms_radius)
keypoints = keypoints[mask]
scores = scores[mask]
descriptors = descriptors[mask]
scales = keypoints[:, 2]
oris = np.rad2deg(keypoints[:, 3])
if self.conf.has_descriptor:
# We still renormalize because COLMAP does not normalize well,
# maybe due to numerical errors
if self.conf.rootsift:
descriptors = sift_to_rootsift(descriptors)
descriptors = torch.from_numpy(descriptors)
keypoints = torch.from_numpy(keypoints[:, :2]) # keep only x, y
scales = torch.from_numpy(scales)
oris = torch.from_numpy(oris)
scores = torch.from_numpy(scores)
# Keep the k keypoints with highest score
max_kps = self.conf.max_num_keypoints
# for val we allow different
if not self.training and self.conf.max_num_keypoints_val is not None:
max_kps = self.conf.max_num_keypoints_val
if max_kps is not None and max_kps > 0:
if self.conf.randomize_keypoints_training and self.training:
# instead of selecting top-k, sample k by score weights
raise NotImplementedError
elif max_kps < scores.shape[0]:
# TODO: check that the scores from PyCOLMAP are 100% correct,
# follow https://github.com/mihaidusmanu/pycolmap/issues/8
indices = torch.topk(scores, max_kps).indices
keypoints = keypoints[indices]
scales = scales[indices]
oris = oris[indices]
scores = scores[indices]
if self.conf.has_descriptor:
descriptors = descriptors[indices]
if self.conf.force_num_keypoints:
keypoints = pad_to_length(
keypoints,
max_kps,
-2,
mode="random_c",
bounds=(0, min(image.shape[1:])),
)
scores = pad_to_length(scores, max_kps, -1, mode="zeros")
scales = pad_to_length(scales, max_kps, -1, mode="zeros")
oris = pad_to_length(oris, max_kps, -1, mode="zeros")
if self.conf.has_descriptor:
descriptors = pad_to_length(descriptors, max_kps, -2, mode="zeros")
pred = {
"keypoints": keypoints,
"scales": scales,
"oris": oris,
"keypoint_scores": scores,
}
if self.conf.has_descriptor:
pred["descriptors"] = descriptors
return pred
@torch.no_grad()
def _forward(self, data):
pred = {
"keypoints": [],
"scales": [],
"oris": [],
"keypoint_scores": [],
"descriptors": [],
}
image = data["image"]
if image.shape[1] == 3: # RGB
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
image = (image * scale).sum(1, keepdim=True).cpu()
for k in range(image.shape[0]):
img = image[k]
if "image_size" in data.keys():
# avoid extracting points in padded areas
w, h = data["image_size"][k]
img = img[:, :h, :w]
p = self.extract_features(img)
for k, v in p.items():
pred[k].append(v)
if (image.shape[0] == 1) or self.conf.force_num_keypoints:
pred = {k: torch.stack(pred[k], 0) for k in pred.keys()}
pred = {k: pred[k].to(device=data["image"].device) for k in pred.keys()}
pred["oris"] = torch.deg2rad(pred["oris"])
return pred
def loss(self, pred, data):
raise NotImplementedError

View File

@ -0,0 +1,45 @@
import kornia
import torch
from ..base_model import BaseModel
class KorniaSIFT(BaseModel):
default_conf = {
"has_detector": True,
"has_descriptor": True,
"max_num_keypoints": -1,
"detection_threshold": None,
"rootsift": True,
}
required_data_keys = ["image"]
def _init(self, conf):
self.sift = kornia.feature.SIFTFeature(
num_features=self.conf.max_num_keypoints, rootsift=self.conf.rootsift
)
def _forward(self, data):
lafs, scores, descriptors = self.sift(data["image"])
keypoints = kornia.feature.get_laf_center(lafs)
scales = kornia.feature.get_laf_scale(lafs)
oris = kornia.feature.get_laf_orientation(lafs)
pred = {
"keypoints": keypoints, # @TODO: confirm keypoints are in corner convention
"scales": scales,
"oris": oris,
"keypoint_scores": scores,
}
if self.conf.has_descriptor:
pred["descriptors"] = descriptors
pred = {k: pred[k].to(device=data["image"].device) for k in pred.keys()}
pred["scales"] = pred["scales"]
pred["oris"] = torch.deg2rad(pred["oris"])
return pred
def loss(self, pred, data):
raise NotImplementedError

View File

@ -0,0 +1,209 @@
"""PyTorch implementation of the SuperPoint model,
derived from the TensorFlow re-implementation (2018).
Authors: Rémi Pautrat, Paul-Edouard Sarlin
https://github.com/rpautrat/SuperPoint
The implementation of this model and its trained weights are made
available under the MIT license.
"""
import torch.nn as nn
import torch
from collections import OrderedDict
from types import SimpleNamespace
from ..base_model import BaseModel
from ..utils.misc import pad_and_stack
def sample_descriptors(keypoints, descriptors, s: int = 8):
"""Interpolate descriptors at keypoint locations"""
b, c, h, w = descriptors.shape
keypoints = (keypoints + 0.5) / (keypoints.new_tensor([w, h]) * s)
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
descriptors = torch.nn.functional.grid_sample(
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False
)
descriptors = torch.nn.functional.normalize(
descriptors.reshape(b, c, -1), p=2, dim=1
)
return descriptors
def batched_nms(scores, nms_radius: int):
assert nms_radius >= 0
def max_pool(x):
return torch.nn.functional.max_pool2d(
x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
)
zeros = torch.zeros_like(scores)
max_mask = scores == max_pool(scores)
for _ in range(2):
supp_mask = max_pool(max_mask.float()) > 0
supp_scores = torch.where(supp_mask, zeros, scores)
new_max_mask = supp_scores == max_pool(supp_scores)
max_mask = max_mask | (new_max_mask & (~supp_mask))
return torch.where(max_mask, scores, zeros)
def select_top_k_keypoints(keypoints, scores, k):
if k >= len(keypoints):
return keypoints, scores
scores, indices = torch.topk(scores, k, dim=0, sorted=True)
return keypoints[indices], scores
class VGGBlock(nn.Sequential):
def __init__(self, c_in, c_out, kernel_size, relu=True):
padding = (kernel_size - 1) // 2
conv = nn.Conv2d(
c_in, c_out, kernel_size=kernel_size, stride=1, padding=padding
)
activation = nn.ReLU(inplace=True) if relu else nn.Identity()
bn = nn.BatchNorm2d(c_out, eps=0.001)
super().__init__(
OrderedDict(
[
("conv", conv),
("activation", activation),
("bn", bn),
]
)
)
class SuperPoint(BaseModel):
default_conf = {
"descriptor_dim": 256,
"nms_radius": 4,
"max_num_keypoints": None,
"force_num_keypoints": False,
"detection_threshold": 0.005,
"remove_borders": 4,
"descriptor_dim": 256,
"channels": [64, 64, 128, 128, 256],
"dense_outputs": None,
}
checkpoint_url = "https://github.com/rpautrat/SuperPoint/raw/master/weights/superpoint_v6_from_tf.pth" # noqa: E501
def _init(self, conf):
self.conf = SimpleNamespace(**conf)
self.stride = 2 ** (len(self.conf.channels) - 2)
channels = [1, *self.conf.channels[:-1]]
backbone = []
for i, c in enumerate(channels[1:], 1):
layers = [VGGBlock(channels[i - 1], c, 3), VGGBlock(c, c, 3)]
if i < len(channels) - 1:
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
backbone.append(nn.Sequential(*layers))
self.backbone = nn.Sequential(*backbone)
c = self.conf.channels[-1]
self.detector = nn.Sequential(
VGGBlock(channels[-1], c, 3),
VGGBlock(c, self.stride**2 + 1, 1, relu=False),
)
self.descriptor = nn.Sequential(
VGGBlock(channels[-1], c, 3),
VGGBlock(c, self.conf.descriptor_dim, 1, relu=False),
)
state_dict = torch.hub.load_state_dict_from_url(self.checkpoint_url)
self.load_state_dict(state_dict)
def _forward(self, data):
image = data["image"]
if image.shape[1] == 3: # RGB
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
image = (image * scale).sum(1, keepdim=True)
features = self.backbone(image)
descriptors_dense = torch.nn.functional.normalize(
self.descriptor(features), p=2, dim=1
)
# Decode the detection scores
scores = self.detector(features)
scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
b, _, h, w = scores.shape
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, self.stride, self.stride)
scores = scores.permute(0, 1, 3, 2, 4).reshape(
b, h * self.stride, w * self.stride
)
scores = batched_nms(scores, self.conf.nms_radius)
# Discard keypoints near the image borders
if self.conf.remove_borders:
pad = self.conf.remove_borders
scores[:, :pad] = -1
scores[:, :, :pad] = -1
scores[:, -pad:] = -1
scores[:, :, -pad:] = -1
# Extract keypoints
if b > 1:
idxs = torch.where(scores > self.conf.detection_threshold)
mask = idxs[0] == torch.arange(b, device=scores.device)[:, None]
else: # Faster shortcut
scores = scores.squeeze(0)
idxs = torch.where(scores > self.conf.detection_threshold)
# Convert (i, j) to (x, y)
keypoints_all = torch.stack(idxs[-2:], dim=-1).flip(1).float()
scores_all = scores[idxs]
keypoints = []
scores = []
for i in range(b):
if b > 1:
k = keypoints_all[mask[i]]
s = scores_all[mask[i]]
else:
k = keypoints_all
s = scores_all
if self.conf.max_num_keypoints is not None:
k, s = select_top_k_keypoints(k, s, self.conf.max_num_keypoints)
keypoints.append(k)
scores.append(s)
if self.conf.force_num_keypoints:
keypoints = pad_and_stack(
keypoints,
self.conf.max_num_keypoints,
-2,
mode="random_c",
bounds=(
0,
data.get("image_size", torch.tensor(image.shape[-2:])).min().item(),
),
)
scores = pad_and_stack(
scores, self.conf.max_num_keypoints, -1, mode="zeros"
)
else:
keypoints = torch.stack(keypoints, 0)
scores = torch.stack(scores, 0)
if len(keypoints) == 1 or self.conf.force_num_keypoints:
# Batch sampling of the descriptors
desc = sample_descriptors(keypoints, descriptors_dense, self.stride)
else:
desc = [
sample_descriptors(k[None], d[None], self.stride)[0]
for k, d in zip(keypoints, descriptors_dense)
]
pred = {
"keypoints": keypoints + 0.5,
"keypoint_scores": scores,
"descriptors": desc.transpose(-1, -2),
}
if self.conf.dense_outputs:
pred["dense_descriptors"] = descriptors_dense
return pred
def loss(self, pred, data):
raise NotImplementedError

View File

View File

@ -0,0 +1,105 @@
import numpy as np
import torch
import deeplsd.models.deeplsd_inference as deeplsd_inference
from ..base_model import BaseModel
from ...settings import DATA_PATH
class DeepLSD(BaseModel):
default_conf = {
"min_length": 15,
"max_num_lines": None,
"force_num_lines": False,
"model_conf": {
"detect_lines": True,
"line_detection_params": {
"merge": False,
"grad_nfa": True,
"filtering": "normal",
"grad_thresh": 3,
},
},
}
required_data_keys = ["image"]
def _init(self, conf):
if self.conf.force_num_lines:
assert (
self.conf.max_num_lines is not None
), "Missing max_num_lines parameter"
ckpt = DATA_PATH / "weights/deeplsd_md.tar"
if not ckpt.is_file():
self.download_model(ckpt)
ckpt = torch.load(ckpt, map_location="cpu")
self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval()
self.net.load_state_dict(ckpt["model"])
def download_model(self, path):
import subprocess
if not path.parent.is_dir():
path.parent.mkdir(parents=True, exist_ok=True)
link = "https://www.polybox.ethz.ch/index.php/s/XVb30sUyuJttFys/download"
cmd = ["wget", link, "-O", path]
print("Downloading DeepLSD model...")
subprocess.run(cmd, check=True)
def _forward(self, data):
image = data["image"]
lines, line_scores, valid_lines = [], [], []
if image.shape[1] == 3:
# Convert to grayscale
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
image = (image * scale).sum(1, keepdim=True)
# Forward pass
with torch.no_grad():
segs = self.net({"image": image})["lines"]
# Line scores are the sqrt of the length
for seg in segs:
lengths = np.linalg.norm(seg[:, 0] - seg[:, 1], axis=1)
segs = seg[lengths >= self.conf.min_length]
scores = np.sqrt(lengths[lengths >= self.conf.min_length])
# Keep the best lines
indices = np.argsort(-scores)
if self.conf.max_num_lines is not None:
indices = indices[: self.conf.max_num_lines]
segs = segs[indices]
scores = scores[indices]
# Pad if necessary
n = len(segs)
valid_mask = np.ones(n, dtype=bool)
if self.conf.force_num_lines:
pad = self.conf.max_num_lines - n
segs = np.concatenate(
[segs, np.zeros((pad, 2, 2), dtype=np.float32)], axis=0
)
scores = np.concatenate(
[scores, np.zeros(pad, dtype=np.float32)], axis=0
)
valid_mask = np.concatenate(
[valid_mask, np.zeros(pad, dtype=bool)], axis=0
)
lines.append(segs)
line_scores.append(scores)
valid_lines.append(valid_mask)
# Batch if possible
if len(image) == 1 or self.conf.force_num_lines:
lines = torch.tensor(lines, dtype=torch.float, device=image.device)
line_scores = torch.tensor(
line_scores, dtype=torch.float, device=image.device
)
valid_lines = torch.tensor(
valid_lines, dtype=torch.bool, device=image.device
)
return {"lines": lines, "line_scores": line_scores, "valid_lines": valid_lines}
def loss(self, pred, data):
raise NotImplementedError

View File

@ -0,0 +1,88 @@
import numpy as np
import torch
from joblib import Parallel, delayed
from pytlsd import lsd
from ..base_model import BaseModel
class LSD(BaseModel):
default_conf = {
"min_length": 15,
"max_num_lines": None,
"force_num_lines": False,
"n_jobs": 4,
}
required_data_keys = ["image"]
def _init(self, conf):
if self.conf.force_num_lines:
assert (
self.conf.max_num_lines is not None
), "Missing max_num_lines parameter"
def detect_lines(self, img):
# Run LSD
segs = lsd(img)
# Filter out keylines that do not meet the minimum length criteria
lengths = np.linalg.norm(segs[:, 2:4] - segs[:, 0:2], axis=1)
to_keep = lengths >= self.conf.min_length
segs, lengths = segs[to_keep], lengths[to_keep]
# Keep the best lines
scores = segs[:, -1] * np.sqrt(lengths)
segs = segs[:, :4].reshape(-1, 2, 2)
indices = np.argsort(-scores)
if self.conf.max_num_lines is not None:
indices = indices[: self.conf.max_num_lines]
segs = segs[indices]
scores = scores[indices]
# Pad if necessary
n = len(segs)
valid_mask = np.ones(n, dtype=bool)
if self.conf.force_num_lines:
pad = self.conf.max_num_lines - n
segs = np.concatenate(
[segs, np.zeros((pad, 2, 2), dtype=np.float32)], axis=0
)
scores = np.concatenate([scores, np.zeros(pad, dtype=np.float32)], axis=0)
valid_mask = np.concatenate([valid_mask, np.zeros(pad, dtype=bool)], axis=0)
return segs, scores, valid_mask
def _forward(self, data):
# Convert to the right data format
image = data["image"]
if image.shape[1] == 3:
# Convert to grayscale
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
image = (image * scale).sum(1, keepdim=True)
device = image.device
b_size = len(image)
image = np.uint8(image.squeeze(1).cpu().numpy() * 255)
# LSD detection in parallel
if b_size == 1:
lines, line_scores, valid_lines = self.detect_lines(image[0])
lines = [lines]
line_scores = [line_scores]
valid_lines = [valid_lines]
else:
lines, line_scores, valid_lines = zip(
*Parallel(n_jobs=self.conf.n_jobs)(
delayed(self.detect_lines)(img) for img in image
)
)
# Batch if possible
if b_size == 1 or self.conf.force_num_lines:
lines = torch.tensor(lines, dtype=torch.float, device=device)
line_scores = torch.tensor(line_scores, dtype=torch.float, device=device)
valid_lines = torch.tensor(valid_lines, dtype=torch.bool, device=device)
return {"lines": lines, "line_scores": line_scores, "valid_lines": valid_lines}
def loss(self, pred, data):
raise NotImplementedError

View File

@ -0,0 +1,312 @@
import torch
from sklearn.cluster import DBSCAN
from ..base_model import BaseModel
from .. import get_model
def sample_descriptors_corner_conv(keypoints, descriptors, s: int = 8):
"""Interpolate descriptors at keypoint locations"""
b, c, h, w = descriptors.shape
keypoints = keypoints / (keypoints.new_tensor([w, h]) * s)
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
descriptors = torch.nn.functional.grid_sample(
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False
)
descriptors = torch.nn.functional.normalize(
descriptors.reshape(b, c, -1), p=2, dim=1
)
return descriptors
def lines_to_wireframe(
lines, line_scores, all_descs, s, nms_radius, force_num_lines, max_num_lines
):
"""Given a set of lines, their score and dense descriptors,
merge close-by endpoints and compute a wireframe defined by
its junctions and connectivity.
Returns:
junctions: list of [num_junc, 2] tensors listing all wireframe junctions
junc_scores: list of [num_junc] tensors with the junction score
junc_descs: list of [dim, num_junc] tensors with the junction descriptors
connectivity: list of [num_junc, num_junc] bool arrays with True when 2
junctions are connected
new_lines: the new set of [b_size, num_lines, 2, 2] lines
lines_junc_idx: a [b_size, num_lines, 2] tensor with the indices of the
junctions of each endpoint
num_true_junctions: a list of the number of valid junctions for each image
in the batch, i.e. before filling with random ones
"""
b_size, _, h, w = all_descs.shape
device = lines.device
h, w = h * s, w * s
endpoints = lines.reshape(b_size, -1, 2)
(
junctions,
junc_scores,
connectivity,
new_lines,
lines_junc_idx,
num_true_junctions,
) = ([], [], [], [], [], [])
for bs in range(b_size):
# Cluster the junctions that are close-by
db = DBSCAN(eps=nms_radius, min_samples=1).fit(endpoints[bs].cpu().numpy())
clusters = db.labels_
n_clusters = len(set(clusters))
num_true_junctions.append(n_clusters)
# Compute the average junction and score for each cluster
clusters = torch.tensor(clusters, dtype=torch.long, device=device)
new_junc = torch.zeros(n_clusters, 2, dtype=torch.float, device=device)
new_junc.scatter_reduce_(
0,
clusters[:, None].repeat(1, 2),
endpoints[bs],
reduce="mean",
include_self=False,
)
junctions.append(new_junc)
new_scores = torch.zeros(n_clusters, dtype=torch.float, device=device)
new_scores.scatter_reduce_(
0,
clusters,
torch.repeat_interleave(line_scores[bs], 2),
reduce="mean",
include_self=False,
)
junc_scores.append(new_scores)
# Compute the new lines
new_lines.append(junctions[-1][clusters].reshape(-1, 2, 2))
lines_junc_idx.append(clusters.reshape(-1, 2))
if force_num_lines:
# Add random junctions (with no connectivity)
missing = max_num_lines * 2 - len(junctions[-1])
junctions[-1] = torch.cat(
[
junctions[-1],
torch.rand(missing, 2).to(lines)
* lines.new_tensor([[w - 1, h - 1]]),
],
dim=0,
)
junc_scores[-1] = torch.cat(
[junc_scores[-1], torch.zeros(missing).to(lines)], dim=0
)
junc_connect = torch.eye(max_num_lines * 2, dtype=torch.bool, device=device)
pairs = clusters.reshape(-1, 2) # these pairs are connected by a line
junc_connect[pairs[:, 0], pairs[:, 1]] = True
junc_connect[pairs[:, 1], pairs[:, 0]] = True
connectivity.append(junc_connect)
else:
# Compute the junction connectivity
junc_connect = torch.eye(n_clusters, dtype=torch.bool, device=device)
pairs = clusters.reshape(-1, 2) # these pairs are connected by a line
junc_connect[pairs[:, 0], pairs[:, 1]] = True
junc_connect[pairs[:, 1], pairs[:, 0]] = True
connectivity.append(junc_connect)
junctions = torch.stack(junctions, dim=0)
new_lines = torch.stack(new_lines, dim=0)
lines_junc_idx = torch.stack(lines_junc_idx, dim=0)
# Interpolate the new junction descriptors
junc_descs = sample_descriptors_corner_conv(junctions, all_descs, s).mT
return (
junctions,
junc_scores,
junc_descs,
connectivity,
new_lines,
lines_junc_idx,
num_true_junctions,
)
class WireframeExtractor(BaseModel):
default_conf = {
"point_extractor": {
"name": None,
"trainable": False,
"dense_outputs": True,
"max_num_keypoints": None,
"force_num_keypoints": False,
},
"line_extractor": {
"name": None,
"trainable": False,
"max_num_lines": None,
"force_num_lines": False,
"min_length": 15,
},
"wireframe_params": {
"merge_points": True,
"merge_line_endpoints": True,
"nms_radius": 3,
},
}
required_data_keys = ["image"]
def _init(self, conf):
self.point_extractor = get_model(self.conf.point_extractor.name)(
self.conf.point_extractor
)
self.line_extractor = get_model(self.conf.line_extractor.name)(
self.conf.line_extractor
)
def _forward(self, data):
b_size, _, h, w = data["image"].shape
device = data["image"].device
if (
not self.conf.point_extractor.force_num_keypoints
or not self.conf.line_extractor.force_num_lines
):
assert b_size == 1, "Only batch size of 1 accepted for non padded inputs"
# Line detection
pred = self.line_extractor(data)
if pred["line_scores"].shape[-1] != 0:
pred["line_scores"] /= pred["line_scores"].max(dim=1)[0][:, None] + 1e-8
# Keypoint prediction
pred = {**pred, **self.point_extractor(data)}
assert (
"dense_descriptors" in pred
), "The KP extractor should return dense descriptors"
s_desc = data["image"].shape[2] // pred["dense_descriptors"].shape[2]
# Remove keypoints that are too close to line endpoints
if self.conf.wireframe_params.merge_points:
line_endpts = pred["lines"].reshape(b_size, -1, 2)
dist_pt_lines = torch.norm(
pred["keypoints"][:, :, None] - line_endpts[:, None], dim=-1
)
# For each keypoint, mark it as valid or to remove
pts_to_remove = torch.any(
dist_pt_lines < self.conf.wireframe_params.nms_radius, dim=2
)
if self.conf.point_extractor.force_num_keypoints:
# Replace the points with random ones
num_to_remove = pts_to_remove.int().sum().item()
pred["keypoints"][pts_to_remove] = torch.rand(
num_to_remove, 2, device=device
) * pred["keypoints"].new_tensor([[w - 1, h - 1]])
pred["keypoint_scores"][pts_to_remove] = 0
for bs in range(b_size):
descrs = sample_descriptors_corner_conv(
pred["keypoints"][bs][pts_to_remove[bs]][None],
pred["dense_descriptors"][bs][None],
s_desc,
)
pred["descriptors"][bs][pts_to_remove[bs]] = descrs[0].T
else:
# Simply remove them (we assume batch_size = 1 here)
assert len(pred["keypoints"]) == 1
pred["keypoints"] = pred["keypoints"][0][~pts_to_remove[0]][None]
pred["keypoint_scores"] = pred["keypoint_scores"][0][~pts_to_remove[0]][
None
]
pred["descriptors"] = pred["descriptors"][0][~pts_to_remove[0]][None]
# Connect the lines together to form a wireframe
orig_lines = pred["lines"].clone()
if (
self.conf.wireframe_params.merge_line_endpoints
and len(pred["lines"][0]) > 0
):
# Merge first close-by endpoints to connect lines
(
line_points,
line_pts_scores,
line_descs,
line_association,
pred["lines"],
lines_junc_idx,
n_true_junctions,
) = lines_to_wireframe(
pred["lines"],
pred["line_scores"],
pred["dense_descriptors"],
s=s_desc,
nms_radius=self.conf.wireframe_params.nms_radius,
force_num_lines=self.conf.line_extractor.force_num_lines,
max_num_lines=self.conf.line_extractor.max_num_lines,
)
# Add the keypoints to the junctions and fill the rest with random keypoints
(all_points, all_scores, all_descs, pl_associativity) = [], [], [], []
for bs in range(b_size):
all_points.append(
torch.cat([line_points[bs], pred["keypoints"][bs]], dim=0)
)
all_scores.append(
torch.cat([line_pts_scores[bs], pred["keypoint_scores"][bs]], dim=0)
)
all_descs.append(
torch.cat([line_descs[bs], pred["descriptors"][bs]], dim=0)
)
associativity = torch.eye(
len(all_points[-1]), dtype=torch.bool, device=device
)
associativity[
: n_true_junctions[bs], : n_true_junctions[bs]
] = line_association[bs][: n_true_junctions[bs], : n_true_junctions[bs]]
pl_associativity.append(associativity)
all_points = torch.stack(all_points, dim=0)
all_scores = torch.stack(all_scores, dim=0)
all_descs = torch.stack(all_descs, dim=0)
pl_associativity = torch.stack(pl_associativity, dim=0)
else:
# Lines are independent
all_points = torch.cat(
[pred["lines"].reshape(b_size, -1, 2), pred["keypoints"]], dim=1
)
n_pts = all_points.shape[1]
num_lines = pred["lines"].shape[1]
n_true_junctions = [num_lines * 2] * b_size
all_scores = torch.cat(
[
torch.repeat_interleave(pred["line_scores"], 2, dim=1),
pred["keypoint_scores"],
],
dim=1,
)
line_descs = sample_descriptors_corner_conv(
pred["lines"].reshape(b_size, -1, 2), pred["dense_descriptors"], s_desc
).mT # [B, n_lines * 2, desc_dim]
all_descs = torch.cat([line_descs, pred["descriptors"]], dim=1)
pl_associativity = torch.eye(n_pts, dtype=torch.bool, device=device)[
None
].repeat(b_size, 1, 1)
lines_junc_idx = (
torch.arange(num_lines * 2, device=device)
.reshape(1, -1, 2)
.repeat(b_size, 1, 1)
)
del pred["dense_descriptors"] # Remove dense descriptors to save memory
torch.cuda.empty_cache()
pred["keypoints"] = all_points
pred["keypoint_scores"] = all_scores
pred["descriptors"] = all_descs
pred["pl_associativity"] = pl_associativity
pred["num_junctions"] = torch.tensor(n_true_junctions)
pred["orig_lines"] = orig_lines
pred["lines_junc_idx"] = lines_junc_idx
return pred
def loss(self, pred, data):
raise NotImplementedError
def metrics(self, _pred, _data):
return {}

View File

View File

View File

@ -0,0 +1,81 @@
from ..base_model import BaseModel
from ...geometry.gt_generation import (
gt_matches_from_pose_depth,
gt_line_matches_from_pose_depth,
)
import torch
class DepthMatcher(BaseModel):
default_conf = {
# GT parameters for points
"use_points": True,
"th_positive": 3.0,
"th_negative": 5.0,
"th_epi": None, # add some more epi outliers
"th_consistency": None, # check for projection consistency in px
# GT parameters for lines
"use_lines": False,
"n_line_sampled_pts": 50,
"line_perp_dist_th": 5,
"overlap_th": 0.2,
"min_visibility_th": 0.5,
}
required_data_keys = ["view0", "view1", "T_0to1", "T_1to0"]
def _init(self, conf):
# TODO (iago): Is this just boilerplate code?
if self.conf.use_points:
self.required_data_keys += ["keypoints0", "keypoints1"]
if self.conf.use_lines:
self.required_data_keys += [
"lines0",
"lines1",
"valid_lines0",
"valid_lines1",
]
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def _forward(self, data):
result = {}
if self.conf.use_points:
if "depth_keypoints0" in data:
keys = [
"depth_keypoints0",
"valid_depth_keypoints0",
"depth_keypoints1",
"valid_depth_keypoints1",
]
kw = {k: data[k] for k in keys}
else:
kw = {}
result = gt_matches_from_pose_depth(
data["keypoints0"],
data["keypoints1"],
data,
pos_th=self.conf.th_positive,
neg_th=self.conf.th_negative,
epi_th=self.conf.th_epi,
cc_th=self.conf.th_consistency,
**kw,
)
if self.conf.use_lines:
line_assignment, line_m0, line_m1 = gt_line_matches_from_pose_depth(
data["lines0"],
data["lines1"],
data["valid_lines0"],
data["valid_lines1"],
data,
self.conf.n_line_sampled_pts,
self.conf.line_perp_dist_th,
self.conf.overlap_th,
self.conf.min_visibility_th,
)
result["line_matches0"] = line_m0
result["line_matches1"] = line_m1
result["line_assignment"] = line_assignment
return result
def loss(self, pred, data):
raise NotImplementedError

View File

@ -0,0 +1,778 @@
import logging
import warnings
from copy import deepcopy
from pathlib import Path
import torch
import torch.utils.checkpoint
from torch import nn
from ..base_model import BaseModel
from ..utils.metrics import matcher_metrics
from ...settings import DATA_PATH
warnings.filterwarnings("ignore", category=UserWarning)
ETH_EPS = 1e-8
class GlueStick(BaseModel):
default_conf = {
"input_dim": 256,
"descriptor_dim": 256,
"weights": None,
"version": "v0.1_arxiv",
"keypoint_encoder": [32, 64, 128, 256],
"GNN_layers": ["self", "cross"] * 9,
"num_line_iterations": 1,
"line_attention": False,
"filter_threshold": 0.2,
"checkpointed": False,
"skip_init": False,
"inter_supervision": None,
"loss": {
"nll_weight": 1.0,
"nll_balancing": 0.5,
"inter_supervision": [0.3, 0.6],
},
}
required_data_keys = [
"view0",
"view1",
"keypoints0",
"keypoints1",
"descriptors0",
"descriptors1",
"keypoint_scores0",
"keypoint_scores1",
"lines0",
"lines1",
"lines_junc_idx0",
"lines_junc_idx1",
"line_scores0",
"line_scores1",
]
DEFAULT_LOSS_CONF = {"nll_weight": 1.0, "nll_balancing": 0.5}
url = (
"https://github.com/cvg/GlueStick/releases/download/{}/"
"checkpoint_GlueStick_MD.tar"
)
def _init(self, conf):
if conf.input_dim != conf.descriptor_dim:
self.input_proj = nn.Conv1d(
conf.input_dim, conf.descriptor_dim, kernel_size=1
)
nn.init.constant_(self.input_proj.bias, 0.0)
self.kenc = KeypointEncoder(conf.descriptor_dim, conf.keypoint_encoder)
self.lenc = EndPtEncoder(conf.descriptor_dim, conf.keypoint_encoder)
self.gnn = AttentionalGNN(
conf.descriptor_dim,
conf.GNN_layers,
checkpointed=conf.checkpointed,
inter_supervision=conf.inter_supervision,
num_line_iterations=conf.num_line_iterations,
line_attention=conf.line_attention,
)
self.final_proj = nn.Conv1d(
conf.descriptor_dim, conf.descriptor_dim, kernel_size=1
)
nn.init.constant_(self.final_proj.bias, 0.0)
nn.init.orthogonal_(self.final_proj.weight, gain=1)
self.final_line_proj = nn.Conv1d(
conf.descriptor_dim, conf.descriptor_dim, kernel_size=1
)
nn.init.constant_(self.final_line_proj.bias, 0.0)
nn.init.orthogonal_(self.final_line_proj.weight, gain=1)
if conf.inter_supervision is not None:
self.inter_line_proj = nn.ModuleList(
[
nn.Conv1d(conf.descriptor_dim, conf.descriptor_dim, kernel_size=1)
for _ in conf.inter_supervision
]
)
self.layer2idx = {}
for i, l in enumerate(conf.inter_supervision):
nn.init.constant_(self.inter_line_proj[i].bias, 0.0)
nn.init.orthogonal_(self.inter_line_proj[i].weight, gain=1)
self.layer2idx[l] = i
bin_score = torch.nn.Parameter(torch.tensor(1.0))
self.register_parameter("bin_score", bin_score)
line_bin_score = torch.nn.Parameter(torch.tensor(1.0))
self.register_parameter("line_bin_score", line_bin_score)
if conf.weights:
assert isinstance(conf.weights, (Path, str))
fname = DATA_PATH / "weights" / f"{conf.weights}_{conf.version}.tar"
fname.parent.mkdir(exist_ok=True, parents=True)
if Path(conf.weights).exists():
logging.info(f'Loading GlueStick model from "{conf.weights}"')
state_dict = torch.load(conf.weights, map_location="cpu")
elif fname.exists():
logging.info(f'Loading GlueStick model from "{fname}"')
state_dict = torch.load(fname, map_location="cpu")
else:
logging.info(
"Loading GlueStick model from " f'"{self.url.format(conf.version)}"'
)
state_dict = torch.hub.load_state_dict_from_url(
self.url.format(conf.version), file_name=fname
)
if "model" in state_dict:
state_dict = {
k.replace("matcher.", ""): v
for k, v in state_dict["model"].items()
if "matcher." in k
}
state_dict = {
k.replace("module.", ""): v for k, v in state_dict.items()
}
self.load_state_dict(state_dict)
def _forward(self, data):
device = data["keypoints0"].device
b_size = len(data["keypoints0"])
image_size0 = (
data["view0"]["image_size"]
if "image_size" in data["view0"]
else data["view0"]["image"].shape
)
image_size1 = (
data["view1"]["image_size"]
if "image_size" in data["view1"]
else data["view1"]["image"].shape
)
pred = {}
desc0, desc1 = data["descriptors0"].mT, data["descriptors1"].mT
kpts0, kpts1 = data["keypoints0"], data["keypoints1"]
n_kpts0, n_kpts1 = kpts0.shape[1], kpts1.shape[1]
n_lines0, n_lines1 = data["lines0"].shape[1], data["lines1"].shape[1]
if n_kpts0 == 0 or n_kpts1 == 0:
# No detected keypoints nor lines
pred["log_assignment"] = torch.zeros(
b_size, n_kpts0, n_kpts1, dtype=torch.float, device=device
)
pred["matches0"] = torch.full(
(b_size, n_kpts0), -1, device=device, dtype=torch.int64
)
pred["matches1"] = torch.full(
(b_size, n_kpts1), -1, device=device, dtype=torch.int64
)
pred["matching_scores0"] = torch.zeros(
(b_size, n_kpts0), device=device, dtype=torch.float32
)
pred["matching_scores1"] = torch.zeros(
(b_size, n_kpts1), device=device, dtype=torch.float32
)
pred["line_log_assignment"] = torch.zeros(
b_size, n_lines0, n_lines1, dtype=torch.float, device=device
)
pred["line_matches0"] = torch.full(
(b_size, n_lines0), -1, device=device, dtype=torch.int64
)
pred["line_matches1"] = torch.full(
(b_size, n_lines1), -1, device=device, dtype=torch.int64
)
pred["line_matching_scores0"] = torch.zeros(
(b_size, n_lines0), device=device, dtype=torch.float32
)
pred["line_matching_scores1"] = torch.zeros(
(b_size, n_kpts1), device=device, dtype=torch.float32
)
return pred
lines0 = data["lines0"].flatten(1, 2)
lines1 = data["lines1"].flatten(1, 2)
# [b_size, num_lines * 2]
lines_junc_idx0 = data["lines_junc_idx0"].flatten(1, 2)
lines_junc_idx1 = data["lines_junc_idx1"].flatten(1, 2)
if self.conf.input_dim != self.conf.descriptor_dim:
desc0 = self.input_proj(desc0)
desc1 = self.input_proj(desc1)
kpts0 = normalize_keypoints(kpts0, image_size0)
kpts1 = normalize_keypoints(kpts1, image_size1)
assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)
desc0 = desc0 + self.kenc(kpts0, data["keypoint_scores0"])
desc1 = desc1 + self.kenc(kpts1, data["keypoint_scores1"])
if n_lines0 != 0 and n_lines1 != 0:
# Pre-compute the line encodings
lines0 = normalize_keypoints(lines0, image_size0).reshape(
b_size, n_lines0, 2, 2
)
lines1 = normalize_keypoints(lines1, image_size1).reshape(
b_size, n_lines1, 2, 2
)
line_enc0 = self.lenc(lines0, data["line_scores0"])
line_enc1 = self.lenc(lines1, data["line_scores1"])
else:
line_enc0 = torch.zeros(
b_size,
self.conf.descriptor_dim,
n_lines0 * 2,
dtype=torch.float,
device=device,
)
line_enc1 = torch.zeros(
b_size,
self.conf.descriptor_dim,
n_lines1 * 2,
dtype=torch.float,
device=device,
)
desc0, desc1 = self.gnn(
desc0, desc1, line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1
)
# Match all points (KP and line junctions)
mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
kp_scores = torch.einsum("bdn,bdm->bnm", mdesc0, mdesc1)
kp_scores = kp_scores / self.conf.descriptor_dim**0.5
kp_scores = log_double_softmax(kp_scores, self.bin_score)
m0, m1, mscores0, mscores1 = self._get_matches(kp_scores)
pred["log_assignment"] = kp_scores
pred["matches0"] = m0
pred["matches1"] = m1
pred["matching_scores0"] = mscores0
pred["matching_scores1"] = mscores1
# Match the lines
if n_lines0 > 0 and n_lines1 > 0:
(
line_scores,
m0_lines,
m1_lines,
mscores0_lines,
mscores1_lines,
raw_line_scores,
) = self._get_line_matches(
desc0[:, :, : 2 * n_lines0],
desc1[:, :, : 2 * n_lines1],
lines_junc_idx0,
lines_junc_idx1,
self.final_line_proj,
)
if self.conf.inter_supervision:
for layer in self.conf.inter_supervision:
(
line_scores_i,
m0_lines_i,
m1_lines_i,
mscores0_lines_i,
mscores1_lines_i,
_,
) = self._get_line_matches(
self.gnn.inter_layers[layer][0][:, :, : 2 * n_lines0],
self.gnn.inter_layers[layer][1][:, :, : 2 * n_lines1],
lines_junc_idx0,
lines_junc_idx1,
self.inter_line_proj[self.layer2idx[layer]],
)
pred[f"line_{layer}_log_assignment"] = line_scores_i
pred[f"line_{layer}_matches0"] = m0_lines_i
pred[f"line_{layer}_matches1"] = m1_lines_i
pred[f"line_{layer}_matching_scores0"] = mscores0_lines_i
pred[f"line_{layer}_matching_scores1"] = mscores1_lines_i
else:
line_scores = torch.zeros(
b_size, n_lines0, n_lines1, dtype=torch.float, device=device
)
m0_lines = torch.full(
(b_size, n_lines0), -1, device=device, dtype=torch.int64
)
m1_lines = torch.full(
(b_size, n_lines1), -1, device=device, dtype=torch.int64
)
mscores0_lines = torch.zeros(
(b_size, n_lines0), device=device, dtype=torch.float32
)
mscores1_lines = torch.zeros(
(b_size, n_lines1), device=device, dtype=torch.float32
)
raw_line_scores = torch.zeros(
b_size, n_lines0, n_lines1, dtype=torch.float, device=device
)
pred["line_log_assignment"] = line_scores
pred["line_matches0"] = m0_lines
pred["line_matches1"] = m1_lines
pred["line_matching_scores0"] = mscores0_lines
pred["line_matching_scores1"] = mscores1_lines
pred["raw_line_scores"] = raw_line_scores
return pred
def _get_matches(self, scores_mat):
max0 = scores_mat[:, :-1, :-1].max(2)
max1 = scores_mat[:, :-1, :-1].max(1)
m0, m1 = max0.indices, max1.indices
mutual0 = arange_like(m0, 1)[None] == m1.gather(1, m0)
mutual1 = arange_like(m1, 1)[None] == m0.gather(1, m1)
zero = scores_mat.new_tensor(0)
mscores0 = torch.where(mutual0, max0.values.exp(), zero)
mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
valid0 = mutual0 & (mscores0 > self.conf.filter_threshold)
valid1 = mutual1 & valid0.gather(1, m1)
m0 = torch.where(valid0, m0, m0.new_tensor(-1))
m1 = torch.where(valid1, m1, m1.new_tensor(-1))
return m0, m1, mscores0, mscores1
def _get_line_matches(
self, ldesc0, ldesc1, lines_junc_idx0, lines_junc_idx1, final_proj
):
mldesc0 = final_proj(ldesc0)
mldesc1 = final_proj(ldesc1)
line_scores = torch.einsum("bdn,bdm->bnm", mldesc0, mldesc1)
line_scores = line_scores / self.conf.descriptor_dim**0.5
# Get the line representation from the junction descriptors
n2_lines0 = lines_junc_idx0.shape[1]
n2_lines1 = lines_junc_idx1.shape[1]
line_scores = torch.gather(
line_scores,
dim=2,
index=lines_junc_idx1[:, None, :].repeat(1, line_scores.shape[1], 1),
)
line_scores = torch.gather(
line_scores,
dim=1,
index=lines_junc_idx0[:, :, None].repeat(1, 1, n2_lines1),
)
line_scores = line_scores.reshape((-1, n2_lines0 // 2, 2, n2_lines1 // 2, 2))
# Match either in one direction or the other
raw_line_scores = 0.5 * torch.maximum(
line_scores[:, :, 0, :, 0] + line_scores[:, :, 1, :, 1],
line_scores[:, :, 0, :, 1] + line_scores[:, :, 1, :, 0],
)
line_scores = log_double_softmax(raw_line_scores, self.line_bin_score)
m0_lines, m1_lines, mscores0_lines, mscores1_lines = self._get_matches(
line_scores
)
return (
line_scores,
m0_lines,
m1_lines,
mscores0_lines,
mscores1_lines,
raw_line_scores,
)
def sub_loss(self, pred, data, losses, bin_score, prefix="", layer=-1):
line_suffix = "" if layer == -1 else f"{layer}_"
layer_weight = (
1.0
if layer == -1
else self.conf.loss.inter_supervision[self.layer2idx[layer]]
)
positive = data["gt_" + prefix + "assignment"].float()
num_pos = torch.max(positive.sum((1, 2)), positive.new_tensor(1))
neg0 = (data["gt_" + prefix + "matches0"] == -1).float()
neg1 = (data["gt_" + prefix + "matches1"] == -1).float()
num_neg = torch.max(neg0.sum(1) + neg1.sum(1), neg0.new_tensor(1))
log_assignment = pred[prefix + line_suffix + "log_assignment"]
nll_pos = -(log_assignment[:, :-1, :-1] * positive).sum((1, 2))
nll_pos /= num_pos
nll_neg0 = -(log_assignment[:, :-1, -1] * neg0).sum(1)
nll_neg1 = -(log_assignment[:, -1, :-1] * neg1).sum(1)
nll_neg = (nll_neg0 + nll_neg1) / num_neg
nll = (
self.conf.loss.nll_balancing * nll_pos
+ (1 - self.conf.loss.nll_balancing) * nll_neg
)
losses[prefix + line_suffix + "assignment_nll"] = nll
if self.conf.loss.nll_weight > 0:
losses["total"] += nll * self.conf.loss.nll_weight * layer_weight
# Some statistics
if line_suffix == "":
losses[prefix + "num_matchable"] = num_pos
losses[prefix + "num_unmatchable"] = num_neg
losses[prefix + "sinkhorn_norm"] = (
log_assignment.exp()[:, :-1].sum(2).mean(1)
)
losses[prefix + "bin_score"] = bin_score[None]
return losses
def loss(self, pred, data):
losses = {"total": 0}
# If there are keypoints add their loss terms
if not (data["keypoints0"].shape[1] == 0 or data["keypoints1"].shape[1] == 0):
losses = self.sub_loss(pred, data, losses, self.bin_score, prefix="")
# If there are lines add their loss terms
if (
"lines0" in data
and "lines1" in data
and data["lines0"].shape[1] > 0
and data["lines1"].shape[1] > 0
):
losses = self.sub_loss(
pred, data, losses, self.line_bin_score, prefix="line_"
)
if self.conf.inter_supervision:
for layer in self.conf.inter_supervision:
losses = self.sub_loss(
pred, data, losses, self.line_bin_score, prefix="line_", layer=layer
)
# Compute the metrics
metrics = {}
if not self.training:
if (
"matches0" in pred
and pred["matches0"].shape[1] > 0
and pred["matches1"].shape[1] > 0
):
metrics = {**metrics, **matcher_metrics(pred, data, prefix="")}
if (
"line_matches0" in pred
and data["lines0"].shape[1] > 0
and data["lines1"].shape[1] > 0
):
metrics = {**metrics, **matcher_metrics(pred, data, prefix="line_")}
if self.conf.inter_supervision:
for layer in self.conf.inter_supervision:
inter_metrics = matcher_metrics(
pred, data, prefix=f"line_{layer}_", prefix_gt="line_"
)
metrics = {**metrics, **inter_metrics}
return losses, metrics
def MLP(channels, do_bn=True):
n = len(channels)
layers = []
for i in range(1, n):
layers.append(nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
if i < (n - 1):
if do_bn:
layers.append(nn.BatchNorm1d(channels[i]))
layers.append(nn.ReLU())
return nn.Sequential(*layers)
def normalize_keypoints(kpts, shape_or_size):
if isinstance(shape_or_size, (tuple, list)):
# it"s a shape
h, w = shape_or_size[-2:]
size = kpts.new_tensor([[w, h]])
else:
# it"s a size
assert isinstance(shape_or_size, torch.Tensor)
size = shape_or_size.to(kpts)
c = size / 2
f = size.max(1, keepdim=True).values * 0.7 # somehow we used 0.7 for SG
return (kpts - c[:, None, :]) / f[:, None, :]
class KeypointEncoder(nn.Module):
def __init__(self, feature_dim, layers):
super().__init__()
self.encoder = MLP([3] + list(layers) + [feature_dim], do_bn=True)
nn.init.constant_(self.encoder[-1].bias, 0.0)
def forward(self, kpts, scores):
inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)]
return self.encoder(torch.cat(inputs, dim=1))
class EndPtEncoder(nn.Module):
def __init__(self, feature_dim, layers):
super().__init__()
self.encoder = MLP([5] + list(layers) + [feature_dim], do_bn=True)
nn.init.constant_(self.encoder[-1].bias, 0.0)
def forward(self, endpoints, scores):
# endpoints should be [B, N, 2, 2]
# output is [B, feature_dim, N * 2]
b_size, n_pts, _, _ = endpoints.shape
assert tuple(endpoints.shape[-2:]) == (2, 2)
endpt_offset = (endpoints[:, :, 1] - endpoints[:, :, 0]).unsqueeze(2)
endpt_offset = torch.cat([endpt_offset, -endpt_offset], dim=2)
endpt_offset = endpt_offset.reshape(b_size, 2 * n_pts, 2).transpose(1, 2)
inputs = [
endpoints.flatten(1, 2).transpose(1, 2),
endpt_offset,
scores.repeat(1, 2).unsqueeze(1),
]
return self.encoder(torch.cat(inputs, dim=1))
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def attention(query, key, value):
dim = query.shape[1]
scores = torch.einsum("bdhn,bdhm->bhnm", query, key) / dim**0.5
prob = torch.nn.functional.softmax(scores, dim=-1)
return torch.einsum("bhnm,bdhm->bdhn", prob, value), prob
class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_model):
super().__init__()
assert d_model % h == 0
self.dim = d_model // h
self.h = h
self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
# self.prob = []
def forward(self, query, key, value):
b = query.size(0)
query, key, value = [
layer(x).view(b, self.dim, self.h, -1)
for layer, x in zip(self.proj, (query, key, value))
]
x, prob = attention(query, key, value)
# self.prob.append(prob.mean(dim=1))
return self.merge(x.contiguous().view(b, self.dim * self.h, -1))
class AttentionalPropagation(nn.Module):
def __init__(self, num_dim, num_heads, skip_init=False):
super().__init__()
self.attn = MultiHeadedAttention(num_heads, num_dim)
self.mlp = MLP([num_dim * 2, num_dim * 2, num_dim], do_bn=True)
nn.init.constant_(self.mlp[-1].bias, 0.0)
if skip_init:
self.register_parameter("scaling", nn.Parameter(torch.tensor(0.0)))
else:
self.scaling = 1.0
def forward(self, x, source):
message = self.attn(x, source, source)
return self.mlp(torch.cat([x, message], dim=1)) * self.scaling
class GNNLayer(nn.Module):
def __init__(self, feature_dim, layer_type, skip_init):
super().__init__()
assert layer_type in ["cross", "self"]
self.type = layer_type
self.update = AttentionalPropagation(feature_dim, 4, skip_init)
def forward(self, desc0, desc1):
if self.type == "cross":
src0, src1 = desc1, desc0
elif self.type == "self":
src0, src1 = desc0, desc1
else:
raise ValueError("Unknown layer type: " + self.type)
# self.update.attn.prob = []
delta0, delta1 = self.update(desc0, src0), self.update(desc1, src1)
desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
return desc0, desc1
class LineLayer(nn.Module):
def __init__(self, feature_dim, line_attention=False):
super().__init__()
self.dim = feature_dim
self.mlp = MLP([self.dim * 3, self.dim * 2, self.dim], do_bn=True)
self.line_attention = line_attention
if line_attention:
self.proj_node = nn.Conv1d(self.dim, self.dim, kernel_size=1)
self.proj_neigh = nn.Conv1d(2 * self.dim, self.dim, kernel_size=1)
def get_endpoint_update(self, ldesc, line_enc, lines_junc_idx):
# ldesc is [bs, D, n_junc], line_enc [bs, D, n_lines * 2]
# and lines_junc_idx [bs, n_lines * 2]
# Create one message per line endpoint
b_size = lines_junc_idx.shape[0]
line_desc = torch.gather(
ldesc, 2, lines_junc_idx[:, None].repeat(1, self.dim, 1)
)
line_desc2 = line_desc.reshape(b_size, self.dim, -1, 2).flip([-1])
message = torch.cat(
[line_desc, line_desc2.flatten(2, 3).clone(), line_enc], dim=1
)
return self.mlp(message) # [b_size, D, n_lines * 2]
def get_endpoint_attention(self, ldesc, line_enc, lines_junc_idx):
# ldesc is [bs, D, n_junc], line_enc [bs, D, n_lines * 2]
# and lines_junc_idx [bs, n_lines * 2]
b_size = lines_junc_idx.shape[0]
expanded_lines_junc_idx = lines_junc_idx[:, None].repeat(1, self.dim, 1)
# Query: desc of the current node
query = self.proj_node(ldesc) # [b_size, D, n_junc]
query = torch.gather(query, 2, expanded_lines_junc_idx)
# query is [b_size, D, n_lines * 2]
# Key: combination of neighboring desc and line encodings
line_desc = torch.gather(ldesc, 2, expanded_lines_junc_idx)
line_desc2 = line_desc.reshape(b_size, self.dim, -1, 2).flip([-1])
key = self.proj_neigh(
torch.cat([line_desc2.flatten(2, 3).clone(), line_enc], dim=1)
) # [b_size, D, n_lines * 2]
# Compute the attention weights with a custom softmax per junction
prob = (query * key).sum(dim=1) / self.dim**0.5 # [b_size, n_lines * 2]
prob = torch.exp(prob - prob.max())
denom = torch.zeros_like(ldesc[:, 0]).scatter_reduce_(
dim=1, index=lines_junc_idx, src=prob, reduce="sum", include_self=False
) # [b_size, n_junc]
denom = torch.gather(denom, 1, lines_junc_idx) # [b_size, n_lines * 2]
prob = prob / (denom + ETH_EPS)
return prob # [b_size, n_lines * 2]
def forward(
self, ldesc0, ldesc1, line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1
):
# Gather the endpoint updates
lupdate0 = self.get_endpoint_update(ldesc0, line_enc0, lines_junc_idx0)
lupdate1 = self.get_endpoint_update(ldesc1, line_enc1, lines_junc_idx1)
update0, update1 = torch.zeros_like(ldesc0), torch.zeros_like(ldesc1)
dim = ldesc0.shape[1]
if self.line_attention:
# Compute an attention for each neighbor and do a weighted average
prob0 = self.get_endpoint_attention(ldesc0, line_enc0, lines_junc_idx0)
lupdate0 = lupdate0 * prob0[:, None]
update0 = update0.scatter_reduce_(
dim=2,
index=lines_junc_idx0[:, None].repeat(1, dim, 1),
src=lupdate0,
reduce="sum",
include_self=False,
)
prob1 = self.get_endpoint_attention(ldesc1, line_enc1, lines_junc_idx1)
lupdate1 = lupdate1 * prob1[:, None]
update1 = update1.scatter_reduce_(
dim=2,
index=lines_junc_idx1[:, None].repeat(1, dim, 1),
src=lupdate1,
reduce="sum",
include_self=False,
)
else:
# Average the updates for each junction (requires torch > 1.12)
update0 = update0.scatter_reduce_(
dim=2,
index=lines_junc_idx0[:, None].repeat(1, dim, 1),
src=lupdate0,
reduce="mean",
include_self=False,
)
update1 = update1.scatter_reduce_(
dim=2,
index=lines_junc_idx1[:, None].repeat(1, dim, 1),
src=lupdate1,
reduce="mean",
include_self=False,
)
# Update
ldesc0 = ldesc0 + update0
ldesc1 = ldesc1 + update1
return ldesc0, ldesc1
class AttentionalGNN(nn.Module):
def __init__(
self,
feature_dim,
layer_types,
checkpointed=False,
skip=False,
inter_supervision=None,
num_line_iterations=1,
line_attention=False,
):
super().__init__()
self.checkpointed = checkpointed
self.inter_supervision = inter_supervision
self.num_line_iterations = num_line_iterations
self.inter_layers = {}
self.layers = nn.ModuleList(
[GNNLayer(feature_dim, layer_type, skip) for layer_type in layer_types]
)
self.line_layers = nn.ModuleList(
[
LineLayer(feature_dim, line_attention)
for _ in range(len(layer_types) // 2)
]
)
def forward(
self, desc0, desc1, line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1
):
for i, layer in enumerate(self.layers):
if self.checkpointed:
desc0, desc1 = torch.utils.checkpoint.checkpoint(
layer, desc0, desc1, preserve_rng_state=False
)
else:
desc0, desc1 = layer(desc0, desc1)
if (
layer.type == "self"
and lines_junc_idx0.shape[1] > 0
and lines_junc_idx1.shape[1] > 0
):
# Add line self attention layers after every self layer
for _ in range(self.num_line_iterations):
if self.checkpointed:
desc0, desc1 = torch.utils.checkpoint.checkpoint(
self.line_layers[i // 2],
desc0,
desc1,
line_enc0,
line_enc1,
lines_junc_idx0,
lines_junc_idx1,
preserve_rng_state=False,
)
else:
desc0, desc1 = self.line_layers[i // 2](
desc0,
desc1,
line_enc0,
line_enc1,
lines_junc_idx0,
lines_junc_idx1,
)
# Optionally store the line descriptor at intermediate layers
if (
self.inter_supervision is not None
and (i // 2) in self.inter_supervision
and layer.type == "cross"
):
self.inter_layers[i // 2] = (desc0.clone(), desc1.clone())
return desc0, desc1
def log_double_softmax(scores, bin_score):
b, m, n = scores.shape
bin_ = bin_score[None, None, None]
scores0 = torch.cat([scores, bin_.expand(b, m, 1)], 2)
scores1 = torch.cat([scores, bin_.expand(b, 1, n)], 1)
scores0 = torch.nn.functional.log_softmax(scores0, 2)
scores1 = torch.nn.functional.log_softmax(scores1, 1)
scores = scores.new_full((b, m + 1, n + 1), 0)
scores[:, :m, :n] = (scores0[:, :, :n] + scores1[:, :m, :]) / 2
scores[:, :-1, -1] = scores0[:, :, -1]
scores[:, -1, :-1] = scores1[:, -1, :]
return scores
def arange_like(x, dim):
return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1

View File

@ -0,0 +1,66 @@
from ..base_model import BaseModel
from ...geometry.gt_generation import (
gt_matches_from_homography,
gt_line_matches_from_homography,
)
class HomographyMatcher(BaseModel):
default_conf = {
# GT parameters for points
"use_points": True,
"th_positive": 3.0,
"th_negative": 3.0,
# GT parameters for lines
"use_lines": False,
"n_line_sampled_pts": 50,
"line_perp_dist_th": 5,
"overlap_th": 0.2,
"min_visibility_th": 0.5,
}
required_data_keys = ["H_0to1"]
def _init(self, conf):
# TODO (iago): Is this just boilerplate code?
if self.conf.use_points:
self.required_data_keys += ["keypoints0", "keypoints1"]
if self.conf.use_lines:
self.required_data_keys += [
"lines0",
"lines1",
"valid_lines0",
"valid_lines1",
]
def _forward(self, data):
result = {}
if self.conf.use_points:
result = gt_matches_from_homography(
data["keypoints0"],
data["keypoints1"],
data["H_0to1"],
pos_th=self.conf.th_positive,
neg_th=self.conf.th_negative,
)
if self.conf.use_lines:
line_assignment, line_m0, line_m1 = gt_line_matches_from_homography(
data["lines0"],
data["lines1"],
data["valid_lines0"],
data["valid_lines1"],
data["view0"]["image"].shape,
data["view1"]["image"].shape,
data["H_0to1"],
self.conf.n_line_sampled_pts,
self.conf.line_perp_dist_th,
self.conf.overlap_th,
self.conf.min_visibility_th,
)
result["line_matches0"] = line_m0
result["line_matches1"] = line_m1
result["line_assignment"] = line_assignment
return result
def loss(self, pred, data):
raise NotImplementedError

View File

@ -0,0 +1,65 @@
import kornia
import torch
from ...models import BaseModel
class LoFTRModule(BaseModel):
default_conf = {
"topk": None,
"zero_pad": False,
}
required_data_keys = ["view0", "view1"]
def _init(self, conf):
self.net = kornia.feature.LoFTR(pretrained="outdoor")
def _forward(self, data):
image0 = data["view0"]["image"]
image1 = data["view1"]["image"]
if self.conf.zero_pad:
image0, mask0 = self.zero_pad(image0)
image1, mask1 = self.zero_pad(image1)
res = self.net(
{"image0": image0, "image1": image1, "mask0": mask0, "mask1": mask1}
)
res = self.net({"image0": image0, "image1": image1})
else:
res = self.net({"image0": image0, "image1": image1})
topk = self.conf.topk
if topk is not None and res["confidence"].shape[-1] > topk:
_, top = torch.topk(res["confidence"], topk, -1)
m_kpts0 = res["keypoints0"][None][:, top]
m_kpts1 = res["keypoints1"][None][:, top]
scores = res["confidence"][None][:, top]
else:
m_kpts0 = res["keypoints0"][None]
m_kpts1 = res["keypoints1"][None]
scores = res["confidence"][None]
m0 = torch.arange(0, scores.shape[-1]).to(scores.device)[None]
m1 = torch.arange(0, scores.shape[-1]).to(scores.device)[None]
return {
"matches0": m0,
"matches1": m1,
"matching_scores0": scores,
"keypoints0": m_kpts0,
"keypoints1": m_kpts1,
"keypoint_scores0": scores,
"keypoint_scores1": scores,
"matching_scores1": scores,
}
def zero_pad(self, img):
b, c, h, w = img.shape
if h == w:
return img
s = max(h, w)
image = torch.zeros((b, c, s, s)).to(img)
image[:, :, :h, :w] = img
mask = torch.zeros_like(image)
mask[:, :, :h, :w] = 1.0
return image, mask.squeeze(0).float()
def loss(self, pred, data):
return NotImplementedError

View File

@ -0,0 +1,610 @@
import warnings
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from typing import Optional, List, Callable
from torch.utils.checkpoint import checkpoint
from omegaconf import OmegaConf
from ...settings import DATA_PATH
from ..utils.losses import NLLLoss
from ..utils.metrics import matcher_metrics
from pathlib import Path
FLASH_AVAILABLE = hasattr(F, "scaled_dot_product_attention")
torch.backends.cudnn.deterministic = True
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def normalize_keypoints(
kpts: torch.Tensor, size: Optional[torch.Tensor] = None
) -> torch.Tensor:
if size is None:
size = 1 + kpts.max(-2).values - kpts.min(-2).values
elif not isinstance(size, torch.Tensor):
size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
size = size.to(kpts)
shift = size / 2
scale = size.max(-1).values / 2
kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
return kpts
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x = x.unflatten(-1, (-1, 2))
x1, x2 = x.unbind(dim=-1)
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return (t * freqs[0]) + (rotate_half(t) * freqs[1])
class LearnableFourierPositionalEncoding(nn.Module):
def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
super().__init__()
F_dim = F_dim if F_dim is not None else dim
self.gamma = gamma
self.Wr = nn.Linear(M, F_dim // 2, bias=False)
nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""encode position vector"""
projected = self.Wr(x)
cosines, sines = torch.cos(projected), torch.sin(projected)
emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
return emb.repeat_interleave(2, dim=-1)
class TokenConfidence(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
self.loss_fn = nn.BCEWithLogitsLoss(reduction="none")
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
"""get confidence tokens"""
return (
self.token(desc0.detach()).squeeze(-1),
self.token(desc1.detach()).squeeze(-1),
)
def loss(self, desc0, desc1, la_now, la_final):
logit0 = self.token[0](desc0.detach()).squeeze(-1)
logit1 = self.token[0](desc1.detach()).squeeze(-1)
la_now, la_final = la_now.detach(), la_final.detach()
correct0 = (
la_final[:, :-1, :].max(-1).indices == la_now[:, :-1, :].max(-1).indices
)
correct1 = (
la_final[:, :, :-1].max(-2).indices == la_now[:, :, :-1].max(-2).indices
)
return (
self.loss_fn(logit0, correct0.float()).mean(-1)
+ self.loss_fn(logit1, correct1.float()).mean(-1)
) / 2.0
class Attention(nn.Module):
def __init__(self, allow_flash: bool) -> None:
super().__init__()
if allow_flash and not FLASH_AVAILABLE:
warnings.warn(
"FlashAttention is not available. For optimal speed, "
"consider installing torch >= 2.0 or flash-attn.",
stacklevel=2,
)
self.enable_flash = allow_flash and FLASH_AVAILABLE
if FLASH_AVAILABLE:
torch.backends.cuda.enable_flash_sdp(allow_flash)
def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.enable_flash and q.device.type == "cuda":
# use torch 2.0 scaled_dot_product_attention with flash
if FLASH_AVAILABLE:
args = [x.half().contiguous() for x in [q, k, v]]
v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
return v if mask is None else v.nan_to_num()
elif FLASH_AVAILABLE:
args = [x.contiguous() for x in [q, k, v]]
v = F.scaled_dot_product_attention(*args, attn_mask=mask)
return v if mask is None else v.nan_to_num()
else:
s = q.shape[-1] ** -0.5
sim = torch.einsum("...id,...jd->...ij", q, k) * s
if mask is not None:
sim.masked_fill(~mask, -float("inf"))
attn = F.softmax(sim, -1)
return torch.einsum("...ij,...jd->...id", attn, v)
class SelfBlock(nn.Module):
def __init__(
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
) -> None:
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0
self.head_dim = self.embed_dim // num_heads
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
self.inner_attn = Attention(flash)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.ffn = nn.Sequential(
nn.Linear(2 * embed_dim, 2 * embed_dim),
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
nn.GELU(),
nn.Linear(2 * embed_dim, embed_dim),
)
def forward(
self,
x: torch.Tensor,
encoding: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv = self.Wqkv(x)
qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
q = apply_cached_rotary_emb(encoding, q)
k = apply_cached_rotary_emb(encoding, k)
context = self.inner_attn(q, k, v, mask=mask)
message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
return x + self.ffn(torch.cat([x, message], -1))
class CrossBlock(nn.Module):
def __init__(
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
) -> None:
super().__init__()
self.heads = num_heads
dim_head = embed_dim // num_heads
self.scale = dim_head**-0.5
inner_dim = dim_head * num_heads
self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
self.ffn = nn.Sequential(
nn.Linear(2 * embed_dim, 2 * embed_dim),
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
nn.GELU(),
nn.Linear(2 * embed_dim, embed_dim),
)
if flash and FLASH_AVAILABLE:
self.flash = Attention(True)
else:
self.flash = None
def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
return func(x0), func(x1)
def forward(
self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> List[torch.Tensor]:
qk0, qk1 = self.map_(self.to_qk, x0, x1)
v0, v1 = self.map_(self.to_v, x0, x1)
qk0, qk1, v0, v1 = map(
lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
(qk0, qk1, v0, v1),
)
if self.flash is not None and qk0.device.type == "cuda":
m0 = self.flash(qk0, qk1, v1, mask)
m1 = self.flash(
qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
)
else:
qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
if mask is not None:
sim = sim.masked_fill(~mask, -float("inf"))
attn01 = F.softmax(sim, dim=-1)
attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
if mask is not None:
m0, m1 = m0.nan_to_num(), m1.nan_to_num()
m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
m0, m1 = self.map_(self.to_out, m0, m1)
x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
return x0, x1
class TransformerLayer(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.self_attn = SelfBlock(*args, **kwargs)
self.cross_attn = CrossBlock(*args, **kwargs)
def forward(
self,
desc0,
desc1,
encoding0,
encoding1,
mask0: Optional[torch.Tensor] = None,
mask1: Optional[torch.Tensor] = None,
):
if mask0 is not None and mask1 is not None:
return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
else:
desc0 = self.self_attn(desc0, encoding0)
desc1 = self.self_attn(desc1, encoding1)
return self.cross_attn(desc0, desc1)
# This part is compiled and allows padding inputs
def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
mask = mask0 & mask1.transpose(-1, -2)
mask0 = mask0 & mask0.transpose(-1, -2)
mask1 = mask1 & mask1.transpose(-1, -2)
desc0 = self.self_attn(desc0, encoding0, mask0)
desc1 = self.self_attn(desc1, encoding1, mask1)
return self.cross_attn(desc0, desc1, mask)
def sigmoid_log_double_softmax(
sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
) -> torch.Tensor:
"""create the log assignment matrix from logits and similarity"""
b, m, n = sim.shape
certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
scores0 = F.log_softmax(sim, 2)
scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
scores = sim.new_full((b, m + 1, n + 1), 0)
scores[:, :m, :n] = scores0 + scores1 + certainties
scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
return scores
class MatchAssignment(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
self.matchability = nn.Linear(dim, 1, bias=True)
self.final_proj = nn.Linear(dim, dim, bias=True)
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
"""build assignment matrix from descriptors"""
mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
_, _, d = mdesc0.shape
mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
z0 = self.matchability(desc0)
z1 = self.matchability(desc1)
scores = sigmoid_log_double_softmax(sim, z0, z1)
return scores, sim
def get_matchability(self, desc: torch.Tensor):
return torch.sigmoid(self.matchability(desc)).squeeze(-1)
def filter_matches(scores: torch.Tensor, th: float):
"""obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
m0, m1 = max0.indices, max1.indices
indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
mutual0 = indices0 == m1.gather(1, m0)
mutual1 = indices1 == m0.gather(1, m1)
max0_exp = max0.values.exp()
zero = max0_exp.new_tensor(0)
mscores0 = torch.where(mutual0, max0_exp, zero)
mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
valid0 = mutual0 & (mscores0 > th)
valid1 = mutual1 & valid0.gather(1, m1)
m0 = torch.where(valid0, m0, -1)
m1 = torch.where(valid1, m1, -1)
return m0, m1, mscores0, mscores1
class LightGlue(nn.Module):
default_conf = {
"name": "lightglue", # just for interfacing
"input_dim": 256, # input descriptor dimension (autoselected from weights)
"add_scale_ori": False,
"descriptor_dim": 256,
"n_layers": 9,
"num_heads": 4,
"flash": False, # enable FlashAttention if available.
"mp": False, # enable mixed precision
"depth_confidence": -1, # early stopping, disable with -1
"width_confidence": -1, # point pruning, disable with -1
"filter_threshold": 0.0, # match threshold
"checkpointed": False,
"weights": None, # either a path or the name of pretrained weights (disk, ...)
"weights_from_version": "v0.1_arxiv",
"loss": {
"gamma": 1.0,
"fn": "nll",
"nll_balancing": 0.5,
},
}
required_data_keys = ["keypoints0", "keypoints1", "descriptors0", "descriptors1"]
url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
def __init__(self, conf) -> None:
super().__init__()
self.conf = conf = OmegaConf.merge(self.default_conf, conf)
if conf.input_dim != conf.descriptor_dim:
self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
else:
self.input_proj = nn.Identity()
head_dim = conf.descriptor_dim // conf.num_heads
self.posenc = LearnableFourierPositionalEncoding(
2 + 2 * conf.add_scale_ori, head_dim, head_dim
)
h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
self.transformers = nn.ModuleList(
[TransformerLayer(d, h, conf.flash) for _ in range(n)]
)
self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
self.token_confidence = nn.ModuleList(
[TokenConfidence(d) for _ in range(n - 1)]
)
self.loss_fn = NLLLoss(conf.loss)
state_dict = None
if conf.weights is not None:
# weights can be either a path or an existing file from official LG
if Path(conf.weights).exists():
state_dict = torch.load(conf.weights, map_location="cpu")
elif (Path(DATA_PATH) / conf.weights).exists():
state_dict = torch.load(
str(DATA_PATH / conf.weights), map_location="cpu"
)
else:
fname = (
f"{conf.weights}_{conf.weights_from_version}".replace(".", "-")
+ ".pth"
)
state_dict = torch.hub.load_state_dict_from_url(
self.url.format(conf.weights_from_version, conf.weights),
file_name=fname,
)
if state_dict:
# rename old state dict entries
for i in range(self.conf.n_layers):
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
self.load_state_dict(state_dict, strict=False)
def compile(self, mode="reduce-overhead"):
if self.conf.width_confidence != -1:
warnings.warn(
"Point pruning is partially disabled for compiled forward.",
stacklevel=2,
)
for i in range(self.conf.n_layers):
self.transformers[i] = torch.compile(
self.transformers[i], mode=mode, fullgraph=True
)
def forward(self, data: dict) -> dict:
for key in self.required_data_keys:
assert key in data, f"Missing key {key} in data"
kpts0, kpts1 = data["keypoints0"], data["keypoints1"]
b, m, _ = kpts0.shape
b, n, _ = kpts1.shape
device = kpts0.device
if "view0" in data.keys() and "view1" in data.keys():
size0 = data["view0"].get("image_size")
size1 = data["view1"].get("image_size")
kpts0 = normalize_keypoints(kpts0, size0).clone()
kpts1 = normalize_keypoints(kpts1, size1).clone()
if self.conf.add_scale_ori:
sc0, o0 = data["scales0"], data["oris0"]
sc1, o1 = data["scales1"], data["oris1"]
kpts0 = torch.cat(
[
kpts0,
sc0 if sc0.dim() == 3 else sc0[..., None],
o0 if o0.dim() == 3 else o0[..., None],
],
-1,
)
kpts1 = torch.cat(
[
kpts1,
sc1 if sc1.dim() == 3 else sc1[..., None],
o1 if o1.dim() == 3 else o1[..., None],
],
-1,
)
desc0 = data["descriptors0"].contiguous()
desc1 = data["descriptors1"].contiguous()
assert desc0.shape[-1] == self.conf.input_dim
assert desc1.shape[-1] == self.conf.input_dim
if torch.is_autocast_enabled():
desc0 = desc0.half()
desc1 = desc1.half()
desc0 = self.input_proj(desc0)
desc1 = self.input_proj(desc1)
# cache positional embeddings
encoding0 = self.posenc(kpts0)
encoding1 = self.posenc(kpts1)
# GNN + final_proj + assignment
do_early_stop = self.conf.depth_confidence > 0 and not self.training
do_point_pruning = self.conf.width_confidence > 0 and not self.training
all_desc0, all_desc1 = [], []
if do_point_pruning:
ind0 = torch.arange(0, m, device=device)[None]
ind1 = torch.arange(0, n, device=device)[None]
# We store the index of the layer at which pruning is detected.
prune0 = torch.ones_like(ind0)
prune1 = torch.ones_like(ind1)
token0, token1 = None, None
for i in range(self.conf.n_layers):
if self.conf.checkpointed and self.training:
desc0, desc1 = checkpoint(
self.transformers[i], desc0, desc1, encoding0, encoding1
)
else:
desc0, desc1 = self.transformers[i](desc0, desc1, encoding0, encoding1)
if self.training or i == self.conf.n_layers - 1:
all_desc0.append(desc0)
all_desc1.append(desc1)
continue # no early stopping or adaptive width at last layer
# only for eval
if do_early_stop:
assert b == 1
token0, token1 = self.token_confidence[i](desc0, desc1)
if self.check_if_stop(token0[..., :m, :], token1[..., :n, :], i, m + n):
break
if do_point_pruning:
assert b == 1
scores0 = self.log_assignment[i].get_matchability(desc0)
prunemask0 = self.get_pruning_mask(token0, scores0, i)
keep0 = torch.where(prunemask0)[1]
ind0 = ind0.index_select(1, keep0)
desc0 = desc0.index_select(1, keep0)
encoding0 = encoding0.index_select(-2, keep0)
prune0[:, ind0] += 1
scores1 = self.log_assignment[i].get_matchability(desc1)
prunemask1 = self.get_pruning_mask(token1, scores1, i)
keep1 = torch.where(prunemask1)[1]
ind1 = ind1.index_select(1, keep1)
desc1 = desc1.index_select(1, keep1)
encoding1 = encoding1.index_select(-2, keep1)
prune1[:, ind1] += 1
desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :]
scores, _ = self.log_assignment[i](desc0, desc1)
m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
if do_point_pruning:
m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
mscores0_ = torch.zeros((b, m), device=mscores0.device)
mscores1_ = torch.zeros((b, n), device=mscores1.device)
mscores0_[:, ind0] = mscores0
mscores1_[:, ind1] = mscores1
m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
else:
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
prune1 = torch.ones_like(mscores1) * self.conf.n_layers
pred = {
"matches0": m0,
"matches1": m1,
"matching_scores0": mscores0,
"matching_scores1": mscores1,
"ref_descriptors0": torch.stack(all_desc0, 1),
"ref_descriptors1": torch.stack(all_desc1, 1),
"log_assignment": scores,
"prune0": prune0,
"prune1": prune1,
}
return pred
def confidence_threshold(self, layer_index: int) -> float:
"""scaled confidence threshold"""
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
return np.clip(threshold, 0, 1)
def get_pruning_mask(
self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
) -> torch.Tensor:
"""mask points which should be removed"""
keep = scores > (1 - self.conf.width_confidence)
if confidences is not None: # Low-confidence points are never pruned.
keep |= confidences <= self.confidence_thresholds[layer_index]
return keep
def check_if_stop(
self,
confidences0: torch.Tensor,
confidences1: torch.Tensor,
layer_index: int,
num_points: int,
) -> torch.Tensor:
"""evaluate stopping condition"""
confidences = torch.cat([confidences0, confidences1], -1)
threshold = self.confidence_thresholds[layer_index]
ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
return ratio_confident > self.conf.depth_confidence
def pruning_min_kpts(self, device: torch.device):
if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
return self.pruning_keypoint_thresholds["flash"]
else:
return self.pruning_keypoint_thresholds[device.type]
def loss(self, pred, data):
def loss_params(pred, i):
la, _ = self.log_assignment[i](
pred["ref_descriptors0"][:, i], pred["ref_descriptors1"][:, i]
)
return {
"log_assignment": la,
}
sum_weights = 1.0
nll, gt_weights, loss_metrics = self.loss_fn(loss_params(pred, -1), data)
N = pred["ref_descriptors0"].shape[1]
losses = {"total": nll, "last": nll.clone().detach(), **loss_metrics}
if self.training:
losses["confidence"] = 0.0
# B = pred['log_assignment'].shape[0]
losses["row_norm"] = pred["log_assignment"].exp()[:, :-1].sum(2).mean(1)
for i in range(N - 1):
params_i = loss_params(pred, i)
nll, _, _ = self.loss_fn(params_i, data, weights=gt_weights)
if self.conf.loss.gamma > 0.0:
weight = self.conf.loss.gamma ** (N - i - 1)
else:
weight = i + 1
sum_weights += weight
losses["total"] = losses["total"] + nll * weight
losses["confidence"] += self.token_confidence[i].loss(
pred["ref_descriptors0"][:, i],
pred["ref_descriptors1"][:, i],
params_i["log_assignment"],
pred["log_assignment"],
) / (N - 1)
del params_i
losses["total"] /= sum_weights
# confidences
if self.training:
losses["total"] = losses["total"] + losses["confidence"]
if not self.training:
# add metrics
metrics = matcher_metrics(pred, data)
else:
metrics = {}
return losses, metrics
__main_model__ = LightGlue

View File

@ -0,0 +1,34 @@
from ..base_model import BaseModel
from lightglue import LightGlue as LightGlue_
from omegaconf import OmegaConf
class LightGlue(BaseModel):
default_conf = {"features": "superpoint", **LightGlue_.default_conf}
required_data_keys = [
"view0",
"keypoints0",
"descriptors0",
"view1",
"keypoints1",
"descriptors1",
]
def _init(self, conf):
dconf = OmegaConf.to_container(conf)
self.net = LightGlue_(dconf.pop("features"), **dconf).cuda()
# self.net.compile()
def _forward(self, data):
view0 = {
**{k: data[k + "0"] for k in ["keypoints", "descriptors"]},
**data["view0"],
}
view1 = {
**{k: data[k + "1"] for k in ["keypoints", "descriptors"]},
**data["view1"],
}
return self.net({"image0": view0, "image1": view1})
def loss(pred, data):
raise NotImplementedError

View File

@ -0,0 +1,96 @@
"""
Nearest neighbor matcher for normalized descriptors.
Optionally apply the mutual check and threshold the distance or ratio.
"""
import torch
import logging
import torch.nn.functional as F
from ..base_model import BaseModel
from ..utils.metrics import matcher_metrics
@torch.no_grad()
def find_nn(sim, ratio_thresh, distance_thresh):
sim_nn, ind_nn = sim.topk(2 if ratio_thresh else 1, dim=-1, largest=True)
dist_nn = 2 * (1 - sim_nn)
mask = torch.ones(ind_nn.shape[:-1], dtype=torch.bool, device=sim.device)
if ratio_thresh:
mask = mask & (dist_nn[..., 0] <= (ratio_thresh**2) * dist_nn[..., 1])
if distance_thresh:
mask = mask & (dist_nn[..., 0] <= distance_thresh**2)
matches = torch.where(mask, ind_nn[..., 0], ind_nn.new_tensor(-1))
return matches
def mutual_check(m0, m1):
inds0 = torch.arange(m0.shape[-1], device=m0.device)
inds1 = torch.arange(m1.shape[-1], device=m1.device)
loop0 = torch.gather(m1, -1, torch.where(m0 > -1, m0, m0.new_tensor(0)))
loop1 = torch.gather(m0, -1, torch.where(m1 > -1, m1, m1.new_tensor(0)))
m0_new = torch.where((m0 > -1) & (inds0 == loop0), m0, m0.new_tensor(-1))
m1_new = torch.where((m1 > -1) & (inds1 == loop1), m1, m1.new_tensor(-1))
return m0_new, m1_new
class NearestNeighborMatcher(BaseModel):
default_conf = {
"ratio_thresh": None,
"distance_thresh": None,
"mutual_check": True,
"loss": None,
}
required_data_keys = ["descriptors0", "descriptors1"]
def _init(self, conf):
if conf.loss == "N_pair":
temperature = torch.nn.Parameter(torch.tensor(1.0))
self.register_parameter("temperature", temperature)
def _forward(self, data):
sim = torch.einsum("bnd,bmd->bnm", data["descriptors0"], data["descriptors1"])
matches0 = find_nn(sim, self.conf.ratio_thresh, self.conf.distance_thresh)
matches1 = find_nn(
sim.transpose(1, 2), self.conf.ratio_thresh, self.conf.distance_thresh
)
if self.conf.mutual_check:
matches0, matches1 = mutual_check(matches0, matches1)
b, m, n = sim.shape
la = sim.new_zeros(b, m + 1, n + 1)
la[:, :-1, :-1] = F.log_softmax(sim, -1) + F.log_softmax(sim, -2)
mscores0 = (matches0 > -1).float()
mscores1 = (matches1 > -1).float()
return {
"matches0": matches0,
"matches1": matches1,
"matching_scores0": mscores0,
"matching_scores1": mscores1,
"similarity": sim,
"log_assignment": la,
}
def loss(self, pred, data):
losses = {}
if self.conf.loss == "N_pair":
sim = pred["similarity"]
if torch.any(sim > (1.0 + 1e-6)):
logging.warning(f"Similarity larger than 1, max={sim.max()}")
scores = torch.sqrt(torch.clamp(2 * (1 - sim), min=1e-6))
scores = self.temperature * (2 - scores)
assert not torch.any(torch.isnan(scores)), torch.any(torch.isnan(sim))
prob0 = torch.nn.functional.log_softmax(scores, 2)
prob1 = torch.nn.functional.log_softmax(scores, 1)
assignment = data["gt_assignment"].float()
num = torch.max(assignment.sum((1, 2)), assignment.new_tensor(1))
nll0 = (prob0 * assignment).sum((1, 2)) / num
nll1 = (prob1 * assignment).sum((1, 2)) / num
nll = -(nll0 + nll1) / 2
losses["n_pair_nll"] = losses["total"] = nll
losses["num_matchable"] = num
losses["n_pair_temperature"] = self.temperature[None]
else:
raise NotImplementedError
metrics = {} if self.training else matcher_metrics(pred, data)
return losses, metrics

View File

@ -0,0 +1,98 @@
"""
A two-view sparse feature matching pipeline on triplets.
If a triplet is found, runs the extractor on three images and
then runs matcher/filter/solver for all three pairs.
Losses and metrics get accumulated accordingly.
If no triplet is found, this falls back to two_view_pipeline.py
"""
from .two_view_pipeline import TwoViewPipeline
import torch
from ..utils.misc import get_twoview, stack_twoviews, unstack_twoviews
def has_triplet(data):
# we already check for image0 and image1 in required_keys
return "view2" in data.keys()
class TripletPipeline(TwoViewPipeline):
default_conf = {"batch_triplets": True, **TwoViewPipeline.default_conf}
def _forward(self, data):
if not has_triplet(data):
return super()._forward(data)
# the two-view outputs are stored in
# pred['0to1'],pred['0to2'], pred['1to2']
assert not self.conf.run_gt_in_forward
pred0 = self.extract_view(data, "0")
pred1 = self.extract_view(data, "1")
pred2 = self.extract_view(data, "2")
pred = {}
pred = {
**{k + "0": v for k, v in pred0.items()},
**{k + "1": v for k, v in pred1.items()},
**{k + "2": v for k, v in pred2.items()},
}
def predict_twoview(pred, data):
# forward pass
if self.conf.matcher.name:
pred = {**pred, **self.matcher({**data, **pred})}
if self.conf.filter.name:
pred = {**pred, **self.filter({**m_data, **pred})}
if self.conf.solver.name:
pred = {**pred, **self.solver({**m_data, **pred})}
return pred
if self.conf.batch_triplets:
B = data["image1"].shape[0]
# stack on batch dimension
m_data = stack_twoviews(data)
m_pred = stack_twoviews(pred)
# forward pass
m_pred = predict_twoview(m_pred, m_data)
# unstack
pred = {**pred, **unstack_twoviews(m_pred, B)}
else:
for idx in ["0to1", "0to2", "1to2"]:
m_data = get_twoview(data, idx)
m_pred = get_twoview(pred, idx)
pred[idx] = predict_twoview(m_pred, m_data)
return pred
def loss(self, pred, data):
if not has_triplet(data):
return super().loss(pred, data)
if self.conf.batch_triplets:
m_data = stack_twoviews(data)
m_pred = stack_twoviews(pred)
losses, metrics = super().loss(m_pred, m_data)
else:
losses = {}
metrics = {}
for idx in ["0to1", "0to2", "1to2"]:
data_i = get_twoview(data, idx)
pred_i = pred[idx]
losses_i, metrics_i = super().loss(pred_i, data_i)
for k, v in losses_i.items():
if k in losses.keys():
losses[k] = losses[k] + v
else:
losses[k] = v
for k, v in metrics_i.items():
if k in metrics.keys():
metrics[k] = torch.cat([metrics[k], v], 0)
else:
metrics[k] = v
return losses, metrics

View File

@ -0,0 +1,114 @@
"""
A two-view sparse feature matching pipeline.
This model contains sub-models for each step:
feature extraction, feature matching, outlier filtering, pose estimation.
Each step is optional, and the features or matches can be provided as input.
Default: SuperPoint with nearest neighbor matching.
Convention for the matches: m0[i] is the index of the keypoint in image 1
that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched.
"""
from omegaconf import OmegaConf
from .base_model import BaseModel
from . import get_model
to_ctr = OmegaConf.to_container # convert DictConfig to dict
class TwoViewPipeline(BaseModel):
default_conf = {
"extractor": {
"name": None,
"trainable": False,
},
"matcher": {"name": None},
"filter": {"name": None},
"solver": {"name": None},
"ground_truth": {"name": None},
"allow_no_extract": False,
"run_gt_in_forward": False,
}
required_data_keys = ["view0", "view1"]
strict_conf = False # need to pass new confs to children models
components = [
"extractor",
"matcher",
"filter",
"solver",
"ground_truth",
]
def _init(self, conf):
if conf.extractor.name:
self.extractor = get_model(conf.extractor.name)(to_ctr(conf.extractor))
if conf.matcher.name:
self.matcher = get_model(conf.matcher.name)(to_ctr(conf.matcher))
if conf.filter.name:
self.filter = get_model(conf.filter.name)(to_ctr(conf.filter))
if conf.solver.name:
self.solver = get_model(conf.solver.name)(to_ctr(conf.solver))
if conf.ground_truth.name:
self.ground_truth = get_model(conf.ground_truth.name)(
to_ctr(conf.ground_truth)
)
def extract_view(self, data, i):
data_i = data[f"view{i}"]
pred_i = data_i.get("cache", {})
skip_extract = len(pred_i) > 0 and self.conf.allow_no_extract
if self.conf.extractor.name and not skip_extract:
pred_i = {**pred_i, **self.extractor(data_i)}
elif self.conf.extractor.name and not self.conf.allow_no_extract:
pred_i = {**pred_i, **self.extractor({**data_i, **pred_i})}
return pred_i
def _forward(self, data):
pred0 = self.extract_view(data, "0")
pred1 = self.extract_view(data, "1")
pred = {
**{k + "0": v for k, v in pred0.items()},
**{k + "1": v for k, v in pred1.items()},
}
if self.conf.matcher.name:
pred = {**pred, **self.matcher({**data, **pred})}
if self.conf.filter.name:
pred = {**pred, **self.filter({**data, **pred})}
if self.conf.solver.name:
pred = {**pred, **self.solver({**data, **pred})}
if self.conf.ground_truth.name and self.conf.run_gt_in_forward:
gt_pred = self.ground_truth({**data, **pred})
pred.update({f"gt_{k}": v for k, v in gt_pred.items()})
return pred
def loss(self, pred, data):
losses = {}
metrics = {}
total = 0
# get labels
if self.conf.ground_truth.name and not self.conf.run_gt_in_forward:
gt_pred = self.ground_truth({**data, **pred})
pred.update({f"gt_{k}": v for k, v in gt_pred.items()})
for k in self.components:
apply = True
if "apply_loss" in self.conf[k].keys():
apply = self.conf[k].apply_loss
if self.conf[k].name and apply:
try:
losses_, metrics_ = getattr(self, k).loss(pred, {**pred, **data})
except NotImplementedError:
continue
losses = {**losses, **losses_}
metrics = {**metrics, **metrics_}
total = losses_["total"] + total
return {**losses, "total": total}, metrics

View File

View File

@ -0,0 +1,73 @@
import torch
import torch.nn as nn
from omegaconf import OmegaConf
def weight_loss(log_assignment, weights, gamma=0.0):
b, m, n = log_assignment.shape
m -= 1
n -= 1
loss_sc = log_assignment * weights
num_neg0 = weights[:, :m, -1].sum(-1).clamp(min=1.0)
num_neg1 = weights[:, -1, :n].sum(-1).clamp(min=1.0)
num_pos = weights[:, :m, :n].sum((-1, -2)).clamp(min=1.0)
nll_pos = -loss_sc[:, :m, :n].sum((-1, -2))
nll_pos /= num_pos.clamp(min=1.0)
nll_neg0 = -loss_sc[:, :m, -1].sum(-1)
nll_neg1 = -loss_sc[:, -1, :n].sum(-1)
nll_neg = (nll_neg0 + nll_neg1) / (num_neg0 + num_neg1)
return nll_pos, nll_neg, num_pos, (num_neg0 + num_neg1) / 2.0
class NLLLoss(nn.Module):
default_conf = {
"nll_balancing": 0.5,
"gamma_f": 0.0, # focal loss
}
def __init__(self, conf):
super().__init__()
self.conf = OmegaConf.merge(self.default_conf, conf)
self.loss_fn = self.nll_loss
def forward(self, pred, data, weights=None):
log_assignment = pred["log_assignment"]
if weights is None:
weights = self.loss_fn(log_assignment, data)
nll_pos, nll_neg, num_pos, num_neg = weight_loss(
log_assignment, weights, gamma=self.conf.gamma_f
)
nll = (
self.conf.nll_balancing * nll_pos + (1 - self.conf.nll_balancing) * nll_neg
)
return (
nll,
weights,
{
"assignment_nll": nll,
"nll_pos": nll_pos,
"nll_neg": nll_neg,
"num_matchable": num_pos,
"num_unmatchable": num_neg,
},
)
def nll_loss(self, log_assignment, data):
m, n = data["gt_matches0"].size(-1), data["gt_matches1"].size(-1)
positive = data["gt_assignment"].float()
neg0 = (data["gt_matches0"] == -1).float()
neg1 = (data["gt_matches1"] == -1).float()
weights = torch.zeros_like(log_assignment)
weights[:, :m, :n] = positive
weights[:, :m, -1] = neg0
weights[:, -1, :m] = neg1
return weights

View File

@ -0,0 +1,50 @@
import torch
@torch.no_grad()
def matcher_metrics(pred, data, prefix="", prefix_gt=None):
def recall(m, gt_m):
mask = (gt_m > -1).float()
return ((m == gt_m) * mask).sum(1) / (1e-8 + mask.sum(1))
def accuracy(m, gt_m):
mask = (gt_m >= -1).float()
return ((m == gt_m) * mask).sum(1) / (1e-8 + mask.sum(1))
def precision(m, gt_m):
mask = ((m > -1) & (gt_m >= -1)).float()
return ((m == gt_m) * mask).sum(1) / (1e-8 + mask.sum(1))
def ranking_ap(m, gt_m, scores):
p_mask = ((m > -1) & (gt_m >= -1)).float()
r_mask = (gt_m > -1).float()
sort_ind = torch.argsort(-scores)
sorted_p_mask = torch.gather(p_mask, -1, sort_ind)
sorted_r_mask = torch.gather(r_mask, -1, sort_ind)
sorted_tp = torch.gather(m == gt_m, -1, sort_ind)
p_pts = torch.cumsum(sorted_tp * sorted_p_mask, -1) / (
1e-8 + torch.cumsum(sorted_p_mask, -1)
)
r_pts = torch.cumsum(sorted_tp * sorted_r_mask, -1) / (
1e-8 + sorted_r_mask.sum(-1)[:, None]
)
r_pts_diff = r_pts[..., 1:] - r_pts[..., :-1]
return torch.sum(r_pts_diff * p_pts[:, None, -1], dim=-1)
if prefix_gt is None:
prefix_gt = prefix
rec = recall(pred[f"{prefix}matches0"], data[f"gt_{prefix_gt}matches0"])
prec = precision(pred[f"{prefix}matches0"], data[f"gt_{prefix_gt}matches0"])
acc = accuracy(pred[f"{prefix}matches0"], data[f"gt_{prefix_gt}matches0"])
ap = ranking_ap(
pred[f"{prefix}matches0"],
data[f"gt_{prefix_gt}matches0"],
pred[f"{prefix}matching_scores0"],
)
metrics = {
f"{prefix}match_recall": rec,
f"{prefix}match_precision": prec,
f"{prefix}accuracy": acc,
f"{prefix}average_precision": ap,
}
return metrics

View File

@ -0,0 +1,69 @@
import math
from typing import List, Optional, Tuple
import torch
def to_sequence(map):
return map.flatten(-2).transpose(-1, -2)
def to_map(sequence):
n = sequence.shape[-2]
e = math.isqrt(n)
assert e * e == n
assert e * e == n
sequence.transpose(-1, -2).unflatten(-1, [e, e])
def pad_to_length(
x,
length: int,
pad_dim: int = -2,
mode: str = "zeros", # zeros, ones, random, random_c
bounds: Tuple[int] = (None, None),
):
shape = list(x.shape)
d = x.shape[pad_dim]
assert d <= length
if d == length:
return x
shape[pad_dim] = length - d
low, high = bounds
if mode == "zeros":
xn = torch.zeros(*shape, device=x.device, dtype=x.dtype)
elif mode == "ones":
xn = torch.ones(*shape, device=x.device, dtype=x.dtype)
elif mode == "random":
low = low if low is not None else x.min()
high = high if high is not None else x.max()
xn = torch.empty(*shape, device=x.device).uniform_(low, high)
elif mode == "random_c":
low, high = bounds # we use the bounds as fallback for empty seq.
xn = torch.cat(
[
torch.empty(*shape[:-1], 1, device=x.device).uniform_(
x[..., i].min() if d > 0 else low,
x[..., i].max() if d > 0 else high,
)
for i in range(shape[-1])
],
dim=-1,
)
else:
raise ValueError(mode)
return torch.cat([x, xn], dim=pad_dim)
def pad_and_stack(
sequences: List[torch.Tensor],
length: Optional[int] = None,
pad_dim: int = -2,
**kwargs,
):
if length is None:
length = max([x.shape[pad_dim] for x in sequences])
y = torch.stack([pad_to_length(x, length, pad_dim, **kwargs) for x in sequences], 0)
return y

View File

@ -0,0 +1,14 @@
import inspect
from .base_estimator import BaseEstimator
def load_estimator(type, estimator):
module_path = f"{__name__}.{type}.{estimator}"
module = __import__(module_path, fromlist=[""])
classes = inspect.getmembers(module, inspect.isclass)
# Filter classes defined in the module
classes = [c for c in classes if c[1].__module__ == module_path]
# Filter classes inherited from BaseModel
classes = [c for c in classes if issubclass(c[1], BaseEstimator)]
assert len(classes) == 1, classes
return classes[0][1]

View File

@ -0,0 +1,32 @@
from omegaconf import OmegaConf
from copy import copy
class BaseEstimator:
base_default_conf = {
"name": "???",
"ransac_th": "???",
}
test_thresholds = [1.0]
required_data_keys = []
strict_conf = False
def __init__(self, conf):
"""Perform some logic and call the _init method of the child model."""
default_conf = OmegaConf.merge(
self.base_default_conf, OmegaConf.create(self.default_conf)
)
if self.strict_conf:
OmegaConf.set_struct(default_conf, True)
if isinstance(conf, dict):
conf = OmegaConf.create(conf)
self.conf = conf = OmegaConf.merge(default_conf, conf)
OmegaConf.set_readonly(conf, True)
OmegaConf.set_struct(conf, True)
self.required_data_keys = copy(self.required_data_keys)
self._init(conf)
def __call__(self, data):
return self._forward(data)

View File

@ -0,0 +1,72 @@
import numpy as np
import torch
from homography_est import (
LineSegment,
ransac_line_homography,
ransac_point_homography,
ransac_point_line_homography,
)
from ..base_estimator import BaseEstimator
def H_estimation_hybrid(kpts0=None, kpts1=None, lines0=None, lines1=None, tol_px=5):
"""Estimate a homography from points and lines with hybrid RANSAC.
All features are expected in x-y convention
"""
# Check that we have at least 4 features
n_features = 0
if kpts0 is not None:
n_features += len(kpts0) + len(kpts1)
if lines0 is not None:
n_features += len(lines0) + len(lines1)
if n_features < 4:
return None
if lines0 is None:
# Point-only RANSAC
H = ransac_point_homography(kpts0, kpts1, tol_px, False, [])
elif kpts0 is None:
# Line-only RANSAC
ls0 = [LineSegment(line[0], line[1]) for line in lines0]
ls1 = [LineSegment(line[0], line[1]) for line in lines1]
H = ransac_line_homography(ls0, ls1, tol_px, False, [])
else:
# Point-lines RANSAC
ls0 = [LineSegment(line[0], line[1]) for line in lines0]
ls1 = [LineSegment(line[0], line[1]) for line in lines1]
H = ransac_point_line_homography(kpts0, kpts1, ls0, ls1, tol_px, False, [], [])
if np.abs(H[-1, -1]) > 1e-8:
H /= H[-1, -1]
return H
class PointLineHomographyEstimator(BaseEstimator):
default_conf = {"ransac_th": 2.0, "options": {}}
required_data_keys = ["m_kpts0", "m_kpts1", "m_lines0", "m_lines1"]
def _init(self, conf):
pass
def _forward(self, data):
m_features = {
"kpts0": data["m_kpts1"].numpy() if "m_kpts1" in data else None,
"kpts1": data["m_kpts0"].numpy() if "m_kpts0" in data else None,
"lines0": data["m_lines1"].numpy() if "m_lines1" in data else None,
"lines1": data["m_lines0"].numpy() if "m_lines0" in data else None,
}
feat = data["m_kpts0"] if "m_kpts0" in data else data["m_lines0"]
M = H_estimation_hybrid(**m_features, tol_px=self.conf.ransac_th)
success = M is not None
if not success:
M = torch.eye(3, device=feat.device, dtype=feat.dtype)
else:
M = torch.tensor(M).to(feat)
estimation = {
"success": success,
"M_0to1": M,
}
return estimation

View File

@ -0,0 +1,53 @@
import cv2
import torch
from ..base_estimator import BaseEstimator
class OpenCVHomographyEstimator(BaseEstimator):
default_conf = {
"ransac_th": 3.0,
"options": {"method": "ransac", "max_iters": 3000, "confidence": 0.995},
}
required_data_keys = ["m_kpts0", "m_kpts1"]
def _init(self, conf):
self.solver = {
"ransac": cv2.RANSAC,
"lmeds": cv2.LMEDS,
"rho": cv2.RHO,
"usac": cv2.USAC_DEFAULT,
"usac_fast": cv2.USAC_FAST,
"usac_accurate": cv2.USAC_ACCURATE,
"usac_prosac": cv2.USAC_PROSAC,
"usac_magsac": cv2.USAC_MAGSAC,
}[conf.options.method]
def _forward(self, data):
pts0, pts1 = data["m_kpts0"], data["m_kpts1"]
try:
M, mask = cv2.findHomography(
pts0.numpy(),
pts1.numpy(),
self.solver,
self.conf.ransac_th,
maxIters=self.conf.options.max_iters,
confidence=self.conf.options.confidence,
)
success = M is not None
except cv2.error:
success = False
if not success:
M = torch.eye(3, device=pts0.device, dtype=pts0.dtype)
inl = torch.zeros_like(pts0[:, 0]).bool()
else:
M = torch.tensor(M).to(pts0)
inl = torch.tensor(mask).bool().to(pts0.device)
return {
"success": success,
"M_0to1": M,
"inliers": inl,
}

View File

@ -0,0 +1,40 @@
import poselib
from omegaconf import OmegaConf
import torch
from ..base_estimator import BaseEstimator
class PoseLibHomographyEstimator(BaseEstimator):
default_conf = {"ransac_th": 2.0, "options": {}}
required_data_keys = ["m_kpts0", "m_kpts1"]
def _init(self, conf):
pass
def _forward(self, data):
pts0, pts1 = data["m_kpts0"], data["m_kpts1"]
M, info = poselib.estimate_homography(
pts0.numpy(),
pts1.numpy(),
{
"max_reproj_error": self.conf.ransac_th,
**OmegaConf.to_container(self.conf.options),
},
)
success = M is not None
if not success:
M = torch.eye(3, device=pts0.device, dtype=pts0.dtype)
inl = torch.zeros_like(pts0[:, 0]).bool()
else:
M = torch.tensor(M).to(pts0)
inl = torch.tensor(info["inliers"]).bool().to(pts0.device)
estimation = {
"success": success,
"M_0to1": M,
"inliers": inl,
}
return estimation

View File

@ -0,0 +1,64 @@
import cv2
import numpy as np
import torch
from ...geometry.wrappers import Pose
from ...geometry.utils import from_homogeneous
from ..base_estimator import BaseEstimator
class OpenCVRelativePoseEstimator(BaseEstimator):
default_conf = {
"ransac_th": 0.5,
"options": {"confidence": 0.99999, "method": "ransac"},
}
required_data_keys = ["m_kpts0", "m_kpts1", "camera0", "camera1"]
def _init(self, conf):
self.solver = {"ransac": cv2.RANSAC, "usac_magsac": cv2.USAC_MAGSAC}[
self.conf.options.method
]
def _forward(self, data):
kpts0, kpts1 = data["m_kpts0"], data["m_kpts1"]
camera0 = data["camera0"]
camera1 = data["camera1"]
M, inl = None, torch.zeros_like(kpts0[:, 0]).bool()
if len(kpts0) >= 5:
f_mean = torch.cat([camera0.f, camera1.f]).mean().item()
norm_thresh = self.conf.ransac_th / f_mean
pts0 = from_homogeneous(camera0.image2cam(kpts0)).cpu().detach().numpy()
pts1 = from_homogeneous(camera1.image2cam(kpts1)).cpu().detach().numpy()
E, mask = cv2.findEssentialMat(
pts0,
pts1,
np.eye(3),
threshold=norm_thresh,
prob=self.conf.options.confidence,
method=self.solver,
)
if E is not None:
best_num_inliers = 0
for _E in np.split(E, len(E) / 3):
n, R, t, _ = cv2.recoverPose(
_E, pts0, pts1, np.eye(3), 1e9, mask=mask
)
if n > best_num_inliers:
best_num_inliers = n
inl = torch.tensor(mask.ravel() > 0)
M = Pose.from_Rt(
torch.tensor(R).to(kpts0), torch.tensor(t[:, 0]).to(kpts0)
)
estimation = {
"success": M is not None,
"M_0to1": M if M is not None else Pose.from_4x4mat(torch.eye(4).to(kpts0)),
"inliers": inl.to(device=kpts0.device),
}
return estimation

View File

@ -0,0 +1,44 @@
import poselib
from omegaconf import OmegaConf
import torch
from ...geometry.wrappers import Pose
from ..base_estimator import BaseEstimator
class PoseLibRelativePoseEstimator(BaseEstimator):
default_conf = {"ransac_th": 2.0, "options": {}}
required_data_keys = ["m_kpts0", "m_kpts1", "camera0", "camera1"]
def _init(self, conf):
pass
def _forward(self, data):
pts0, pts1 = data["m_kpts0"], data["m_kpts1"]
camera0 = data["camera0"]
camera1 = data["camera1"]
M, info = poselib.estimate_relative_pose(
pts0.numpy(),
pts1.numpy(),
camera0.to_cameradict(),
camera1.to_cameradict(),
{
"max_epipolar_error": self.conf.ransac_th,
**OmegaConf.to_container(self.conf.options),
},
)
success = M is not None
if success:
M = Pose.from_Rt(torch.tensor(M.R), torch.tensor(M.t)).to(pts0)
else:
M = Pose.from_4x4mat(torch.eye(4)).to(pts0)
estimation = {
"success": success,
"M_0to1": M,
"inliers": torch.tensor(info.pop("inliers")).to(pts0),
**info,
}
return estimation

View File

@ -0,0 +1,52 @@
import pycolmap
from omegaconf import OmegaConf
import torch
from ...geometry.wrappers import Pose
from ..base_estimator import BaseEstimator
class PycolmapTwoViewEstimator(BaseEstimator):
default_conf = {
"ransac_th": 4.0,
"options": {**pycolmap.TwoViewGeometryOptions().todict()},
}
required_data_keys = ["m_kpts0", "m_kpts1", "camera0", "camera1"]
def _init(self, conf):
opts = OmegaConf.to_container(conf.options)
self.options = pycolmap.TwoViewGeometryOptions(opts)
self.options.ransac.max_error = conf.ransac_th
def _forward(self, data):
pts0, pts1 = data["m_kpts0"], data["m_kpts1"]
camera0 = data["camera0"]
camera1 = data["camera1"]
info = pycolmap.two_view_geometry_estimation(
pts0.numpy(),
pts1.numpy(),
camera0.to_cameradict(),
camera1.to_cameradict(),
self.options,
)
success = info["success"]
if success:
R = pycolmap.qvec_to_rotmat(info["qvec"])
t = info["tvec"]
M = Pose.from_Rt(torch.tensor(R), torch.tensor(t)).to(pts0)
inl = torch.tensor(info.pop("inliers")).to(pts0)
else:
M = Pose.from_4x4mat(torch.eye(4)).to(pts0)
inl = torch.zeros_like(pts0[:, 0]).bool()
estimation = {
"success": success,
"M_0to1": M,
"inliers": inl,
"type": str(
info.get("configuration_type", pycolmap.TwoViewGeometry.UNDEFINED)
),
}
return estimation

Some files were not shown because too many files have changed in this diff Show More