Initial commit
Co-authored-by: Philipp Lindenberger <phil.lindenberger@gmail.com> Co-authored-by: Iago Suárez <iago_h92@hotmail.com> Co-authored-by: Paul-Edouard Sarlin <paul.edouard.sarlin@gmail.com>main
commit
55c4fbd454
|
@ -0,0 +1,31 @@
|
|||
name: Format and Lint Checks
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- '*.py'
|
||||
pull_request:
|
||||
types: [ assigned, opened, synchronize, reopened ]
|
||||
jobs:
|
||||
formatting-check:
|
||||
name: Formatting Check
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: psf/black@stable
|
||||
with:
|
||||
jupyter: true
|
||||
linting-check:
|
||||
name: Linting Check
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: 'pip'
|
||||
- run: python -m pip install --upgrade pip
|
||||
- run: python -m pip install .
|
||||
- run: python -m pip install --upgrade flake8
|
||||
- run: python -m flake8 . --exclude build/
|
|
@ -0,0 +1,9 @@
|
|||
.venv
|
||||
/build/
|
||||
*.egg-info
|
||||
*.pyc
|
||||
/.idea/
|
||||
/venv/
|
||||
/data/
|
||||
/outputs/
|
||||
__pycache__
|
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,335 @@
|
|||
# Glue Factory
|
||||
Glue Factory is CVG's library for training and evaluating deep neural network that extract and match local visual feature. It enables you to:
|
||||
- Reproduce the training of state-of-the-art models for point and line matching, like [LightGlue](https://github.com/cvg/LightGlue) and [GlueStick](https://github.com/cvg/GlueStick) (ICCV 2023)
|
||||
- Train these models on multiple datasets using your own local features or lines
|
||||
- Evaluate feature extractors or matchers on standard benchmarks like HPatches or MegaDepth-1500
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/cvg/LightGlue"><img src="docs/lightglue_matches.svg" width="60%"/></a>
|
||||
<a href="https://github.com/cvg/GlueStick"><img src="docs/gluestick_img.svg" width="60%"/></a>
|
||||
<br /><em>Point and line matching with LightGlue and GlueStick.</em>
|
||||
</p>
|
||||
|
||||
## Installation
|
||||
Glue Factory runs with Python 3 and [PyTorch](https://pytorch.org/). The following installs the library and its basic dependencies:
|
||||
```bash
|
||||
git clone https://github.com/cvg/glue-factory
|
||||
cd glue-factory
|
||||
python3 -m pip install -e . # editable mode
|
||||
```
|
||||
Some advanced features might require installing the full set of dependencies:
|
||||
```bash
|
||||
python3 -m pip install -e .[extra]
|
||||
```
|
||||
|
||||
All models and datasets in gluefactory have auto-downloaders, so you can get started right away!
|
||||
|
||||
## License
|
||||
The code and trained models in Glue Factory are released with an Apache-2.0 license. This includes LightGlue trained with an [open version of SuperPoint](https://github.com/rpautrat/SuperPoint). Third-party models that are not compatible with this license, such as SuperPoint (original) and SuperGlue, are provided in `gluefactory_nonfree`, where each model might follow its own, restrictive license.
|
||||
|
||||
## Evaluation
|
||||
|
||||
#### HPatches
|
||||
Running the evaluation commands automatically downloads the dataset, by default to the directory `data/`. You will need about 1.8 GB of free disk space.
|
||||
|
||||
<details>
|
||||
<summary>[Evaluating LightGlue]</summary>
|
||||
|
||||
To evaluate the pre-trained SuperPoint+LightGlue model on HPatches, run:
|
||||
```bash
|
||||
python -m gluefactory.eval.hpatches --conf superpoint+lightglue-official --overwrite
|
||||
```
|
||||
You should expect the following results
|
||||
```
|
||||
{'H_error_dlt@1px': 0.3515,
|
||||
'H_error_dlt@3px': 0.6723,
|
||||
'H_error_dlt@5px': 0.7756,
|
||||
'H_error_ransac@1px': 0.3428,
|
||||
'H_error_ransac@3px': 0.5763,
|
||||
'H_error_ransac@5px': 0.6943,
|
||||
'mnum_keypoints': 1024.0,
|
||||
'mnum_matches': 560.756,
|
||||
'mprec@1px': 0.337,
|
||||
'mprec@3px': 0.89,
|
||||
'mransac_inl': 130.081,
|
||||
'mransac_inl%': 0.217,
|
||||
'ransac_mAA': 0.5378}
|
||||
```
|
||||
|
||||
The default robust estimator is `opencv`, but we strongly recommend to use `poselib` instead:
|
||||
```bash
|
||||
python -m gluefactory.eval.hpatches --conf superpoint+lightglue-official --overwrite \
|
||||
eval.estimator=poselib eval.ransac_th=-1
|
||||
```
|
||||
Setting `eval.ransac_th=-1` auto-tunes the RANSAC inlier threshold by running the evaluation with a range of thresholds and reports results for the optimal value.
|
||||
Here are the results as Area Under the Curve (AUC) of the homography error at 1/3/5 pixels:
|
||||
|
||||
| Methods | DLT | [OpenCV](../gluefactory/robust_estimators/homography/opencv.py) | [PoseLib](../gluefactory/robust_estimators/homography/poselib.py) |
|
||||
| ------------------------------------------------------------ | ------------------ | ------------------ | ------------------ |
|
||||
| [SuperPoint + SuperGlue](../gluefactory/configs/superpoint+superglue.yaml) | 32.1 / 65.0 / 75.7 | 32.9 / 55.7 / 68.0 | 37.0 / 68.2 / 78.7 |
|
||||
| [SuperPoint + LightGlue](../gluefactory/configs/superpoint+lightglue.yaml) | 35.1 / 67.2 / 77.6 | 34.2 / 57.9 / 69.9 | 37.1 / 67.4 / 77.8 |
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>[Evaluating GlueStick]</summary>
|
||||
|
||||
To evaluate GlueStick on HPatches, run:
|
||||
```bash
|
||||
python -m gluefactory.eval.hpatches --conf gluefactory/configs/superpoint+lsd+gluestick.yaml --overwrite
|
||||
```
|
||||
You should expect the following results
|
||||
```
|
||||
{"mprec@1px": 0.245,
|
||||
"mprec@3px": 0.838,
|
||||
"mnum_matches": 1290.5,
|
||||
"mnum_keypoints": 2287.5,
|
||||
"mH_error_dlt": null,
|
||||
"H_error_dlt@1px": 0.3355,
|
||||
"H_error_dlt@3px": 0.6637,
|
||||
"H_error_dlt@5px": 0.7713,
|
||||
"H_error_ransac@1px": 0.3915,
|
||||
"H_error_ransac@3px": 0.6972,
|
||||
"H_error_ransac@5px": 0.7955,
|
||||
"H_error_ransac_mAA": 0.62806,
|
||||
"mH_error_ransac": null}
|
||||
```
|
||||
|
||||
Since we use points and lines to solve for the homography, we use a different robust estimator here: [Hest](https://github.com/rpautrat/homography_est/). Here are the results as Area Under the Curve (AUC) of the homography error at 1/3/5 pixels:
|
||||
|
||||
| Methods | DLT | [Hest](gluefactory/robust_estimators/homography/homography_est.py) |
|
||||
| ------------------------------------------------------------ | ------------------ | ------------------ |
|
||||
| [SP + LSD + GlueStick](gluefactory/configs/superpoint+lsd+gluestick.yaml) | 33.6 / 66.4 / 77.1 | 39.2 / 69.7 / 79.6 |
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
#### MegaDepth-1500
|
||||
|
||||
Running the evaluation commands automatically downloads the dataset, which takes about 1.5 GB of disk space.
|
||||
|
||||
<details>
|
||||
<summary>[Evaluating LightGlue]</summary>
|
||||
|
||||
To evaluate the pre-trained SuperPoint+LightGlue model on MegaDepth-1500, run:
|
||||
```bash
|
||||
python -m gluefactory.eval.megadepth1500 --conf superpoint+lightglue-official
|
||||
# or the adaptive variant
|
||||
python -m gluefactory.eval.megadepth1500 --conf superpoint+lightglue-official \
|
||||
model.matcher.{depth_confidence=0.95,width_confidence=0.95}
|
||||
```
|
||||
The first command should print the following results
|
||||
```
|
||||
{'mepi_prec@1e-3': 0.795,
|
||||
'mepi_prec@1e-4': 0.15,
|
||||
'mepi_prec@5e-4': 0.567,
|
||||
'mnum_keypoints': 2048.0,
|
||||
'mnum_matches': 613.287,
|
||||
'mransac_inl': 280.518,
|
||||
'mransac_inl%': 0.442,
|
||||
'rel_pose_error@10°': 0.681,
|
||||
'rel_pose_error@20°': 0.8065,
|
||||
'rel_pose_error@5°': 0.5102,
|
||||
'ransac_mAA': 0.6659}
|
||||
```
|
||||
|
||||
To use the PoseLib estimator:
|
||||
|
||||
```bash
|
||||
python -m gluefactory.eval.megadepth1500 --conf superpoint+lightglue-official \
|
||||
eval.estimator=poselib eval.ransac_th=2.0
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>[Evaluating GlueStick]</summary>
|
||||
|
||||
To evaluate the pre-trained SuperPoint+GlueStick model on MegaDepth-1500, run:
|
||||
```bash
|
||||
python -m gluefactory.eval.megadepth1500 --conf gluefactory/configs/superpoint+lsd+gluestick.yaml
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
Here are the results as Area Under the Curve (AUC) of the pose error at 5/10/20 degrees:
|
||||
|
||||
| Methods | [pycolmap](../gluefactory/robust_estimators/relative_pose/pycolmap.py) | [OpenCV](../gluefactory/robust_estimators/relative_pose/opencv.py) | [PoseLib](../gluefactory/robust_estimators/relative_pose/poselib.py) |
|
||||
| ------------------------------------------------------------ | ------------------ | ------------------ | ------------------ |
|
||||
| [SuperPoint + SuperGlue](../gluefactory/configs/superpoint+superglue.yaml) | 54.4 / 70.4 / 82.4 | 48.7 / 65.6 / 79.0 | 64.8 / 77.9 / 87.0 |
|
||||
| [SuperPoint + LightGlue](../gluefactory/configs/superpoint+lightglue.yaml) | 56.7 / 72.4 / 83.7 | 51.0 / 68.1 / 80.7 | 66.8 / 79.3 / 87.9 |
|
||||
| [SuperPoint + GlueStick](../gluefactory/configs/superpoint+lsd+gluestick.yaml) | 53.2 / 69.8 / 81.9 | 46.3 / 64.2 / 78.1 | 64.4 / 77.5 / 86.5 |
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
#### ETH3D
|
||||
|
||||
The dataset will be auto-downloaded if it is not found on disk, and will need about 6 GB of free disk space.
|
||||
|
||||
<details>
|
||||
<summary>[Evaluating GlueStick]</summary>
|
||||
|
||||
To evaluate GlueStick on ETH3D, run:
|
||||
```bash
|
||||
python -m gluefactory.eval.eth3d --conf gluefactory/configs/superpoint+lsd+gluestick.yaml
|
||||
```
|
||||
You should expect the following results
|
||||
```
|
||||
AP: 77.92
|
||||
AP_lines: 69.22
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### Image Matching Challenge 2021
|
||||
Coming soon!
|
||||
|
||||
#### Image Matching Challenge 2023
|
||||
Coming soon!
|
||||
|
||||
#### Visual inspection
|
||||
<details>
|
||||
To inspect the evaluation visually, you can run:
|
||||
|
||||
```bash
|
||||
python -m gluefactory.eval.inspect hpatches superpoint+lightglue-official
|
||||
```
|
||||
|
||||
Click on a point to visualize matches on this pair.
|
||||
|
||||
To compare multiple methods on a dataset:
|
||||
|
||||
```bash
|
||||
python -m gluefactory.eval.inspect hpatches superpoint+lightglue-official superpoint+superglue-official
|
||||
```
|
||||
|
||||
All current benchmarks are supported by the viewer.
|
||||
</details>
|
||||
|
||||
Detailed evaluation instructions can be found [here](./docs/evaluation.md).
|
||||
|
||||
## Training
|
||||
|
||||
We generally follow a two-stage training:
|
||||
1. Pre-train on a large dataset of synthetic homographies applied to internet images. We use the 1M-image distractor set of the Oxford-Paris retrieval dataset. It requires about 450 GB of disk space.
|
||||
2. Fine-tune on the MegaDepth dataset, which is based on PhotoTourism pictures of popular landmarks around the world. It exhibits more complex and realistic appearance and viewpoint changes. It requires about 420 GB of disk space.
|
||||
|
||||
All training commands automatically download the datasets.
|
||||
|
||||
<details>
|
||||
<summary>[Training LightGlue]</summary>
|
||||
|
||||
We show how to train LightGlue with [SuperPoint open](https://github.com/rpautrat/SuperPoint).
|
||||
We first pre-train LightGlue on the homography dataset:
|
||||
```bash
|
||||
python -m gluefactory.train sp+lg_homography \ # experiment name
|
||||
--conf gluefactory/configs/superpoint-open+lightglue_homography.yaml
|
||||
```
|
||||
Feel free to use any other experiment name. By default the checkpoints are written to `outputs/training/`. The default batch size of 128 corresponds to the results reported in the paper and requires 2x 3090 GPUs with 24GB of VRAM each as well as PyTorch >= 2.0 (FlashAttention).
|
||||
Configurations are managed by [OmegaConf](https://omegaconf.readthedocs.io/) so any entry can be overridden from the command line.
|
||||
If you have PyTorch < 2.0 or weaker GPUs, you may thus need to reduce the batch size via:
|
||||
```bash
|
||||
python -m gluefactory.train sp+lg_homography \
|
||||
--conf gluefactory/configs/superpoint-open+lightglue_homography.yaml \
|
||||
data.batch_size=32 # for 1x 1080 GPU
|
||||
```
|
||||
Be aware that this can impact the overall performance. You might need to adjust the learning rate accordingly.
|
||||
|
||||
We then fine-tune the model on the MegaDepth dataset:
|
||||
```bash
|
||||
python -m gluefactory.train sp+lg_megadepth \
|
||||
--conf gluefactory/configs/superpoint-open+lightglue_megadepth.yaml \
|
||||
train.load_experiment=sp+lg_homography
|
||||
```
|
||||
|
||||
Here the default batch size is 32. To speed up training on MegaDepth, we suggest to cache the local features before training (requires around 150 GB of disk space):
|
||||
```bash
|
||||
# extract features
|
||||
python -m gluefactory.scripts.export_megadepth --method sp_open --num_workers 8
|
||||
# run training with cached features
|
||||
python -m gluefactory.train sp+lg_megadepth \
|
||||
--conf gluefactory/configs/superpoint-open+lightglue_megadepth.yaml \
|
||||
train.load_experiment=sp+lg_homography \
|
||||
data.load_features.do=True
|
||||
```
|
||||
|
||||
The model can then be evaluated using its experiment name:
|
||||
```bash
|
||||
python -m gluefactory.eval.megadepth1500 --checkpoint sp+lg_megadepth
|
||||
```
|
||||
|
||||
You can also run all benchmarks after each training epoch with the option `--run_benchmarks`.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>[Training GlueStick]</summary>
|
||||
|
||||
We first pre-train GlueStick on the homography dataset:
|
||||
```bash
|
||||
python -m gluefactory.train gluestick_H --conf gluefactory/configs/superpoint+lsd+gluestick-homography.yaml --distributed
|
||||
```
|
||||
Feel free to use any other experiment name. Configurations are managed by [OmegaConf](https://omegaconf.readthedocs.io/) so any entry can be overridden from the command line.
|
||||
|
||||
We then fine-tune the model on the MegaDepth dataset:
|
||||
```bash
|
||||
python -m gluefactory.train gluestick_MD --conf gluefactory/configs/superpoint+lsd+gluestick-megadepth.yaml --distributed
|
||||
```
|
||||
Note that we used the training splits `train_scenes.txt` and `valid_scenes.txt` to train the original model, which contains some overlap with the IMC challenge. The new default splits are now `train_scenes_clean.txt` and `valid_scenes_clean.txt`, without this overlap.
|
||||
|
||||
</details>
|
||||
|
||||
### Available models
|
||||
Glue Factory supports training and evaluating the following deep matchers:
|
||||
| Model | Training? | Evaluation? |
|
||||
| --------- | --------- | ----------- |
|
||||
| [LightGlue](https://github.com/cvg/LightGlue) | ✅ | ✅ |
|
||||
| [GlueStick](https://github.com/cvg/GlueStick) | ✅ | ✅ |
|
||||
| [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork) | ✅ | ✅ |
|
||||
| [LoFTR](https://github.com/zju3dv/LoFTR) | ❌ | ✅ |
|
||||
|
||||
Using the following local feature extractors:
|
||||
|
||||
| Model | LightGlue config |
|
||||
| --------- | --------- |
|
||||
| [SuperPoint (open)](https://github.com/rpautrat/SuperPoint) | `superpoint-open+lightglue_{homography,megadepth}.yaml` |
|
||||
| [SuperPoint (official)](https://github.com/magicleap/SuperPointPretrainedNetwork) | ❌ TODO |
|
||||
| SIFT (via [pycolmap](https://github.com/colmap/pycolmap)) | `sift+lightglue_{homography,megadepth}.yaml` |
|
||||
| [ALIKED](https://github.com/Shiaoming/ALIKED) | `aliked+lightglue_{homography,megadepth}.yaml` |
|
||||
| [DISK](https://github.com/cvlab-epfl/disk) | ❌ TODO |
|
||||
| Key.Net + HardNet | ❌ TODO |
|
||||
|
||||
## Coming soon
|
||||
- [ ] More baselines (LoFTR, ASpanFormer, MatchFormer, SGMNet, DKM, RoMa)
|
||||
- [ ] Training deep detectors and descriptors like SuperPoint
|
||||
- [ ] IMC evaluations
|
||||
- [ ] Better documentation
|
||||
|
||||
## BibTeX Citation
|
||||
Please consider citing the following papers if you found this library useful:
|
||||
```bibtex
|
||||
@InProceedings{lindenberger_2023_lightglue,
|
||||
title = {{LightGlue: Local Feature Matching at Light Speed}},
|
||||
author = {Philipp Lindenberger and
|
||||
Paul-Edouard Sarlin and
|
||||
Marc Pollefeys},
|
||||
booktitle = {International Conference on Computer Vision (ICCV)},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
```bibtex
|
||||
@InProceedings{pautrat_suarez_2023_gluestick,
|
||||
title = {{GlueStick: Robust Image Matching by Sticking Points and Lines Together}},
|
||||
author = {R{\'e}mi Pautrat* and
|
||||
Iago Su{\'a}rez* and
|
||||
Yifan Yu and
|
||||
Marc Pollefeys and
|
||||
Viktor Larsson},
|
||||
booktitle = {International Conference on Computer Vision (ICCV)},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,121 @@
|
|||
# Evaluation
|
||||
|
||||
Glue Factory is designed for simple and tight integration between training and evaluation.
|
||||
All benchmarks are designed around one principle: only evaluate on cached results.
|
||||
This enforces reproducible baselines.
|
||||
Therefore, we first export model predictions for each dataset (`export`), and evaluate the cached results in a second pass (`evaluation`).
|
||||
|
||||
### Running an evaluation
|
||||
|
||||
We currently provide evaluation scripts for [MegaDepth-1500](../gluefactory/eval/megadepth1500.py), [HPatches](../gluefactory/eval/hpatches.py), and [ETH3D](../gluefactory/eval/eth3d.py).
|
||||
You can run them with:
|
||||
|
||||
```bash
|
||||
python -m gluefactory.eval.<benchmark_name> --conf "a name in gluefactory/configs/ or path" --checkpoint "and/or a checkpoint name"
|
||||
```
|
||||
|
||||
Each evaluation run is assigned a `tag`, which can (optionally) be customized from the command line with `--tag <your_tag>`.
|
||||
|
||||
To overwrite an experiment, add `--overwrite`. To only overwrite the results of the evaluation loop, add `--overwrite_eval`. We perform config checks to warn the user about non-conforming configurations between runs.
|
||||
|
||||
The following files are written to `outputs/results/<benchmark_name>/<tag>`:
|
||||
|
||||
```yaml
|
||||
conf.yaml # the config which was used
|
||||
predictions.h5 # cached predictions
|
||||
results.h5 # Results for each data point in eval, in the format <metric_name>: List[float]
|
||||
summaries.json # Aggregated results for the entire dataset <agg_metric_name>: float
|
||||
<plots> # some benchmarks add plots as png files here
|
||||
```
|
||||
|
||||
Some datasets further output plots (add `--plot` to the command line).
|
||||
|
||||
<details>
|
||||
<summary>[Configuration]</summary>
|
||||
|
||||
Each evaluation has 3 main configurations:
|
||||
|
||||
```yaml
|
||||
data:
|
||||
... # How to load the data. The user can overwrite this only during "export". The defaults are used in "evaluation".
|
||||
model:
|
||||
... # model configuration: this is only required for "export".
|
||||
eval:
|
||||
... # configuration for the "evaluation" loop, e.g. pose estimators and ransac thresholds.
|
||||
```
|
||||
|
||||
The default configurations can be found in the respective evaluation scripts, e.g. [MegaDepth1500](../gluefactory/eval/megadepth1500.py).
|
||||
|
||||
To run an evaluation with a custom config, we expect them to be in the following format ([example](../gluefactory/configs/superpoint+lightglue.yaml)):
|
||||
|
||||
```yaml
|
||||
model:
|
||||
... # <your model configs>
|
||||
benchmarks:
|
||||
<benchmark_name1>:
|
||||
data:
|
||||
... # <your data configs for "export">
|
||||
model:
|
||||
... # <your benchmark-specific model configs>
|
||||
eval:
|
||||
... # <your evaluation configs, e.g. pose estimators>
|
||||
<benchmark_name2>:
|
||||
... # <same structure as above>
|
||||
```
|
||||
|
||||
The configs are then merged in the following order (taking megadepth1500 as an example):
|
||||
|
||||
```yaml
|
||||
data:
|
||||
default < custom.benchmarks.megadepth1500.data
|
||||
model:
|
||||
default < custom.model < custom.benchmarks.megadepth1500.model
|
||||
eval:
|
||||
default < custom.benchmarks.megadepth1500.eval
|
||||
```
|
||||
|
||||
You can then use the command line to further customize this configuration.
|
||||
|
||||
</details>
|
||||
|
||||
### Robust estimators
|
||||
Gluefactory offers a flexible interface to state-of-the-art [robust estimators](../gluefactory/robust_estimators/) for points and lines.
|
||||
You can configure the estimator in the benchmarks with the following config structure:
|
||||
|
||||
```yaml
|
||||
eval:
|
||||
estimator: <estimator_name> # poselib, opencv, pycolmap, ...
|
||||
ransac_th: 0.5 # run evaluation on fixed threshold
|
||||
#or
|
||||
ransac_th: [0.5, 1.0, 1.5] # test on multiple thresholds, autoselect best
|
||||
<extra configs for the estimator, e.g. max iters, ...>
|
||||
```
|
||||
|
||||
For convenience, most benchmarks convert `eval.ransac_th=-1` to a default range of thresholds.
|
||||
|
||||
> [!NOTE]
|
||||
> Gluefactory follows the corner convention of COLMAP, i.e. the top-left corner of the top-left pixel is (0, 0).
|
||||
|
||||
### Visualization
|
||||
|
||||
We provide a powerful, interactive visualization tool for our benchmarks, based on matplotlib.
|
||||
You can run the visualization (after running the evaluations) with:
|
||||
```bash
|
||||
python -m gluefactory.eval.inspect <benchmark_name> <experiment_name1> <experiment_name2> ...
|
||||
```
|
||||
|
||||
This prints the summaries of each experiment on the respective benchmark and visualizes the data as a scatter plot, where each point is the result of from a experiment on a specific data point in the dataset.
|
||||
|
||||
<details>
|
||||
|
||||
- Clicking on one of the data points opens a new frame showing the prediction on this specific data point for all experiments listed.
|
||||
- You can customize the x / y axis from the navigation bar or by clicking `x` or `y`.
|
||||
- Hiting `diff_only` computes the difference between `<experiment_name1>` and all other experiments.
|
||||
- Hovering over a point shows lines to the results of other experiments on the same data.
|
||||
- You can switch the visualization (matches, keypoints, ...) from the navigation bar or by clicking `shift+r`.
|
||||
- Clicking `t` prints a summary of the eval on this data point.
|
||||
- Hitting the `left` or `right` arrows circles between data points. `shift+left` opens an extra window.
|
||||
|
||||
When working on a remote machine (e.g. over ssh), the plots can be forwarded to the browser with the option `--backend webagg`. Note that you need to refresh the page everytime you load a new figure (e.g. when clicking on a scatter point). This part requires some more work, and we would highly appreciate any contributions!
|
||||
|
||||
</details>
|
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 375 KiB |
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 603 KiB |
|
@ -0,0 +1,16 @@
|
|||
import logging
|
||||
from .utils.experiments import load_experiment # noqa: F401
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
||||
)
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
handler.setLevel(logging.INFO)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
|
||||
__module_name__ = __name__
|
|
@ -0,0 +1,24 @@
|
|||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.aliked
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
matcher:
|
||||
name: matchers.nearest_neighbor_matcher
|
||||
benchmarks:
|
||||
megadepth1500:
|
||||
data:
|
||||
preprocessing:
|
||||
side: long
|
||||
resize: 1600
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
model:
|
||||
extractor:
|
||||
max_num_keypoints: 1024 # overwrite config above
|
|
@ -0,0 +1,50 @@
|
|||
data:
|
||||
name: homographies
|
||||
data_dir: revisitop1m
|
||||
train_size: 150000
|
||||
val_size: 2000
|
||||
batch_size: 128
|
||||
num_workers: 14
|
||||
homography:
|
||||
difficulty: 0.7
|
||||
max_angle: 45
|
||||
photometric:
|
||||
name: lg
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.aliked
|
||||
max_num_keypoints: 512
|
||||
detection_threshold: 0.0
|
||||
trainable: False
|
||||
detector:
|
||||
name: null
|
||||
descriptor:
|
||||
name: null
|
||||
ground_truth:
|
||||
name: matchers.homography_matcher
|
||||
th_positive: 3
|
||||
th_negative: 3
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
filter_threshold: 0.1
|
||||
flash: false
|
||||
checkpointed: true
|
||||
input_dim: 128
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 40
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 500
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
start: 20
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
|
||||
benchmarks:
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
|
@ -0,0 +1,70 @@
|
|||
data:
|
||||
name: megadepth
|
||||
preprocessing:
|
||||
resize: 1024
|
||||
side: long
|
||||
square_pad: True
|
||||
train_split: train_scenes_clean.txt
|
||||
train_num_per_scene: 300
|
||||
val_split: valid_scenes_clean.txt
|
||||
val_pairs: valid_pairs.txt
|
||||
min_overlap: 0.1
|
||||
max_overlap: 0.7
|
||||
num_overlap_bins: 3
|
||||
read_depth: true
|
||||
read_image: true
|
||||
batch_size: 32
|
||||
num_workers: 14
|
||||
load_features:
|
||||
do: false # enable this if you have cached predictions
|
||||
path: exports/megadepth-undist-depth-r1024_ALIKED-k2048-n16/{scene}.h5
|
||||
padding_length: 2048
|
||||
padding_fn: pad_local_features
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.aliked
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
trainable: False
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
filter_threshold: 0.1
|
||||
flash: false
|
||||
checkpointed: true
|
||||
input_dim: 128
|
||||
ground_truth:
|
||||
name: matchers.depth_matcher
|
||||
th_positive: 3
|
||||
th_negative: 5
|
||||
th_epi: 5
|
||||
allow_no_extract: True
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 50
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 1000
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
start: 30
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
dataset_callback_fn: sample_new_items
|
||||
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
|
||||
benchmarks:
|
||||
megadepth1500:
|
||||
data:
|
||||
preprocessing:
|
||||
side: long
|
||||
resize: 1600
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
model:
|
||||
extractor:
|
||||
max_num_keypoints: 1024
|
|
@ -0,0 +1,24 @@
|
|||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.disk_kornia
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
matcher:
|
||||
name: matchers.nearest_neighbor_matcher
|
||||
benchmarks:
|
||||
megadepth1500:
|
||||
data:
|
||||
preprocessing:
|
||||
side: long
|
||||
resize: 1600
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
model:
|
||||
extractor:
|
||||
max_num_keypoints: 1024 # overwrite config above
|
|
@ -0,0 +1,28 @@
|
|||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.disk_kornia
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
matcher:
|
||||
name: matchers.lightglue_pretrained
|
||||
features: disk
|
||||
depth_confidence: -1
|
||||
width_confidence: -1
|
||||
filter_threshold: 0.1
|
||||
benchmarks:
|
||||
megadepth1500:
|
||||
data:
|
||||
preprocessing:
|
||||
side: long
|
||||
resize: 1600
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
model:
|
||||
extractor:
|
||||
max_num_keypoints: 1024 # overwrite config above
|
|
@ -0,0 +1,28 @@
|
|||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.sift
|
||||
detector: pycolmap_cuda
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.00666666
|
||||
nms_radius: -1
|
||||
pycolmap_options:
|
||||
first_octave: -1
|
||||
matcher:
|
||||
name: matchers.nearest_neighbor_matcher
|
||||
benchmarks:
|
||||
megadepth1500:
|
||||
data:
|
||||
preprocessing:
|
||||
side: long
|
||||
resize: 1600
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
model:
|
||||
extractor:
|
||||
max_num_keypoints: 1024 # overwrite config above
|
|
@ -0,0 +1,48 @@
|
|||
data:
|
||||
name: homographies
|
||||
data_dir: revisitop1m
|
||||
train_size: 150000
|
||||
val_size: 2000
|
||||
batch_size: 64
|
||||
num_workers: 14
|
||||
homography:
|
||||
difficulty: 0.7
|
||||
max_angle: 45
|
||||
photometric:
|
||||
name: lg
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.sift
|
||||
detector: pycolmap_cuda
|
||||
max_num_keypoints: 1024
|
||||
force_num_keypoints: True
|
||||
detection_threshold: 0.0001
|
||||
trainable: False
|
||||
ground_truth:
|
||||
name: matchers.homography_matcher
|
||||
th_positive: 3
|
||||
th_negative: 3
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
filter_threshold: 0.1
|
||||
flash: false
|
||||
checkpointed: true
|
||||
input_dim: 128
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 40
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 500
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
start: 20
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
|
||||
benchmarks:
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
|
@ -0,0 +1,74 @@
|
|||
data:
|
||||
name: megadepth
|
||||
preprocessing:
|
||||
resize: 1024
|
||||
side: long
|
||||
square_pad: True
|
||||
train_split: train_scenes_clean.txt
|
||||
train_num_per_scene: 300
|
||||
val_split: valid_scenes_clean.txt
|
||||
val_pairs: valid_pairs.txt
|
||||
min_overlap: 0.1
|
||||
max_overlap: 0.7
|
||||
num_overlap_bins: 3
|
||||
read_depth: true
|
||||
read_image: true
|
||||
batch_size: 32
|
||||
num_workers: 14
|
||||
load_features:
|
||||
do: false # enable this if you have cached predictions
|
||||
path: exports/megadepth-undist-depth-r1024_pycolmap_SIFTGPU-nms3-fixed-k2048/{scene}.h5
|
||||
padding_length: 2048
|
||||
padding_fn: pad_local_features
|
||||
data_keys: ["keypoints", "keypoint_scores", "descriptors", "oris", "scales"]
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.sift
|
||||
detector: pycolmap_cuda
|
||||
max_num_keypoints: 2048
|
||||
force_num_keypoints: True
|
||||
detection_threshold: 0.0001
|
||||
trainable: False
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
filter_threshold: 0.1
|
||||
flash: false
|
||||
checkpointed: true
|
||||
add_scale_ori: true
|
||||
input_dim: 128
|
||||
ground_truth:
|
||||
name: matchers.depth_matcher
|
||||
th_positive: 3
|
||||
th_negative: 5
|
||||
th_epi: 5
|
||||
allow_no_extract: True
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 50
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 1000
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
start: 30
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
dataset_callback_fn: sample_new_items
|
||||
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
|
||||
benchmarks:
|
||||
megadepth1500:
|
||||
data:
|
||||
preprocessing:
|
||||
side: long
|
||||
resize: 1600
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
model:
|
||||
extractor:
|
||||
max_num_keypoints: 1024
|
|
@ -0,0 +1,25 @@
|
|||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: gluefactory_nonfree.superpoint
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
nms_radius: 3
|
||||
matcher:
|
||||
name: matchers.nearest_neighbor_matcher
|
||||
benchmarks:
|
||||
megadepth1500:
|
||||
data:
|
||||
preprocessing:
|
||||
side: long
|
||||
resize: 1600
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 1.0
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
model:
|
||||
extractor:
|
||||
max_num_keypoints: 1024 # overwrite config above
|
|
@ -0,0 +1,29 @@
|
|||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: gluefactory_nonfree.superpoint
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
nms_radius: 3
|
||||
matcher:
|
||||
name: matchers.lightglue_pretrained
|
||||
features: superpoint
|
||||
depth_confidence: -1
|
||||
width_confidence: -1
|
||||
filter_threshold: 0.1
|
||||
benchmarks:
|
||||
megadepth1500:
|
||||
data:
|
||||
preprocessing:
|
||||
side: long
|
||||
resize: 1600
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
model:
|
||||
extractor:
|
||||
max_num_keypoints: 1024 # overwrite config above
|
|
@ -0,0 +1,73 @@
|
|||
data:
|
||||
name: homographies
|
||||
homography:
|
||||
difficulty: 0.5
|
||||
max_angle: 30
|
||||
patch_shape: [640, 480]
|
||||
photometric:
|
||||
p: 0.75
|
||||
train_size: 900000
|
||||
val_size: 1000
|
||||
batch_size: 80 # 20 per 10GB of GPU mem (12 for triplet)
|
||||
num_workers: 15
|
||||
model:
|
||||
name: gluefactory.models.two_view_pipeline
|
||||
extractor:
|
||||
name: gluefactory.models.lines.wireframe
|
||||
trainable: False
|
||||
point_extractor:
|
||||
name: gluefactory.models.extractors.superpoint_open
|
||||
# name: disk
|
||||
# chunk: 10
|
||||
max_num_keypoints: 1000
|
||||
force_num_keypoints: true
|
||||
trainable: False
|
||||
line_extractor:
|
||||
name: gluefactory.models.lines.lsd
|
||||
max_num_lines: 250
|
||||
force_num_lines: True
|
||||
min_length: 15
|
||||
trainable: False
|
||||
wireframe_params:
|
||||
merge_points: True
|
||||
merge_line_endpoints: True
|
||||
nms_radius: 4
|
||||
detector:
|
||||
name: null
|
||||
descriptor:
|
||||
name: null
|
||||
ground_truth:
|
||||
name: gluefactory.models.matchers.homography_matcher
|
||||
trainable: False
|
||||
use_points: True
|
||||
use_lines: True
|
||||
th_positive: 3
|
||||
th_negative: 5
|
||||
matcher:
|
||||
name: gluefactory.models.matchers.gluestick
|
||||
input_dim: 256 # 128 for DISK
|
||||
descriptor_dim: 256 # 128 for DISK
|
||||
inter_supervision: [2, 5]
|
||||
GNN_layers: [
|
||||
self, cross, self, cross, self, cross,
|
||||
self, cross, self, cross, self, cross,
|
||||
self, cross, self, cross, self, cross,
|
||||
]
|
||||
checkpointed: true
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 200
|
||||
log_every_iter: 400
|
||||
eval_every_iter: 700
|
||||
save_every_iter: 1400
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
type: exp # exp or multi_step
|
||||
start: 200e3
|
||||
exp_div_10: 200e3
|
||||
gamma: 0.5
|
||||
step: 50e3
|
||||
n_steps: 4
|
||||
submodules: []
|
||||
# clip_grad: 10 # Use only with mixed precision
|
||||
# load_experiment:
|
|
@ -0,0 +1,69 @@
|
|||
data:
|
||||
name: gluefactory.datasets.megadepth
|
||||
views: 2
|
||||
preprocessing:
|
||||
resize: 640
|
||||
square_pad: True
|
||||
batch_size: 60
|
||||
num_workers: 15
|
||||
model:
|
||||
name: gluefactory.models.two_view_pipeline
|
||||
extractor:
|
||||
name: gluefactory.models.lines.wireframe
|
||||
trainable: False
|
||||
point_extractor:
|
||||
name: gluefactory.models.extractors.superpoint_open
|
||||
# name: disk
|
||||
# chunk: 10
|
||||
max_num_keypoints: 1000
|
||||
force_num_keypoints: true
|
||||
trainable: False
|
||||
line_extractor:
|
||||
name: gluefactory.models.lines.lsd
|
||||
max_num_lines: 250
|
||||
force_num_lines: True
|
||||
min_length: 15
|
||||
trainable: False
|
||||
wireframe_params:
|
||||
merge_points: True
|
||||
merge_line_endpoints: True
|
||||
nms_radius: 4
|
||||
detector:
|
||||
name: null
|
||||
descriptor:
|
||||
name: null
|
||||
ground_truth:
|
||||
name: gluefactory.models.matchers.depth_matcher
|
||||
trainable: False
|
||||
use_points: True
|
||||
use_lines: True
|
||||
th_positive: 3
|
||||
th_negative: 5
|
||||
matcher:
|
||||
name: gluefactory.models.matchers.gluestick
|
||||
input_dim: 256 # 128 for DISK
|
||||
descriptor_dim: 256 # 128 for DISK
|
||||
inter_supervision: null
|
||||
GNN_layers: [
|
||||
self, cross, self, cross, self, cross,
|
||||
self, cross, self, cross, self, cross,
|
||||
self, cross, self, cross, self, cross,
|
||||
]
|
||||
checkpointed: true
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 200
|
||||
log_every_iter: 10
|
||||
eval_every_iter: 100
|
||||
save_every_iter: 500
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
type: exp # exp or multi_step
|
||||
start: 200e3
|
||||
exp_div_10: 200e3
|
||||
gamma: 0.5
|
||||
step: 50e3
|
||||
n_steps: 4
|
||||
submodules: []
|
||||
# clip_grad: 10 # Use only with mixed precision
|
||||
load_experiment: gluestick_H
|
|
@ -0,0 +1,49 @@
|
|||
model:
|
||||
name: gluefactory.models.two_view_pipeline
|
||||
extractor:
|
||||
name: gluefactory.models.lines.wireframe
|
||||
point_extractor:
|
||||
name: gluefactory_nonfree.superpoint
|
||||
trainable: False
|
||||
dense_outputs: True
|
||||
max_num_keypoints: 2048
|
||||
force_num_keypoints: False
|
||||
detection_threshold: 0
|
||||
line_extractor:
|
||||
name: gluefactory.models.lines.lsd
|
||||
trainable: False
|
||||
max_num_lines: 512
|
||||
force_num_lines: False
|
||||
min_length: 15
|
||||
wireframe_params:
|
||||
merge_points: True
|
||||
merge_line_endpoints: True
|
||||
nms_radius: 3
|
||||
matcher:
|
||||
name: gluefactory.models.matchers.gluestick
|
||||
weights: checkpoint_GlueStick_MD # This will download weights from internet
|
||||
|
||||
# ground_truth: # for ETH3D, comment otherwise
|
||||
# name: gluefactory.models.matchers.depth_matcher
|
||||
# use_lines: True
|
||||
|
||||
benchmarks:
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: homography_est
|
||||
ransac_th: -1 # [1., 1.5, 2., 2.5, 3.]
|
||||
megadepth1500:
|
||||
data:
|
||||
preprocessing:
|
||||
side: long
|
||||
resize: 1600
|
||||
eval:
|
||||
estimator: poselib
|
||||
ransac_th: -1
|
||||
eth3d:
|
||||
ground_truth:
|
||||
name: gluefactory.models.matchers.depth_matcher
|
||||
use_lines: True
|
||||
eval:
|
||||
plot_methods: [ ] # ['sp+NN', 'sp+sg', 'superpoint+lsd+gluestick']
|
||||
plot_line_methods: [ ] # ['superpoint+lsd+gluestick', 'sp+deeplsd+gs']
|
|
@ -0,0 +1,26 @@
|
|||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: gluefactory_nonfree.superpoint
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
nms_radius: 3
|
||||
matcher:
|
||||
name: gluefactory_nonfree.superglue
|
||||
benchmarks:
|
||||
megadepth1500:
|
||||
data:
|
||||
preprocessing:
|
||||
side: long
|
||||
resize: 1600
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
model:
|
||||
extractor:
|
||||
max_num_keypoints: 1024 # overwrite config above
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.superpoint_open
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
nms_radius: 3
|
||||
matcher:
|
||||
name: matchers.nearest_neighbor_matcher
|
||||
benchmarks:
|
||||
megadepth1500:
|
||||
data:
|
||||
preprocessing:
|
||||
side: long
|
||||
resize: 1600
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 1.0
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
model:
|
||||
extractor:
|
||||
max_num_keypoints: 1024 # overwrite config above
|
|
@ -0,0 +1,47 @@
|
|||
data:
|
||||
name: homographies
|
||||
data_dir: revisitop1m
|
||||
train_size: 150000
|
||||
val_size: 2000
|
||||
batch_size: 128
|
||||
num_workers: 14
|
||||
homography:
|
||||
difficulty: 0.7
|
||||
max_angle: 45
|
||||
photometric:
|
||||
name: lg
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.superpoint_open
|
||||
max_num_keypoints: 512
|
||||
force_num_keypoints: True
|
||||
detection_threshold: -1
|
||||
nms_radius: 3
|
||||
trainable: False
|
||||
ground_truth:
|
||||
name: matchers.homography_matcher
|
||||
th_positive: 3
|
||||
th_negative: 3
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
filter_threshold: 0.1
|
||||
flash: false
|
||||
checkpointed: true
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 40
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 500
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
start: 20
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
|
||||
benchmarks:
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
|
@ -0,0 +1,71 @@
|
|||
data:
|
||||
name: megadepth
|
||||
preprocessing:
|
||||
resize: 1024
|
||||
side: long
|
||||
square_pad: True
|
||||
train_split: train_scenes_clean.txt
|
||||
train_num_per_scene: 300
|
||||
val_split: valid_scenes_clean.txt
|
||||
val_pairs: valid_pairs.txt
|
||||
min_overlap: 0.1
|
||||
max_overlap: 0.7
|
||||
num_overlap_bins: 3
|
||||
read_depth: true
|
||||
read_image: true
|
||||
batch_size: 32
|
||||
num_workers: 14
|
||||
load_features:
|
||||
do: false # enable this if you have cached predictions
|
||||
path: exports/megadepth-undist-depth-r1024_SP-open-k2048-nms3/{scene}.h5
|
||||
padding_length: 2048
|
||||
padding_fn: pad_local_features
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.superpoint_open
|
||||
max_num_keypoints: 2048
|
||||
force_num_keypoints: True
|
||||
detection_threshold: -1
|
||||
nms_radius: 3
|
||||
trainable: False
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
filter_threshold: 0.1
|
||||
flash: false
|
||||
checkpointed: true
|
||||
ground_truth:
|
||||
name: matchers.depth_matcher
|
||||
th_positive: 3
|
||||
th_negative: 5
|
||||
th_epi: 5
|
||||
allow_no_extract: True
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 50
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 1000
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
start: 30
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
dataset_callback_fn: sample_new_items
|
||||
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
|
||||
benchmarks:
|
||||
megadepth1500:
|
||||
data:
|
||||
preprocessing:
|
||||
side: long
|
||||
resize: 1600
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
model:
|
||||
extractor:
|
||||
max_num_keypoints: 1024
|
|
@ -0,0 +1,24 @@
|
|||
import importlib.util
|
||||
from .base_dataset import BaseDataset
|
||||
from ..utils.tools import get_class
|
||||
|
||||
|
||||
def get_dataset(name):
|
||||
import_paths = [name, f"{__name__}.{name}"]
|
||||
for path in import_paths:
|
||||
try:
|
||||
spec = importlib.util.find_spec(path)
|
||||
except ModuleNotFoundError:
|
||||
spec = None
|
||||
if spec is not None:
|
||||
try:
|
||||
return get_class(path, BaseDataset)
|
||||
except AssertionError:
|
||||
mod = __import__(path, fromlist=[""])
|
||||
try:
|
||||
return mod.__main_dataset__
|
||||
except AttributeError as exc:
|
||||
print(exc)
|
||||
continue
|
||||
|
||||
raise RuntimeError(f'Dataset {name} not found in any of [{" ".join(import_paths)}]')
|
|
@ -0,0 +1,244 @@
|
|||
from typing import Union
|
||||
|
||||
import albumentations as A
|
||||
import numpy as np
|
||||
import torch
|
||||
from albumentations.pytorch.transforms import ToTensorV2
|
||||
from omegaconf import OmegaConf
|
||||
import cv2
|
||||
|
||||
|
||||
class IdentityTransform(A.ImageOnlyTransform):
|
||||
def apply(self, img, **params):
|
||||
return img
|
||||
|
||||
def get_transform_init_args_names(self):
|
||||
return ()
|
||||
|
||||
|
||||
class RandomAdditiveShade(A.ImageOnlyTransform):
|
||||
def __init__(
|
||||
self,
|
||||
nb_ellipses=10,
|
||||
transparency_limit=[-0.5, 0.8],
|
||||
kernel_size_limit=[150, 350],
|
||||
always_apply=False,
|
||||
p=0.5,
|
||||
):
|
||||
super().__init__(always_apply, p)
|
||||
self.nb_ellipses = nb_ellipses
|
||||
self.transparency_limit = transparency_limit
|
||||
self.kernel_size_limit = kernel_size_limit
|
||||
|
||||
def apply(self, img, **params):
|
||||
if img.dtype == np.float32:
|
||||
shaded = self._py_additive_shade(img * 255.0)
|
||||
shaded /= 255.0
|
||||
elif img.dtype == np.uint8:
|
||||
shaded = self._py_additive_shade(img.astype(np.float32))
|
||||
shaded = shaded.astype(np.uint8)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Data augmentation not available for type: {img.dtype}"
|
||||
)
|
||||
return shaded
|
||||
|
||||
def _py_additive_shade(self, img):
|
||||
grayscale = len(img.shape) == 2
|
||||
if grayscale:
|
||||
img = img[None]
|
||||
min_dim = min(img.shape[:2]) / 4
|
||||
mask = np.zeros(img.shape[:2], img.dtype)
|
||||
for i in range(self.nb_ellipses):
|
||||
ax = int(max(np.random.rand() * min_dim, min_dim / 5))
|
||||
ay = int(max(np.random.rand() * min_dim, min_dim / 5))
|
||||
max_rad = max(ax, ay)
|
||||
x = np.random.randint(max_rad, img.shape[1] - max_rad) # center
|
||||
y = np.random.randint(max_rad, img.shape[0] - max_rad)
|
||||
angle = np.random.rand() * 90
|
||||
cv2.ellipse(mask, (x, y), (ax, ay), angle, 0, 360, 255, -1)
|
||||
|
||||
transparency = np.random.uniform(*self.transparency_limit)
|
||||
ks = np.random.randint(*self.kernel_size_limit)
|
||||
if (ks % 2) == 0: # kernel_size has to be odd
|
||||
ks += 1
|
||||
mask = cv2.GaussianBlur(mask.astype(np.float32), (ks, ks), 0)
|
||||
shaded = img * (1 - transparency * mask[..., np.newaxis] / 255.0)
|
||||
out = np.clip(shaded, 0, 255)
|
||||
if grayscale:
|
||||
out = out.squeeze(0)
|
||||
return out
|
||||
|
||||
def get_transform_init_args_names(self):
|
||||
return "transparency_limit", "kernel_size_limit", "nb_ellipses"
|
||||
|
||||
|
||||
def kw(entry: Union[float, dict], n=None, **default):
|
||||
if not isinstance(entry, dict):
|
||||
entry = {"p": entry}
|
||||
entry = OmegaConf.create(entry)
|
||||
if n is not None:
|
||||
entry = default.get(n, entry)
|
||||
return OmegaConf.merge(default, entry)
|
||||
|
||||
|
||||
def kwi(entry: Union[float, dict], n=None, **default):
|
||||
conf = kw(entry, n=n, **default)
|
||||
return {k: conf[k] for k in set(default.keys()).union(set(["p"]))}
|
||||
|
||||
|
||||
def replay_str(transforms, s="Replay:\n", log_inactive=True):
|
||||
for t in transforms:
|
||||
if "transforms" in t.keys():
|
||||
s = replay_str(t["transforms"], s=s)
|
||||
elif t["applied"] or log_inactive:
|
||||
s += t["__class_fullname__"] + " " + str(t["applied"]) + "\n"
|
||||
return s
|
||||
|
||||
|
||||
class BaseAugmentation(object):
|
||||
base_default_conf = {
|
||||
"name": "???",
|
||||
"shuffle": False,
|
||||
"p": 1.0,
|
||||
"verbose": False,
|
||||
"dtype": "uint8", # (byte, float)
|
||||
}
|
||||
|
||||
default_conf = {}
|
||||
|
||||
def __init__(self, conf={}):
|
||||
"""Perform some logic and call the _init method of the child model."""
|
||||
default_conf = OmegaConf.merge(
|
||||
OmegaConf.create(self.base_default_conf),
|
||||
OmegaConf.create(self.default_conf),
|
||||
)
|
||||
OmegaConf.set_struct(default_conf, True)
|
||||
if isinstance(conf, dict):
|
||||
conf = OmegaConf.create(conf)
|
||||
self.conf = OmegaConf.merge(default_conf, conf)
|
||||
OmegaConf.set_readonly(self.conf, True)
|
||||
self._init(self.conf)
|
||||
|
||||
self.conf = OmegaConf.merge(self.conf, conf)
|
||||
if self.conf.verbose:
|
||||
self.compose = A.ReplayCompose
|
||||
else:
|
||||
self.compose = A.Compose
|
||||
if self.conf.dtype == "uint8":
|
||||
self.dtype = np.uint8
|
||||
self.preprocess = A.FromFloat(always_apply=True, dtype="uint8")
|
||||
self.postprocess = A.ToFloat(always_apply=True)
|
||||
elif self.conf.dtype == "float32":
|
||||
self.dtype = np.float32
|
||||
self.preprocess = A.ToFloat(always_apply=True)
|
||||
self.postprocess = IdentityTransform()
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype {self.conf.dtype}")
|
||||
self.to_tensor = ToTensorV2()
|
||||
|
||||
def _init(self, conf):
|
||||
"""Child class overwrites this, setting up a list of transforms"""
|
||||
self.transforms = []
|
||||
|
||||
def __call__(self, image, return_tensor=False):
|
||||
"""image as HW or HWC"""
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().detach().numpy()
|
||||
data = {"image": image}
|
||||
if image.dtype != self.dtype:
|
||||
data = self.preprocess(**data)
|
||||
transforms = self.transforms
|
||||
if self.conf.shuffle:
|
||||
order = [i for i, _ in enumerate(transforms)]
|
||||
np.random.shuffle(order)
|
||||
transforms = [transforms[i] for i in order]
|
||||
transformed = self.compose(transforms, p=self.conf.p)(**data)
|
||||
if self.conf.verbose:
|
||||
print(replay_str(transformed["replay"]["transforms"]))
|
||||
transformed = self.postprocess(**transformed)
|
||||
if return_tensor:
|
||||
return self.to_tensor(**transformed)["image"]
|
||||
else:
|
||||
return transformed["image"]
|
||||
|
||||
|
||||
class IdentityAugmentation(BaseAugmentation):
|
||||
default_conf = {}
|
||||
|
||||
def _init(self, conf):
|
||||
self.transforms = [IdentityTransform(p=1.0)]
|
||||
|
||||
|
||||
class DarkAugmentation(BaseAugmentation):
|
||||
default_conf = {"p": 0.75}
|
||||
|
||||
def _init(self, conf):
|
||||
bright_contr = 0.5
|
||||
blur = 0.1
|
||||
random_gamma = 0.1
|
||||
hue = 0.1
|
||||
self.transforms = [
|
||||
A.RandomRain(p=0.2),
|
||||
A.RandomBrightnessContrast(
|
||||
**kw(
|
||||
bright_contr,
|
||||
brightness_limit=(-0.4, 0.0),
|
||||
contrast_limit=(-0.3, 0.0),
|
||||
)
|
||||
),
|
||||
A.OneOf(
|
||||
[
|
||||
A.Blur(**kwi(blur, p=0.1, blur_limit=(3, 9), n="blur")),
|
||||
A.MotionBlur(
|
||||
**kwi(blur, p=0.2, blur_limit=(3, 25), n="motion_blur")
|
||||
),
|
||||
A.ISONoise(),
|
||||
A.ImageCompression(),
|
||||
],
|
||||
**kwi(blur, p=0.1),
|
||||
),
|
||||
A.RandomGamma(**kw(random_gamma, gamma_limit=(15, 65))),
|
||||
A.OneOf(
|
||||
[
|
||||
A.Equalize(),
|
||||
A.CLAHE(p=0.2),
|
||||
A.ToGray(),
|
||||
A.ToSepia(p=0.1),
|
||||
A.HueSaturationValue(**kw(hue, val_shift_limit=(-100, -40))),
|
||||
],
|
||||
p=0.5,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class LGAugmentation(BaseAugmentation):
|
||||
default_conf = {"p": 0.95}
|
||||
|
||||
def _init(self, conf):
|
||||
self.transforms = [
|
||||
A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
|
||||
A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)),
|
||||
A.OneOf(
|
||||
[
|
||||
A.Blur(blur_limit=(3, 9)),
|
||||
A.MotionBlur(blur_limit=(3, 25)),
|
||||
A.ISONoise(),
|
||||
A.ImageCompression(),
|
||||
],
|
||||
p=0.1,
|
||||
),
|
||||
A.Blur(p=0.1, blur_limit=(3, 9)),
|
||||
A.MotionBlur(p=0.1, blur_limit=(3, 25)),
|
||||
A.RandomBrightnessContrast(
|
||||
p=0.5, brightness_limit=(-0.4, 0.0), contrast_limit=(-0.3, 0.0)
|
||||
),
|
||||
A.CLAHE(p=0.2),
|
||||
]
|
||||
|
||||
|
||||
augmentations = {
|
||||
"dark": DarkAugmentation,
|
||||
"lg": LGAugmentation,
|
||||
"identity": IdentityAugmentation,
|
||||
}
|
|
@ -0,0 +1,205 @@
|
|||
"""
|
||||
Base class for dataset.
|
||||
See mnist.py for an example of dataset.
|
||||
"""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import collections
|
||||
import logging
|
||||
from omegaconf import OmegaConf
|
||||
import omegaconf
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Sampler, get_worker_info
|
||||
from torch.utils.data._utils.collate import (
|
||||
default_collate_err_msg_format,
|
||||
np_str_obj_array_pattern,
|
||||
)
|
||||
|
||||
from ..utils.tensor import string_classes
|
||||
from ..utils.tools import set_num_threads, set_seed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopSampler(Sampler):
|
||||
def __init__(self, loop_size, total_size=None):
|
||||
self.loop_size = loop_size
|
||||
self.total_size = total_size - (total_size % loop_size)
|
||||
|
||||
def __iter__(self):
|
||||
return (i % self.loop_size for i in range(self.total_size))
|
||||
|
||||
def __len__(self):
|
||||
return self.total_size
|
||||
|
||||
|
||||
def worker_init_fn(i):
|
||||
info = get_worker_info()
|
||||
if hasattr(info.dataset, "conf"):
|
||||
conf = info.dataset.conf
|
||||
set_seed(info.id + conf.seed)
|
||||
set_num_threads(conf.num_threads)
|
||||
else:
|
||||
set_num_threads(1)
|
||||
|
||||
|
||||
def collate(batch):
|
||||
"""Difference with PyTorch default_collate: it can stack of other objects."""
|
||||
if not isinstance(batch, list): # no batching
|
||||
return batch
|
||||
elem = batch[0]
|
||||
elem_type = type(elem)
|
||||
if isinstance(elem, torch.Tensor):
|
||||
if torch.utils.data.get_worker_info() is not None:
|
||||
# If we're in a background process, concatenate directly into a
|
||||
# shared memory tensor to avoid an extra copy
|
||||
numel = sum([x.numel() for x in batch])
|
||||
try:
|
||||
storage = elem.untyped_storage()._new_shared(numel) # noqa: F841
|
||||
except AttributeError:
|
||||
storage = elem.storage()._new_shared(numel) # noqa: F841
|
||||
return torch.stack(batch, dim=0)
|
||||
elif (
|
||||
elem_type.__module__ == "numpy"
|
||||
and elem_type.__name__ != "str_"
|
||||
and elem_type.__name__ != "string_"
|
||||
):
|
||||
if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
|
||||
# array of string classes and object
|
||||
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
||||
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
||||
return collate([torch.as_tensor(b) for b in batch])
|
||||
elif elem.shape == (): # scalars
|
||||
return torch.as_tensor(batch)
|
||||
elif isinstance(elem, float):
|
||||
return torch.tensor(batch, dtype=torch.float64)
|
||||
elif isinstance(elem, int):
|
||||
return torch.tensor(batch)
|
||||
elif isinstance(elem, string_classes):
|
||||
return batch
|
||||
elif isinstance(elem, collections.abc.Mapping):
|
||||
return {key: collate([d[key] for d in batch]) for key in elem}
|
||||
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
|
||||
return elem_type(*(collate(samples) for samples in zip(*batch)))
|
||||
elif isinstance(elem, collections.abc.Sequence):
|
||||
# check to make sure that the elements in batch have consistent size
|
||||
it = iter(batch)
|
||||
elem_size = len(next(it))
|
||||
if not all(len(elem) == elem_size for elem in it):
|
||||
raise RuntimeError("each element in list of batch should be of equal size")
|
||||
transposed = zip(*batch)
|
||||
return [collate(samples) for samples in transposed]
|
||||
elif elem is None:
|
||||
return elem
|
||||
else:
|
||||
# try to stack anyway in case the object implements stacking.
|
||||
return torch.stack(batch, 0)
|
||||
|
||||
|
||||
class BaseDataset(metaclass=ABCMeta):
|
||||
"""
|
||||
What the dataset model is expect to declare:
|
||||
default_conf: dictionary of the default configuration of the dataset.
|
||||
It overwrites base_default_conf in BaseModel, and it is overwritten by
|
||||
the user-provided configuration passed to __init__.
|
||||
Configurations can be nested.
|
||||
|
||||
_init(self, conf): initialization method, where conf is the final
|
||||
configuration object (also accessible with `self.conf`). Accessing
|
||||
unknown configuration entries will raise an error.
|
||||
|
||||
get_dataset(self, split): method that returns an instance of
|
||||
torch.utils.data.Dataset corresponding to the requested split string,
|
||||
which can be `'train'`, `'val'`, or `'test'`.
|
||||
"""
|
||||
|
||||
base_default_conf = {
|
||||
"name": "???",
|
||||
"num_workers": "???",
|
||||
"train_batch_size": "???",
|
||||
"val_batch_size": "???",
|
||||
"test_batch_size": "???",
|
||||
"shuffle_training": True,
|
||||
"batch_size": 1,
|
||||
"num_threads": 1,
|
||||
"seed": 0,
|
||||
"prefetch_factor": 2,
|
||||
}
|
||||
default_conf = {}
|
||||
|
||||
def __init__(self, conf):
|
||||
"""Perform some logic and call the _init method of the child model."""
|
||||
default_conf = OmegaConf.merge(
|
||||
OmegaConf.create(self.base_default_conf),
|
||||
OmegaConf.create(self.default_conf),
|
||||
)
|
||||
OmegaConf.set_struct(default_conf, True)
|
||||
if isinstance(conf, dict):
|
||||
conf = OmegaConf.create(conf)
|
||||
self.conf = OmegaConf.merge(default_conf, conf)
|
||||
OmegaConf.set_readonly(self.conf, True)
|
||||
logger.info(f"Creating dataset {self.__class__.__name__}")
|
||||
self._init(self.conf)
|
||||
|
||||
@abstractmethod
|
||||
def _init(self, conf):
|
||||
"""To be implemented by the child class."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_dataset(self, split):
|
||||
"""To be implemented by the child class."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_data_loader(self, split, shuffle=None, pinned=False, distributed=False):
|
||||
"""Return a data loader for a given split."""
|
||||
assert split in ["train", "val", "test"]
|
||||
dataset = self.get_dataset(split)
|
||||
try:
|
||||
batch_size = self.conf[split + "_batch_size"]
|
||||
except omegaconf.MissingMandatoryValue:
|
||||
batch_size = self.conf.batch_size
|
||||
num_workers = self.conf.get("num_workers", batch_size)
|
||||
if distributed:
|
||||
shuffle = False
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
else:
|
||||
sampler = None
|
||||
if shuffle is None:
|
||||
shuffle = split == "train" and self.conf.shuffle_training
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
pin_memory=pinned,
|
||||
collate_fn=collate,
|
||||
num_workers=num_workers,
|
||||
worker_init_fn=worker_init_fn,
|
||||
prefetch_factor=self.conf.prefetch_factor,
|
||||
drop_last=True if split == "train" else False,
|
||||
)
|
||||
|
||||
def get_overfit_loader(self, split):
|
||||
"""Return an overfit data loader.
|
||||
The training set is composed of a single duplicated batch, while
|
||||
the validation and test sets contain a single copy of this same batch.
|
||||
This is useful to debug a model and make sure that losses and metrics
|
||||
correlate well.
|
||||
"""
|
||||
assert split in ["train", "val", "test"]
|
||||
dataset = self.get_dataset("train")
|
||||
sampler = LoopSampler(
|
||||
self.conf.batch_size,
|
||||
len(dataset) if split == "train" else self.conf.batch_size,
|
||||
)
|
||||
num_workers = self.conf.get("num_workers", self.conf.batch_size)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=self.conf.batch_size,
|
||||
pin_memory=True,
|
||||
num_workers=num_workers,
|
||||
sampler=sampler,
|
||||
worker_init_fn=worker_init_fn,
|
||||
collate_fn=collate,
|
||||
)
|
|
@ -0,0 +1,254 @@
|
|||
"""
|
||||
ETH3D multi-view benchmark, used for line matching evaluation.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import zipfile
|
||||
|
||||
from .base_dataset import BaseDataset
|
||||
from .utils import scale_intrinsics
|
||||
from ..geometry.wrappers import Camera, Pose
|
||||
from ..settings import DATA_PATH
|
||||
from ..utils.image import ImagePreprocessor, load_image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def read_cameras(camera_file, scale_factor=None):
|
||||
"""Read the camera intrinsics from a file in COLMAP format."""
|
||||
with open(camera_file, "r") as f:
|
||||
raw_cameras = f.read().rstrip().split("\n")
|
||||
raw_cameras = raw_cameras[3:]
|
||||
cameras = []
|
||||
for c in raw_cameras:
|
||||
data = c.split(" ")
|
||||
fx, fy, cx, cy = np.array(list(map(float, data[4:])))
|
||||
K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=np.float32)
|
||||
if scale_factor is not None:
|
||||
K = scale_intrinsics(K, np.array([scale_factor, scale_factor]))
|
||||
cameras.append(Camera.from_calibration_matrix(K).float())
|
||||
return cameras
|
||||
|
||||
|
||||
def qvec2rotmat(qvec):
|
||||
"""Convert from quaternions to rotation matrix."""
|
||||
return np.array(
|
||||
[
|
||||
[
|
||||
1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
|
||||
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
|
||||
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
|
||||
],
|
||||
[
|
||||
2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
|
||||
1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
|
||||
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
|
||||
],
|
||||
[
|
||||
2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
|
||||
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
|
||||
1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ETH3DDataset(BaseDataset):
|
||||
default_conf = {
|
||||
"data_dir": "ETH3D_undistorted",
|
||||
"grayscale": True,
|
||||
"downsize_factor": 8,
|
||||
"min_covisibility": 500,
|
||||
"batch_size": 1,
|
||||
"two_view": True,
|
||||
"min_overlap": 0.5,
|
||||
"max_overlap": 1.0,
|
||||
"sort_by_overlap": False,
|
||||
"seed": 0,
|
||||
}
|
||||
|
||||
def _init(self, conf):
|
||||
self.grayscale = conf.grayscale
|
||||
self.downsize_factor = conf.downsize_factor
|
||||
|
||||
# Set random seeds
|
||||
np.random.seed(conf.seed)
|
||||
torch.manual_seed(conf.seed)
|
||||
|
||||
# Auto-download the dataset
|
||||
if not (DATA_PATH / conf.data_dir).exists():
|
||||
logger.info("Downloading the ETH3D dataset...")
|
||||
self.download_eth3d()
|
||||
|
||||
# Form pairs of images from the multiview dataset
|
||||
self.img_dir = DATA_PATH / conf.data_dir
|
||||
self.data = []
|
||||
for folder in self.img_dir.iterdir():
|
||||
img_folder = Path(folder, "images", "dslr_images_undistorted")
|
||||
depth_folder = Path(folder, "ground_truth_depth/undistorted_depth")
|
||||
depth_ext = ".png"
|
||||
names = [img.name for img in img_folder.iterdir()]
|
||||
names.sort()
|
||||
|
||||
# Read intrinsics and extrinsics data
|
||||
cameras = read_cameras(
|
||||
str(Path(folder, "dslr_calibration_undistorted", "cameras.txt")),
|
||||
1 / self.downsize_factor,
|
||||
)
|
||||
name_to_cam_idx = {name: {} for name in names}
|
||||
with open(
|
||||
str(Path(folder, "dslr_calibration_jpg", "images.txt")), "r"
|
||||
) as f:
|
||||
raw_data = f.read().rstrip().split("\n")[4::2]
|
||||
for raw_line in raw_data:
|
||||
line = raw_line.split(" ")
|
||||
img_name = os.path.basename(line[-1])
|
||||
name_to_cam_idx[img_name]["dist_camera_idx"] = int(line[-2])
|
||||
T_world_to_camera = {}
|
||||
image_visible_points3D = {}
|
||||
with open(
|
||||
str(Path(folder, "dslr_calibration_undistorted", "images.txt")), "r"
|
||||
) as f:
|
||||
lines = f.readlines()[4:] # Skip the header
|
||||
raw_poses = [line.strip("\n").split(" ") for line in lines[::2]]
|
||||
raw_points = [line.strip("\n").split(" ") for line in lines[1::2]]
|
||||
for raw_pose, raw_pts in zip(raw_poses, raw_points):
|
||||
img_name = os.path.basename(raw_pose[-1])
|
||||
# Extract the transform from world to camera
|
||||
target_extrinsics = list(map(float, raw_pose[1:8]))
|
||||
pose = np.eye(4, dtype=np.float32)
|
||||
pose[:3, :3] = qvec2rotmat(target_extrinsics[:4])
|
||||
pose[:3, 3] = target_extrinsics[4:]
|
||||
T_world_to_camera[img_name] = pose
|
||||
name_to_cam_idx[img_name]["undist_camera_idx"] = int(raw_pose[-2])
|
||||
# Extract the visible 3D points
|
||||
point3D_ids = [id for id in map(int, raw_pts[2::3]) if id != -1]
|
||||
image_visible_points3D[img_name] = set(point3D_ids)
|
||||
|
||||
# Extract the covisibility of each image
|
||||
num_imgs = len(names)
|
||||
n_covisible_points = np.zeros((num_imgs, num_imgs))
|
||||
for i in range(num_imgs - 1):
|
||||
for j in range(i + 1, num_imgs):
|
||||
visible_points3D1 = image_visible_points3D[names[i]]
|
||||
visible_points3D2 = image_visible_points3D[names[j]]
|
||||
n_covisible_points[i, j] = len(
|
||||
visible_points3D1 & visible_points3D2
|
||||
)
|
||||
|
||||
# Keep only the pairs with enough covisibility
|
||||
valid_pairs = np.where(n_covisible_points >= conf.min_covisibility)
|
||||
valid_pairs = np.stack(valid_pairs, axis=1)
|
||||
|
||||
self.data += [
|
||||
{
|
||||
"view0": {
|
||||
"name": names[i][:-4],
|
||||
"img_path": str(Path(img_folder, names[i])),
|
||||
"depth_path": str(Path(depth_folder, names[i][:-4]))
|
||||
+ depth_ext,
|
||||
"camera": cameras[name_to_cam_idx[names[i]]["dist_camera_idx"]],
|
||||
"T_w2cam": Pose.from_4x4mat(T_world_to_camera[names[i]]),
|
||||
},
|
||||
"view1": {
|
||||
"name": names[j][:-4],
|
||||
"img_path": str(Path(img_folder, names[j])),
|
||||
"depth_path": str(Path(depth_folder, names[j][:-4]))
|
||||
+ depth_ext,
|
||||
"camera": cameras[name_to_cam_idx[names[j]]["dist_camera_idx"]],
|
||||
"T_w2cam": Pose.from_4x4mat(T_world_to_camera[names[j]]),
|
||||
},
|
||||
"T_world_to_ref": Pose.from_4x4mat(T_world_to_camera[names[i]]),
|
||||
"T_world_to_target": Pose.from_4x4mat(T_world_to_camera[names[j]]),
|
||||
"T_0to1": Pose.from_4x4mat(
|
||||
np.float32(
|
||||
T_world_to_camera[names[j]]
|
||||
@ np.linalg.inv(T_world_to_camera[names[i]])
|
||||
)
|
||||
),
|
||||
"T_1to0": Pose.from_4x4mat(
|
||||
np.float32(
|
||||
T_world_to_camera[names[i]]
|
||||
@ np.linalg.inv(T_world_to_camera[names[j]])
|
||||
)
|
||||
),
|
||||
"n_covisible_points": n_covisible_points[i, j],
|
||||
}
|
||||
for (i, j) in valid_pairs
|
||||
]
|
||||
|
||||
# Print some info
|
||||
print("[Info] Successfully initialized dataset")
|
||||
print("\t Name: ETH3D")
|
||||
print("----------------------------------------")
|
||||
|
||||
def download_eth3d(self):
|
||||
data_dir = DATA_PATH / self.conf.data_dir
|
||||
tmp_dir = data_dir.parent / "ETH3D_tmp"
|
||||
if tmp_dir.exists():
|
||||
shutil.rmtree(tmp_dir)
|
||||
tmp_dir.mkdir(exist_ok=True, parents=True)
|
||||
url_base = "https://cvg-data.inf.ethz.ch/ETH3D_undistorted/"
|
||||
zip_name = "ETH3D_undistorted.zip"
|
||||
zip_path = tmp_dir / zip_name
|
||||
torch.hub.download_url_to_file(url_base + zip_name, zip_path)
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(tmp_dir)
|
||||
shutil.move(tmp_dir / zip_name.split(".")[0], data_dir)
|
||||
|
||||
def get_dataset(self, split):
|
||||
return ETH3DDataset(self.conf)
|
||||
|
||||
def _read_image(self, img_path):
|
||||
img = load_image(img_path, grayscale=self.grayscale)
|
||||
shape = img.shape[-2:]
|
||||
# instead of INTER_AREA this does bilinear interpolation with antialiasing
|
||||
img_data = ImagePreprocessor({"resize": max(shape) // self.downsize_factor})(
|
||||
img
|
||||
)
|
||||
return img_data
|
||||
|
||||
def read_depth(self, depth_path):
|
||||
if self.downsize_factor != 8:
|
||||
raise ValueError(
|
||||
"Undistorted depth only available for low res"
|
||||
+ " images(downsize_factor = 8)."
|
||||
)
|
||||
depth_img = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
|
||||
depth_img = depth_img.astype(np.float32) / 256
|
||||
|
||||
return depth_img
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Returns the data associated to a pair of images (reference, target)
|
||||
that are co-visible."""
|
||||
data = self.data[idx]
|
||||
# Load the images
|
||||
view0 = data.pop("view0")
|
||||
view1 = data.pop("view1")
|
||||
view0 = {**view0, **self._read_image(view0["img_path"])}
|
||||
view1 = {**view1, **self._read_image(view1["img_path"])}
|
||||
view0["scales"] = np.array([1.0, 1]).astype(np.float32)
|
||||
view1["scales"] = np.array([1.0, 1]).astype(np.float32)
|
||||
|
||||
# Load the depths
|
||||
view0["depth"] = self.read_depth(view0["depth_path"])
|
||||
view1["depth"] = self.read_depth(view1["depth_path"])
|
||||
|
||||
outputs = {
|
||||
**data,
|
||||
"view0": view0,
|
||||
"view1": view1,
|
||||
"name": f"{view0['name']}_{view1['name']}",
|
||||
}
|
||||
|
||||
return outputs
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
|
@ -0,0 +1,311 @@
|
|||
"""
|
||||
Simply load images from a folder or nested folders (does not have any split),
|
||||
and apply homographic adaptations to it. Yields an image pair without border
|
||||
artifacts.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import omegaconf
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
from .augmentations import IdentityAugmentation, augmentations
|
||||
from .base_dataset import BaseDataset
|
||||
from ..settings import DATA_PATH
|
||||
from ..models.cache_loader import CacheLoader, pad_local_features
|
||||
from ..utils.image import read_image
|
||||
from ..geometry.homography import (
|
||||
sample_homography_corners,
|
||||
compute_homography,
|
||||
warp_points,
|
||||
)
|
||||
from ..utils.tools import fork_rng
|
||||
from ..visualization.viz2d import plot_image_grid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def sample_homography(img, conf: dict, size: list):
|
||||
data = {}
|
||||
H, _, coords, _ = sample_homography_corners(img.shape[:2][::-1], **conf)
|
||||
data["image"] = cv2.warpPerspective(img, H, tuple(size))
|
||||
data["H_"] = H.astype(np.float32)
|
||||
data["coords"] = coords.astype(np.float32)
|
||||
data["image_size"] = np.array(size, dtype=np.float32)
|
||||
return data
|
||||
|
||||
|
||||
class HomographyDataset(BaseDataset):
|
||||
default_conf = {
|
||||
# image search
|
||||
"data_dir": "revisitop1m", # the top-level directory
|
||||
"image_dir": "jpg/", # the subdirectory with the images
|
||||
"image_list": "revisitop1m.txt", # optional: list or filename of list
|
||||
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
|
||||
# splits
|
||||
"train_size": 100,
|
||||
"val_size": 10,
|
||||
"shuffle_seed": 0, # or None to skip
|
||||
# image loading
|
||||
"grayscale": False,
|
||||
"triplet": False,
|
||||
"right_only": False, # image0 is orig (rescaled), image1 is right
|
||||
"reseed": False,
|
||||
"homography": {
|
||||
"difficulty": 0.8,
|
||||
"translation": 1.0,
|
||||
"max_angle": 60,
|
||||
"n_angles": 10,
|
||||
"patch_shape": [640, 480],
|
||||
"min_convexity": 0.05,
|
||||
},
|
||||
"photometric": {
|
||||
"name": "dark",
|
||||
"p": 0.75,
|
||||
# 'difficulty': 1.0, # currently unused
|
||||
},
|
||||
# feature loading
|
||||
"load_features": {
|
||||
"do": False,
|
||||
**CacheLoader.default_conf,
|
||||
"collate": False,
|
||||
"thresh": 0.0,
|
||||
"max_num_keypoints": -1,
|
||||
"force_num_keypoints": False,
|
||||
},
|
||||
}
|
||||
|
||||
def _init(self, conf):
|
||||
data_dir = DATA_PATH / conf.data_dir
|
||||
if not data_dir.exists():
|
||||
if conf.data_dir == "revisitop1m":
|
||||
logger.info("Downloading the revisitop1m dataset.")
|
||||
self.download_revisitop1m(data_dir)
|
||||
else:
|
||||
raise FileNotFoundError(data_dir)
|
||||
|
||||
image_dir = data_dir / conf.image_dir
|
||||
images = []
|
||||
if conf.image_list is None:
|
||||
glob = [conf.glob] if isinstance(conf.glob, str) else conf.glob
|
||||
for g in glob:
|
||||
images += list(image_dir.glob("**/" + g))
|
||||
if len(images) == 0:
|
||||
raise ValueError(f"Cannot find any image in folder: {image_dir}.")
|
||||
images = [i.relative_to(image_dir).as_posix() for i in images]
|
||||
images = sorted(images) # for deterministic behavior
|
||||
logger.info("Found %d images in folder.", len(images))
|
||||
elif isinstance(conf.image_list, (str, Path)):
|
||||
image_list = data_dir / conf.image_list
|
||||
if not image_list.exists():
|
||||
raise FileNotFoundError(f"Cannot find image list {image_list}.")
|
||||
images = image_list.read_text().rstrip("\n").split("\n")
|
||||
for image in images:
|
||||
if not (image_dir / image).exists():
|
||||
raise FileNotFoundError(image_dir / image)
|
||||
logger.info("Found %d images in list file.", len(images))
|
||||
elif isinstance(conf.image_list, omegaconf.listconfig.ListConfig):
|
||||
images = conf.image_list.to_container()
|
||||
for image in images:
|
||||
if not (image_dir / image).exists():
|
||||
raise FileNotFoundError(image_dir / image)
|
||||
else:
|
||||
raise ValueError(conf.image_list)
|
||||
|
||||
if conf.shuffle_seed is not None:
|
||||
np.random.RandomState(conf.shuffle_seed).shuffle(images)
|
||||
train_images = images[: conf.train_size]
|
||||
val_images = images[conf.train_size : conf.train_size + conf.val_size]
|
||||
self.images = {"train": train_images, "val": val_images}
|
||||
|
||||
def download_revisitop1m(self):
|
||||
data_dir = DATA_PATH / self.conf.data_dir
|
||||
tmp_dir = data_dir.parent / "revisitop1m_tmp"
|
||||
if tmp_dir.exists(): # The previous download failed.
|
||||
shutil.rmtree(tmp_dir)
|
||||
image_dir = tmp_dir / self.conf.image_dir
|
||||
image_dir.mkdir(exist_ok=True, parents=True)
|
||||
num_files = 100
|
||||
url_base = "http://ptak.felk.cvut.cz/revisitop/revisitop1m/"
|
||||
list_name = "revisitop1m.txt"
|
||||
torch.hub.download_url_to_file(url_base + list_name, tmp_dir / list_name)
|
||||
for n in tqdm(range(num_files), position=1):
|
||||
tar_name = "revisitop1m.{}.tar.gz".format(n + 1)
|
||||
tar_path = image_dir / tar_name
|
||||
torch.hub.download_url_to_file(url_base + "jpg/" + tar_name, tar_path)
|
||||
with tarfile.open(tar_path) as tar:
|
||||
tar.extractall(path=image_dir)
|
||||
tar_path.unlink()
|
||||
shutil.move(tmp_dir, data_dir)
|
||||
|
||||
def get_dataset(self, split):
|
||||
return _Dataset(self.conf, self.images[split], split)
|
||||
|
||||
|
||||
class _Dataset(torch.utils.data.Dataset):
|
||||
def __init__(self, conf, image_names, split):
|
||||
self.conf = conf
|
||||
self.split = split
|
||||
self.image_names = np.array(image_names)
|
||||
self.image_dir = DATA_PATH / conf.data_dir / conf.image_dir
|
||||
|
||||
aug_conf = conf.photometric
|
||||
aug_name = aug_conf.name
|
||||
assert (
|
||||
aug_name in augmentations.keys()
|
||||
), f'{aug_name} not in {" ".join(augmentations.keys())}'
|
||||
self.photo_augment = augmentations[aug_name](aug_conf)
|
||||
self.left_augment = (
|
||||
IdentityAugmentation() if conf.right_only else self.photo_augment
|
||||
)
|
||||
self.img_to_tensor = IdentityAugmentation()
|
||||
|
||||
if conf.load_features.do:
|
||||
self.feature_loader = CacheLoader(conf.load_features)
|
||||
|
||||
def _transform_keypoints(self, features, data):
|
||||
"""Transform keypoints by a homography, threshold them,
|
||||
and potentially keep only the best ones."""
|
||||
# Warp points
|
||||
features["keypoints"] = warp_points(
|
||||
features["keypoints"], data["H_"], inverse=False
|
||||
)
|
||||
h, w = data["image"].shape[1:3]
|
||||
valid = (
|
||||
(features["keypoints"][:, 0] >= 0)
|
||||
& (features["keypoints"][:, 0] <= w - 1)
|
||||
& (features["keypoints"][:, 1] >= 0)
|
||||
& (features["keypoints"][:, 1] <= h - 1)
|
||||
)
|
||||
features["keypoints"] = features["keypoints"][valid]
|
||||
|
||||
# Threshold
|
||||
if self.conf.load_features.thresh > 0:
|
||||
valid = features["keypoint_scores"] >= self.conf.load_features.thresh
|
||||
features = {k: v[valid] for k, v in features.items()}
|
||||
|
||||
# Get the top keypoints and pad
|
||||
n = self.conf.load_features.max_num_keypoints
|
||||
if n > -1:
|
||||
inds = np.argsort(-features["keypoint_scores"])
|
||||
features = {k: v[inds[:n]] for k, v in features.items()}
|
||||
|
||||
if self.conf.load_features.force_num_keypoints:
|
||||
features = pad_local_features(
|
||||
features, self.conf.load_features.max_num_keypoints
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.conf.reseed:
|
||||
with fork_rng(self.conf.seed + idx, False):
|
||||
return self.getitem(idx)
|
||||
else:
|
||||
return self.getitem(idx)
|
||||
|
||||
def _read_view(self, img, H_conf, ps, left=False):
|
||||
data = sample_homography(img, H_conf, ps)
|
||||
if left:
|
||||
data["image"] = self.left_augment(data["image"], return_tensor=True)
|
||||
else:
|
||||
data["image"] = self.photo_augment(data["image"], return_tensor=True)
|
||||
|
||||
gs = data["image"].new_tensor([0.299, 0.587, 0.114]).view(3, 1, 1)
|
||||
if self.conf.grayscale:
|
||||
data["image"] = (data["image"] * gs).sum(0, keepdim=True)
|
||||
|
||||
if self.conf.load_features.do:
|
||||
features = self.feature_loader({k: [v] for k, v in data.items()})
|
||||
features = self._transform_keypoints(features, data)
|
||||
data["cache"] = features
|
||||
|
||||
return data
|
||||
|
||||
def getitem(self, idx):
|
||||
name = self.image_names[idx]
|
||||
img = read_image(self.image_dir / name, False)
|
||||
if img is None:
|
||||
logging.warning("Image %s could not be read.", name)
|
||||
img = np.zeros((1024, 1024) + (() if self.conf.grayscale else (3,)))
|
||||
img = img.astype(np.float32) / 255.0
|
||||
size = img.shape[:2][::-1]
|
||||
ps = self.conf.homography.patch_shape
|
||||
|
||||
left_conf = omegaconf.OmegaConf.to_container(self.conf.homography)
|
||||
if self.conf.right_only:
|
||||
left_conf["difficulty"] = 0.0
|
||||
|
||||
data0 = self._read_view(img, left_conf, ps, left=True)
|
||||
data1 = self._read_view(img, self.conf.homography, ps, left=False)
|
||||
|
||||
H = compute_homography(data0["coords"], data1["coords"], [1, 1])
|
||||
|
||||
data = {
|
||||
"name": name,
|
||||
"original_image_size": np.array(size),
|
||||
"H_0to1": H.astype(np.float32),
|
||||
"idx": idx,
|
||||
"view0": data0,
|
||||
"view1": data1,
|
||||
}
|
||||
|
||||
if self.conf.triplet:
|
||||
# Generate third image
|
||||
data2 = self._read_view(img, self.conf.homography, ps, left=False)
|
||||
H02 = compute_homography(data0["coords"], data2["coords"], [1, 1])
|
||||
H12 = compute_homography(data1["coords"], data2["coords"], [1, 1])
|
||||
|
||||
data = {
|
||||
"H_0to2": H02.astype(np.float32),
|
||||
"H_1to2": H12.astype(np.float32),
|
||||
"view2": data2,
|
||||
**data,
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_names)
|
||||
|
||||
|
||||
def visualize(args):
|
||||
conf = {
|
||||
"batch_size": 1,
|
||||
"num_workers": 1,
|
||||
"prefetch_factor": 1,
|
||||
}
|
||||
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
|
||||
dataset = HomographyDataset(conf)
|
||||
loader = dataset.get_data_loader("train")
|
||||
logger.info("The dataset has %d elements.", len(loader))
|
||||
|
||||
with fork_rng(seed=dataset.conf.seed):
|
||||
images = []
|
||||
for _, data in zip(range(args.num_items), loader):
|
||||
images.append(
|
||||
(data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2))
|
||||
)
|
||||
plot_image_grid(images, dpi=args.dpi)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from .. import logger # overwrite the logger
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num_items", type=int, default=8)
|
||||
parser.add_argument("--dpi", type=int, default=100)
|
||||
parser.add_argument("dotlist", nargs="*")
|
||||
args = parser.parse_intermixed_args()
|
||||
visualize(args)
|
|
@ -0,0 +1,144 @@
|
|||
"""
|
||||
Simply load images from a folder or nested folders (does not have any split).
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import tarfile
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from .base_dataset import BaseDataset
|
||||
from ..settings import DATA_PATH
|
||||
from ..utils.image import load_image, ImagePreprocessor
|
||||
from ..utils.tools import fork_rng
|
||||
from ..visualization.viz2d import plot_image_grid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def read_homography(path):
|
||||
with open(path) as f:
|
||||
result = []
|
||||
for line in f.readlines():
|
||||
while " " in line: # Remove double spaces
|
||||
line = line.replace(" ", " ")
|
||||
line = line.replace(" \n", "").replace("\n", "")
|
||||
# Split and discard empty strings
|
||||
elements = list(filter(lambda s: s, line.split(" ")))
|
||||
if elements:
|
||||
result.append(elements)
|
||||
return np.array(result).astype(float)
|
||||
|
||||
|
||||
class HPatches(BaseDataset, torch.utils.data.Dataset):
|
||||
default_conf = {
|
||||
"preprocessing": ImagePreprocessor.default_conf,
|
||||
"data_dir": "hpatches-sequences-release",
|
||||
"subset": None,
|
||||
"ignore_large_images": True,
|
||||
"grayscale": False,
|
||||
}
|
||||
|
||||
# Large images that were ignored in previous papers
|
||||
ignored_scenes = (
|
||||
"i_contruction",
|
||||
"i_crownnight",
|
||||
"i_dc",
|
||||
"i_pencils",
|
||||
"i_whitebuilding",
|
||||
"v_artisans",
|
||||
"v_astronautis",
|
||||
"v_talent",
|
||||
)
|
||||
url = "http://icvl.ee.ic.ac.uk/vbalnt/hpatches/hpatches-sequences-release.tar.gz"
|
||||
|
||||
def _init(self, conf):
|
||||
assert conf.batch_size == 1
|
||||
self.preprocessor = ImagePreprocessor(conf.preprocessing)
|
||||
|
||||
self.root = DATA_PATH / conf.data_dir
|
||||
if not self.root.exists():
|
||||
logger.info("Downloading the HPatches dataset.")
|
||||
self.download()
|
||||
self.sequences = sorted([x.name for x in self.root.iterdir()])
|
||||
if not self.sequences:
|
||||
raise ValueError("No image found!")
|
||||
self.items = [] # (seq, q_idx, is_illu)
|
||||
for seq in self.sequences:
|
||||
if conf.ignore_large_images and seq in self.ignored_scenes:
|
||||
continue
|
||||
if conf.subset is not None and conf.subset != seq[0]:
|
||||
continue
|
||||
for i in range(2, 7):
|
||||
self.items.append((seq, i, seq[0] == "i"))
|
||||
|
||||
def download(self):
|
||||
data_dir = self.root.parent
|
||||
data_dir.mkdir(exist_ok=True, parents=True)
|
||||
tar_path = data_dir / self.url.rsplit("/", 1)[-1]
|
||||
torch.hub.download_url_to_file(self.url, tar_path)
|
||||
with tarfile.open(tar_path) as tar:
|
||||
tar.extractall(data_dir)
|
||||
tar_path.unlink()
|
||||
|
||||
def get_dataset(self, split):
|
||||
assert split in ["val", "test"]
|
||||
return self
|
||||
|
||||
def _read_image(self, seq: str, idx: int) -> dict:
|
||||
img = load_image(self.root / seq / f"{idx}.ppm", self.conf.grayscale)
|
||||
return self.preprocessor(img)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
seq, q_idx, is_illu = self.items[idx]
|
||||
data0 = self._read_image(seq, 1)
|
||||
data1 = self._read_image(seq, q_idx)
|
||||
H = read_homography(self.root / seq / f"H_1_{q_idx}")
|
||||
H = data1["transform"] @ H @ np.linalg.inv(data0["transform"])
|
||||
return {
|
||||
"H_0to1": H.astype(np.float32),
|
||||
"scene": seq,
|
||||
"idx": idx,
|
||||
"is_illu": is_illu,
|
||||
"name": f"{seq}/{idx}.ppm",
|
||||
"view0": data0,
|
||||
"view1": data1,
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
|
||||
def visualize(args):
|
||||
conf = {
|
||||
"batch_size": 1,
|
||||
"num_workers": 8,
|
||||
"prefetch_factor": 1,
|
||||
}
|
||||
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
|
||||
dataset = HPatches(conf)
|
||||
loader = dataset.get_data_loader("test")
|
||||
logger.info("The dataset has %d elements.", len(loader))
|
||||
|
||||
with fork_rng(seed=dataset.conf.seed):
|
||||
images = []
|
||||
for _, data in zip(range(args.num_items), loader):
|
||||
images.append(
|
||||
(data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2))
|
||||
)
|
||||
plot_image_grid(images, dpi=args.dpi)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from .. import logger # overwrite the logger
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num_items", type=int, default=8)
|
||||
parser.add_argument("--dpi", type=int, default=100)
|
||||
parser.add_argument("dotlist", nargs="*")
|
||||
args = parser.parse_intermixed_args()
|
||||
visualize(args)
|
|
@ -0,0 +1,58 @@
|
|||
"""
|
||||
Simply load images from a folder or nested folders (does not have any split).
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import logging
|
||||
import omegaconf
|
||||
|
||||
from .base_dataset import BaseDataset
|
||||
from ..utils.image import load_image, ImagePreprocessor
|
||||
|
||||
|
||||
class ImageFolder(BaseDataset, torch.utils.data.Dataset):
|
||||
default_conf = {
|
||||
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
|
||||
"images": "???",
|
||||
"root_folder": "/",
|
||||
"preprocessing": ImagePreprocessor.default_conf,
|
||||
}
|
||||
|
||||
def _init(self, conf):
|
||||
self.root = conf.root_folder
|
||||
if isinstance(conf.images, str):
|
||||
if not Path(conf.images).is_dir():
|
||||
with open(conf.images, "r") as f:
|
||||
self.images = f.read().rstrip("\n").split("\n")
|
||||
logging.info(f"Found {len(self.images)} images in list file.")
|
||||
else:
|
||||
self.images = []
|
||||
glob = [conf.glob] if isinstance(conf.glob, str) else conf.glob
|
||||
for g in glob:
|
||||
self.images += list(Path(conf.images).glob("**/" + g))
|
||||
if len(self.images) == 0:
|
||||
raise ValueError(
|
||||
f"Could not find any image in folder: {conf.images}."
|
||||
)
|
||||
self.images = [i.relative_to(conf.images) for i in self.images]
|
||||
self.root = conf.images
|
||||
logging.info(f"Found {len(self.images)} images in folder.")
|
||||
elif isinstance(conf.images, omegaconf.listconfig.ListConfig):
|
||||
self.images = conf.images.to_container()
|
||||
else:
|
||||
raise ValueError(conf.images)
|
||||
|
||||
self.preprocessor = ImagePreprocessor(conf.preprocessing)
|
||||
|
||||
def get_dataset(self, split):
|
||||
return self
|
||||
|
||||
def __getitem__(self, idx):
|
||||
path = self.images[idx]
|
||||
img = load_image(path)
|
||||
data = {"name": str(path), **self.preprocessor(img)}
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
|
@ -0,0 +1,99 @@
|
|||
"""
|
||||
Simply load images from a folder or nested folders (does not have any split).
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import numpy as np
|
||||
from .base_dataset import BaseDataset
|
||||
from ..utils.image import load_image, ImagePreprocessor
|
||||
|
||||
from ..settings import DATA_PATH
|
||||
from ..geometry.wrappers import Camera, Pose
|
||||
|
||||
|
||||
def names_to_pair(name0, name1, separator="/"):
|
||||
return separator.join((name0.replace("/", "-"), name1.replace("/", "-")))
|
||||
|
||||
|
||||
def parse_homography(homography_elems) -> Camera:
|
||||
return (
|
||||
np.array([float(x) for x in homography_elems[:9]])
|
||||
.reshape(3, 3)
|
||||
.astype(np.float32)
|
||||
)
|
||||
|
||||
|
||||
def parse_camera(calib_elems) -> Camera:
|
||||
# assert len(calib_list) == 9
|
||||
K = np.array([float(x) for x in calib_elems[:9]]).reshape(3, 3).astype(np.float32)
|
||||
return Camera.from_calibration_matrix(K)
|
||||
|
||||
|
||||
def parse_relative_pose(pose_elems) -> Pose:
|
||||
# assert len(calib_list) == 9
|
||||
R, t = pose_elems[:9], pose_elems[9:12]
|
||||
R = np.array([float(x) for x in R]).reshape(3, 3).astype(np.float32)
|
||||
t = np.array([float(x) for x in t]).astype(np.float32)
|
||||
return Pose.from_Rt(R, t)
|
||||
|
||||
|
||||
class ImagePairs(BaseDataset, torch.utils.data.Dataset):
|
||||
default_conf = {
|
||||
"pairs": "???", # ToDo: add image folder interface
|
||||
"root": "???",
|
||||
"preprocessing": ImagePreprocessor.default_conf,
|
||||
"extra_data": None, # relative_pose, homography
|
||||
}
|
||||
|
||||
def _init(self, conf):
|
||||
pair_f = (
|
||||
Path(conf.pairs) if Path(conf.pairs).exists() else DATA_PATH / conf.pairs
|
||||
)
|
||||
with open(str(pair_f), "r") as f:
|
||||
self.items = [line.rstrip() for line in f]
|
||||
self.preprocessor = ImagePreprocessor(conf.preprocessing)
|
||||
|
||||
def get_dataset(self, split):
|
||||
return self
|
||||
|
||||
def _read_view(self, name):
|
||||
path = DATA_PATH / self.conf.root / name
|
||||
img = load_image(path)
|
||||
return self.preprocessor(img)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
line = self.items[idx]
|
||||
pair_data = line.split(" ")
|
||||
name0, name1 = pair_data[:2]
|
||||
data0 = self._read_view(name0)
|
||||
data1 = self._read_view(name1)
|
||||
|
||||
data = {
|
||||
"view0": data0,
|
||||
"view1": data1,
|
||||
}
|
||||
if self.conf.extra_data == "relative_pose":
|
||||
data["view0"]["camera"] = parse_camera(pair_data[2:11]).scale(
|
||||
data0["scales"]
|
||||
)
|
||||
data["view1"]["camera"] = parse_camera(pair_data[11:20]).scale(
|
||||
data1["scales"]
|
||||
)
|
||||
data["T_0to1"] = parse_relative_pose(pair_data[20:32])
|
||||
elif self.conf.extra_data == "homography":
|
||||
data["H_0to1"] = (
|
||||
data1["transform"]
|
||||
@ parse_homography(pair_data[2:11])
|
||||
@ np.linalg.inv(data0["transform"])
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
self.conf.extra_data is None
|
||||
), f"Unknown extra data format {self.conf.extra_data}"
|
||||
|
||||
data["name"] = names_to_pair(name0, name1)
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
|
@ -0,0 +1,514 @@
|
|||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from collections.abc import Iterable
|
||||
import tarfile
|
||||
import shutil
|
||||
|
||||
import h5py
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from .base_dataset import BaseDataset
|
||||
from .utils import (
|
||||
scale_intrinsics,
|
||||
rotate_intrinsics,
|
||||
rotate_pose_inplane,
|
||||
)
|
||||
from ..geometry.wrappers import Camera, Pose
|
||||
from ..models.cache_loader import CacheLoader
|
||||
from ..utils.tools import fork_rng
|
||||
from ..utils.image import load_image, ImagePreprocessor
|
||||
from ..settings import DATA_PATH
|
||||
from ..visualization.viz2d import plot_image_grid, plot_heatmaps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
scene_lists_path = Path(__file__).parent / "megadepth_scene_lists"
|
||||
|
||||
|
||||
def sample_n(data, num, seed=None):
|
||||
if len(data) > num:
|
||||
selected = np.random.RandomState(seed).choice(len(data), num, replace=False)
|
||||
return data[selected]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
class MegaDepth(BaseDataset):
|
||||
default_conf = {
|
||||
# paths
|
||||
"data_dir": "megadepth/",
|
||||
"depth_subpath": "depth_undistorted/",
|
||||
"image_subpath": "Undistorted_SfM/",
|
||||
"info_dir": "scene_info/", # @TODO: intrinsics problem?
|
||||
# Training
|
||||
"train_split": "train_scenes_clean.txt",
|
||||
"train_num_per_scene": 500,
|
||||
# Validation
|
||||
"val_split": "valid_scenes_clean.txt",
|
||||
"val_num_per_scene": None,
|
||||
"val_pairs": None,
|
||||
# Test
|
||||
"test_split": "test_scenes_clean.txt",
|
||||
"test_num_per_scene": None,
|
||||
"test_pairs": None,
|
||||
# data sampling
|
||||
"views": 2,
|
||||
"min_overlap": 0.3, # only with D2-Net format
|
||||
"max_overlap": 1.0, # only with D2-Net format
|
||||
"num_overlap_bins": 1,
|
||||
"sort_by_overlap": False,
|
||||
"triplet_enforce_overlap": False, # only with views==3
|
||||
# image options
|
||||
"read_depth": True,
|
||||
"read_image": True,
|
||||
"grayscale": False,
|
||||
"preprocessing": ImagePreprocessor.default_conf,
|
||||
"p_rotate": 0.0, # probability to rotate image by +/- 90°
|
||||
"reseed": False,
|
||||
"seed": 0,
|
||||
# features from cache
|
||||
"load_features": {
|
||||
"do": False,
|
||||
**CacheLoader.default_conf,
|
||||
"collate": False,
|
||||
},
|
||||
}
|
||||
|
||||
def _init(self, conf):
|
||||
if not (DATA_PATH / conf.data_dir).exists():
|
||||
logger.info("Downloading the MegaDepth dataset.")
|
||||
self.download()
|
||||
|
||||
def download(self):
|
||||
data_dir = DATA_PATH / self.conf.data_dir
|
||||
tmp_dir = data_dir.parent / "megadepth_tmp"
|
||||
if tmp_dir.exists(): # The previous download failed.
|
||||
shutil.rmtree(tmp_dir)
|
||||
tmp_dir.mkdir(exist_ok=True, parents=True)
|
||||
url_base = "https://cvg-data.inf.ethz.ch/megadepth/"
|
||||
for tar_name, out_name in (
|
||||
("Undistorted_SfM.tar.gz", self.conf.image_subpath),
|
||||
("depth_undistorted.tar.gz", self.conf.depth_subpath),
|
||||
("scene_info.tar.gz", self.conf.info_dir),
|
||||
):
|
||||
tar_path = tmp_dir / tar_name
|
||||
torch.hub.download_url_to_file(url_base + tar_name, tar_path)
|
||||
with tarfile.open(tar_path) as tar:
|
||||
tar.extractall(path=tmp_dir)
|
||||
tar_path.unlink()
|
||||
shutil.move(tmp_dir / tar_name.split(".")[0], tmp_dir / out_name)
|
||||
shutil.move(tmp_dir, data_dir)
|
||||
|
||||
def get_dataset(self, split):
|
||||
assert self.conf.views in [1, 2, 3]
|
||||
if self.conf.views == 3:
|
||||
return _TripletDataset(self.conf, split)
|
||||
else:
|
||||
return _PairDataset(self.conf, split)
|
||||
|
||||
|
||||
class _PairDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, conf, split, load_sample=True):
|
||||
self.root = DATA_PATH / conf.data_dir
|
||||
assert self.root.exists(), self.root
|
||||
self.split = split
|
||||
self.conf = conf
|
||||
|
||||
split_conf = conf[split + "_split"]
|
||||
if isinstance(split_conf, (str, Path)):
|
||||
scenes_path = scene_lists_path / split_conf
|
||||
scenes = scenes_path.read_text().rstrip("\n").split("\n")
|
||||
elif isinstance(split_conf, Iterable):
|
||||
scenes = list(split_conf)
|
||||
else:
|
||||
raise ValueError(f"Unknown split configuration: {split_conf}.")
|
||||
scenes = sorted(set(scenes))
|
||||
|
||||
if conf.load_features.do:
|
||||
self.feature_loader = CacheLoader(conf.load_features)
|
||||
|
||||
self.preprocessor = ImagePreprocessor(conf.preprocessing)
|
||||
|
||||
self.images = {}
|
||||
self.depths = {}
|
||||
self.poses = {}
|
||||
self.intrinsics = {}
|
||||
self.valid = {}
|
||||
|
||||
# load metadata
|
||||
self.info_dir = self.root / self.conf.info_dir
|
||||
self.scenes = []
|
||||
for scene in scenes:
|
||||
path = self.info_dir / (scene + ".npz")
|
||||
try:
|
||||
info = np.load(str(path), allow_pickle=True)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Cannot load scene info for scene %s at %s.", scene, path
|
||||
)
|
||||
continue
|
||||
self.images[scene] = info["image_paths"]
|
||||
self.depths[scene] = info["depth_paths"]
|
||||
self.poses[scene] = info["poses"]
|
||||
self.intrinsics[scene] = info["intrinsics"]
|
||||
self.scenes.append(scene)
|
||||
|
||||
if load_sample:
|
||||
self.sample_new_items(conf.seed)
|
||||
assert len(self.items) > 0
|
||||
|
||||
def sample_new_items(self, seed):
|
||||
logger.info("Sampling new %s data with seed %d.", self.split, seed)
|
||||
self.items = []
|
||||
split = self.split
|
||||
num_per_scene = self.conf[self.split + "_num_per_scene"]
|
||||
if isinstance(num_per_scene, Iterable):
|
||||
num_pos, num_neg = num_per_scene
|
||||
else:
|
||||
num_pos = num_per_scene
|
||||
num_neg = None
|
||||
if split != "train" and self.conf[split + "_pairs"] is not None:
|
||||
# Fixed validation or test pairs
|
||||
assert num_pos is None
|
||||
assert num_neg is None
|
||||
assert self.conf.views == 2
|
||||
pairs_path = scene_lists_path / self.conf[split + "_pairs"]
|
||||
for line in pairs_path.read_text().rstrip("\n").split("\n"):
|
||||
im0, im1 = line.split(" ")
|
||||
scene = im0.split("/")[0]
|
||||
assert im1.split("/")[0] == scene
|
||||
im0, im1 = [self.conf.image_subpath + im for im in [im0, im1]]
|
||||
assert im0 in self.images[scene]
|
||||
assert im1 in self.images[scene]
|
||||
idx0 = np.where(self.images[scene] == im0)[0][0]
|
||||
idx1 = np.where(self.images[scene] == im1)[0][0]
|
||||
self.items.append((scene, idx0, idx1, 1.0))
|
||||
elif self.conf.views == 1:
|
||||
for scene in self.scenes:
|
||||
if scene not in self.images:
|
||||
continue
|
||||
valid = (self.images[scene] != None) | ( # noqa: E711
|
||||
self.depths[scene] != None # noqa: E711
|
||||
)
|
||||
ids = np.where(valid)[0]
|
||||
if num_pos and len(ids) > num_pos:
|
||||
ids = np.random.RandomState(seed).choice(
|
||||
ids, num_pos, replace=False
|
||||
)
|
||||
ids = [(scene, i) for i in ids]
|
||||
self.items.extend(ids)
|
||||
else:
|
||||
for scene in self.scenes:
|
||||
path = self.info_dir / (scene + ".npz")
|
||||
assert path.exists(), path
|
||||
info = np.load(str(path), allow_pickle=True)
|
||||
valid = (self.images[scene] != None) & ( # noqa: E711
|
||||
self.depths[scene] != None # noqa: E711
|
||||
)
|
||||
ind = np.where(valid)[0]
|
||||
mat = info["overlap_matrix"][valid][:, valid]
|
||||
|
||||
if num_pos is not None:
|
||||
# Sample a subset of pairs, binned by overlap.
|
||||
num_bins = self.conf.num_overlap_bins
|
||||
assert num_bins > 0
|
||||
bin_width = (
|
||||
self.conf.max_overlap - self.conf.min_overlap
|
||||
) / num_bins
|
||||
num_per_bin = num_pos // num_bins
|
||||
pairs_all = []
|
||||
for k in range(num_bins):
|
||||
bin_min = self.conf.min_overlap + k * bin_width
|
||||
bin_max = bin_min + bin_width
|
||||
pairs_bin = (mat > bin_min) & (mat <= bin_max)
|
||||
pairs_bin = np.stack(np.where(pairs_bin), -1)
|
||||
pairs_all.append(pairs_bin)
|
||||
# Skip bins with too few samples
|
||||
has_enough_samples = [len(p) >= num_per_bin * 2 for p in pairs_all]
|
||||
num_per_bin_2 = num_pos // max(1, sum(has_enough_samples))
|
||||
pairs = []
|
||||
for pairs_bin, keep in zip(pairs_all, has_enough_samples):
|
||||
if keep:
|
||||
pairs.append(sample_n(pairs_bin, num_per_bin_2, seed))
|
||||
pairs = np.concatenate(pairs, 0)
|
||||
else:
|
||||
pairs = (mat > self.conf.min_overlap) & (
|
||||
mat <= self.conf.max_overlap
|
||||
)
|
||||
pairs = np.stack(np.where(pairs), -1)
|
||||
|
||||
pairs = [(scene, ind[i], ind[j], mat[i, j]) for i, j in pairs]
|
||||
if num_neg is not None:
|
||||
neg_pairs = np.stack(np.where(mat <= 0.0), -1)
|
||||
neg_pairs = sample_n(neg_pairs, num_neg, seed)
|
||||
pairs += [(scene, ind[i], ind[j], mat[i, j]) for i, j in neg_pairs]
|
||||
self.items.extend(pairs)
|
||||
if self.conf.views == 2 and self.conf.sort_by_overlap:
|
||||
self.items.sort(key=lambda i: i[-1], reverse=True)
|
||||
else:
|
||||
np.random.RandomState(seed).shuffle(self.items)
|
||||
|
||||
def _read_view(self, scene, idx):
|
||||
path = self.root / self.images[scene][idx]
|
||||
|
||||
# read pose data
|
||||
K = self.intrinsics[scene][idx].astype(np.float32, copy=False)
|
||||
T = self.poses[scene][idx].astype(np.float32, copy=False)
|
||||
|
||||
# read image
|
||||
if self.conf.read_image:
|
||||
img = load_image(self.root / self.images[scene][idx], self.conf.grayscale)
|
||||
else:
|
||||
size = PIL.Image.open(path).size[::-1]
|
||||
img = torch.zeros(
|
||||
[3 - 2 * int(self.conf.grayscale), size[0], size[1]]
|
||||
).float()
|
||||
|
||||
# read depth
|
||||
if self.conf.read_depth:
|
||||
depth_path = (
|
||||
self.root / self.conf.depth_subpath / scene / (path.stem + ".h5")
|
||||
)
|
||||
with h5py.File(str(depth_path), "r") as f:
|
||||
depth = f["/depth"].__array__().astype(np.float32, copy=False)
|
||||
depth = torch.Tensor(depth)[None]
|
||||
assert depth.shape[-2:] == img.shape[-2:]
|
||||
else:
|
||||
depth = None
|
||||
|
||||
# add random rotations
|
||||
do_rotate = self.conf.p_rotate > 0.0 and self.split == "train"
|
||||
if do_rotate:
|
||||
p = self.conf.p_rotate
|
||||
k = 0
|
||||
if np.random.rand() < p:
|
||||
k = np.random.choice(2, 1, replace=False)[0] * 2 - 1
|
||||
img = np.rot90(img, k=-k, axes=(-2, -1))
|
||||
if self.conf.read_depth:
|
||||
depth = np.rot90(depth, k=-k, axes=(-2, -1)).copy()
|
||||
K = rotate_intrinsics(K, img.shape, k + 2)
|
||||
T = rotate_pose_inplane(T, k + 2)
|
||||
|
||||
name = path.name
|
||||
|
||||
data = self.preprocessor(img)
|
||||
if depth is not None:
|
||||
data["depth"] = self.preprocessor(depth, interpolation="nearest")["image"][
|
||||
0
|
||||
]
|
||||
K = scale_intrinsics(K, data["scales"])
|
||||
|
||||
data = {
|
||||
"name": name,
|
||||
"scene": scene,
|
||||
"T_w2cam": Pose.from_4x4mat(T),
|
||||
"depth": depth,
|
||||
"camera": Camera.from_calibration_matrix(K).float(),
|
||||
**data,
|
||||
}
|
||||
|
||||
if self.conf.load_features.do:
|
||||
features = self.feature_loader({k: [v] for k, v in data.items()})
|
||||
if do_rotate and k != 0:
|
||||
# ang = np.deg2rad(k * 90.)
|
||||
kpts = features["keypoints"].copy()
|
||||
x, y = kpts[:, 0].copy(), kpts[:, 1].copy()
|
||||
w, h = data["image_size"]
|
||||
if k == 1:
|
||||
kpts[:, 0] = w - y
|
||||
kpts[:, 1] = x
|
||||
elif k == -1:
|
||||
kpts[:, 0] = y
|
||||
kpts[:, 1] = h - x
|
||||
|
||||
else:
|
||||
raise ValueError
|
||||
features["keypoints"] = kpts
|
||||
|
||||
data = {"cache": features, **data}
|
||||
return data
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.conf.reseed:
|
||||
with fork_rng(self.conf.seed + idx, False):
|
||||
return self.getitem(idx)
|
||||
else:
|
||||
return self.getitem(idx)
|
||||
|
||||
def getitem(self, idx):
|
||||
if self.conf.views == 2:
|
||||
if isinstance(idx, list):
|
||||
scene, idx0, idx1, overlap = idx
|
||||
else:
|
||||
scene, idx0, idx1, overlap = self.items[idx]
|
||||
data0 = self._read_view(scene, idx0)
|
||||
data1 = self._read_view(scene, idx1)
|
||||
data = {
|
||||
"view0": data0,
|
||||
"view1": data1,
|
||||
}
|
||||
data["T_0to1"] = data1["T_w2cam"] @ data0["T_w2cam"].inv()
|
||||
data["T_1to0"] = data0["T_w2cam"] @ data1["T_w2cam"].inv()
|
||||
data["overlap_0to1"] = overlap
|
||||
data["name"] = f"{scene}/{data0['name']}_{data1['name']}"
|
||||
else:
|
||||
assert self.conf.views == 1
|
||||
scene, idx0 = self.items[idx]
|
||||
data = self._read_view(scene, idx0)
|
||||
data["scene"] = scene
|
||||
data["idx"] = idx
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
|
||||
class _TripletDataset(_PairDataset):
|
||||
def sample_new_items(self, seed):
|
||||
logging.info("Sampling new triplets with seed %d", seed)
|
||||
self.items = []
|
||||
split = self.split
|
||||
num = self.conf[self.split + "_num_per_scene"]
|
||||
if split != "train" and self.conf[split + "_pairs"] is not None:
|
||||
if Path(self.conf[split + "_pairs"]).exists():
|
||||
pairs_path = Path(self.conf[split + "_pairs"])
|
||||
else:
|
||||
pairs_path = DATA_PATH / "configs" / self.conf[split + "_pairs"]
|
||||
for line in pairs_path.read_text().rstrip("\n").split("\n"):
|
||||
im0, im1, im2 = line.split(" ")
|
||||
assert im0[:4] == im1[:4]
|
||||
scene = im1[:4]
|
||||
idx0 = np.where(self.images[scene] == im0)
|
||||
idx1 = np.where(self.images[scene] == im1)
|
||||
idx2 = np.where(self.images[scene] == im2)
|
||||
self.items.append((scene, idx0, idx1, idx2, 1.0, 1.0, 1.0))
|
||||
else:
|
||||
for scene in self.scenes:
|
||||
path = self.info_dir / (scene + ".npz")
|
||||
assert path.exists(), path
|
||||
info = np.load(str(path), allow_pickle=True)
|
||||
if self.conf.num_overlap_bins > 1:
|
||||
raise NotImplementedError("TODO")
|
||||
valid = (self.images[scene] != None) & ( # noqa: E711
|
||||
self.depth[scene] != None # noqa: E711
|
||||
)
|
||||
ind = np.where(valid)[0]
|
||||
mat = info["overlap_matrix"][valid][:, valid]
|
||||
good = (mat > self.conf.min_overlap) & (mat <= self.conf.max_overlap)
|
||||
triplets = []
|
||||
if self.conf.triplet_enforce_overlap:
|
||||
pairs = np.stack(np.where(good), -1)
|
||||
for i0, i1 in pairs:
|
||||
for i2 in pairs[pairs[:, 0] == i0, 1]:
|
||||
if good[i1, i2]:
|
||||
triplets.append((i0, i1, i2))
|
||||
if len(triplets) > num:
|
||||
selected = np.random.RandomState(seed).choice(
|
||||
len(triplets), num, replace=False
|
||||
)
|
||||
selected = range(num)
|
||||
triplets = np.array(triplets)[selected]
|
||||
else:
|
||||
# we first enforce that each row has >1 pairs
|
||||
non_unique = good.sum(-1) > 1
|
||||
ind_r = np.where(non_unique)[0]
|
||||
good = good[non_unique]
|
||||
pairs = np.stack(np.where(good), -1)
|
||||
if len(pairs) > num:
|
||||
selected = np.random.RandomState(seed).choice(
|
||||
len(pairs), num, replace=False
|
||||
)
|
||||
pairs = pairs[selected]
|
||||
for idx, (k, i) in enumerate(pairs):
|
||||
# We now sample a j from row k s.t. i != j
|
||||
possible_j = np.where(good[k])[0]
|
||||
possible_j = possible_j[possible_j != i]
|
||||
selected = np.random.RandomState(seed + idx).choice(
|
||||
len(possible_j), 1, replace=False
|
||||
)[0]
|
||||
triplets.append((ind_r[k], i, possible_j[selected]))
|
||||
triplets = [
|
||||
(scene, ind[k], ind[i], ind[j], mat[k, i], mat[k, j], mat[i, j])
|
||||
for k, i, j in triplets
|
||||
]
|
||||
self.items.extend(triplets)
|
||||
np.random.RandomState(seed).shuffle(self.items)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
scene, idx0, idx1, idx2, overlap01, overlap02, overlap12 = self.items[idx]
|
||||
data0 = self._read_view(scene, idx0)
|
||||
data1 = self._read_view(scene, idx1)
|
||||
data2 = self._read_view(scene, idx2)
|
||||
data = {
|
||||
"view0": data0,
|
||||
"view1": data1,
|
||||
"view2": data2,
|
||||
}
|
||||
data["T_0to1"] = data1["T_w2cam"] @ data0["T_w2cam"].inv()
|
||||
data["T_0to2"] = data2["T_w2cam"] @ data0["T_w2cam"].inv()
|
||||
data["T_1to2"] = data2["T_w2cam"] @ data1["T_w2cam"].inv()
|
||||
data["T_1to0"] = data0["T_w2cam"] @ data1["T_w2cam"].inv()
|
||||
data["T_2to0"] = data0["T_w2cam"] @ data2["T_w2cam"].inv()
|
||||
data["T_2to1"] = data1["T_w2cam"] @ data2["T_w2cam"].inv()
|
||||
|
||||
data["overlap_0to1"] = overlap01
|
||||
data["overlap_0to2"] = overlap02
|
||||
data["overlap_1to2"] = overlap12
|
||||
data["scene"] = scene
|
||||
data["name"] = f"{scene}/{data0['name']}_{data1['name']}_{data2['name']}"
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
|
||||
def visualize(args):
|
||||
conf = {
|
||||
"min_overlap": 0.1,
|
||||
"max_overlap": 0.7,
|
||||
"num_overlap_bins": 3,
|
||||
"sort_by_overlap": False,
|
||||
"train_num_per_scene": 5,
|
||||
"batch_size": 1,
|
||||
"num_workers": 0,
|
||||
"prefetch_factor": None,
|
||||
"val_num_per_scene": None,
|
||||
}
|
||||
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
|
||||
dataset = MegaDepth(conf)
|
||||
loader = dataset.get_data_loader(args.split)
|
||||
logger.info("The dataset has elements.", len(loader))
|
||||
|
||||
with fork_rng(seed=dataset.conf.seed):
|
||||
images, depths = [], []
|
||||
for _, data in zip(range(args.num_items), loader):
|
||||
images.append(
|
||||
[
|
||||
data[f"view{i}"]["image"][0].permute(1, 2, 0)
|
||||
for i in range(dataset.conf.views)
|
||||
]
|
||||
)
|
||||
depths.append(
|
||||
[data[f"view{i}"]["depth"][0] for i in range(dataset.conf.views)]
|
||||
)
|
||||
|
||||
axes = plot_image_grid(images, dpi=args.dpi)
|
||||
for i in range(len(images)):
|
||||
plot_heatmaps(depths[i], axes=axes[i])
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from .. import logger # overwrite the logger
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--split", type=str, default="val")
|
||||
parser.add_argument("--num_items", type=int, default=4)
|
||||
parser.add_argument("--dpi", type=int, default=100)
|
||||
parser.add_argument("dotlist", nargs="*")
|
||||
args = parser.parse_intermixed_args()
|
||||
visualize(args)
|
|
@ -0,0 +1,8 @@
|
|||
0008
|
||||
0019
|
||||
0021
|
||||
0024
|
||||
0025
|
||||
0032
|
||||
0063
|
||||
1589
|
|
@ -0,0 +1,118 @@
|
|||
0000
|
||||
0001
|
||||
0002
|
||||
0003
|
||||
0004
|
||||
0005
|
||||
0007
|
||||
0008
|
||||
0011
|
||||
0012
|
||||
0013
|
||||
0015
|
||||
0017
|
||||
0019
|
||||
0020
|
||||
0021
|
||||
0022
|
||||
0023
|
||||
0024
|
||||
0025
|
||||
0026
|
||||
0027
|
||||
0032
|
||||
0035
|
||||
0036
|
||||
0037
|
||||
0039
|
||||
0042
|
||||
0043
|
||||
0046
|
||||
0048
|
||||
0050
|
||||
0056
|
||||
0057
|
||||
0060
|
||||
0061
|
||||
0063
|
||||
0065
|
||||
0070
|
||||
0080
|
||||
0083
|
||||
0086
|
||||
0087
|
||||
0092
|
||||
0095
|
||||
0098
|
||||
0100
|
||||
0101
|
||||
0103
|
||||
0104
|
||||
0105
|
||||
0107
|
||||
0115
|
||||
0117
|
||||
0122
|
||||
0130
|
||||
0137
|
||||
0143
|
||||
0147
|
||||
0148
|
||||
0149
|
||||
0150
|
||||
0156
|
||||
0160
|
||||
0176
|
||||
0183
|
||||
0189
|
||||
0190
|
||||
0200
|
||||
0214
|
||||
0224
|
||||
0235
|
||||
0237
|
||||
0240
|
||||
0243
|
||||
0258
|
||||
0265
|
||||
0269
|
||||
0299
|
||||
0312
|
||||
0326
|
||||
0327
|
||||
0331
|
||||
0335
|
||||
0341
|
||||
0348
|
||||
0366
|
||||
0377
|
||||
0380
|
||||
0394
|
||||
0407
|
||||
0411
|
||||
0430
|
||||
0446
|
||||
0455
|
||||
0472
|
||||
0474
|
||||
0476
|
||||
0478
|
||||
0493
|
||||
0494
|
||||
0496
|
||||
0505
|
||||
0559
|
||||
0733
|
||||
0860
|
||||
1017
|
||||
1589
|
||||
4541
|
||||
5004
|
||||
5005
|
||||
5006
|
||||
5007
|
||||
5009
|
||||
5010
|
||||
5012
|
||||
5013
|
||||
5017
|
|
@ -0,0 +1,153 @@
|
|||
0001
|
||||
0003
|
||||
0004
|
||||
0005
|
||||
0007
|
||||
0012
|
||||
0013
|
||||
0016
|
||||
0017
|
||||
0023
|
||||
0026
|
||||
0027
|
||||
0034
|
||||
0035
|
||||
0036
|
||||
0037
|
||||
0039
|
||||
0041
|
||||
0042
|
||||
0043
|
||||
0044
|
||||
0046
|
||||
0047
|
||||
0048
|
||||
0049
|
||||
0056
|
||||
0057
|
||||
0058
|
||||
0060
|
||||
0061
|
||||
0062
|
||||
0064
|
||||
0065
|
||||
0067
|
||||
0070
|
||||
0071
|
||||
0076
|
||||
0078
|
||||
0080
|
||||
0083
|
||||
0086
|
||||
0087
|
||||
0090
|
||||
0094
|
||||
0095
|
||||
0098
|
||||
0099
|
||||
0100
|
||||
0101
|
||||
0102
|
||||
0104
|
||||
0107
|
||||
0115
|
||||
0117
|
||||
0122
|
||||
0129
|
||||
0130
|
||||
0137
|
||||
0141
|
||||
0147
|
||||
0148
|
||||
0149
|
||||
0150
|
||||
0151
|
||||
0156
|
||||
0160
|
||||
0162
|
||||
0175
|
||||
0181
|
||||
0183
|
||||
0185
|
||||
0186
|
||||
0189
|
||||
0190
|
||||
0197
|
||||
0200
|
||||
0204
|
||||
0205
|
||||
0212
|
||||
0214
|
||||
0217
|
||||
0223
|
||||
0224
|
||||
0231
|
||||
0235
|
||||
0237
|
||||
0238
|
||||
0240
|
||||
0243
|
||||
0252
|
||||
0257
|
||||
0258
|
||||
0269
|
||||
0271
|
||||
0275
|
||||
0277
|
||||
0281
|
||||
0285
|
||||
0286
|
||||
0290
|
||||
0294
|
||||
0299
|
||||
0303
|
||||
0306
|
||||
0307
|
||||
0312
|
||||
0323
|
||||
0326
|
||||
0327
|
||||
0331
|
||||
0335
|
||||
0341
|
||||
0348
|
||||
0360
|
||||
0377
|
||||
0380
|
||||
0387
|
||||
0389
|
||||
0394
|
||||
0402
|
||||
0406
|
||||
0407
|
||||
0411
|
||||
0446
|
||||
0455
|
||||
0472
|
||||
0476
|
||||
0478
|
||||
0482
|
||||
0493
|
||||
0496
|
||||
0505
|
||||
0559
|
||||
0733
|
||||
0768
|
||||
1017
|
||||
3346
|
||||
5000
|
||||
5001
|
||||
5002
|
||||
5003
|
||||
5004
|
||||
5005
|
||||
5006
|
||||
5007
|
||||
5008
|
||||
5009
|
||||
5010
|
||||
5011
|
||||
5012
|
||||
5013
|
||||
5017
|
||||
5018
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,77 @@
|
|||
0016
|
||||
0033
|
||||
0034
|
||||
0041
|
||||
0044
|
||||
0047
|
||||
0049
|
||||
0058
|
||||
0062
|
||||
0064
|
||||
0067
|
||||
0071
|
||||
0076
|
||||
0078
|
||||
0090
|
||||
0094
|
||||
0099
|
||||
0102
|
||||
0121
|
||||
0129
|
||||
0133
|
||||
0141
|
||||
0151
|
||||
0162
|
||||
0168
|
||||
0175
|
||||
0177
|
||||
0178
|
||||
0181
|
||||
0185
|
||||
0186
|
||||
0197
|
||||
0204
|
||||
0205
|
||||
0209
|
||||
0212
|
||||
0217
|
||||
0223
|
||||
0229
|
||||
0231
|
||||
0238
|
||||
0252
|
||||
0257
|
||||
0271
|
||||
0275
|
||||
0277
|
||||
0281
|
||||
0285
|
||||
0286
|
||||
0290
|
||||
0294
|
||||
0303
|
||||
0306
|
||||
0307
|
||||
0323
|
||||
0349
|
||||
0360
|
||||
0387
|
||||
0389
|
||||
0402
|
||||
0406
|
||||
0412
|
||||
0443
|
||||
0482
|
||||
0768
|
||||
1001
|
||||
3346
|
||||
5000
|
||||
5001
|
||||
5002
|
||||
5003
|
||||
5008
|
||||
5011
|
||||
5014
|
||||
5015
|
||||
5016
|
||||
5018
|
|
@ -0,0 +1,2 @@
|
|||
0015
|
||||
0022
|
|
@ -0,0 +1,131 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def read_image(path, grayscale=False):
|
||||
"""Read an image from path as RGB or grayscale"""
|
||||
mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
|
||||
image = cv2.imread(str(path), mode)
|
||||
if image is None:
|
||||
raise IOError(f"Could not read image at {path}.")
|
||||
if not grayscale:
|
||||
image = image[..., ::-1]
|
||||
return image
|
||||
|
||||
|
||||
def numpy_image_to_torch(image):
|
||||
"""Normalize the image tensor and reorder the dimensions."""
|
||||
if image.ndim == 3:
|
||||
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
|
||||
elif image.ndim == 2:
|
||||
image = image[None] # add channel axis
|
||||
else:
|
||||
raise ValueError(f"Not an image: {image.shape}")
|
||||
return torch.tensor(image / 255.0, dtype=torch.float)
|
||||
|
||||
|
||||
def rotate_intrinsics(K, image_shape, rot):
|
||||
"""image_shape is the shape of the image after rotation"""
|
||||
assert rot <= 3
|
||||
h, w = image_shape[:2][:: -1 if (rot % 2) else 1]
|
||||
fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
|
||||
rot = rot % 4
|
||||
if rot == 1:
|
||||
return np.array(
|
||||
[[fy, 0.0, cy], [0.0, fx, w - cx], [0.0, 0.0, 1.0]], dtype=K.dtype
|
||||
)
|
||||
elif rot == 2:
|
||||
return np.array(
|
||||
[[fx, 0.0, w - cx], [0.0, fy, h - cy], [0.0, 0.0, 1.0]],
|
||||
dtype=K.dtype,
|
||||
)
|
||||
else: # if rot == 3:
|
||||
return np.array(
|
||||
[[fy, 0.0, h - cy], [0.0, fx, cx], [0.0, 0.0, 1.0]], dtype=K.dtype
|
||||
)
|
||||
|
||||
|
||||
def rotate_pose_inplane(i_T_w, rot):
|
||||
rotation_matrices = [
|
||||
np.array(
|
||||
[
|
||||
[np.cos(r), -np.sin(r), 0.0, 0.0],
|
||||
[np.sin(r), np.cos(r), 0.0, 0.0],
|
||||
[0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 1.0],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
|
||||
]
|
||||
return np.dot(rotation_matrices[rot], i_T_w)
|
||||
|
||||
|
||||
def scale_intrinsics(K, scales):
|
||||
"""Scale intrinsics after resizing the corresponding image."""
|
||||
scales = np.diag(np.concatenate([scales, [1.0]]))
|
||||
return np.dot(scales.astype(K.dtype, copy=False), K)
|
||||
|
||||
|
||||
def get_divisible_wh(w, h, df=None):
|
||||
if df is not None:
|
||||
w_new, h_new = map(lambda x: int(x // df * df), [w, h])
|
||||
else:
|
||||
w_new, h_new = w, h
|
||||
return w_new, h_new
|
||||
|
||||
|
||||
def resize(image, size, fn=None, interp="linear", df=None):
|
||||
"""Resize an image to a fixed size, or according to max or min edge."""
|
||||
h, w = image.shape[:2]
|
||||
if isinstance(size, int):
|
||||
scale = size / fn(h, w)
|
||||
h_new, w_new = int(round(h * scale)), int(round(w * scale))
|
||||
w_new, h_new = get_divisible_wh(w_new, h_new, df)
|
||||
scale = (w_new / w, h_new / h)
|
||||
elif isinstance(size, (tuple, list)):
|
||||
h_new, w_new = size
|
||||
scale = (w_new / w, h_new / h)
|
||||
else:
|
||||
raise ValueError(f"Incorrect new size: {size}")
|
||||
mode = {
|
||||
"linear": cv2.INTER_LINEAR,
|
||||
"cubic": cv2.INTER_CUBIC,
|
||||
"nearest": cv2.INTER_NEAREST,
|
||||
"area": cv2.INTER_AREA,
|
||||
}[interp]
|
||||
return cv2.resize(image, (w_new, h_new), interpolation=mode), scale
|
||||
|
||||
|
||||
def crop(image, size, random=True, other=None, K=None, return_bbox=False):
|
||||
"""Random or deterministic crop of an image, adjust depth and intrinsics."""
|
||||
h, w = image.shape[:2]
|
||||
h_new, w_new = (size, size) if isinstance(size, int) else size
|
||||
top = np.random.randint(0, h - h_new + 1) if random else 0
|
||||
left = np.random.randint(0, w - w_new + 1) if random else 0
|
||||
image = image[top : top + h_new, left : left + w_new]
|
||||
ret = [image]
|
||||
if other is not None:
|
||||
ret += [other[top : top + h_new, left : left + w_new]]
|
||||
if K is not None:
|
||||
K[0, 2] -= left
|
||||
K[1, 2] -= top
|
||||
ret += [K]
|
||||
if return_bbox:
|
||||
ret += [(top, top + h_new, left, left + w_new)]
|
||||
return ret
|
||||
|
||||
|
||||
def zero_pad(size, *images):
|
||||
"""zero pad images to size x size"""
|
||||
ret = []
|
||||
for image in images:
|
||||
if image is None:
|
||||
ret.append(None)
|
||||
continue
|
||||
h, w = image.shape[:2]
|
||||
padded = np.zeros((size, size) + image.shape[2:], dtype=image.dtype)
|
||||
padded[:h, :w] = image
|
||||
ret.append(padded)
|
||||
return ret
|
|
@ -0,0 +1,19 @@
|
|||
import torch
|
||||
from ..utils.tools import get_class
|
||||
from .eval_pipeline import EvalPipeline
|
||||
|
||||
|
||||
def get_benchmark(benchmark):
|
||||
return get_class(f"{__name__}.{benchmark}", EvalPipeline)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_benchmark(benchmark, eval_conf, experiment_dir, model=None):
|
||||
"""This overwrites existing benchmarks"""
|
||||
experiment_dir.mkdir(exist_ok=True, parents=True)
|
||||
bm = get_benchmark(benchmark)
|
||||
|
||||
pipeline = bm(eval_conf)
|
||||
return pipeline.run(
|
||||
experiment_dir, model=model, overwrite=True, overwrite_eval=True
|
||||
)
|
|
@ -0,0 +1,215 @@
|
|||
import torch
|
||||
from pathlib import Path
|
||||
from omegaconf import OmegaConf
|
||||
import matplotlib.pyplot as plt
|
||||
import resource
|
||||
from collections import defaultdict
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
from .io import (
|
||||
parse_eval_args,
|
||||
load_model,
|
||||
get_eval_parser,
|
||||
)
|
||||
|
||||
from .eval_pipeline import EvalPipeline, load_eval
|
||||
|
||||
from ..utils.export_predictions import export_predictions
|
||||
from .utils import get_tp_fp_pts, aggregate_pr_results
|
||||
from ..settings import EVAL_PATH
|
||||
from ..models.cache_loader import CacheLoader
|
||||
from ..datasets import get_dataset
|
||||
|
||||
|
||||
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def eval_dataset(loader, pred_file, suffix=""):
|
||||
results = defaultdict(list)
|
||||
results["num_pos" + suffix] = 0
|
||||
cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
|
||||
for data in tqdm(loader):
|
||||
pred = cache_loader(data)
|
||||
|
||||
if suffix == "":
|
||||
scores = pred["matching_scores0"].numpy()
|
||||
sort_indices = np.argsort(scores)[::-1]
|
||||
gt_matches = pred["gt_matches0"].numpy()[sort_indices]
|
||||
pred_matches = pred["matches0"].numpy()[sort_indices]
|
||||
else:
|
||||
scores = pred["line_matching_scores0"].numpy()
|
||||
sort_indices = np.argsort(scores)[::-1]
|
||||
gt_matches = pred["gt_line_matches0"].numpy()[sort_indices]
|
||||
pred_matches = pred["line_matches0"].numpy()[sort_indices]
|
||||
scores = scores[sort_indices]
|
||||
|
||||
tp, fp, scores, num_pos = get_tp_fp_pts(pred_matches, gt_matches, scores)
|
||||
results["tp" + suffix].append(tp)
|
||||
results["fp" + suffix].append(fp)
|
||||
results["scores" + suffix].append(scores)
|
||||
results["num_pos" + suffix] += num_pos
|
||||
|
||||
# Aggregate the results
|
||||
return aggregate_pr_results(results, suffix=suffix)
|
||||
|
||||
|
||||
class ETH3DPipeline(EvalPipeline):
|
||||
default_conf = {
|
||||
"data": {
|
||||
"name": "eth3d",
|
||||
"batch_size": 1,
|
||||
"train_batch_size": 1,
|
||||
"val_batch_size": 1,
|
||||
"test_batch_size": 1,
|
||||
"num_workers": 16,
|
||||
},
|
||||
"model": {
|
||||
"name": "gluefactory.models.two_view_pipeline",
|
||||
"ground_truth": {
|
||||
"name": "gluefactory.models.matchers.depth_matcher",
|
||||
"use_lines": False,
|
||||
},
|
||||
"run_gt_in_forward": True,
|
||||
},
|
||||
"eval": {"plot_methods": [], "plot_line_methods": [], "eval_lines": False},
|
||||
}
|
||||
|
||||
export_keys = [
|
||||
"gt_matches0",
|
||||
"matches0",
|
||||
"matching_scores0",
|
||||
]
|
||||
|
||||
optional_export_keys = [
|
||||
"gt_line_matches0",
|
||||
"line_matches0",
|
||||
"line_matching_scores0",
|
||||
]
|
||||
|
||||
def get_dataloader(self, data_conf=None):
|
||||
data_conf = data_conf if data_conf is not None else self.default_conf["data"]
|
||||
dataset = get_dataset("eth3d")(data_conf)
|
||||
return dataset.get_data_loader("test")
|
||||
|
||||
def get_predictions(self, experiment_dir, model=None, overwrite=False):
|
||||
pred_file = experiment_dir / "predictions.h5"
|
||||
if not pred_file.exists() or overwrite:
|
||||
if model is None:
|
||||
model = load_model(self.conf.model, self.conf.checkpoint)
|
||||
export_predictions(
|
||||
self.get_dataloader(self.conf.data),
|
||||
model,
|
||||
pred_file,
|
||||
keys=self.export_keys,
|
||||
optional_keys=self.optional_export_keys,
|
||||
)
|
||||
return pred_file
|
||||
|
||||
def run_eval(self, loader, pred_file):
|
||||
eval_conf = self.conf.eval
|
||||
r = eval_dataset(loader, pred_file)
|
||||
if self.conf.eval.eval_lines:
|
||||
r.update(eval_dataset(loader, pred_file, conf=eval_conf, suffix="_lines"))
|
||||
s = {}
|
||||
|
||||
return s, {}, r
|
||||
|
||||
|
||||
def plot_pr_curve(
|
||||
models_name, results, dst_file="eth3d_pr_curve.pdf", title=None, suffix=""
|
||||
):
|
||||
plt.figure()
|
||||
f_scores = np.linspace(0.2, 0.9, num=8)
|
||||
for f_score in f_scores:
|
||||
x = np.linspace(0.01, 1)
|
||||
y = f_score * x / (2 * x - f_score)
|
||||
plt.plot(x[y >= 0], y[y >= 0], color=[0, 0.5, 0], alpha=0.3)
|
||||
plt.annotate(
|
||||
"f={0:0.1}".format(f_score),
|
||||
xy=(0.9, y[45] + 0.02),
|
||||
alpha=0.4,
|
||||
fontsize=14,
|
||||
)
|
||||
|
||||
plt.rcParams.update({"font.size": 12})
|
||||
# plt.rc('legend', fontsize=10)
|
||||
plt.grid(True)
|
||||
plt.axis([0.0, 1.0, 0.0, 1.0])
|
||||
plt.xticks(np.arange(0, 1.05, step=0.1), fontsize=16)
|
||||
plt.xlabel("Recall", fontsize=18)
|
||||
plt.ylabel("Precision", fontsize=18)
|
||||
plt.yticks(np.arange(0, 1.05, step=0.1), fontsize=16)
|
||||
plt.ylim([0.3, 1.0])
|
||||
prop_cycle = plt.rcParams["axes.prop_cycle"]
|
||||
colors = prop_cycle.by_key()["color"]
|
||||
for m, c in zip(models_name, colors):
|
||||
sAP_string = f'{m}: {results[m]["AP" + suffix]:.1f}'
|
||||
plt.plot(
|
||||
results[m]["curve_recall" + suffix],
|
||||
results[m]["curve_precision" + suffix],
|
||||
label=sAP_string,
|
||||
color=c,
|
||||
)
|
||||
|
||||
plt.legend(fontsize=16, loc="lower right")
|
||||
if title:
|
||||
plt.title(title)
|
||||
|
||||
plt.tight_layout(pad=0.5)
|
||||
print(f"Saving plot to: {dst_file}")
|
||||
plt.savefig(dst_file)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset_name = Path(__file__).stem
|
||||
parser = get_eval_parser()
|
||||
args = parser.parse_intermixed_args()
|
||||
|
||||
default_conf = OmegaConf.create(ETH3DPipeline.default_conf)
|
||||
|
||||
# mingle paths
|
||||
output_dir = Path(EVAL_PATH, dataset_name)
|
||||
output_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
name, conf = parse_eval_args(
|
||||
dataset_name,
|
||||
args,
|
||||
"configs/",
|
||||
default_conf,
|
||||
)
|
||||
|
||||
experiment_dir = output_dir / name
|
||||
experiment_dir.mkdir(exist_ok=True)
|
||||
|
||||
pipeline = ETH3DPipeline(conf)
|
||||
s, f, r = pipeline.run(
|
||||
experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval
|
||||
)
|
||||
|
||||
# print results
|
||||
for k, v in r.items():
|
||||
if k.startswith("AP"):
|
||||
print(f"{k}: {v:.2f}")
|
||||
|
||||
if args.plot:
|
||||
results = {}
|
||||
for m in conf.eval.plot_methods:
|
||||
exp_dir = output_dir / m
|
||||
results[m] = load_eval(exp_dir)[1]
|
||||
|
||||
plot_pr_curve(conf.eval.plot_methods, results, dst_file="eth3d_pr_curve.pdf")
|
||||
if conf.eval.eval_lines:
|
||||
for m in conf.eval.plot_line_methods:
|
||||
exp_dir = output_dir / m
|
||||
results[m] = load_eval(exp_dir)[1]
|
||||
plot_pr_curve(
|
||||
conf.eval.plot_line_methods,
|
||||
results,
|
||||
dst_file="eth3d_pr_curve_lines.pdf",
|
||||
suffix="_lines",
|
||||
)
|
|
@ -0,0 +1,108 @@
|
|||
from omegaconf import OmegaConf
|
||||
import numpy as np
|
||||
import json
|
||||
import h5py
|
||||
|
||||
|
||||
def load_eval(dir):
|
||||
summaries, results = {}, {}
|
||||
with h5py.File(str(dir / "results.h5"), "r") as hfile:
|
||||
for k in hfile.keys():
|
||||
r = np.array(hfile[k])
|
||||
if len(r.shape) < 3:
|
||||
results[k] = r
|
||||
for k, v in hfile.attrs.items():
|
||||
summaries[k] = v
|
||||
with open(dir / "summaries.json", "r") as f:
|
||||
s = json.load(f)
|
||||
summaries = {k: v if v is not None else np.nan for k, v in s.items()}
|
||||
return summaries, results
|
||||
|
||||
|
||||
def save_eval(dir, summaries, figures, results):
|
||||
with h5py.File(str(dir / "results.h5"), "w") as hfile:
|
||||
for k, v in results.items():
|
||||
arr = np.array(v)
|
||||
if not np.issubdtype(arr.dtype, np.number):
|
||||
arr = arr.astype("object")
|
||||
hfile.create_dataset(k, data=arr)
|
||||
# just to be safe, not used in practice
|
||||
for k, v in summaries.items():
|
||||
hfile.attrs[k] = v
|
||||
s = {
|
||||
k: float(v) if np.isfinite(v) else None
|
||||
for k, v in summaries.items()
|
||||
if not isinstance(v, list)
|
||||
}
|
||||
s = {**s, **{k: v for k, v in summaries.items() if isinstance(v, list)}}
|
||||
with open(dir / "summaries.json", "w") as f:
|
||||
json.dump(s, f, indent=4)
|
||||
|
||||
for fig_name, fig in figures.items():
|
||||
fig.savefig(dir / f"{fig_name}.png")
|
||||
|
||||
|
||||
def exists_eval(dir):
|
||||
return (dir / "results.h5").exists() and (dir / "summaries.json").exists()
|
||||
|
||||
|
||||
class EvalPipeline:
|
||||
default_conf = {}
|
||||
|
||||
export_keys = []
|
||||
optional_export_keys = []
|
||||
|
||||
def __init__(self, conf):
|
||||
"""Assumes"""
|
||||
self.default_conf = OmegaConf.create(self.default_conf)
|
||||
self.conf = OmegaConf.merge(self.default_conf, conf)
|
||||
self._init(self.conf)
|
||||
|
||||
def _init(self, conf):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_dataloader(self, data_conf=None):
|
||||
"""Returns a data loader with samples for each eval datapoint"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_predictions(self, experiment_dir, model=None, overwrite=False):
|
||||
"""Export a prediction file for each eval datapoint"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run_eval(self, loader, pred_file):
|
||||
"""Run the eval on cached predictions"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, experiment_dir, model=None, overwrite=False, overwrite_eval=False):
|
||||
"""Run export+eval loop"""
|
||||
self.save_conf(
|
||||
experiment_dir, overwrite=overwrite, overwrite_eval=overwrite_eval
|
||||
)
|
||||
pred_file = self.get_predictions(
|
||||
experiment_dir, model=model, overwrite=overwrite
|
||||
)
|
||||
|
||||
f = {}
|
||||
if not exists_eval(experiment_dir) or overwrite_eval or overwrite:
|
||||
s, f, r = self.run_eval(self.get_dataloader(), pred_file)
|
||||
save_eval(experiment_dir, s, f, r)
|
||||
s, r = load_eval(experiment_dir)
|
||||
return s, f, r
|
||||
|
||||
def save_conf(self, experiment_dir, overwrite=False, overwrite_eval=False):
|
||||
# store config
|
||||
conf_output_path = experiment_dir / "conf.yaml"
|
||||
if conf_output_path.exists():
|
||||
saved_conf = OmegaConf.load(conf_output_path)
|
||||
if (saved_conf.data != self.conf.data) or (
|
||||
saved_conf.model != self.conf.model
|
||||
):
|
||||
assert (
|
||||
overwrite
|
||||
), "configs changed, add --overwrite to rerun experiment with new conf"
|
||||
if saved_conf.eval != self.conf.eval:
|
||||
assert (
|
||||
overwrite or overwrite_eval
|
||||
), "eval configs changed, add --overwrite_eval to rerun evaluation"
|
||||
OmegaConf.save(self.conf, experiment_dir / "conf.yaml")
|
|
@ -0,0 +1,211 @@
|
|||
import torch
|
||||
from pathlib import Path
|
||||
from omegaconf import OmegaConf
|
||||
from pprint import pprint
|
||||
import matplotlib.pyplot as plt
|
||||
import resource
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from ..visualization.viz2d import plot_cumulative
|
||||
|
||||
from .io import (
|
||||
parse_eval_args,
|
||||
load_model,
|
||||
get_eval_parser,
|
||||
)
|
||||
from ..utils.export_predictions import export_predictions
|
||||
from ..settings import EVAL_PATH
|
||||
from ..models.cache_loader import CacheLoader
|
||||
from ..datasets import get_dataset
|
||||
from .utils import (
|
||||
eval_homography_robust,
|
||||
eval_poses,
|
||||
eval_matches_homography,
|
||||
eval_homography_dlt,
|
||||
)
|
||||
from ..utils.tools import AUCMetric
|
||||
|
||||
from .eval_pipeline import EvalPipeline
|
||||
|
||||
|
||||
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
class HPatchesPipeline(EvalPipeline):
|
||||
default_conf = {
|
||||
"data": {
|
||||
"batch_size": 1,
|
||||
"name": "hpatches",
|
||||
"num_workers": 16,
|
||||
"preprocessing": {
|
||||
"resize": 480, # we also resize during eval to have comparable metrics
|
||||
"side": "short",
|
||||
},
|
||||
},
|
||||
"model": {
|
||||
"ground_truth": {
|
||||
"name": None, # remove gt matches
|
||||
}
|
||||
},
|
||||
"eval": {
|
||||
"estimator": "poselib",
|
||||
"ransac_th": 1.0, # -1 runs a bunch of thresholds and selects the best
|
||||
},
|
||||
}
|
||||
export_keys = [
|
||||
"keypoints0",
|
||||
"keypoints1",
|
||||
"keypoint_scores0",
|
||||
"keypoint_scores1",
|
||||
"matches0",
|
||||
"matches1",
|
||||
"matching_scores0",
|
||||
"matching_scores1",
|
||||
]
|
||||
|
||||
optional_export_keys = [
|
||||
"lines0",
|
||||
"lines1",
|
||||
"orig_lines0",
|
||||
"orig_lines1",
|
||||
"line_matches0",
|
||||
"line_matches1",
|
||||
"line_matching_scores0",
|
||||
"line_matching_scores1",
|
||||
]
|
||||
|
||||
def _init(self, conf):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_dataloader(self, data_conf=None):
|
||||
data_conf = data_conf if data_conf else self.default_conf["data"]
|
||||
dataset = get_dataset("hpatches")(data_conf)
|
||||
return dataset.get_data_loader("test")
|
||||
|
||||
def get_predictions(self, experiment_dir, model=None, overwrite=False):
|
||||
pred_file = experiment_dir / "predictions.h5"
|
||||
if not pred_file.exists() or overwrite:
|
||||
if model is None:
|
||||
model = load_model(self.conf.model, self.conf.checkpoint)
|
||||
export_predictions(
|
||||
self.get_dataloader(self.conf.data),
|
||||
model,
|
||||
pred_file,
|
||||
keys=self.export_keys,
|
||||
optional_keys=self.optional_export_keys,
|
||||
)
|
||||
return pred_file
|
||||
|
||||
def run_eval(self, loader, pred_file):
|
||||
assert pred_file.exists()
|
||||
results = defaultdict(list)
|
||||
|
||||
conf = self.conf.eval
|
||||
|
||||
test_thresholds = (
|
||||
([conf.ransac_th] if conf.ransac_th > 0 else [0.5, 1.0, 1.5, 2.0, 2.5, 3.0])
|
||||
if not isinstance(conf.ransac_th, Iterable)
|
||||
else conf.ransac_th
|
||||
)
|
||||
pose_results = defaultdict(lambda: defaultdict(list))
|
||||
cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
|
||||
for i, data in enumerate(tqdm(loader)):
|
||||
pred = cache_loader(data)
|
||||
# add custom evaluations here
|
||||
if "keypoints0" in pred:
|
||||
results_i = eval_matches_homography(data, pred, {})
|
||||
results_i = {**results_i, **eval_homography_dlt(data, pred)}
|
||||
else:
|
||||
results_i = {}
|
||||
for th in test_thresholds:
|
||||
pose_results_i = eval_homography_robust(
|
||||
data,
|
||||
pred,
|
||||
{"estimator": conf.estimator, "ransac_th": th},
|
||||
)
|
||||
[pose_results[th][k].append(v) for k, v in pose_results_i.items()]
|
||||
|
||||
# we also store the names for later reference
|
||||
results_i["names"] = data["name"][0]
|
||||
results_i["scenes"] = data["scene"][0]
|
||||
|
||||
for k, v in results_i.items():
|
||||
results[k].append(v)
|
||||
|
||||
# summarize results as a dict[str, float]
|
||||
# you can also add your custom evaluations here
|
||||
summaries = {}
|
||||
for k, v in results.items():
|
||||
arr = np.array(v)
|
||||
if not np.issubdtype(np.array(v).dtype, np.number):
|
||||
continue
|
||||
summaries[f"m{k}"] = round(np.median(arr), 3)
|
||||
|
||||
auc_ths = [1, 3, 5]
|
||||
best_pose_results, best_th = eval_poses(
|
||||
pose_results, auc_ths=auc_ths, key="H_error_ransac", unit="px"
|
||||
)
|
||||
if "H_error_dlt" in results.keys():
|
||||
dlt_aucs = AUCMetric(auc_ths, results["H_error_dlt"]).compute()
|
||||
for i, ath in enumerate(auc_ths):
|
||||
summaries[f"H_error_dlt@{ath}px"] = dlt_aucs[i]
|
||||
|
||||
results = {**results, **pose_results[best_th]}
|
||||
summaries = {
|
||||
**summaries,
|
||||
**best_pose_results,
|
||||
}
|
||||
|
||||
figures = {
|
||||
"homography_recall": plot_cumulative(
|
||||
{
|
||||
"DLT": results["H_error_dlt"],
|
||||
self.conf.eval.estimator: results["H_error_ransac"],
|
||||
},
|
||||
[0, 10],
|
||||
unit="px",
|
||||
title="Homography ",
|
||||
)
|
||||
}
|
||||
|
||||
return summaries, figures, results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset_name = Path(__file__).stem
|
||||
parser = get_eval_parser()
|
||||
args = parser.parse_intermixed_args()
|
||||
|
||||
default_conf = OmegaConf.create(HPatchesPipeline.default_conf)
|
||||
|
||||
# mingle paths
|
||||
output_dir = Path(EVAL_PATH, dataset_name)
|
||||
output_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
name, conf = parse_eval_args(
|
||||
dataset_name,
|
||||
args,
|
||||
"configs/",
|
||||
default_conf,
|
||||
)
|
||||
|
||||
experiment_dir = output_dir / name
|
||||
experiment_dir.mkdir(exist_ok=True)
|
||||
|
||||
pipeline = HPatchesPipeline(conf)
|
||||
s, f, r = pipeline.run(
|
||||
experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval
|
||||
)
|
||||
|
||||
# print results
|
||||
pprint(s)
|
||||
if args.plot:
|
||||
for name, fig in f.items():
|
||||
fig.canvas.manager.set_window_title(name)
|
||||
plt.show()
|
|
@ -0,0 +1,61 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
from pprint import pprint
|
||||
from collections import defaultdict
|
||||
|
||||
from ..settings import EVAL_PATH
|
||||
from ..visualization.global_frame import GlobalFrame
|
||||
from ..visualization.two_view_frame import TwoViewFrame
|
||||
from . import get_benchmark
|
||||
from .eval_pipeline import load_eval
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("benchmark", type=str)
|
||||
parser.add_argument("--x", type=str, default=None)
|
||||
parser.add_argument("--y", type=str, default=None)
|
||||
parser.add_argument("--backend", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--default_plot", type=str, default=TwoViewFrame.default_conf["default"]
|
||||
)
|
||||
|
||||
parser.add_argument("dotlist", nargs="*")
|
||||
args = parser.parse_intermixed_args()
|
||||
|
||||
output_dir = Path(EVAL_PATH, args.benchmark)
|
||||
|
||||
results = {}
|
||||
summaries = defaultdict(dict)
|
||||
|
||||
predictions = {}
|
||||
|
||||
if args.backend:
|
||||
matplotlib.use(args.backend)
|
||||
|
||||
bm = get_benchmark(args.benchmark)
|
||||
loader = bm.get_dataloader()
|
||||
|
||||
for name in args.dotlist:
|
||||
experiment_dir = output_dir / name
|
||||
pred_file = experiment_dir / "predictions.h5"
|
||||
s, results[name] = load_eval(experiment_dir)
|
||||
predictions[name] = pred_file
|
||||
for k, v in s.items():
|
||||
summaries[k][name] = v
|
||||
|
||||
pprint(summaries)
|
||||
|
||||
plt.close("all")
|
||||
|
||||
frame = GlobalFrame(
|
||||
{"child": {"default": args.default_plot}, **vars(args)},
|
||||
results,
|
||||
loader,
|
||||
predictions,
|
||||
child_frame=TwoViewFrame,
|
||||
)
|
||||
frame.draw()
|
||||
plt.show()
|
|
@ -0,0 +1,103 @@
|
|||
import pkg_resources
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from omegaconf import OmegaConf
|
||||
import argparse
|
||||
from pprint import pprint
|
||||
|
||||
from ..models import get_model
|
||||
from ..utils.experiments import load_experiment
|
||||
from ..settings import TRAINING_PATH
|
||||
|
||||
|
||||
def parse_config_path(name_or_path: Optional[str], defaults: str) -> Path:
|
||||
default_configs = {}
|
||||
for c in pkg_resources.resource_listdir("gluefactory", str(defaults)):
|
||||
if c.endswith(".yaml"):
|
||||
default_configs[Path(c).stem] = Path(
|
||||
pkg_resources.resource_filename("gluefactory", defaults + c)
|
||||
)
|
||||
if name_or_path is None:
|
||||
return None
|
||||
if name_or_path in default_configs:
|
||||
return default_configs[name_or_path]
|
||||
path = Path(name_or_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Cannot find the config file: {name_or_path}. "
|
||||
f"Not in the default configs {list(default_configs.keys())} "
|
||||
"and not an existing path."
|
||||
)
|
||||
return Path(path)
|
||||
|
||||
|
||||
def extract_benchmark_conf(conf, benchmark):
|
||||
mconf = OmegaConf.create(
|
||||
{
|
||||
"model": conf.get("model", {}),
|
||||
}
|
||||
)
|
||||
if "benchmarks" in conf.keys():
|
||||
return OmegaConf.merge(mconf, conf.benchmarks.get(benchmark, {}))
|
||||
else:
|
||||
return mconf
|
||||
|
||||
|
||||
def parse_eval_args(benchmark, args, configs_path, default=None):
|
||||
conf = {"data": {}, "model": {}, "eval": {}}
|
||||
if args.conf:
|
||||
conf_path = parse_config_path(args.conf, configs_path)
|
||||
custom_conf = OmegaConf.load(conf_path)
|
||||
conf = extract_benchmark_conf(OmegaConf.merge(conf, custom_conf), benchmark)
|
||||
args.tag = (
|
||||
args.tag if args.tag is not None else conf_path.name.replace(".yaml", "")
|
||||
)
|
||||
|
||||
cli_conf = OmegaConf.from_cli(args.dotlist)
|
||||
conf = OmegaConf.merge(conf, cli_conf)
|
||||
conf.checkpoint = args.checkpoint if args.checkpoint else conf.get("checkpoint")
|
||||
|
||||
if conf.checkpoint and not conf.checkpoint.endswith(".tar"):
|
||||
checkpoint_conf = OmegaConf.load(
|
||||
TRAINING_PATH / conf.checkpoint / "config.yaml"
|
||||
)
|
||||
conf = OmegaConf.merge(extract_benchmark_conf(checkpoint_conf, benchmark), conf)
|
||||
|
||||
if default:
|
||||
conf = OmegaConf.merge(default, conf)
|
||||
|
||||
if args.tag is not None:
|
||||
name = args.tag
|
||||
elif args.conf and conf.checkpoint:
|
||||
name = f"{args.conf}_{conf.checkpoint}"
|
||||
elif args.conf:
|
||||
name = args.conf
|
||||
elif conf.checkpoint:
|
||||
name = conf.checkpoint
|
||||
if len(args.dotlist) > 0 and not args.tag:
|
||||
name = name + "_" + ":".join(args.dotlist)
|
||||
print("Running benchmark:", benchmark)
|
||||
print("Experiment tag:", name)
|
||||
print("Config:")
|
||||
pprint(OmegaConf.to_container(conf))
|
||||
return name, conf
|
||||
|
||||
|
||||
def load_model(model_conf, checkpoint):
|
||||
if checkpoint:
|
||||
model = load_experiment(checkpoint, conf=model_conf).eval()
|
||||
else:
|
||||
model = get_model("two_view_pipeline")(model_conf).eval()
|
||||
return model
|
||||
|
||||
|
||||
def get_eval_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tag", type=str, default=None)
|
||||
parser.add_argument("--checkpoint", type=str, default=None)
|
||||
parser.add_argument("--conf", type=str, default=None)
|
||||
parser.add_argument("--overwrite", action="store_true")
|
||||
parser.add_argument("--overwrite_eval", action="store_true")
|
||||
parser.add_argument("--plot", action="store_true")
|
||||
parser.add_argument("dotlist", nargs="*")
|
||||
return parser
|
|
@ -0,0 +1,191 @@
|
|||
import torch
|
||||
from pathlib import Path
|
||||
from omegaconf import OmegaConf
|
||||
from pprint import pprint
|
||||
import matplotlib.pyplot as plt
|
||||
import resource
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from tqdm import tqdm
|
||||
import zipfile
|
||||
import numpy as np
|
||||
from ..visualization.viz2d import plot_cumulative
|
||||
from .io import (
|
||||
parse_eval_args,
|
||||
load_model,
|
||||
get_eval_parser,
|
||||
)
|
||||
from ..utils.export_predictions import export_predictions
|
||||
from ..settings import EVAL_PATH, DATA_PATH
|
||||
from ..models.cache_loader import CacheLoader
|
||||
from ..datasets import get_dataset
|
||||
from .eval_pipeline import EvalPipeline
|
||||
|
||||
from .utils import eval_relative_pose_robust, eval_poses, eval_matches_epipolar
|
||||
|
||||
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
class MegaDepth1500Pipeline(EvalPipeline):
|
||||
default_conf = {
|
||||
"data": {
|
||||
"name": "image_pairs",
|
||||
"pairs": "megadepth1500/pairs_calibrated.txt",
|
||||
"root": "megadepth1500/images/",
|
||||
"extra_data": "relative_pose",
|
||||
"preprocessing": {
|
||||
"side": "long",
|
||||
},
|
||||
},
|
||||
"model": {
|
||||
"ground_truth": {
|
||||
"name": None, # remove gt matches
|
||||
}
|
||||
},
|
||||
"eval": {
|
||||
"estimator": "poselib",
|
||||
"ransac_th": 1.0, # -1 runs a bunch of thresholds and selects the best
|
||||
},
|
||||
}
|
||||
|
||||
export_keys = [
|
||||
"keypoints0",
|
||||
"keypoints1",
|
||||
"keypoint_scores0",
|
||||
"keypoint_scores1",
|
||||
"matches0",
|
||||
"matches1",
|
||||
"matching_scores0",
|
||||
"matching_scores1",
|
||||
]
|
||||
optional_export_keys = []
|
||||
|
||||
def _init(self, conf):
|
||||
if not (DATA_PATH / "megadepth1500").exists():
|
||||
url = "https://cvg-data.inf.ethz.ch/megadepth/megadepth1500.zip"
|
||||
zip_path = DATA_PATH / url.rsplit("/", 1)[-1]
|
||||
torch.hub.download_url_to_file(url, zip_path)
|
||||
with zipfile.ZipFile(zip_path) as zip:
|
||||
zip.extractall(DATA_PATH)
|
||||
zip_path.unlink()
|
||||
|
||||
@classmethod
|
||||
def get_dataloader(self, data_conf=None):
|
||||
"""Returns a data loader with samples for each eval datapoint"""
|
||||
data_conf = data_conf if data_conf else self.default_conf["data"]
|
||||
dataset = get_dataset(data_conf["name"])(data_conf)
|
||||
return dataset.get_data_loader("test")
|
||||
|
||||
def get_predictions(self, experiment_dir, model=None, overwrite=False):
|
||||
"""Export a prediction file for each eval datapoint"""
|
||||
pred_file = experiment_dir / "predictions.h5"
|
||||
if not pred_file.exists() or overwrite:
|
||||
if model is None:
|
||||
model = load_model(self.conf.model, self.conf.checkpoint)
|
||||
export_predictions(
|
||||
self.get_dataloader(self.conf.data),
|
||||
model,
|
||||
pred_file,
|
||||
keys=self.export_keys,
|
||||
optional_keys=self.optional_export_keys,
|
||||
)
|
||||
return pred_file
|
||||
|
||||
def run_eval(self, loader, pred_file):
|
||||
"""Run the eval on cached predictions"""
|
||||
conf = self.conf.eval
|
||||
results = defaultdict(list)
|
||||
test_thresholds = (
|
||||
([conf.ransac_th] if conf.ransac_th > 0 else [0.5, 1.0, 1.5, 2.0, 2.5, 3.0])
|
||||
if not isinstance(conf.ransac_th, Iterable)
|
||||
else conf.ransac_th
|
||||
)
|
||||
pose_results = defaultdict(lambda: defaultdict(list))
|
||||
cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
|
||||
for i, data in enumerate(tqdm(loader)):
|
||||
pred = cache_loader(data)
|
||||
# add custom evaluations here
|
||||
results_i = eval_matches_epipolar(data, pred)
|
||||
for th in test_thresholds:
|
||||
pose_results_i = eval_relative_pose_robust(
|
||||
data,
|
||||
pred,
|
||||
{"estimator": conf.estimator, "ransac_th": th},
|
||||
)
|
||||
[pose_results[th][k].append(v) for k, v in pose_results_i.items()]
|
||||
|
||||
# we also store the names for later reference
|
||||
results_i["names"] = data["name"][0]
|
||||
if "scene" in data.keys():
|
||||
results_i["scenes"] = data["scene"][0]
|
||||
|
||||
for k, v in results_i.items():
|
||||
results[k].append(v)
|
||||
|
||||
# summarize results as a dict[str, float]
|
||||
# you can also add your custom evaluations here
|
||||
summaries = {}
|
||||
for k, v in results.items():
|
||||
arr = np.array(v)
|
||||
if not np.issubdtype(np.array(v).dtype, np.number):
|
||||
continue
|
||||
summaries[f"m{k}"] = round(np.mean(arr), 3)
|
||||
|
||||
best_pose_results, best_th = eval_poses(
|
||||
pose_results, auc_ths=[5, 10, 20], key="rel_pose_error"
|
||||
)
|
||||
results = {**results, **pose_results[best_th]}
|
||||
summaries = {
|
||||
**summaries,
|
||||
**best_pose_results,
|
||||
}
|
||||
|
||||
figures = {
|
||||
"pose_recall": plot_cumulative(
|
||||
{self.conf.eval.estimator: results["rel_pose_error"]},
|
||||
[0, 30],
|
||||
unit="°",
|
||||
title="Pose ",
|
||||
)
|
||||
}
|
||||
|
||||
return summaries, figures, results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset_name = Path(__file__).stem
|
||||
parser = get_eval_parser()
|
||||
args = parser.parse_intermixed_args()
|
||||
|
||||
default_conf = OmegaConf.create(MegaDepth1500Pipeline.default_conf)
|
||||
|
||||
# mingle paths
|
||||
output_dir = Path(EVAL_PATH, dataset_name)
|
||||
output_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
name, conf = parse_eval_args(
|
||||
dataset_name,
|
||||
args,
|
||||
"configs/",
|
||||
default_conf,
|
||||
)
|
||||
|
||||
experiment_dir = output_dir / name
|
||||
experiment_dir.mkdir(exist_ok=True)
|
||||
|
||||
pipeline = MegaDepth1500Pipeline(conf)
|
||||
s, f, r = pipeline.run(
|
||||
experiment_dir,
|
||||
overwrite=args.overwrite,
|
||||
overwrite_eval=args.overwrite_eval,
|
||||
)
|
||||
|
||||
pprint(s)
|
||||
|
||||
if args.plot:
|
||||
for name, fig in f.items():
|
||||
fig.canvas.manager.set_window_title(name)
|
||||
plt.show()
|
|
@ -0,0 +1,254 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import kornia
|
||||
from ..geometry.epipolar import relative_pose_error, generalized_epi_dist
|
||||
from ..geometry.homography import sym_homography_error, homography_corner_error
|
||||
from ..geometry.gt_generation import IGNORE_FEATURE
|
||||
from ..utils.tools import AUCMetric
|
||||
from ..robust_estimators import load_estimator
|
||||
|
||||
|
||||
def check_keys_recursive(d, pattern):
|
||||
if isinstance(pattern, dict):
|
||||
{check_keys_recursive(d[k], v) for k, v in pattern.items()}
|
||||
else:
|
||||
for k in pattern:
|
||||
assert k in d.keys()
|
||||
|
||||
|
||||
def get_matches_scores(kpts0, kpts1, matches0, mscores0):
|
||||
m0 = matches0 > -1
|
||||
m1 = matches0[m0]
|
||||
pts0 = kpts0[m0]
|
||||
pts1 = kpts1[m1]
|
||||
scores = mscores0[m0]
|
||||
return pts0, pts1, scores
|
||||
|
||||
|
||||
def eval_matches_epipolar(data: dict, pred: dict) -> dict:
|
||||
check_keys_recursive(data, ["view0", "view1", "T_0to1"])
|
||||
check_keys_recursive(
|
||||
pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"]
|
||||
)
|
||||
|
||||
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
|
||||
m0, scores0 = pred["matches0"], pred["matching_scores0"]
|
||||
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
|
||||
|
||||
results = {}
|
||||
|
||||
# match metrics
|
||||
n_epi_err = generalized_epi_dist(
|
||||
pts0[None],
|
||||
pts1[None],
|
||||
data["view0"]["camera"],
|
||||
data["view1"]["camera"],
|
||||
data["T_0to1"],
|
||||
False,
|
||||
essential=True,
|
||||
)[0]
|
||||
results["epi_prec@1e-4"] = (n_epi_err < 1e-4).float().mean()
|
||||
results["epi_prec@5e-4"] = (n_epi_err < 5e-4).float().mean()
|
||||
results["epi_prec@1e-3"] = (n_epi_err < 1e-3).float().mean()
|
||||
|
||||
results["num_matches"] = pts0.shape[0]
|
||||
results["num_keypoints"] = (kp0.shape[0] + kp1.shape[0]) / 2.0
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def eval_matches_homography(data: dict, pred: dict, conf) -> dict:
|
||||
check_keys_recursive(data, ["H_0to1"])
|
||||
check_keys_recursive(
|
||||
pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"]
|
||||
)
|
||||
|
||||
H_gt = data["H_0to1"]
|
||||
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
|
||||
m0, scores0 = pred["matches0"], pred["matching_scores0"]
|
||||
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
|
||||
err = sym_homography_error(pts0, pts1, H_gt[0])
|
||||
results = {}
|
||||
results["prec@1px"] = (err < 1).float().mean().nan_to_num().item()
|
||||
results["prec@3px"] = (err < 3).float().mean().nan_to_num().item()
|
||||
results["num_matches"] = pts0.shape[0]
|
||||
results["num_keypoints"] = (kp0.shape[0] + kp1.shape[0]) / 2.0
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def eval_relative_pose_robust(data, pred, conf):
|
||||
check_keys_recursive(data, ["view0", "view1", "T_0to1"])
|
||||
check_keys_recursive(
|
||||
pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"]
|
||||
)
|
||||
|
||||
T_gt = data["T_0to1"][0]
|
||||
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
|
||||
m0, scores0 = pred["matches0"], pred["matching_scores0"]
|
||||
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
|
||||
|
||||
results = {}
|
||||
|
||||
estimator = load_estimator("relative_pose", conf["estimator"])(conf)
|
||||
data_ = {
|
||||
"m_kpts0": pts0,
|
||||
"m_kpts1": pts1,
|
||||
"camera0": data["view0"]["camera"][0],
|
||||
"camera1": data["view1"]["camera"][0],
|
||||
}
|
||||
est = estimator(data_)
|
||||
|
||||
if not est["success"]:
|
||||
results["rel_pose_error"] = float("inf")
|
||||
results["ransac_inl"] = 0
|
||||
results["ransac_inl%"] = 0
|
||||
else:
|
||||
# R, t, inl = ret
|
||||
M = est["M_0to1"]
|
||||
R, t = M.numpy()
|
||||
inl = est["inliers"].numpy()
|
||||
r_error, t_error = relative_pose_error(T_gt, R, t)
|
||||
results["rel_pose_error"] = max(r_error, t_error)
|
||||
results["ransac_inl"] = np.sum(inl)
|
||||
results["ransac_inl%"] = np.mean(inl)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def eval_homography_robust(data, pred, conf):
|
||||
H_gt = data["H_0to1"]
|
||||
estimator = load_estimator("homography", conf["estimator"])(conf)
|
||||
|
||||
data_ = {}
|
||||
if "keypoints0" in pred:
|
||||
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
|
||||
m0, scores0 = pred["matches0"], pred["matching_scores0"]
|
||||
pts0, pts1, _ = get_matches_scores(kp0, kp1, m0, scores0)
|
||||
data_["m_kpts0"] = pts0
|
||||
data_["m_kpts1"] = pts1
|
||||
if "lines0" in pred:
|
||||
if "orig_lines0" in pred:
|
||||
lines0 = pred["orig_lines0"]
|
||||
lines1 = pred["orig_lines1"]
|
||||
else:
|
||||
lines0 = pred["lines0"]
|
||||
lines1 = pred["lines1"]
|
||||
m_lines0, m_lines1, _ = get_matches_scores(
|
||||
lines0, lines1, pred["line_matches0"], pred["line_matching_scores0"]
|
||||
)
|
||||
data_["m_lines0"] = m_lines0
|
||||
data_["m_lines1"] = m_lines1
|
||||
|
||||
est = estimator(data_)
|
||||
if est["success"]:
|
||||
M = est["M_0to1"]
|
||||
error_r = homography_corner_error(M, H_gt, data["view0"]["image_size"]).item()
|
||||
else:
|
||||
error_r = float("inf")
|
||||
|
||||
results = {}
|
||||
results["H_error_ransac"] = error_r
|
||||
if "inliers" in est:
|
||||
inl = est["inliers"]
|
||||
results["ransac_inl"] = inl.float().sum().item()
|
||||
results["ransac_inl%"] = inl.float().sum().item() / max(len(inl), 1)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def eval_homography_dlt(data, pred, *args):
|
||||
H_gt = data["H_0to1"]
|
||||
H_inf = torch.ones_like(H_gt) * float("inf")
|
||||
|
||||
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
|
||||
m0, scores0 = pred["matches0"], pred["matching_scores0"]
|
||||
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
|
||||
results = {}
|
||||
try:
|
||||
Hdlt = kornia.geometry.homography.find_homography_dlt(
|
||||
pts0[None], pts1[None], scores[None].to(pts0)
|
||||
)[0]
|
||||
except AssertionError:
|
||||
Hdlt = H_inf
|
||||
|
||||
error_dlt = homography_corner_error(Hdlt, H_gt, data["view0"]["image_size"])
|
||||
results["H_error_dlt"] = error_dlt.item()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def eval_poses(pose_results, auc_ths, key, unit="°"):
|
||||
pose_aucs = {}
|
||||
best_th = -1
|
||||
for th, results_i in pose_results.items():
|
||||
pose_aucs[th] = AUCMetric(auc_ths, results_i[key]).compute()
|
||||
mAAs = {k: np.mean(v) for k, v in pose_aucs.items()}
|
||||
best_th = max(mAAs, key=mAAs.get)
|
||||
|
||||
if len(pose_aucs) > -1:
|
||||
print("Tested ransac setup with following results:")
|
||||
print("AUC", pose_aucs)
|
||||
print("mAA", mAAs)
|
||||
print("best threshold =", best_th)
|
||||
|
||||
summaries = {}
|
||||
|
||||
for i, ath in enumerate(auc_ths):
|
||||
summaries[f"{key}@{ath}{unit}"] = pose_aucs[best_th][i]
|
||||
summaries[f"{key}_mAA"] = mAAs[best_th]
|
||||
|
||||
for k, v in pose_results[best_th].items():
|
||||
arr = np.array(v)
|
||||
if not np.issubdtype(np.array(v).dtype, np.number):
|
||||
continue
|
||||
summaries[f"m{k}"] = round(np.median(arr), 3)
|
||||
return summaries, best_th
|
||||
|
||||
|
||||
def get_tp_fp_pts(pred_matches, gt_matches, pred_scores):
|
||||
"""
|
||||
Computes the True Positives (TP), False positives (FP), the score associated
|
||||
to each match and the number of positives for a set of matches.
|
||||
"""
|
||||
assert pred_matches.shape == pred_scores.shape
|
||||
ignore_mask = gt_matches != IGNORE_FEATURE
|
||||
pred_matches, gt_matches, pred_scores = (
|
||||
pred_matches[ignore_mask],
|
||||
gt_matches[ignore_mask],
|
||||
pred_scores[ignore_mask],
|
||||
)
|
||||
num_pos = np.sum(gt_matches != -1)
|
||||
pred_positives = pred_matches != -1
|
||||
tp = pred_matches[pred_positives] == gt_matches[pred_positives]
|
||||
fp = pred_matches[pred_positives] != gt_matches[pred_positives]
|
||||
scores = pred_scores[pred_positives]
|
||||
return tp, fp, scores, num_pos
|
||||
|
||||
|
||||
def AP(tp, fp):
|
||||
recall = tp
|
||||
precision = tp / np.maximum(tp + fp, 1e-9)
|
||||
recall = np.concatenate(([0.0], recall, [1.0]))
|
||||
precision = np.concatenate(([0.0], precision, [0.0]))
|
||||
for i in range(precision.size - 1, 0, -1):
|
||||
precision[i - 1] = max(precision[i - 1], precision[i])
|
||||
i = np.where(recall[1:] != recall[:-1])[0]
|
||||
ap = np.sum((recall[i + 1] - recall[i]) * precision[i + 1])
|
||||
return ap
|
||||
|
||||
|
||||
def aggregate_pr_results(results, suffix=""):
|
||||
tp_list = np.concatenate(results["tp" + suffix], axis=0)
|
||||
fp_list = np.concatenate(results["fp" + suffix], axis=0)
|
||||
scores_list = np.concatenate(results["scores" + suffix], axis=0)
|
||||
n_gt = max(results["num_pos" + suffix], 1)
|
||||
|
||||
out = {}
|
||||
idx = np.argsort(scores_list)[::-1]
|
||||
tp_vals = np.cumsum(tp_list[idx]) / n_gt
|
||||
fp_vals = np.cumsum(fp_list[idx]) / n_gt
|
||||
out["curve_recall" + suffix] = tp_vals
|
||||
out["curve_precision" + suffix] = tp_vals / np.maximum(tp_vals + fp_vals, 1e-9)
|
||||
out["AP" + suffix] = AP(tp_vals, fp_vals) * 100
|
||||
return out
|
|
@ -0,0 +1,88 @@
|
|||
import torch
|
||||
import kornia
|
||||
|
||||
from .utils import get_image_coords
|
||||
from .wrappers import Camera
|
||||
|
||||
|
||||
def sample_fmap(pts, fmap):
|
||||
h, w = fmap.shape[-2:]
|
||||
grid_sample = torch.nn.functional.grid_sample
|
||||
pts = (pts / pts.new_tensor([[w, h]]) * 2 - 1)[:, None]
|
||||
# @TODO: This might still be a source of noise --> bilinear interpolation dangerous
|
||||
interp_lin = grid_sample(fmap, pts, align_corners=False, mode="bilinear")
|
||||
interp_nn = grid_sample(fmap, pts, align_corners=False, mode="nearest")
|
||||
return torch.where(torch.isnan(interp_lin), interp_nn, interp_lin)[:, :, 0].permute(
|
||||
0, 2, 1
|
||||
)
|
||||
|
||||
|
||||
def sample_depth(pts, depth_):
|
||||
depth = torch.where(depth_ > 0, depth_, depth_.new_tensor(float("nan")))
|
||||
depth = depth[:, None]
|
||||
interp = sample_fmap(pts, depth).squeeze(-1)
|
||||
valid = (~torch.isnan(interp)) & (interp > 0)
|
||||
return interp, valid
|
||||
|
||||
|
||||
def sample_normals_from_depth(pts, depth, K):
|
||||
depth = depth[:, None]
|
||||
normals = kornia.geometry.depth.depth_to_normals(depth, K)
|
||||
normals = torch.where(depth > 0, normals, 0.0)
|
||||
interp = sample_fmap(pts, normals)
|
||||
valid = (~torch.isnan(interp)) & (interp > 0)
|
||||
return interp, valid
|
||||
|
||||
|
||||
def project(
|
||||
kpi,
|
||||
di,
|
||||
depthj,
|
||||
camera_i,
|
||||
camera_j,
|
||||
T_itoj,
|
||||
validi,
|
||||
ccth=None,
|
||||
sample_depth_fun=sample_depth,
|
||||
sample_depth_kwargs=None,
|
||||
):
|
||||
if sample_depth_kwargs is None:
|
||||
sample_depth_kwargs = {}
|
||||
|
||||
kpi_3d_i = camera_i.image2cam(kpi)
|
||||
kpi_3d_i = kpi_3d_i * di[..., None]
|
||||
kpi_3d_j = T_itoj.transform(kpi_3d_i)
|
||||
kpi_j, validj = camera_j.cam2image(kpi_3d_j)
|
||||
# di_j = kpi_3d_j[..., -1]
|
||||
validi = validi & validj
|
||||
if depthj is None or ccth is None:
|
||||
return kpi_j, validi & validj
|
||||
else:
|
||||
# circle consistency
|
||||
dj, validj = sample_depth_fun(kpi_j, depthj, **sample_depth_kwargs)
|
||||
kpi_j_3d_j = camera_j.image2cam(kpi_j) * dj[..., None]
|
||||
kpi_j_i, validj_i = camera_i.cam2image(T_itoj.inv().transform(kpi_j_3d_j))
|
||||
consistent = ((kpi - kpi_j_i) ** 2).sum(-1) < ccth
|
||||
visible = validi & consistent & validj_i & validj
|
||||
# visible = validi
|
||||
return kpi_j, visible
|
||||
|
||||
|
||||
def dense_warp_consistency(
|
||||
depthi: torch.Tensor,
|
||||
depthj: torch.Tensor,
|
||||
T_itoj: torch.Tensor,
|
||||
camerai: Camera,
|
||||
cameraj: Camera,
|
||||
**kwargs,
|
||||
):
|
||||
kpi = get_image_coords(depthi).flatten(-3, -2)
|
||||
di = depthi.flatten(
|
||||
-2,
|
||||
)
|
||||
validi = di > 0
|
||||
kpir, validir = project(kpi, di, depthj, camerai, cameraj, T_itoj, validi, **kwargs)
|
||||
|
||||
return kpir.unflatten(-2, depthi.shape[-2:]), validir.unflatten(
|
||||
-1, (depthj.shape[-2:])
|
||||
)
|
|
@ -0,0 +1,161 @@
|
|||
import torch
|
||||
from .utils import skew_symmetric, to_homogeneous
|
||||
from .wrappers import Pose, Camera
|
||||
import numpy as np
|
||||
|
||||
|
||||
def T_to_E(T: Pose):
|
||||
"""Convert batched poses (..., 4, 4) to batched essential matrices."""
|
||||
return skew_symmetric(T.t) @ T.R
|
||||
|
||||
|
||||
def T_to_F(cam0: Camera, cam1: Camera, T_0to1: Pose):
|
||||
return E_to_F(cam0, cam1, T_to_E(T_0to1))
|
||||
|
||||
|
||||
def E_to_F(cam0: Camera, cam1: Camera, E: torch.Tensor):
|
||||
assert cam0._data.shape[-1] == 6, "only pinhole cameras supported"
|
||||
assert cam1._data.shape[-1] == 6, "only pinhole cameras supported"
|
||||
K0 = cam0.calibration_matrix()
|
||||
K1 = cam1.calibration_matrix()
|
||||
return K1.inverse().transpose(-1, -2) @ E @ K0.inverse()
|
||||
|
||||
|
||||
def F_to_E(cam0: Camera, cam1: Camera, F: torch.Tensor):
|
||||
assert cam0._data.shape[-1] == 6, "only pinhole cameras supported"
|
||||
assert cam1._data.shape[-1] == 6, "only pinhole cameras supported"
|
||||
K0 = cam0.calibration_matrix()
|
||||
K1 = cam1.calibration_matrix()
|
||||
return K1.transpose(-1, -2) @ F @ K0
|
||||
|
||||
|
||||
def sym_epipolar_distance(p0, p1, E, squared=True):
|
||||
"""Compute batched symmetric epipolar distances.
|
||||
Args:
|
||||
p0, p1: batched tensors of N 2D points of size (..., N, 2).
|
||||
E: essential matrices from camera 0 to camera 1, size (..., 3, 3).
|
||||
Returns:
|
||||
The symmetric epipolar distance of each point-pair: (..., N).
|
||||
"""
|
||||
assert p0.shape[-2] == p1.shape[-2]
|
||||
if p0.shape[-2] == 0:
|
||||
return torch.zeros(p0.shape[:-1]).to(p0)
|
||||
if p0.shape[-1] != 3:
|
||||
p0 = to_homogeneous(p0)
|
||||
if p1.shape[-1] != 3:
|
||||
p1 = to_homogeneous(p1)
|
||||
p1_E_p0 = torch.einsum("...ni,...ij,...nj->...n", p1, E, p0)
|
||||
E_p0 = torch.einsum("...ij,...nj->...ni", E, p0)
|
||||
Et_p1 = torch.einsum("...ij,...ni->...nj", E, p1)
|
||||
d0 = (E_p0[..., 0] ** 2 + E_p0[..., 1] ** 2).clamp(min=1e-6)
|
||||
d1 = (Et_p1[..., 0] ** 2 + Et_p1[..., 1] ** 2).clamp(min=1e-6)
|
||||
if squared:
|
||||
d = p1_E_p0**2 * (1 / d0 + 1 / d1)
|
||||
else:
|
||||
d = p1_E_p0.abs() * (1 / d0.sqrt() + 1 / d1.sqrt()) / 2
|
||||
return d
|
||||
|
||||
|
||||
def sym_epipolar_distance_all(p0, p1, E, eps=1e-15):
|
||||
if p0.shape[-1] != 3:
|
||||
p0 = to_homogeneous(p0)
|
||||
if p1.shape[-1] != 3:
|
||||
p1 = to_homogeneous(p1)
|
||||
p1_E_p0 = torch.einsum("...mi,...ij,...nj->...nm", p1, E, p0).abs()
|
||||
E_p0 = torch.einsum("...ij,...nj->...ni", E, p0)
|
||||
Et_p1 = torch.einsum("...ij,...mi->...mj", E, p1)
|
||||
d0 = p1_E_p0 / (E_p0[..., None, 0] ** 2 + E_p0[..., None, 1] ** 2 + eps).sqrt()
|
||||
d1 = (
|
||||
p1_E_p0
|
||||
/ (Et_p1[..., None, :, 0] ** 2 + Et_p1[..., None, :, 1] ** 2 + eps).sqrt()
|
||||
)
|
||||
return (d0 + d1) / 2
|
||||
|
||||
|
||||
def generalized_epi_dist(
|
||||
kpts0, kpts1, cam0: Camera, cam1: Camera, T_0to1: Pose, all=True, essential=True
|
||||
):
|
||||
if essential:
|
||||
E = T_to_E(T_0to1)
|
||||
p0 = cam0.image2cam(kpts0)
|
||||
p1 = cam1.image2cam(kpts1)
|
||||
if all:
|
||||
return sym_epipolar_distance_all(p0, p1, E, agg="max")
|
||||
else:
|
||||
return sym_epipolar_distance(p0, p1, E, squared=False)
|
||||
else:
|
||||
assert cam0._data.shape[-1] == 6
|
||||
assert cam1._data.shape[-1] == 6
|
||||
K0, K1 = cam0.calibration_matrix(), cam1.calibration_matrix()
|
||||
F = K1.inverse().transpose(-1, -2) @ T_to_E(T_0to1) @ K0.inverse()
|
||||
if all:
|
||||
return sym_epipolar_distance_all(kpts0, kpts1, F)
|
||||
else:
|
||||
return sym_epipolar_distance(kpts0, kpts1, F, squared=False)
|
||||
|
||||
|
||||
def decompose_essential_matrix(E):
|
||||
# decompose matrix by its singular values
|
||||
U, _, V = torch.svd(E)
|
||||
Vt = V.transpose(-2, -1)
|
||||
|
||||
mask = torch.ones_like(E)
|
||||
mask[..., -1:] *= -1.0 # fill last column with negative values
|
||||
|
||||
maskt = mask.transpose(-2, -1)
|
||||
|
||||
# avoid singularities
|
||||
U = torch.where((torch.det(U) < 0.0)[..., None, None], U * mask, U)
|
||||
Vt = torch.where((torch.det(Vt) < 0.0)[..., None, None], Vt * maskt, Vt)
|
||||
|
||||
W = skew_symmetric(E.new_tensor([[0, 0, 1]]))
|
||||
W[..., 2, 2] += 1.0
|
||||
|
||||
# reconstruct rotations and retrieve translation vector
|
||||
U_W_Vt = U @ W @ Vt
|
||||
U_Wt_Vt = U @ W.transpose(-2, -1) @ Vt
|
||||
|
||||
# return values
|
||||
R1 = U_W_Vt
|
||||
R2 = U_Wt_Vt
|
||||
T = U[..., -1]
|
||||
return R1, R2, T
|
||||
|
||||
|
||||
# pose errors
|
||||
# TODO: port to torch and batch
|
||||
def angle_error_mat(R1, R2):
|
||||
cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
|
||||
cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
|
||||
return np.rad2deg(np.abs(np.arccos(cos)))
|
||||
|
||||
|
||||
def angle_error_vec(v1, v2):
|
||||
n = np.linalg.norm(v1) * np.linalg.norm(v2)
|
||||
return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
|
||||
|
||||
|
||||
def compute_pose_error(T_0to1, R, t):
|
||||
R_gt = T_0to1[:3, :3]
|
||||
t_gt = T_0to1[:3, 3]
|
||||
error_t = angle_error_vec(t, t_gt)
|
||||
error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
|
||||
error_R = angle_error_mat(R, R_gt)
|
||||
return error_t, error_R
|
||||
|
||||
|
||||
def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
|
||||
# angle error between 2 vectors
|
||||
R_gt, t_gt = T_0to1.numpy()
|
||||
n = np.linalg.norm(t) * np.linalg.norm(t_gt)
|
||||
t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0)))
|
||||
t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity
|
||||
if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging
|
||||
t_err = 0
|
||||
|
||||
# angle error between 2 rotation matrices
|
||||
cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
|
||||
cos = np.clip(cos, -1.0, 1.0) # handle numercial errors
|
||||
R_err = np.rad2deg(np.abs(np.arccos(cos)))
|
||||
|
||||
return t_err, R_err
|
|
@ -0,0 +1,558 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
from .homography import warp_points_torch
|
||||
from .epipolar import T_to_E, sym_epipolar_distance_all
|
||||
from .depth import sample_depth, project
|
||||
|
||||
IGNORE_FEATURE = -2
|
||||
UNMATCHED_FEATURE = -1
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gt_matches_from_pose_depth(
|
||||
kp0, kp1, data, pos_th=3, neg_th=5, epi_th=None, cc_th=None, **kw
|
||||
):
|
||||
if kp0.shape[1] == 0 or kp1.shape[1] == 0:
|
||||
b_size, n_kp0 = kp0.shape[:2]
|
||||
n_kp1 = kp1.shape[1]
|
||||
assignment = torch.zeros(
|
||||
b_size, n_kp0, n_kp1, dtype=torch.bool, device=kp0.device
|
||||
)
|
||||
m0 = -torch.ones_like(kp0[:, :, 0]).long()
|
||||
m1 = -torch.ones_like(kp1[:, :, 0]).long()
|
||||
return assignment, m0, m1
|
||||
camera0, camera1 = data["view0"]["camera"], data["view1"]["camera"]
|
||||
T_0to1, T_1to0 = data["T_0to1"], data["T_1to0"]
|
||||
|
||||
depth0 = data["view0"].get("depth")
|
||||
depth1 = data["view1"].get("depth")
|
||||
if "depth_keypoints0" in kw and "depth_keypoints1" in kw:
|
||||
d0, valid0 = kw["depth_keypoints0"], kw["valid_depth_keypoints0"]
|
||||
d1, valid1 = kw["depth_keypoints1"], kw["valid_depth_keypoints1"]
|
||||
else:
|
||||
assert depth0 is not None
|
||||
assert depth1 is not None
|
||||
d0, valid0 = sample_depth(kp0, depth0)
|
||||
d1, valid1 = sample_depth(kp1, depth1)
|
||||
|
||||
kp0_1, visible0 = project(
|
||||
kp0, d0, depth1, camera0, camera1, T_0to1, valid0, ccth=cc_th
|
||||
)
|
||||
kp1_0, visible1 = project(
|
||||
kp1, d1, depth0, camera1, camera0, T_1to0, valid1, ccth=cc_th
|
||||
)
|
||||
mask_visible = visible0.unsqueeze(-1) & visible1.unsqueeze(-2)
|
||||
|
||||
# build a distance matrix of size [... x M x N]
|
||||
dist0 = torch.sum((kp0_1.unsqueeze(-2) - kp1.unsqueeze(-3)) ** 2, -1)
|
||||
dist1 = torch.sum((kp0.unsqueeze(-2) - kp1_0.unsqueeze(-3)) ** 2, -1)
|
||||
dist = torch.max(dist0, dist1)
|
||||
inf = dist.new_tensor(float("inf"))
|
||||
dist = torch.where(mask_visible, dist, inf)
|
||||
|
||||
min0 = dist.min(-1).indices
|
||||
min1 = dist.min(-2).indices
|
||||
|
||||
ismin0 = torch.zeros(dist.shape, dtype=torch.bool, device=dist.device)
|
||||
ismin1 = ismin0.clone()
|
||||
ismin0.scatter_(-1, min0.unsqueeze(-1), value=1)
|
||||
ismin1.scatter_(-2, min1.unsqueeze(-2), value=1)
|
||||
positive = ismin0 & ismin1 & (dist < pos_th**2)
|
||||
|
||||
negative0 = (dist0.min(-1).values > neg_th**2) & valid0
|
||||
negative1 = (dist1.min(-2).values > neg_th**2) & valid1
|
||||
|
||||
# pack the indices of positive matches
|
||||
# if -1: unmatched point
|
||||
# if -2: ignore point
|
||||
unmatched = min0.new_tensor(UNMATCHED_FEATURE)
|
||||
ignore = min0.new_tensor(IGNORE_FEATURE)
|
||||
m0 = torch.where(positive.any(-1), min0, ignore)
|
||||
m1 = torch.where(positive.any(-2), min1, ignore)
|
||||
m0 = torch.where(negative0, unmatched, m0)
|
||||
m1 = torch.where(negative1, unmatched, m1)
|
||||
|
||||
F = (
|
||||
camera1.calibration_matrix().inverse().transpose(-1, -2)
|
||||
@ T_to_E(T_0to1)
|
||||
@ camera0.calibration_matrix().inverse()
|
||||
)
|
||||
epi_dist = sym_epipolar_distance_all(kp0, kp1, F)
|
||||
|
||||
# Add some more unmatched points using epipolar geometry
|
||||
if epi_th is not None:
|
||||
mask_ignore = (m0.unsqueeze(-1) == ignore) & (m1.unsqueeze(-2) == ignore)
|
||||
epi_dist = torch.where(mask_ignore, epi_dist, inf)
|
||||
exclude0 = epi_dist.min(-1).values > neg_th
|
||||
exclude1 = epi_dist.min(-2).values > neg_th
|
||||
m0 = torch.where((~valid0) & exclude0, ignore.new_tensor(-1), m0)
|
||||
m1 = torch.where((~valid1) & exclude1, ignore.new_tensor(-1), m1)
|
||||
|
||||
return {
|
||||
"assignment": positive,
|
||||
"reward": (dist < pos_th**2).float() - (epi_dist > neg_th).float(),
|
||||
"matches0": m0,
|
||||
"matches1": m1,
|
||||
"matching_scores0": (m0 > -1).float(),
|
||||
"matching_scores1": (m1 > -1).float(),
|
||||
"depth_keypoints0": d0,
|
||||
"depth_keypoints1": d1,
|
||||
"proj_0to1": kp0_1,
|
||||
"proj_1to0": kp1_0,
|
||||
"visible0": visible0,
|
||||
"visible1": visible1,
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gt_matches_from_homography(kp0, kp1, H, pos_th=3, neg_th=6, **kw):
|
||||
if kp0.shape[1] == 0 or kp1.shape[1] == 0:
|
||||
b_size, n_kp0 = kp0.shape[:2]
|
||||
n_kp1 = kp1.shape[1]
|
||||
assignment = torch.zeros(
|
||||
b_size, n_kp0, n_kp1, dtype=torch.bool, device=kp0.device
|
||||
)
|
||||
m0 = -torch.ones_like(kp0[:, :, 0]).long()
|
||||
m1 = -torch.ones_like(kp1[:, :, 0]).long()
|
||||
return assignment, m0, m1
|
||||
kp0_1 = warp_points_torch(kp0, H, inverse=False)
|
||||
kp1_0 = warp_points_torch(kp1, H, inverse=True)
|
||||
|
||||
# build a distance matrix of size [... x M x N]
|
||||
dist0 = torch.sum((kp0_1.unsqueeze(-2) - kp1.unsqueeze(-3)) ** 2, -1)
|
||||
dist1 = torch.sum((kp0.unsqueeze(-2) - kp1_0.unsqueeze(-3)) ** 2, -1)
|
||||
dist = torch.max(dist0, dist1)
|
||||
|
||||
reward = (dist < pos_th**2).float() - (dist > neg_th**2).float()
|
||||
|
||||
min0 = dist.min(-1).indices
|
||||
min1 = dist.min(-2).indices
|
||||
|
||||
ismin0 = torch.zeros(dist.shape, dtype=torch.bool, device=dist.device)
|
||||
ismin1 = ismin0.clone()
|
||||
ismin0.scatter_(-1, min0.unsqueeze(-1), value=1)
|
||||
ismin1.scatter_(-2, min1.unsqueeze(-2), value=1)
|
||||
positive = ismin0 & ismin1 & (dist < pos_th**2)
|
||||
|
||||
negative0 = dist0.min(-1).values > neg_th**2
|
||||
negative1 = dist1.min(-2).values > neg_th**2
|
||||
|
||||
# pack the indices of positive matches
|
||||
# if -1: unmatched point
|
||||
# if -2: ignore point
|
||||
unmatched = min0.new_tensor(UNMATCHED_FEATURE)
|
||||
ignore = min0.new_tensor(IGNORE_FEATURE)
|
||||
m0 = torch.where(positive.any(-1), min0, ignore)
|
||||
m1 = torch.where(positive.any(-2), min1, ignore)
|
||||
m0 = torch.where(negative0, unmatched, m0)
|
||||
m1 = torch.where(negative1, unmatched, m1)
|
||||
|
||||
return {
|
||||
"assignment": positive,
|
||||
"reward": reward,
|
||||
"matches0": m0,
|
||||
"matches1": m1,
|
||||
"matching_scores0": (m0 > -1).float(),
|
||||
"matching_scores1": (m1 > -1).float(),
|
||||
"proj_0to1": kp0_1,
|
||||
"proj_1to0": kp1_0,
|
||||
}
|
||||
|
||||
|
||||
def sample_pts(lines, npts):
|
||||
dir_vec = (lines[..., 2:4] - lines[..., :2]) / (npts - 1)
|
||||
pts = lines[..., :2, np.newaxis] + dir_vec[..., np.newaxis].expand(
|
||||
dir_vec.shape + (npts,)
|
||||
) * torch.arange(npts).to(lines)
|
||||
pts = torch.transpose(pts, -1, -2)
|
||||
return pts
|
||||
|
||||
|
||||
def torch_perp_dist(segs2d, points_2d):
|
||||
# Check batch size and segments format
|
||||
assert segs2d.shape[0] == points_2d.shape[0]
|
||||
assert segs2d.shape[-1] == 4
|
||||
dir = segs2d[..., 2:] - segs2d[..., :2]
|
||||
sizes = torch.norm(dir, dim=-1).half()
|
||||
norm_dir = dir / torch.unsqueeze(sizes, dim=-1)
|
||||
# middle_ptn = 0.5 * (segs2d[..., 2:] + segs2d[..., :2])
|
||||
# centered [batch, nsegs0, nsegs1, n_sampled_pts, 2]
|
||||
centered = points_2d[:, None] - segs2d[..., None, None, 2:]
|
||||
|
||||
R = torch.cat(
|
||||
[
|
||||
norm_dir[..., 0, None],
|
||||
norm_dir[..., 1, None],
|
||||
-norm_dir[..., 1, None],
|
||||
norm_dir[..., 0, None],
|
||||
],
|
||||
dim=2,
|
||||
).reshape((len(segs2d), -1, 2, 2))
|
||||
# Try to reduce the memory consumption by using float16 type
|
||||
if centered.is_cuda:
|
||||
centered, R = centered.half(), R.half()
|
||||
# R: [batch, nsegs0, 2, 2] , centered: [batch, nsegs1, n_sampled_pts, 2]
|
||||
# -> [batch, nsegs0, nsegs1, n_sampled_pts, 2]
|
||||
rotated = torch.einsum("bdji,bdepi->bdepj", R, centered)
|
||||
|
||||
overlaping = (rotated[..., 0] <= 0) & (
|
||||
torch.abs(rotated[..., 0]) <= sizes[..., None, None]
|
||||
)
|
||||
|
||||
return torch.abs(rotated[..., 1]), overlaping
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gt_line_matches_from_pose_depth(
|
||||
pred_lines0,
|
||||
pred_lines1,
|
||||
valid_lines0,
|
||||
valid_lines1,
|
||||
data,
|
||||
npts=50,
|
||||
dist_th=5,
|
||||
overlap_th=0.2,
|
||||
min_visibility_th=0.5,
|
||||
):
|
||||
"""Compute ground truth line matches and label the remaining the lines as:
|
||||
- UNMATCHED: if reprojection is outside the image
|
||||
or far away from any other line.
|
||||
- IGNORE: if a line has not enough valid depth pixels along itself
|
||||
or it is labeled as invalid."""
|
||||
lines0 = pred_lines0.clone()
|
||||
lines1 = pred_lines1.clone()
|
||||
|
||||
if pred_lines0.shape[1] == 0 or pred_lines1.shape[1] == 0:
|
||||
bsize, nlines0, nlines1 = (
|
||||
pred_lines0.shape[0],
|
||||
pred_lines0.shape[1],
|
||||
pred_lines1.shape[1],
|
||||
)
|
||||
positive = torch.zeros(
|
||||
(bsize, nlines0, nlines1), dtype=torch.bool, device=pred_lines0.device
|
||||
)
|
||||
m0 = torch.full((bsize, nlines0), -1, device=pred_lines0.device)
|
||||
m1 = torch.full((bsize, nlines1), -1, device=pred_lines0.device)
|
||||
return positive, m0, m1
|
||||
|
||||
if lines0.shape[-2:] == (2, 2):
|
||||
lines0 = torch.flatten(lines0, -2)
|
||||
elif lines0.dim() == 4:
|
||||
lines0 = torch.cat([lines0[:, :, 0], lines0[:, :, -1]], dim=2)
|
||||
if lines1.shape[-2:] == (2, 2):
|
||||
lines1 = torch.flatten(lines1, -2)
|
||||
elif lines1.dim() == 4:
|
||||
lines1 = torch.cat([lines1[:, :, 0], lines1[:, :, -1]], dim=2)
|
||||
b_size, n_lines0, _ = lines0.shape
|
||||
b_size, n_lines1, _ = lines1.shape
|
||||
h0, w0 = data["view0"]["depth"][0].shape
|
||||
h1, w1 = data["view1"]["depth"][0].shape
|
||||
|
||||
lines0 = torch.min(
|
||||
torch.max(lines0, torch.zeros_like(lines0)),
|
||||
lines0.new_tensor([w0 - 1, h0 - 1, w0 - 1, h0 - 1], dtype=torch.float),
|
||||
)
|
||||
lines1 = torch.min(
|
||||
torch.max(lines1, torch.zeros_like(lines1)),
|
||||
lines1.new_tensor([w1 - 1, h1 - 1, w1 - 1, h1 - 1], dtype=torch.float),
|
||||
)
|
||||
|
||||
# Sample points along each line
|
||||
pts0 = sample_pts(lines0, npts).reshape(b_size, n_lines0 * npts, 2)
|
||||
pts1 = sample_pts(lines1, npts).reshape(b_size, n_lines1 * npts, 2)
|
||||
|
||||
# Sample depth and valid points
|
||||
d0, valid0_pts0 = sample_depth(pts0, data["view0"]["depth"])
|
||||
d1, valid1_pts1 = sample_depth(pts1, data["view1"]["depth"])
|
||||
|
||||
# Reproject to the other view
|
||||
pts0_1, visible0 = project(
|
||||
pts0,
|
||||
d0,
|
||||
data["view1"]["depth"],
|
||||
data["view0"]["camera"],
|
||||
data["view1"]["camera"],
|
||||
data["T_0to1"],
|
||||
valid0_pts0,
|
||||
)
|
||||
pts1_0, visible1 = project(
|
||||
pts1,
|
||||
d1,
|
||||
data["view0"]["depth"],
|
||||
data["view1"]["camera"],
|
||||
data["view0"]["camera"],
|
||||
data["T_1to0"],
|
||||
valid1_pts1,
|
||||
)
|
||||
|
||||
h0, w0 = data["view0"]["image"].shape[-2:]
|
||||
h1, w1 = data["view1"]["image"].shape[-2:]
|
||||
# If a line has less than min_visibility_th inside the image is considered OUTSIDE
|
||||
pts_out_of0 = (pts1_0 < 0).any(-1) | (
|
||||
pts1_0 >= torch.tensor([w0, h0]).to(pts1_0)
|
||||
).any(-1)
|
||||
pts_out_of0 = pts_out_of0.reshape(b_size, n_lines1, npts).float()
|
||||
out_of0 = pts_out_of0.mean(dim=-1) >= (1 - min_visibility_th)
|
||||
pts_out_of1 = (pts0_1 < 0).any(-1) | (
|
||||
pts0_1 >= torch.tensor([w1, h1]).to(pts0_1)
|
||||
).any(-1)
|
||||
pts_out_of1 = pts_out_of1.reshape(b_size, n_lines0, npts).float()
|
||||
out_of1 = pts_out_of1.mean(dim=-1) >= (1 - min_visibility_th)
|
||||
|
||||
# visible0 is [bs, nl0 * npts]
|
||||
pts0_1 = pts0_1.reshape(b_size, n_lines0, npts, 2)
|
||||
pts1_0 = pts1_0.reshape(b_size, n_lines1, npts, 2)
|
||||
|
||||
perp_dists0, overlaping0 = torch_perp_dist(lines0, pts1_0)
|
||||
close_points0 = (perp_dists0 < dist_th) & overlaping0 # [bs, nl0, nl1, npts]
|
||||
del perp_dists0, overlaping0
|
||||
close_points0 = close_points0 * visible1.reshape(b_size, 1, n_lines1, npts)
|
||||
|
||||
perp_dists1, overlaping1 = torch_perp_dist(lines1, pts0_1)
|
||||
close_points1 = (perp_dists1 < dist_th) & overlaping1 # [bs, nl1, nl0, npts]
|
||||
del perp_dists1, overlaping1
|
||||
close_points1 = close_points1 * visible0.reshape(b_size, 1, n_lines0, npts)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# For each segment detected in 0, how many sampled points from
|
||||
# reprojected segments 1 are close
|
||||
num_close_pts0 = close_points0.sum(dim=-1) # [bs, nl0, nl1]
|
||||
|
||||
# num_close_pts0_t = num_close_pts0.transpose(-1, -2)
|
||||
# For each segment detected in 1, how many sampled points from
|
||||
# reprojected segments 0 are close
|
||||
num_close_pts1 = close_points1.sum(dim=-1)
|
||||
num_close_pts1_t = num_close_pts1.transpose(-1, -2) # [bs, nl1, nl0]
|
||||
num_close_pts = num_close_pts0 * num_close_pts1_t
|
||||
mask_close = (
|
||||
num_close_pts1_t
|
||||
> visible0.reshape(b_size, n_lines0, npts).float().sum(-1)[:, :, None]
|
||||
* overlap_th
|
||||
) & (
|
||||
num_close_pts0
|
||||
> visible1.reshape(b_size, n_lines1, npts).float().sum(-1)[:, None] * overlap_th
|
||||
)
|
||||
# mask_close = (num_close_pts1_t > npts * overlap_th) & (
|
||||
# num_close_pts0 > npts * overlap_th)
|
||||
|
||||
# Define the unmatched lines
|
||||
unmatched0 = torch.all(~mask_close, dim=2) | out_of1
|
||||
unmatched1 = torch.all(~mask_close, dim=1) | out_of0
|
||||
|
||||
# Define the lines to ignore
|
||||
ignore0 = (
|
||||
valid0_pts0.reshape(b_size, n_lines0, npts).float().mean(dim=-1)
|
||||
< min_visibility_th
|
||||
) | ~valid_lines0
|
||||
ignore1 = (
|
||||
valid1_pts1.reshape(b_size, n_lines1, npts).float().mean(dim=-1)
|
||||
< min_visibility_th
|
||||
) | ~valid_lines1
|
||||
|
||||
cost = -num_close_pts.clone()
|
||||
# High score for unmatched and non-valid lines
|
||||
cost[unmatched0] = 1e6
|
||||
cost[ignore0] = 1e6
|
||||
# TODO: Is it reasonable to forbid the matching with a segment because it
|
||||
# has not GT depth?
|
||||
cost = cost.transpose(1, 2)
|
||||
cost[unmatched1] = 1e6
|
||||
cost[ignore1] = 1e6
|
||||
cost = cost.transpose(1, 2)
|
||||
|
||||
# For each row, returns the col of max number of points
|
||||
assignation = np.array(
|
||||
[linear_sum_assignment(C) for C in cost.detach().cpu().numpy()]
|
||||
)
|
||||
assignation = torch.tensor(assignation).to(num_close_pts)
|
||||
# Set ignore and unmatched labels
|
||||
unmatched = assignation.new_tensor(UNMATCHED_FEATURE)
|
||||
ignore = assignation.new_tensor(IGNORE_FEATURE)
|
||||
|
||||
positive = num_close_pts.new_zeros(num_close_pts.shape, dtype=torch.bool)
|
||||
all_in_batch = (
|
||||
torch.arange(b_size)[:, None].repeat(1, assignation.shape[-1]).flatten()
|
||||
)
|
||||
positive[
|
||||
all_in_batch, assignation[:, 0].flatten(), assignation[:, 1].flatten()
|
||||
] = True
|
||||
|
||||
m0 = assignation.new_full((b_size, n_lines0), unmatched, dtype=torch.long)
|
||||
m0.scatter_(-1, assignation[:, 0], assignation[:, 1])
|
||||
m1 = assignation.new_full((b_size, n_lines1), unmatched, dtype=torch.long)
|
||||
m1.scatter_(-1, assignation[:, 1], assignation[:, 0])
|
||||
|
||||
positive = positive & mask_close
|
||||
# Remove values to be ignored or unmatched
|
||||
positive[unmatched0] = False
|
||||
positive[ignore0] = False
|
||||
positive = positive.transpose(1, 2)
|
||||
positive[unmatched1] = False
|
||||
positive[ignore1] = False
|
||||
positive = positive.transpose(1, 2)
|
||||
m0[~positive.any(-1)] = unmatched
|
||||
m0[unmatched0] = unmatched
|
||||
m0[ignore0] = ignore
|
||||
m1[~positive.any(-2)] = unmatched
|
||||
m1[unmatched1] = unmatched
|
||||
m1[ignore1] = ignore
|
||||
|
||||
if num_close_pts.numel() == 0:
|
||||
no_matches = torch.zeros(positive.shape[0], 0).to(positive)
|
||||
return positive, no_matches, no_matches
|
||||
|
||||
return positive, m0, m1
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gt_line_matches_from_homography(
|
||||
pred_lines0,
|
||||
pred_lines1,
|
||||
valid_lines0,
|
||||
valid_lines1,
|
||||
shape0,
|
||||
shape1,
|
||||
H,
|
||||
npts=50,
|
||||
dist_th=5,
|
||||
overlap_th=0.2,
|
||||
min_visibility_th=0.2,
|
||||
):
|
||||
"""Compute ground truth line matches and label the remaining the lines as:
|
||||
- UNMATCHED: if reprojection is outside the image or far away from any other line.
|
||||
- IGNORE: if a line is labeled as invalid."""
|
||||
h0, w0 = shape0[-2:]
|
||||
h1, w1 = shape1[-2:]
|
||||
lines0 = pred_lines0.clone()
|
||||
lines1 = pred_lines1.clone()
|
||||
if lines0.shape[-2:] == (2, 2):
|
||||
lines0 = torch.flatten(lines0, -2)
|
||||
elif lines0.dim() == 4:
|
||||
lines0 = torch.cat([lines0[:, :, 0], lines0[:, :, -1]], dim=2)
|
||||
if lines1.shape[-2:] == (2, 2):
|
||||
lines1 = torch.flatten(lines1, -2)
|
||||
elif lines1.dim() == 4:
|
||||
lines1 = torch.cat([lines1[:, :, 0], lines1[:, :, -1]], dim=2)
|
||||
b_size, n_lines0, _ = lines0.shape
|
||||
b_size, n_lines1, _ = lines1.shape
|
||||
|
||||
lines0 = torch.min(
|
||||
torch.max(lines0, torch.zeros_like(lines0)),
|
||||
lines0.new_tensor([w0 - 1, h0 - 1, w0 - 1, h0 - 1], dtype=torch.float),
|
||||
)
|
||||
lines1 = torch.min(
|
||||
torch.max(lines1, torch.zeros_like(lines1)),
|
||||
lines1.new_tensor([w1 - 1, h1 - 1, w1 - 1, h1 - 1], dtype=torch.float),
|
||||
)
|
||||
|
||||
# Sample points along each line
|
||||
pts0 = sample_pts(lines0, npts).reshape(b_size, n_lines0 * npts, 2)
|
||||
pts1 = sample_pts(lines1, npts).reshape(b_size, n_lines1 * npts, 2)
|
||||
|
||||
# Project the points to the other image
|
||||
pts0_1 = warp_points_torch(pts0, H, inverse=False)
|
||||
pts1_0 = warp_points_torch(pts1, H, inverse=True)
|
||||
pts0_1 = pts0_1.reshape(b_size, n_lines0, npts, 2)
|
||||
pts1_0 = pts1_0.reshape(b_size, n_lines1, npts, 2)
|
||||
|
||||
# If a line has less than min_visibility_th inside the image is considered OUTSIDE
|
||||
pts_out_of0 = (pts1_0 < 0).any(-1) | (
|
||||
pts1_0 >= torch.tensor([w0, h0]).to(pts1_0)
|
||||
).any(-1)
|
||||
pts_out_of0 = pts_out_of0.reshape(b_size, n_lines1, npts).float()
|
||||
out_of0 = pts_out_of0.mean(dim=-1) >= (1 - min_visibility_th)
|
||||
pts_out_of1 = (pts0_1 < 0).any(-1) | (
|
||||
pts0_1 >= torch.tensor([w1, h1]).to(pts0_1)
|
||||
).any(-1)
|
||||
pts_out_of1 = pts_out_of1.reshape(b_size, n_lines0, npts).float()
|
||||
out_of1 = pts_out_of1.mean(dim=-1) >= (1 - min_visibility_th)
|
||||
|
||||
perp_dists0, overlaping0 = torch_perp_dist(lines0, pts1_0)
|
||||
close_points0 = (perp_dists0 < dist_th) & overlaping0 # [bs, nl0, nl1, npts]
|
||||
del perp_dists0, overlaping0
|
||||
|
||||
perp_dists1, overlaping1 = torch_perp_dist(lines1, pts0_1)
|
||||
close_points1 = (perp_dists1 < dist_th) & overlaping1 # [bs, nl1, nl0, npts]
|
||||
del perp_dists1, overlaping1
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# For each segment detected in 0,
|
||||
# how many sampled points from reprojected segments 1 are close
|
||||
num_close_pts0 = close_points0.sum(dim=-1) # [bs, nl0, nl1]
|
||||
# num_close_pts0_t = num_close_pts0.transpose(-1, -2)
|
||||
# For each segment detected in 1,
|
||||
# how many sampled points from reprojected segments 0 are close
|
||||
num_close_pts1 = close_points1.sum(dim=-1)
|
||||
num_close_pts1_t = num_close_pts1.transpose(-1, -2) # [bs, nl1, nl0]
|
||||
|
||||
num_close_pts = num_close_pts0 * num_close_pts1_t
|
||||
mask_close = (
|
||||
(num_close_pts1_t > npts * overlap_th)
|
||||
& (num_close_pts0 > npts * overlap_th)
|
||||
& ~out_of0.unsqueeze(1)
|
||||
& ~out_of1.unsqueeze(-1)
|
||||
)
|
||||
|
||||
# Define the unmatched lines
|
||||
unmatched0 = torch.all(~mask_close, dim=2) | out_of1
|
||||
unmatched1 = torch.all(~mask_close, dim=1) | out_of0
|
||||
|
||||
# Define the lines to ignore
|
||||
ignore0 = ~valid_lines0
|
||||
ignore1 = ~valid_lines1
|
||||
|
||||
cost = -num_close_pts.clone()
|
||||
# High score for unmatched and non-valid lines
|
||||
cost[unmatched0] = 1e6
|
||||
cost[ignore0] = 1e6
|
||||
cost = cost.transpose(1, 2)
|
||||
cost[unmatched1] = 1e6
|
||||
cost[ignore1] = 1e6
|
||||
cost = cost.transpose(1, 2)
|
||||
# For each row, returns the col of max number of points
|
||||
assignation = np.array(
|
||||
[linear_sum_assignment(C) for C in cost.detach().cpu().numpy()]
|
||||
)
|
||||
assignation = torch.tensor(assignation).to(num_close_pts)
|
||||
|
||||
# Set unmatched labels
|
||||
unmatched = assignation.new_tensor(UNMATCHED_FEATURE)
|
||||
ignore = assignation.new_tensor(IGNORE_FEATURE)
|
||||
|
||||
positive = num_close_pts.new_zeros(num_close_pts.shape, dtype=torch.bool)
|
||||
# TODO Do with a single and beautiful call
|
||||
# for b in range(b_size):
|
||||
# positive[b][assignation[b, 0], assignation[b, 1]] = True
|
||||
positive[
|
||||
torch.arange(b_size)[:, None].repeat(1, assignation.shape[-1]).flatten(),
|
||||
assignation[:, 0].flatten(),
|
||||
assignation[:, 1].flatten(),
|
||||
] = True
|
||||
|
||||
m0 = assignation.new_full((b_size, n_lines0), unmatched, dtype=torch.long)
|
||||
m0.scatter_(-1, assignation[:, 0], assignation[:, 1])
|
||||
m1 = assignation.new_full((b_size, n_lines1), unmatched, dtype=torch.long)
|
||||
m1.scatter_(-1, assignation[:, 1], assignation[:, 0])
|
||||
|
||||
positive = positive & mask_close
|
||||
# Remove values to be ignored or unmatched
|
||||
positive[unmatched0] = False
|
||||
positive[ignore0] = False
|
||||
positive = positive.transpose(1, 2)
|
||||
positive[unmatched1] = False
|
||||
positive[ignore1] = False
|
||||
positive = positive.transpose(1, 2)
|
||||
m0[~positive.any(-1)] = unmatched
|
||||
m0[unmatched0] = unmatched
|
||||
m0[ignore0] = ignore
|
||||
m1[~positive.any(-2)] = unmatched
|
||||
m1[unmatched1] = unmatched
|
||||
m1[ignore1] = ignore
|
||||
|
||||
if num_close_pts.numel() == 0:
|
||||
no_matches = torch.zeros(positive.shape[0], 0).to(positive)
|
||||
return positive, no_matches, no_matches
|
||||
|
||||
return positive, m0, m1
|
|
@ -0,0 +1,340 @@
|
|||
from typing import Tuple
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .utils import to_homogeneous, from_homogeneous
|
||||
|
||||
|
||||
def flat2mat(H):
|
||||
return np.reshape(np.concatenate([H, np.ones_like(H[:, :1])], axis=1), [3, 3])
|
||||
|
||||
|
||||
# Homography creation
|
||||
|
||||
|
||||
def create_center_patch(shape, patch_shape=None):
|
||||
if patch_shape is None:
|
||||
patch_shape = shape
|
||||
width, height = shape
|
||||
pwidth, pheight = patch_shape
|
||||
left = int((width - pwidth) / 2)
|
||||
bottom = int((height - pheight) / 2)
|
||||
right = int((width + pwidth) / 2)
|
||||
top = int((height + pheight) / 2)
|
||||
return np.array([[left, bottom], [left, top], [right, top], [right, bottom]])
|
||||
|
||||
|
||||
def check_convex(patch, min_convexity=0.05):
|
||||
"""Checks if given polygon vertices [N,2] form a convex shape"""
|
||||
for i in range(patch.shape[0]):
|
||||
x1, y1 = patch[(i - 1) % patch.shape[0]]
|
||||
x2, y2 = patch[i]
|
||||
x3, y3 = patch[(i + 1) % patch.shape[0]]
|
||||
if (x2 - x1) * (y3 - y2) - (x3 - x2) * (y2 - y1) > -min_convexity:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def sample_homography_corners(
|
||||
shape,
|
||||
patch_shape,
|
||||
difficulty=1.0,
|
||||
translation=0.4,
|
||||
n_angles=10,
|
||||
max_angle=90,
|
||||
min_convexity=0.05,
|
||||
rng=np.random,
|
||||
):
|
||||
max_angle = max_angle / 180.0 * math.pi
|
||||
width, height = shape
|
||||
pwidth, pheight = width * (1 - difficulty), height * (1 - difficulty)
|
||||
min_pts1 = create_center_patch(shape, (pwidth, pheight))
|
||||
full = create_center_patch(shape)
|
||||
pts2 = create_center_patch(patch_shape)
|
||||
scale = min_pts1 - full
|
||||
found_valid = False
|
||||
cnt = -1
|
||||
while not found_valid:
|
||||
offsets = rng.uniform(0.0, 1.0, size=(4, 2)) * scale
|
||||
pts1 = full + offsets
|
||||
found_valid = check_convex(pts1 / np.array(shape), min_convexity)
|
||||
cnt += 1
|
||||
|
||||
# re-center
|
||||
pts1 = pts1 - np.mean(pts1, axis=0, keepdims=True)
|
||||
pts1 = pts1 + np.mean(min_pts1, axis=0, keepdims=True)
|
||||
|
||||
# Rotation
|
||||
if n_angles > 0 and difficulty > 0:
|
||||
angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles)
|
||||
rng.shuffle(angles)
|
||||
rng.shuffle(angles)
|
||||
angles = np.concatenate([[0.0], angles], axis=0)
|
||||
|
||||
center = np.mean(pts1, axis=0, keepdims=True)
|
||||
rot_mat = np.reshape(
|
||||
np.stack(
|
||||
[np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)],
|
||||
axis=1,
|
||||
),
|
||||
[-1, 2, 2],
|
||||
)
|
||||
rotated = (
|
||||
np.matmul(
|
||||
np.tile(np.expand_dims(pts1 - center, axis=0), [n_angles + 1, 1, 1]),
|
||||
rot_mat,
|
||||
)
|
||||
+ center
|
||||
)
|
||||
|
||||
for idx in range(1, n_angles):
|
||||
warped_points = rotated[idx] / np.array(shape)
|
||||
if np.all((warped_points >= 0.0) & (warped_points < 1.0)):
|
||||
pts1 = rotated[idx]
|
||||
break
|
||||
|
||||
# Translation
|
||||
if translation > 0:
|
||||
min_trans = -np.min(pts1, axis=0)
|
||||
max_trans = shape - np.max(pts1, axis=0)
|
||||
trans = rng.uniform(min_trans, max_trans)[None]
|
||||
pts1 += trans * translation * difficulty
|
||||
|
||||
H = compute_homography(pts1, pts2, [1.0, 1.0])
|
||||
warped = warp_points(full, H, inverse=False)
|
||||
return H, full, warped, patch_shape
|
||||
|
||||
|
||||
def compute_homography(pts1_, pts2_, shape):
|
||||
"""Compute the homography matrix from 4 point correspondences"""
|
||||
# Rescale to actual size
|
||||
shape = np.array(shape[::-1], dtype=np.float32) # different convention [y, x]
|
||||
pts1 = pts1_ * np.expand_dims(shape, axis=0)
|
||||
pts2 = pts2_ * np.expand_dims(shape, axis=0)
|
||||
|
||||
def ax(p, q):
|
||||
return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
|
||||
|
||||
def ay(p, q):
|
||||
return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
|
||||
|
||||
a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0)
|
||||
p_mat = np.transpose(
|
||||
np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0)
|
||||
)
|
||||
homography = np.transpose(np.linalg.solve(a_mat, p_mat))
|
||||
return flat2mat(homography)
|
||||
|
||||
|
||||
# Point warping utils
|
||||
|
||||
|
||||
def warp_points(points, homography, inverse=True):
|
||||
"""
|
||||
Warp a list of points with the INVERSE of the given homography.
|
||||
The inverse is used to be coherent with tf.contrib.image.transform
|
||||
Arguments:
|
||||
points: list of N points, shape (N, 2).
|
||||
homography: batched or not (shapes (B, 3, 3) and (3, 3) respectively).
|
||||
Returns: a Tensor of shape (N, 2) or (B, N, 2) (depending on whether the homography
|
||||
is batched) containing the new coordinates of the warped points.
|
||||
"""
|
||||
H = homography[None] if len(homography.shape) == 2 else homography
|
||||
|
||||
# Get the points to the homogeneous format
|
||||
num_points = points.shape[0]
|
||||
# points = points.astype(np.float32)[:, ::-1]
|
||||
points = np.concatenate([points, np.ones([num_points, 1], dtype=np.float32)], -1)
|
||||
|
||||
H_inv = np.transpose(np.linalg.inv(H) if inverse else H)
|
||||
warped_points = np.tensordot(points, H_inv, axes=[[1], [0]])
|
||||
|
||||
warped_points = np.transpose(warped_points, [2, 0, 1])
|
||||
warped_points[np.abs(warped_points[:, :, 2]) < 1e-8, 2] = 1e-8
|
||||
warped_points = warped_points[:, :, :2] / warped_points[:, :, 2:]
|
||||
|
||||
return warped_points[0] if len(homography.shape) == 2 else warped_points
|
||||
|
||||
|
||||
def warp_points_torch(points, H, inverse=True):
|
||||
"""
|
||||
Warp a list of points with the INVERSE of the given homography.
|
||||
The inverse is used to be coherent with tf.contrib.image.transform
|
||||
Arguments:
|
||||
points: batched list of N points, shape (B, N, 2).
|
||||
homography: batched or not (shapes (B, 3, 3) and (3, 3) respectively).
|
||||
Returns: a Tensor of shape (B, N, 2) containing the new coordinates of the warps.
|
||||
"""
|
||||
|
||||
# Get the points to the homogeneous format
|
||||
points = to_homogeneous(points)
|
||||
|
||||
# Apply the homography
|
||||
H_mat = (torch.inverse(H) if inverse else H).transpose(-2, -1)
|
||||
warped_points = torch.einsum("...nj,...ji->...ni", points, H_mat)
|
||||
|
||||
warped_points = from_homogeneous(warped_points, eps=1e-5)
|
||||
return warped_points
|
||||
|
||||
|
||||
# Line warping utils
|
||||
|
||||
|
||||
def seg_equation(segs):
|
||||
# calculate list of start, end and midpoints points from both lists
|
||||
start_points, end_points = to_homogeneous(segs[..., 0, :]), to_homogeneous(
|
||||
segs[..., 1, :]
|
||||
)
|
||||
# Compute the line equations as ax + by + c = 0 , where x^2 + y^2 = 1
|
||||
lines = torch.cross(start_points, end_points, dim=-1)
|
||||
lines_norm = torch.sqrt(lines[..., 0] ** 2 + lines[..., 1] ** 2)[..., None]
|
||||
assert torch.all(
|
||||
lines_norm > 0
|
||||
), "Error: trying to compute the equation of a line with a single point"
|
||||
lines = lines / lines_norm
|
||||
return lines
|
||||
|
||||
|
||||
def is_inside_img(pts: torch.Tensor, img_shape: Tuple[int, int]):
|
||||
h, w = img_shape
|
||||
return (
|
||||
(pts >= 0).all(dim=-1)
|
||||
& (pts[..., 0] < w)
|
||||
& (pts[..., 1] < h)
|
||||
& (~torch.isinf(pts).any(dim=-1))
|
||||
)
|
||||
|
||||
|
||||
def shrink_segs_to_img(segs: torch.Tensor, img_shape: Tuple[int, int]) -> torch.Tensor:
|
||||
"""
|
||||
Shrink an array of segments to fit inside the image.
|
||||
:param segs: The tensor of segments with shape (N, 2, 2)
|
||||
:param img_shape: The image shape in format (H, W)
|
||||
"""
|
||||
EPS = 1e-4
|
||||
device = segs.device
|
||||
w, h = img_shape[1], img_shape[0]
|
||||
# Project the segments to the reference image
|
||||
segs = segs.clone()
|
||||
eqs = seg_equation(segs)
|
||||
x0, y0 = torch.tensor([1.0, 0, 0.0], device=device), torch.tensor(
|
||||
[0.0, 1, 0], device=device
|
||||
)
|
||||
x0 = x0.repeat(eqs.shape[:-1] + (1,))
|
||||
y0 = y0.repeat(eqs.shape[:-1] + (1,))
|
||||
pt_x0s = torch.cross(eqs, x0, dim=-1)
|
||||
pt_x0s = pt_x0s[..., :-1] / pt_x0s[..., None, -1]
|
||||
pt_x0s_valid = is_inside_img(pt_x0s, img_shape)
|
||||
pt_y0s = torch.cross(eqs, y0, dim=-1)
|
||||
pt_y0s = pt_y0s[..., :-1] / pt_y0s[..., None, -1]
|
||||
pt_y0s_valid = is_inside_img(pt_y0s, img_shape)
|
||||
|
||||
xW = torch.tensor([1.0, 0, EPS - w], device=device)
|
||||
yH = torch.tensor([0.0, 1, EPS - h], device=device)
|
||||
xW = xW.repeat(eqs.shape[:-1] + (1,))
|
||||
yH = yH.repeat(eqs.shape[:-1] + (1,))
|
||||
pt_xWs = torch.cross(eqs, xW, dim=-1)
|
||||
pt_xWs = pt_xWs[..., :-1] / pt_xWs[..., None, -1]
|
||||
pt_xWs_valid = is_inside_img(pt_xWs, img_shape)
|
||||
pt_yHs = torch.cross(eqs, yH, dim=-1)
|
||||
pt_yHs = pt_yHs[..., :-1] / pt_yHs[..., None, -1]
|
||||
pt_yHs_valid = is_inside_img(pt_yHs, img_shape)
|
||||
|
||||
# If the X coordinate of the first endpoint is out
|
||||
mask = (segs[..., 0, 0] < 0) & pt_x0s_valid
|
||||
segs[mask, 0, :] = pt_x0s[mask]
|
||||
mask = (segs[..., 0, 0] > (w - 1)) & pt_xWs_valid
|
||||
segs[mask, 0, :] = pt_xWs[mask]
|
||||
# If the X coordinate of the second endpoint is out
|
||||
mask = (segs[..., 1, 0] < 0) & pt_x0s_valid
|
||||
segs[mask, 1, :] = pt_x0s[mask]
|
||||
mask = (segs[:, 1, 0] > (w - 1)) & pt_xWs_valid
|
||||
segs[mask, 1, :] = pt_xWs[mask]
|
||||
# If the Y coordinate of the first endpoint is out
|
||||
mask = (segs[..., 0, 1] < 0) & pt_y0s_valid
|
||||
segs[mask, 0, :] = pt_y0s[mask]
|
||||
mask = (segs[..., 0, 1] > (h - 1)) & pt_yHs_valid
|
||||
segs[mask, 0, :] = pt_yHs[mask]
|
||||
# If the Y coordinate of the second endpoint is out
|
||||
mask = (segs[..., 1, 1] < 0) & pt_y0s_valid
|
||||
segs[mask, 1, :] = pt_y0s[mask]
|
||||
mask = (segs[..., 1, 1] > (h - 1)) & pt_yHs_valid
|
||||
segs[mask, 1, :] = pt_yHs[mask]
|
||||
|
||||
assert (
|
||||
torch.all(segs >= 0)
|
||||
and torch.all(segs[..., 0] < w)
|
||||
and torch.all(segs[..., 1] < h)
|
||||
)
|
||||
return segs
|
||||
|
||||
|
||||
def warp_lines_torch(
|
||||
lines, H, inverse=True, dst_shape: Tuple[int, int] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
:param lines: A tensor of shape (B, N, 2, 2)
|
||||
where B is the batch size, N the number of lines.
|
||||
:param H: The homography used to convert the lines.
|
||||
batched or not (shapes (B, 3, 3) and (3, 3) respectively).
|
||||
:param inverse: Whether to apply H or the inverse of H
|
||||
:param dst_shape:If provided, lines are trimmed to be inside the image
|
||||
"""
|
||||
device = lines.device
|
||||
batch_size = len(lines)
|
||||
lines = warp_points_torch(lines.reshape(batch_size, -1, 2), H, inverse).reshape(
|
||||
lines.shape
|
||||
)
|
||||
|
||||
if dst_shape is None:
|
||||
return lines, torch.ones(lines.shape[:-2], dtype=torch.bool, device=device)
|
||||
|
||||
out_img = torch.any(
|
||||
(lines < 0) | (lines >= torch.tensor(dst_shape[::-1], device=device)), -1
|
||||
)
|
||||
valid = ~out_img.all(-1)
|
||||
any_out_of_img = out_img.any(-1)
|
||||
lines_to_trim = valid & any_out_of_img
|
||||
|
||||
for b in range(batch_size):
|
||||
lines_to_trim_mask_b = lines_to_trim[b]
|
||||
lines_to_trim_b = lines[b][lines_to_trim_mask_b]
|
||||
corrected_lines = shrink_segs_to_img(lines_to_trim_b, dst_shape)
|
||||
lines[b][lines_to_trim_mask_b] = corrected_lines
|
||||
|
||||
return lines, valid
|
||||
|
||||
|
||||
# Homography evaluation utils
|
||||
|
||||
|
||||
def sym_homography_error(kpts0, kpts1, T_0to1):
|
||||
kpts0_1 = from_homogeneous(to_homogeneous(kpts0) @ T_0to1.transpose(-1, -2))
|
||||
dist0_1 = ((kpts0_1 - kpts1) ** 2).sum(-1).sqrt()
|
||||
|
||||
kpts1_0 = from_homogeneous(
|
||||
to_homogeneous(kpts1) @ torch.pinverse(T_0to1.transpose(-1, -2))
|
||||
)
|
||||
dist1_0 = ((kpts1_0 - kpts0) ** 2).sum(-1).sqrt()
|
||||
|
||||
return (dist0_1 + dist1_0) / 2.0
|
||||
|
||||
|
||||
def sym_homography_error_all(kpts0, kpts1, H):
|
||||
kp0_1 = warp_points_torch(kpts0, H, inverse=False)
|
||||
kp1_0 = warp_points_torch(kpts1, H, inverse=True)
|
||||
|
||||
# build a distance matrix of size [... x M x N]
|
||||
dist0 = torch.sum((kp0_1.unsqueeze(-2) - kpts1.unsqueeze(-3)) ** 2, -1).sqrt()
|
||||
dist1 = torch.sum((kpts0.unsqueeze(-2) - kp1_0.unsqueeze(-3)) ** 2, -1).sqrt()
|
||||
return (dist0 + dist1) / 2.0
|
||||
|
||||
|
||||
def homography_corner_error(T, T_gt, image_size):
|
||||
W, H = image_size[:, 0], image_size[:, 1]
|
||||
corners0 = torch.Tensor([[0, 0], [W, 0], [W, H], [0, H]]).float().to(T)
|
||||
corners1_gt = from_homogeneous(to_homogeneous(corners0) @ T_gt.transpose(-1, -2))
|
||||
corners1 = from_homogeneous(to_homogeneous(corners0) @ T.transpose(-1, -2))
|
||||
d = torch.sqrt(((corners1 - corners1_gt) ** 2).sum(-1))
|
||||
return d.mean(-1)
|
|
@ -0,0 +1,166 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def to_homogeneous(points):
|
||||
"""Convert N-dimensional points to homogeneous coordinates.
|
||||
Args:
|
||||
points: torch.Tensor or numpy.ndarray with size (..., N).
|
||||
Returns:
|
||||
A torch.Tensor or numpy.ndarray with size (..., N+1).
|
||||
"""
|
||||
if isinstance(points, torch.Tensor):
|
||||
pad = points.new_ones(points.shape[:-1] + (1,))
|
||||
return torch.cat([points, pad], dim=-1)
|
||||
elif isinstance(points, np.ndarray):
|
||||
pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype)
|
||||
return np.concatenate([points, pad], axis=-1)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def from_homogeneous(points, eps=0.0):
|
||||
"""Remove the homogeneous dimension of N-dimensional points.
|
||||
Args:
|
||||
points: torch.Tensor or numpy.ndarray with size (..., N+1).
|
||||
Returns:
|
||||
A torch.Tensor or numpy ndarray with size (..., N).
|
||||
"""
|
||||
return points[..., :-1] / (points[..., -1:] + eps)
|
||||
|
||||
|
||||
def batched_eye_like(x: torch.Tensor, n: int):
|
||||
"""Create a batch of identity matrices.
|
||||
Args:
|
||||
x: a reference torch.Tensor whose batch dimension will be copied.
|
||||
n: the size of each identity matrix.
|
||||
Returns:
|
||||
A torch.Tensor of size (B, n, n), with same dtype and device as x.
|
||||
"""
|
||||
return torch.eye(n).to(x)[None].repeat(len(x), 1, 1)
|
||||
|
||||
|
||||
def skew_symmetric(v):
|
||||
"""Create a skew-symmetric matrix from a (batched) vector of size (..., 3)."""
|
||||
z = torch.zeros_like(v[..., 0])
|
||||
M = torch.stack(
|
||||
[
|
||||
z,
|
||||
-v[..., 2],
|
||||
v[..., 1],
|
||||
v[..., 2],
|
||||
z,
|
||||
-v[..., 0],
|
||||
-v[..., 1],
|
||||
v[..., 0],
|
||||
z,
|
||||
],
|
||||
dim=-1,
|
||||
).reshape(v.shape[:-1] + (3, 3))
|
||||
return M
|
||||
|
||||
|
||||
def transform_points(T, points):
|
||||
return from_homogeneous(to_homogeneous(points) @ T.transpose(-1, -2))
|
||||
|
||||
|
||||
def is_inside(pts, shape):
|
||||
return (pts > 0).all(-1) & (pts < shape[:, None]).all(-1)
|
||||
|
||||
|
||||
def so3exp_map(w, eps: float = 1e-7):
|
||||
"""Compute rotation matrices from batched twists.
|
||||
Args:
|
||||
w: batched 3D axis-angle vectors of size (..., 3).
|
||||
Returns:
|
||||
A batch of rotation matrices of size (..., 3, 3).
|
||||
"""
|
||||
theta = w.norm(p=2, dim=-1, keepdim=True)
|
||||
small = theta < eps
|
||||
div = torch.where(small, torch.ones_like(theta), theta)
|
||||
W = skew_symmetric(w / div)
|
||||
theta = theta[..., None] # ... x 1 x 1
|
||||
res = W * torch.sin(theta) + (W @ W) * (1 - torch.cos(theta))
|
||||
res = torch.where(small[..., None], W, res) # first-order Taylor approx
|
||||
return torch.eye(3).to(W) + res
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def distort_points(pts, dist):
|
||||
"""Distort normalized 2D coordinates
|
||||
and check for validity of the distortion model.
|
||||
"""
|
||||
dist = dist.unsqueeze(-2) # add point dimension
|
||||
ndist = dist.shape[-1]
|
||||
undist = pts
|
||||
valid = torch.ones(pts.shape[:-1], device=pts.device, dtype=torch.bool)
|
||||
if ndist > 0:
|
||||
k1, k2 = dist[..., :2].split(1, -1)
|
||||
r2 = torch.sum(pts**2, -1, keepdim=True)
|
||||
radial = k1 * r2 + k2 * r2**2
|
||||
undist = undist + pts * radial
|
||||
|
||||
# The distortion model is supposedly only valid within the image
|
||||
# boundaries. Because of the negative radial distortion, points that
|
||||
# are far outside of the boundaries might actually be mapped back
|
||||
# within the image. To account for this, we discard points that are
|
||||
# beyond the inflection point of the distortion model,
|
||||
# e.g. such that d(r + k_1 r^3 + k2 r^5)/dr = 0
|
||||
limited = ((k2 > 0) & ((9 * k1**2 - 20 * k2) > 0)) | ((k2 <= 0) & (k1 > 0))
|
||||
limit = torch.abs(
|
||||
torch.where(
|
||||
k2 > 0,
|
||||
(torch.sqrt(9 * k1**2 - 20 * k2) - 3 * k1) / (10 * k2),
|
||||
1 / (3 * k1),
|
||||
)
|
||||
)
|
||||
valid = valid & torch.squeeze(~limited | (r2 < limit), -1)
|
||||
|
||||
if ndist > 2:
|
||||
p12 = dist[..., 2:]
|
||||
p21 = p12.flip(-1)
|
||||
uv = torch.prod(pts, -1, keepdim=True)
|
||||
undist = undist + 2 * p12 * uv + p21 * (r2 + 2 * pts**2)
|
||||
# TODO: handle tangential boundaries
|
||||
|
||||
return undist, valid
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def J_distort_points(pts, dist):
|
||||
dist = dist.unsqueeze(-2) # add point dimension
|
||||
ndist = dist.shape[-1]
|
||||
|
||||
J_diag = torch.ones_like(pts)
|
||||
J_cross = torch.zeros_like(pts)
|
||||
if ndist > 0:
|
||||
k1, k2 = dist[..., :2].split(1, -1)
|
||||
r2 = torch.sum(pts**2, -1, keepdim=True)
|
||||
uv = torch.prod(pts, -1, keepdim=True)
|
||||
radial = k1 * r2 + k2 * r2**2
|
||||
d_radial = 2 * k1 + 4 * k2 * r2
|
||||
J_diag += radial + (pts**2) * d_radial
|
||||
J_cross += uv * d_radial
|
||||
|
||||
if ndist > 2:
|
||||
p12 = dist[..., 2:]
|
||||
p21 = p12.flip(-1)
|
||||
J_diag += 2 * p12 * pts.flip(-1) + 6 * p21 * pts
|
||||
J_cross += 2 * p12 * pts + 2 * p21 * pts.flip(-1)
|
||||
|
||||
J = torch.diag_embed(J_diag) + torch.diag_embed(J_cross).flip(-1)
|
||||
return J
|
||||
|
||||
|
||||
def get_image_coords(img):
|
||||
h, w = img.shape[-2:]
|
||||
return (
|
||||
torch.stack(
|
||||
torch.meshgrid(
|
||||
torch.arange(h, dtype=torch.float32, device=img.device),
|
||||
torch.arange(w, dtype=torch.float32, device=img.device),
|
||||
indexing="ij",
|
||||
)[::-1],
|
||||
dim=0,
|
||||
).permute(1, 2, 0)
|
||||
)[None] + 0.5
|
|
@ -0,0 +1,424 @@
|
|||
"""
|
||||
Convenience classes for an SE3 pose and a pinhole Camera with lens distortion.
|
||||
Based on PyTorch tensors: differentiable, batched, with GPU support.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import math
|
||||
from typing import Union, Tuple, List, Dict, NamedTuple, Optional
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from .utils import (
|
||||
distort_points,
|
||||
J_distort_points,
|
||||
skew_symmetric,
|
||||
so3exp_map,
|
||||
to_homogeneous,
|
||||
)
|
||||
|
||||
|
||||
def autocast(func):
|
||||
"""Cast the inputs of a TensorWrapper method to PyTorch tensors
|
||||
if they are numpy arrays. Use the device and dtype of the wrapper.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrap(self, *args):
|
||||
device = torch.device("cpu")
|
||||
dtype = None
|
||||
if isinstance(self, TensorWrapper):
|
||||
if self._data is not None:
|
||||
device = self.device
|
||||
dtype = self.dtype
|
||||
elif not inspect.isclass(self) or not issubclass(self, TensorWrapper):
|
||||
raise ValueError(self)
|
||||
|
||||
cast_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, np.ndarray):
|
||||
arg = torch.from_numpy(arg)
|
||||
arg = arg.to(device=device, dtype=dtype)
|
||||
cast_args.append(arg)
|
||||
return func(self, *cast_args)
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
_data = None
|
||||
|
||||
@autocast
|
||||
def __init__(self, data: torch.Tensor):
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._data.shape[:-1]
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._data.device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._data.dtype
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.__class__(self._data[index])
|
||||
|
||||
def __setitem__(self, index, item):
|
||||
self._data[index] = item.data
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return self.__class__(self._data.to(*args, **kwargs))
|
||||
|
||||
def cpu(self):
|
||||
return self.__class__(self._data.cpu())
|
||||
|
||||
def cuda(self):
|
||||
return self.__class__(self._data.cuda())
|
||||
|
||||
def pin_memory(self):
|
||||
return self.__class__(self._data.pin_memory())
|
||||
|
||||
def float(self):
|
||||
return self.__class__(self._data.float())
|
||||
|
||||
def double(self):
|
||||
return self.__class__(self._data.double())
|
||||
|
||||
def detach(self):
|
||||
return self.__class__(self._data.detach())
|
||||
|
||||
@classmethod
|
||||
def stack(cls, objects: List, dim=0, *, out=None):
|
||||
data = torch.stack([obj._data for obj in objects], dim=dim, out=out)
|
||||
return cls(data)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func is torch.stack:
|
||||
return self.stack(*args, **kwargs)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class Pose(TensorWrapper):
|
||||
def __init__(self, data: torch.Tensor):
|
||||
assert data.shape[-1] == 12
|
||||
super().__init__(data)
|
||||
|
||||
@classmethod
|
||||
@autocast
|
||||
def from_Rt(cls, R: torch.Tensor, t: torch.Tensor):
|
||||
"""Pose from a rotation matrix and translation vector.
|
||||
Accepts numpy arrays or PyTorch tensors.
|
||||
|
||||
Args:
|
||||
R: rotation matrix with shape (..., 3, 3).
|
||||
t: translation vector with shape (..., 3).
|
||||
"""
|
||||
assert R.shape[-2:] == (3, 3)
|
||||
assert t.shape[-1] == 3
|
||||
assert R.shape[:-2] == t.shape[:-1]
|
||||
data = torch.cat([R.flatten(start_dim=-2), t], -1)
|
||||
return cls(data)
|
||||
|
||||
@classmethod
|
||||
@autocast
|
||||
def from_aa(cls, aa: torch.Tensor, t: torch.Tensor):
|
||||
"""Pose from an axis-angle rotation vector and translation vector.
|
||||
Accepts numpy arrays or PyTorch tensors.
|
||||
|
||||
Args:
|
||||
aa: axis-angle rotation vector with shape (..., 3).
|
||||
t: translation vector with shape (..., 3).
|
||||
"""
|
||||
assert aa.shape[-1] == 3
|
||||
assert t.shape[-1] == 3
|
||||
assert aa.shape[:-1] == t.shape[:-1]
|
||||
return cls.from_Rt(so3exp_map(aa), t)
|
||||
|
||||
@classmethod
|
||||
def from_4x4mat(cls, T: torch.Tensor):
|
||||
"""Pose from an SE(3) transformation matrix.
|
||||
Args:
|
||||
T: transformation matrix with shape (..., 4, 4).
|
||||
"""
|
||||
assert T.shape[-2:] == (4, 4)
|
||||
R, t = T[..., :3, :3], T[..., :3, 3]
|
||||
return cls.from_Rt(R, t)
|
||||
|
||||
@classmethod
|
||||
def from_colmap(cls, image: NamedTuple):
|
||||
"""Pose from a COLMAP Image."""
|
||||
return cls.from_Rt(image.qvec2rotmat(), image.tvec)
|
||||
|
||||
@property
|
||||
def R(self) -> torch.Tensor:
|
||||
"""Underlying rotation matrix with shape (..., 3, 3)."""
|
||||
rvec = self._data[..., :9]
|
||||
return rvec.reshape(rvec.shape[:-1] + (3, 3))
|
||||
|
||||
@property
|
||||
def t(self) -> torch.Tensor:
|
||||
"""Underlying translation vector with shape (..., 3)."""
|
||||
return self._data[..., -3:]
|
||||
|
||||
def inv(self) -> "Pose":
|
||||
"""Invert an SE(3) pose."""
|
||||
R = self.R.transpose(-1, -2)
|
||||
t = -(R @ self.t.unsqueeze(-1)).squeeze(-1)
|
||||
return self.__class__.from_Rt(R, t)
|
||||
|
||||
def compose(self, other: "Pose") -> "Pose":
|
||||
"""Chain two SE(3) poses: T_B2C.compose(T_A2B) -> T_A2C."""
|
||||
R = self.R @ other.R
|
||||
t = self.t + (self.R @ other.t.unsqueeze(-1)).squeeze(-1)
|
||||
return self.__class__.from_Rt(R, t)
|
||||
|
||||
@autocast
|
||||
def transform(self, p3d: torch.Tensor) -> torch.Tensor:
|
||||
"""Transform a set of 3D points.
|
||||
Args:
|
||||
p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3).
|
||||
"""
|
||||
assert p3d.shape[-1] == 3
|
||||
# assert p3d.shape[:-2] == self.shape # allow broadcasting
|
||||
return p3d @ self.R.transpose(-1, -2) + self.t.unsqueeze(-2)
|
||||
|
||||
def __mul__(self, p3D: torch.Tensor) -> torch.Tensor:
|
||||
"""Transform a set of 3D points: T_A2B * p3D_A -> p3D_B."""
|
||||
return self.transform(p3D)
|
||||
|
||||
def __matmul__(
|
||||
self, other: Union["Pose", torch.Tensor]
|
||||
) -> Union["Pose", torch.Tensor]:
|
||||
"""Transform a set of 3D points: T_A2B * p3D_A -> p3D_B.
|
||||
or chain two SE(3) poses: T_B2C @ T_A2B -> T_A2C."""
|
||||
if isinstance(other, self.__class__):
|
||||
return self.compose(other)
|
||||
else:
|
||||
return self.transform(other)
|
||||
|
||||
@autocast
|
||||
def J_transform(self, p3d_out: torch.Tensor):
|
||||
# [[1,0,0,0,-pz,py],
|
||||
# [0,1,0,pz,0,-px],
|
||||
# [0,0,1,-py,px,0]]
|
||||
J_t = torch.diag_embed(torch.ones_like(p3d_out))
|
||||
J_rot = -skew_symmetric(p3d_out)
|
||||
J = torch.cat([J_t, J_rot], dim=-1)
|
||||
return J # N x 3 x 6
|
||||
|
||||
def numpy(self) -> Tuple[np.ndarray]:
|
||||
return self.R.numpy(), self.t.numpy()
|
||||
|
||||
def magnitude(self) -> Tuple[torch.Tensor]:
|
||||
"""Magnitude of the SE(3) transformation.
|
||||
Returns:
|
||||
dr: rotation anngle in degrees.
|
||||
dt: translation distance in meters.
|
||||
"""
|
||||
trace = torch.diagonal(self.R, dim1=-1, dim2=-2).sum(-1)
|
||||
cos = torch.clamp((trace - 1) / 2, -1, 1)
|
||||
dr = torch.acos(cos).abs() / math.pi * 180
|
||||
dt = torch.norm(self.t, dim=-1)
|
||||
return dr, dt
|
||||
|
||||
def __repr__(self):
|
||||
return f"Pose: {self.shape} {self.dtype} {self.device}"
|
||||
|
||||
|
||||
class Camera(TensorWrapper):
|
||||
eps = 1e-4
|
||||
|
||||
def __init__(self, data: torch.Tensor):
|
||||
assert data.shape[-1] in {6, 8, 10}
|
||||
super().__init__(data)
|
||||
|
||||
@classmethod
|
||||
def from_colmap(cls, camera: Union[Dict, NamedTuple]):
|
||||
"""Camera from a COLMAP Camera tuple or dictionary.
|
||||
We use the corner-convetion from COLMAP (center of top left pixel is (0.5, 0.5))
|
||||
"""
|
||||
if isinstance(camera, tuple):
|
||||
camera = camera._asdict()
|
||||
|
||||
model = camera["model"]
|
||||
params = camera["params"]
|
||||
|
||||
if model in ["OPENCV", "PINHOLE", "RADIAL"]:
|
||||
(fx, fy, cx, cy), params = np.split(params, [4])
|
||||
elif model in ["SIMPLE_PINHOLE", "SIMPLE_RADIAL"]:
|
||||
(f, cx, cy), params = np.split(params, [3])
|
||||
fx = fy = f
|
||||
if model == "SIMPLE_RADIAL":
|
||||
params = np.r_[params, 0.0]
|
||||
else:
|
||||
raise NotImplementedError(model)
|
||||
|
||||
data = np.r_[camera["width"], camera["height"], fx, fy, cx, cy, params]
|
||||
return cls(data)
|
||||
|
||||
@classmethod
|
||||
@autocast
|
||||
def from_calibration_matrix(cls, K: torch.Tensor):
|
||||
cx, cy = K[..., 0, 2], K[..., 1, 2]
|
||||
fx, fy = K[..., 0, 0], K[..., 1, 1]
|
||||
data = torch.stack([2 * cx, 2 * cy, fx, fy, cx, cy], -1)
|
||||
return cls(data)
|
||||
|
||||
@autocast
|
||||
def calibration_matrix(self):
|
||||
K = torch.zeros(
|
||||
*self._data.shape[:-1],
|
||||
3,
|
||||
3,
|
||||
device=self._data.device,
|
||||
dtype=self._data.dtype,
|
||||
)
|
||||
K[..., 0, 2] = self._data[..., 4]
|
||||
K[..., 1, 2] = self._data[..., 5]
|
||||
K[..., 0, 0] = self._data[..., 2]
|
||||
K[..., 1, 1] = self._data[..., 3]
|
||||
K[..., 2, 2] = 1.0
|
||||
return K
|
||||
|
||||
@property
|
||||
def size(self) -> torch.Tensor:
|
||||
"""Size (width height) of the images, with shape (..., 2)."""
|
||||
return self._data[..., :2]
|
||||
|
||||
@property
|
||||
def f(self) -> torch.Tensor:
|
||||
"""Focal lengths (fx, fy) with shape (..., 2)."""
|
||||
return self._data[..., 2:4]
|
||||
|
||||
@property
|
||||
def c(self) -> torch.Tensor:
|
||||
"""Principal points (cx, cy) with shape (..., 2)."""
|
||||
return self._data[..., 4:6]
|
||||
|
||||
@property
|
||||
def dist(self) -> torch.Tensor:
|
||||
"""Distortion parameters, with shape (..., {0, 2, 4})."""
|
||||
return self._data[..., 6:]
|
||||
|
||||
@autocast
|
||||
def scale(self, scales: torch.Tensor):
|
||||
"""Update the camera parameters after resizing an image."""
|
||||
s = scales
|
||||
data = torch.cat([self.size * s, self.f * s, self.c * s, self.dist], -1)
|
||||
return self.__class__(data)
|
||||
|
||||
def crop(self, left_top: Tuple[float], size: Tuple[int]):
|
||||
"""Update the camera parameters after cropping an image."""
|
||||
left_top = self._data.new_tensor(left_top)
|
||||
size = self._data.new_tensor(size)
|
||||
data = torch.cat([size, self.f, self.c - left_top, self.dist], -1)
|
||||
return self.__class__(data)
|
||||
|
||||
@autocast
|
||||
def in_image(self, p2d: torch.Tensor):
|
||||
"""Check if 2D points are within the image boundaries."""
|
||||
assert p2d.shape[-1] == 2
|
||||
# assert p2d.shape[:-2] == self.shape # allow broadcasting
|
||||
size = self.size.unsqueeze(-2)
|
||||
valid = torch.all((p2d >= 0) & (p2d <= (size - 1)), -1)
|
||||
return valid
|
||||
|
||||
@autocast
|
||||
def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
|
||||
"""Project 3D points into the camera plane and check for visibility."""
|
||||
z = p3d[..., -1]
|
||||
valid = z > self.eps
|
||||
z = z.clamp(min=self.eps)
|
||||
p2d = p3d[..., :-1] / z.unsqueeze(-1)
|
||||
return p2d, valid
|
||||
|
||||
def J_project(self, p3d: torch.Tensor):
|
||||
x, y, z = p3d[..., 0], p3d[..., 1], p3d[..., 2]
|
||||
zero = torch.zeros_like(z)
|
||||
z = z.clamp(min=self.eps)
|
||||
J = torch.stack([1 / z, zero, -x / z**2, zero, 1 / z, -y / z**2], dim=-1)
|
||||
J = J.reshape(p3d.shape[:-1] + (2, 3))
|
||||
return J # N x 2 x 3
|
||||
|
||||
@autocast
|
||||
def distort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
|
||||
"""Distort normalized 2D coordinates
|
||||
and check for validity of the distortion model.
|
||||
"""
|
||||
assert pts.shape[-1] == 2
|
||||
# assert pts.shape[:-2] == self.shape # allow broadcasting
|
||||
return distort_points(pts, self.dist)
|
||||
|
||||
def J_distort(self, pts: torch.Tensor):
|
||||
return J_distort_points(pts, self.dist) # N x 2 x 2
|
||||
|
||||
@autocast
|
||||
def denormalize(self, p2d: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert normalized 2D coordinates into pixel coordinates."""
|
||||
return p2d * self.f.unsqueeze(-2) + self.c.unsqueeze(-2)
|
||||
|
||||
@autocast
|
||||
def normalize(self, p2d: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert normalized 2D coordinates into pixel coordinates."""
|
||||
return (p2d - self.c.unsqueeze(-2)) / self.f.unsqueeze(-2)
|
||||
|
||||
def J_denormalize(self):
|
||||
return torch.diag_embed(self.f).unsqueeze(-3) # 1 x 2 x 2
|
||||
|
||||
@autocast
|
||||
def cam2image(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
|
||||
"""Transform 3D points into 2D pixel coordinates."""
|
||||
p2d, visible = self.project(p3d)
|
||||
p2d, mask = self.distort(p2d)
|
||||
p2d = self.denormalize(p2d)
|
||||
valid = visible & mask & self.in_image(p2d)
|
||||
return p2d, valid
|
||||
|
||||
def J_world2image(self, p3d: torch.Tensor):
|
||||
p2d_dist, valid = self.project(p3d)
|
||||
J = self.J_denormalize() @ self.J_distort(p2d_dist) @ self.J_project(p3d)
|
||||
return J, valid
|
||||
|
||||
@autocast
|
||||
def image2cam(self, p2d: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert 2D pixel corrdinates to 3D points with z=1"""
|
||||
assert self._data.shape
|
||||
p2d = self.normalize(p2d)
|
||||
# iterative undistortion
|
||||
return to_homogeneous(p2d)
|
||||
|
||||
def to_cameradict(self, camera_model: Optional[str] = None) -> List[Dict]:
|
||||
data = self._data.clone()
|
||||
if data.dim() == 1:
|
||||
data = data.unsqueeze(0)
|
||||
assert data.dim() == 2
|
||||
b, d = data.shape
|
||||
if camera_model is None:
|
||||
camera_model = {6: "PINHOLE", 8: "RADIAL", 10: "OPENCV"}[d]
|
||||
cameras = []
|
||||
for i in range(b):
|
||||
if camera_model.startswith("SIMPLE_"):
|
||||
params = [x.item() for x in data[i, 3 : min(d, 7)]]
|
||||
else:
|
||||
params = [x.item() for x in data[i, 2:]]
|
||||
cameras.append(
|
||||
{
|
||||
"model": camera_model,
|
||||
"width": int(data[i, 0].item()),
|
||||
"height": int(data[i, 1].item()),
|
||||
"params": params,
|
||||
}
|
||||
)
|
||||
return cameras if self._data.dim() == 2 else cameras[0]
|
||||
|
||||
def __repr__(self):
|
||||
return f"Camera {self.shape} {self.dtype} {self.device}"
|
|
@ -0,0 +1,29 @@
|
|||
import importlib.util
|
||||
from .base_model import BaseModel
|
||||
from ..utils.tools import get_class
|
||||
|
||||
|
||||
def get_model(name):
|
||||
import_paths = [
|
||||
name,
|
||||
f"{__name__}.{name}",
|
||||
f"{__name__}.extractors.{name}", # backward compatibility
|
||||
f"{__name__}.matchers.{name}", # backward compatibility
|
||||
]
|
||||
for path in import_paths:
|
||||
try:
|
||||
spec = importlib.util.find_spec(path)
|
||||
except ModuleNotFoundError:
|
||||
spec = None
|
||||
if spec is not None:
|
||||
try:
|
||||
return get_class(path, BaseModel)
|
||||
except AssertionError:
|
||||
mod = __import__(path, fromlist=[""])
|
||||
try:
|
||||
return mod.__main_model__
|
||||
except AttributeError as exc:
|
||||
print(exc)
|
||||
continue
|
||||
|
||||
raise RuntimeError(f'Model {name} not found in any of [{" ".join(import_paths)}]')
|
|
@ -0,0 +1,29 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..base_model import BaseModel
|
||||
|
||||
|
||||
class DinoV2(BaseModel):
|
||||
default_conf = {"weights": "dinov2_vits14", "allow_resize": False}
|
||||
required_data_keys = ["image"]
|
||||
|
||||
def _init(self, conf):
|
||||
self.net = torch.hub.load("facebookresearch/dinov2", conf.weights)
|
||||
|
||||
def _forward(self, data):
|
||||
img = data["image"]
|
||||
if self.conf.allow_resize:
|
||||
img = F.upsample(img, [int(x // 14 * 14) for x in img.shape[-2:]])
|
||||
desc, cls_token = self.net.get_intermediate_layers(
|
||||
img, n=1, return_class_token=True, reshape=True
|
||||
)[0]
|
||||
|
||||
return {
|
||||
"features": desc,
|
||||
"global_descriptor": cls_token,
|
||||
"descriptors": desc.flatten(-2).transpose(-2, -1),
|
||||
}
|
||||
|
||||
def loss(self, pred, data):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,126 @@
|
|||
"""
|
||||
Base class for trainable models.
|
||||
"""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import omegaconf
|
||||
from omegaconf import OmegaConf
|
||||
from torch import nn
|
||||
from copy import copy
|
||||
|
||||
|
||||
class MetaModel(ABCMeta):
|
||||
def __prepare__(name, bases, **kwds):
|
||||
total_conf = OmegaConf.create()
|
||||
for base in bases:
|
||||
for key in ("base_default_conf", "default_conf"):
|
||||
update = getattr(base, key, {})
|
||||
if isinstance(update, dict):
|
||||
update = OmegaConf.create(update)
|
||||
total_conf = OmegaConf.merge(total_conf, update)
|
||||
return dict(base_default_conf=total_conf)
|
||||
|
||||
|
||||
class BaseModel(nn.Module, metaclass=MetaModel):
|
||||
"""
|
||||
What the child model is expect to declare:
|
||||
default_conf: dictionary of the default configuration of the model.
|
||||
It recursively updates the default_conf of all parent classes, and
|
||||
it is updated by the user-provided configuration passed to __init__.
|
||||
Configurations can be nested.
|
||||
|
||||
required_data_keys: list of expected keys in the input data dictionary.
|
||||
|
||||
strict_conf (optional): boolean. If false, BaseModel does not raise
|
||||
an error when the user provides an unknown configuration entry.
|
||||
|
||||
_init(self, conf): initialization method, where conf is the final
|
||||
configuration object (also accessible with `self.conf`). Accessing
|
||||
unknown configuration entries will raise an error.
|
||||
|
||||
_forward(self, data): method that returns a dictionary of batched
|
||||
prediction tensors based on a dictionary of batched input data tensors.
|
||||
|
||||
loss(self, pred, data): method that returns a dictionary of losses,
|
||||
computed from model predictions and input data. Each loss is a batch
|
||||
of scalars, i.e. a torch.Tensor of shape (B,).
|
||||
The total loss to be optimized has the key `'total'`.
|
||||
|
||||
metrics(self, pred, data): method that returns a dictionary of metrics,
|
||||
each as a batch of scalars.
|
||||
"""
|
||||
|
||||
default_conf = {
|
||||
"name": None,
|
||||
"trainable": True, # if false: do not optimize this model parameters
|
||||
"freeze_batch_normalization": False, # use test-time statistics
|
||||
"timeit": False, # time forward pass
|
||||
}
|
||||
required_data_keys = []
|
||||
strict_conf = False
|
||||
|
||||
def __init__(self, conf):
|
||||
"""Perform some logic and call the _init method of the child model."""
|
||||
super().__init__()
|
||||
default_conf = OmegaConf.merge(
|
||||
self.base_default_conf, OmegaConf.create(self.default_conf)
|
||||
)
|
||||
if self.strict_conf:
|
||||
OmegaConf.set_struct(default_conf, True)
|
||||
|
||||
# fixme: backward compatibility
|
||||
if "pad" in conf and "pad" not in default_conf: # backward compat.
|
||||
with omegaconf.read_write(conf):
|
||||
with omegaconf.open_dict(conf):
|
||||
conf["interpolation"] = {"pad": conf.pop("pad")}
|
||||
|
||||
if isinstance(conf, dict):
|
||||
conf = OmegaConf.create(conf)
|
||||
self.conf = conf = OmegaConf.merge(default_conf, conf)
|
||||
OmegaConf.set_readonly(conf, True)
|
||||
OmegaConf.set_struct(conf, True)
|
||||
self.required_data_keys = copy(self.required_data_keys)
|
||||
self._init(conf)
|
||||
|
||||
if not conf.trainable:
|
||||
for p in self.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
|
||||
def freeze_bn(module):
|
||||
if isinstance(module, nn.modules.batchnorm._BatchNorm):
|
||||
module.eval()
|
||||
|
||||
if self.conf.freeze_batch_normalization:
|
||||
self.apply(freeze_bn)
|
||||
|
||||
return self
|
||||
|
||||
def forward(self, data):
|
||||
"""Check the data and call the _forward method of the child model."""
|
||||
|
||||
def recursive_key_check(expected, given):
|
||||
for key in expected:
|
||||
assert key in given, f"Missing key {key} in data"
|
||||
if isinstance(expected, dict):
|
||||
recursive_key_check(expected[key], given[key])
|
||||
|
||||
recursive_key_check(self.required_data_keys, data)
|
||||
return self._forward(data)
|
||||
|
||||
@abstractmethod
|
||||
def _init(self, conf):
|
||||
"""To be implemented by the child class."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _forward(self, data):
|
||||
"""To be implemented by the child class."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def loss(self, pred, data):
|
||||
"""To be implemented by the child class."""
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,129 @@
|
|||
import torch
|
||||
import string
|
||||
import h5py
|
||||
|
||||
from .base_model import BaseModel
|
||||
from ..settings import DATA_PATH
|
||||
from ..datasets.base_dataset import collate
|
||||
from ..utils.tensor import batch_to_device
|
||||
from .utils.misc import pad_to_length
|
||||
|
||||
|
||||
def pad_local_features(pred: dict, seq_l: int):
|
||||
pred["keypoints"] = pad_to_length(
|
||||
pred["keypoints"],
|
||||
seq_l,
|
||||
-2,
|
||||
mode="random_c",
|
||||
)
|
||||
if "keypoint_scores" in pred.keys():
|
||||
pred["keypoint_scores"] = pad_to_length(
|
||||
pred["keypoint_scores"], seq_l, -1, mode="zeros"
|
||||
)
|
||||
if "descriptors" in pred.keys():
|
||||
pred["descriptors"] = pad_to_length(
|
||||
pred["descriptors"], seq_l, -2, mode="random"
|
||||
)
|
||||
if "scales" in pred.keys():
|
||||
pred["scales"] = pad_to_length(pred["scales"], seq_l, -1, mode="zeros")
|
||||
if "oris" in pred.keys():
|
||||
pred["oris"] = pad_to_length(pred["oris"], seq_l, -1, mode="zeros")
|
||||
return pred
|
||||
|
||||
|
||||
def pad_line_features(pred, seq_l: int = None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def recursive_load(grp, pkeys):
|
||||
return {
|
||||
k: torch.from_numpy(grp[k].__array__())
|
||||
if isinstance(grp[k], h5py.Dataset)
|
||||
else recursive_load(grp[k], list(grp.keys()))
|
||||
for k in pkeys
|
||||
}
|
||||
|
||||
|
||||
class CacheLoader(BaseModel):
|
||||
default_conf = {
|
||||
"path": "???", # can be a format string like exports/{scene}/
|
||||
"data_keys": None, # load all keys
|
||||
"device": None, # load to same device as data
|
||||
"trainable": False,
|
||||
"add_data_path": True,
|
||||
"collate": True,
|
||||
"scale": ["keypoints", "lines", "orig_lines"],
|
||||
"padding_fn": None,
|
||||
"padding_length": None, # required for batching!
|
||||
"numeric_type": "float32", # [None, "float16", "float32", "float64"]
|
||||
}
|
||||
|
||||
required_data_keys = ["name"] # we need an identifier
|
||||
|
||||
def _init(self, conf):
|
||||
self.hfiles = {}
|
||||
self.padding_fn = conf.padding_fn
|
||||
if self.padding_fn is not None:
|
||||
self.padding_fn = eval(self.padding_fn)
|
||||
self.numeric_dtype = {
|
||||
None: None,
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64,
|
||||
}[conf.numeric_type]
|
||||
|
||||
def _forward(self, data):
|
||||
preds = []
|
||||
device = self.conf.device
|
||||
if not device:
|
||||
devices = set(
|
||||
[v.device for v in data.values() if isinstance(v, torch.Tensor)]
|
||||
)
|
||||
if len(devices) == 0:
|
||||
device = "cpu"
|
||||
else:
|
||||
assert len(devices) == 1
|
||||
device = devices.pop()
|
||||
|
||||
var_names = [x[1] for x in string.Formatter().parse(self.conf.path) if x[1]]
|
||||
for i, name in enumerate(data["name"]):
|
||||
fpath = self.conf.path.format(**{k: data[k][i] for k in var_names})
|
||||
if self.conf.add_data_path:
|
||||
fpath = DATA_PATH / fpath
|
||||
hfile = h5py.File(str(fpath), "r")
|
||||
grp = hfile[name]
|
||||
pkeys = (
|
||||
self.conf.data_keys if self.conf.data_keys is not None else grp.keys()
|
||||
)
|
||||
pred = recursive_load(grp, pkeys)
|
||||
if self.numeric_dtype is not None:
|
||||
pred = {
|
||||
k: v
|
||||
if not isinstance(v, torch.Tensor) or not torch.is_floating_point(v)
|
||||
else v.to(dtype=self.numeric_dtype)
|
||||
for k, v in pred.items()
|
||||
}
|
||||
pred = batch_to_device(pred, device)
|
||||
for k, v in pred.items():
|
||||
for pattern in self.conf.scale:
|
||||
if k.startswith(pattern):
|
||||
view_idx = k.replace(pattern, "")
|
||||
scales = (
|
||||
data["scales"]
|
||||
if len(view_idx) == 0
|
||||
else data[f"view{view_idx}"]["scales"]
|
||||
)
|
||||
pred[k] = pred[k] * scales[i]
|
||||
# use this function to fix number of keypoints etc.
|
||||
if self.padding_fn is not None:
|
||||
pred = self.padding_fn(pred, self.conf.padding_length)
|
||||
preds.append(pred)
|
||||
hfile.close()
|
||||
if self.conf.collate:
|
||||
return batch_to_device(collate(preds), device)
|
||||
else:
|
||||
assert len(preds) == 1
|
||||
return batch_to_device(preds[0], device)
|
||||
|
||||
def loss(self, pred, data):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,785 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision.models import resnet
|
||||
from typing import Optional, Callable
|
||||
from torch.nn.modules.utils import _pair
|
||||
import torchvision
|
||||
|
||||
from gluefactory.models.base_model import BaseModel
|
||||
|
||||
# coordinates system
|
||||
# ------------------------------> [ x: range=-1.0~1.0; w: range=0~W ]
|
||||
# | -----------------------------
|
||||
# | | |
|
||||
# | | |
|
||||
# | | |
|
||||
# | | image |
|
||||
# | | |
|
||||
# | | |
|
||||
# | | |
|
||||
# | |---------------------------|
|
||||
# v
|
||||
# [ y: range=-1.0~1.0; h: range=0~H ]
|
||||
|
||||
|
||||
def get_patches(
|
||||
tensor: torch.Tensor, required_corners: torch.Tensor, ps: int
|
||||
) -> torch.Tensor:
|
||||
c, h, w = tensor.shape
|
||||
corner = (required_corners - ps / 2 + 1).long()
|
||||
corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps)
|
||||
corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps)
|
||||
offset = torch.arange(0, ps)
|
||||
|
||||
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
|
||||
x, y = torch.meshgrid(offset, offset, **kw)
|
||||
patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2)
|
||||
patches = patches.to(corner) + corner[None, None]
|
||||
pts = patches.reshape(-1, 2)
|
||||
sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]]
|
||||
sampled = sampled.reshape(ps, ps, -1, c)
|
||||
assert sampled.shape[:3] == patches.shape[:3]
|
||||
return sampled.permute(2, 3, 0, 1)
|
||||
|
||||
|
||||
def simple_nms(scores: torch.Tensor, nms_radius: int):
|
||||
"""Fast Non-maximum suppression to remove nearby points"""
|
||||
|
||||
zeros = torch.zeros_like(scores)
|
||||
max_mask = scores == torch.nn.functional.max_pool2d(
|
||||
scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
||||
)
|
||||
|
||||
for _ in range(2):
|
||||
supp_mask = (
|
||||
torch.nn.functional.max_pool2d(
|
||||
max_mask.float(),
|
||||
kernel_size=nms_radius * 2 + 1,
|
||||
stride=1,
|
||||
padding=nms_radius,
|
||||
)
|
||||
> 0
|
||||
)
|
||||
supp_scores = torch.where(supp_mask, zeros, scores)
|
||||
new_max_mask = supp_scores == torch.nn.functional.max_pool2d(
|
||||
supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
||||
)
|
||||
max_mask = max_mask | (new_max_mask & (~supp_mask))
|
||||
return torch.where(max_mask, scores, zeros)
|
||||
|
||||
|
||||
class DKD(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
radius: int = 2,
|
||||
top_k: int = 0,
|
||||
scores_th: float = 0.2,
|
||||
n_limit: int = 20000,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
radius: soft detection radius, kernel size is (2 * radius + 1)
|
||||
top_k: top_k > 0: return top k keypoints
|
||||
scores_th: top_k <= 0 threshold mode:
|
||||
scores_th > 0: return keypoints with scores>scores_th
|
||||
else: return keypoints with scores > scores.mean()
|
||||
n_limit: max number of keypoint in threshold mode
|
||||
"""
|
||||
super().__init__()
|
||||
self.radius = radius
|
||||
self.top_k = top_k
|
||||
self.scores_th = scores_th
|
||||
self.n_limit = n_limit
|
||||
self.kernel_size = 2 * self.radius + 1
|
||||
self.temperature = 0.1 # tuned temperature
|
||||
self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
|
||||
# local xy grid
|
||||
x = torch.linspace(-self.radius, self.radius, self.kernel_size)
|
||||
# (kernel_size*kernel_size) x 2 : (w,h)
|
||||
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
|
||||
self.hw_grid = (
|
||||
torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
scores_map: torch.Tensor,
|
||||
sub_pixel: bool = True,
|
||||
image_size: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
:param scores_map: Bx1xHxW
|
||||
:param descriptor_map: BxCxHxW
|
||||
:param sub_pixel: whether to use sub-pixel keypoint detection
|
||||
:return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1~1
|
||||
"""
|
||||
b, c, h, w = scores_map.shape
|
||||
scores_nograd = scores_map.detach()
|
||||
nms_scores = simple_nms(scores_nograd, self.radius)
|
||||
|
||||
# remove border
|
||||
nms_scores[:, :, : self.radius, :] = 0
|
||||
nms_scores[:, :, :, : self.radius] = 0
|
||||
if image_size is not None:
|
||||
for i in range(scores_map.shape[0]):
|
||||
w, h = image_size[i].long()
|
||||
nms_scores[i, :, h.item() - self.radius :, :] = 0
|
||||
nms_scores[i, :, :, w.item() - self.radius :] = 0
|
||||
else:
|
||||
nms_scores[:, :, -self.radius :, :] = 0
|
||||
nms_scores[:, :, :, -self.radius :] = 0
|
||||
|
||||
# detect keypoints without grad
|
||||
if self.top_k > 0:
|
||||
topk = torch.topk(nms_scores.view(b, -1), self.top_k)
|
||||
indices_keypoints = [topk.indices[i] for i in range(b)] # B x top_k
|
||||
else:
|
||||
if self.scores_th > 0:
|
||||
masks = nms_scores > self.scores_th
|
||||
if masks.sum() == 0:
|
||||
th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
|
||||
masks = nms_scores > th.reshape(b, 1, 1, 1)
|
||||
else:
|
||||
th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
|
||||
masks = nms_scores > th.reshape(b, 1, 1, 1)
|
||||
masks = masks.reshape(b, -1)
|
||||
|
||||
indices_keypoints = [] # list, B x (any size)
|
||||
scores_view = scores_nograd.reshape(b, -1)
|
||||
for mask, scores in zip(masks, scores_view):
|
||||
indices = mask.nonzero()[:, 0]
|
||||
if len(indices) > self.n_limit:
|
||||
kpts_sc = scores[indices]
|
||||
sort_idx = kpts_sc.sort(descending=True)[1]
|
||||
sel_idx = sort_idx[: self.n_limit]
|
||||
indices = indices[sel_idx]
|
||||
indices_keypoints.append(indices)
|
||||
|
||||
wh = torch.tensor([w - 1, h - 1], device=scores_nograd.device)
|
||||
|
||||
keypoints = []
|
||||
scoredispersitys = []
|
||||
kptscores = []
|
||||
if sub_pixel:
|
||||
# detect soft keypoints with grad backpropagation
|
||||
patches = self.unfold(scores_map) # B x (kernel**2) x (H*W)
|
||||
self.hw_grid = self.hw_grid.to(scores_map) # to device
|
||||
for b_idx in range(b):
|
||||
patch = patches[b_idx].t() # (H*W) x (kernel**2)
|
||||
indices_kpt = indices_keypoints[
|
||||
b_idx
|
||||
] # one dimension vector, say its size is M
|
||||
patch_scores = patch[indices_kpt] # M x (kernel**2)
|
||||
keypoints_xy_nms = torch.stack(
|
||||
[indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
|
||||
dim=1,
|
||||
) # Mx2
|
||||
|
||||
# max is detached to prevent undesired backprop loops in the graph
|
||||
max_v = patch_scores.max(dim=1).values.detach()[:, None]
|
||||
x_exp = (
|
||||
(patch_scores - max_v) / self.temperature
|
||||
).exp() # M * (kernel**2), in [0, 1]
|
||||
|
||||
# \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
|
||||
xy_residual = (
|
||||
x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
|
||||
) # Soft-argmax, Mx2
|
||||
|
||||
hw_grid_dist2 = (
|
||||
torch.norm(
|
||||
(self.hw_grid[None, :, :] - xy_residual[:, None, :])
|
||||
/ self.radius,
|
||||
dim=-1,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
|
||||
|
||||
# compute result keypoints
|
||||
keypoints_xy = keypoints_xy_nms + xy_residual
|
||||
keypoints_xy = keypoints_xy / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
|
||||
|
||||
kptscore = torch.nn.functional.grid_sample(
|
||||
scores_map[b_idx].unsqueeze(0),
|
||||
keypoints_xy.view(1, 1, -1, 2),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)[
|
||||
0, 0, 0, :
|
||||
] # CxN
|
||||
|
||||
keypoints.append(keypoints_xy)
|
||||
scoredispersitys.append(scoredispersity)
|
||||
kptscores.append(kptscore)
|
||||
else:
|
||||
for b_idx in range(b):
|
||||
indices_kpt = indices_keypoints[
|
||||
b_idx
|
||||
] # one dimension vector, say its size is M
|
||||
# To avoid warning: UserWarning: __floordiv__ is deprecated
|
||||
keypoints_xy_nms = torch.stack(
|
||||
[indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
|
||||
dim=1,
|
||||
) # Mx2
|
||||
keypoints_xy = keypoints_xy_nms / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
|
||||
kptscore = torch.nn.functional.grid_sample(
|
||||
scores_map[b_idx].unsqueeze(0),
|
||||
keypoints_xy.view(1, 1, -1, 2),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)[
|
||||
0, 0, 0, :
|
||||
] # CxN
|
||||
keypoints.append(keypoints_xy)
|
||||
scoredispersitys.append(kptscore) # for jit.script compatability
|
||||
kptscores.append(kptscore)
|
||||
|
||||
return keypoints, scoredispersitys, kptscores
|
||||
|
||||
|
||||
class InputPadder(object):
|
||||
"""Pads images such that dimensions are divisible by 8"""
|
||||
|
||||
def __init__(self, h: int, w: int, divis_by: int = 8):
|
||||
self.ht = h
|
||||
self.wd = w
|
||||
pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
|
||||
pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
|
||||
self._pad = [
|
||||
pad_wd // 2,
|
||||
pad_wd - pad_wd // 2,
|
||||
pad_ht // 2,
|
||||
pad_ht - pad_ht // 2,
|
||||
]
|
||||
|
||||
def pad(self, x: torch.Tensor):
|
||||
assert x.ndim == 4
|
||||
return F.pad(x, self._pad, mode="replicate")
|
||||
|
||||
def unpad(self, x: torch.Tensor):
|
||||
assert x.ndim == 4
|
||||
ht = x.shape[-2]
|
||||
wd = x.shape[-1]
|
||||
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
||||
return x[..., c[0] : c[1], c[2] : c[3]]
|
||||
|
||||
|
||||
class DeformableConv2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
mask=False,
|
||||
):
|
||||
super(DeformableConv2d, self).__init__()
|
||||
|
||||
self.padding = padding
|
||||
self.mask = mask
|
||||
|
||||
self.channel_num = (
|
||||
3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size
|
||||
)
|
||||
self.offset_conv = nn.Conv2d(
|
||||
in_channels,
|
||||
self.channel_num,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.regular_conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h, w = x.shape[2:]
|
||||
max_offset = max(h, w) / 4.0
|
||||
|
||||
out = self.offset_conv(x)
|
||||
if self.mask:
|
||||
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
||||
offset = torch.cat((o1, o2), dim=1)
|
||||
mask = torch.sigmoid(mask)
|
||||
else:
|
||||
offset = out
|
||||
mask = None
|
||||
offset = offset.clamp(-max_offset, max_offset)
|
||||
x = torchvision.ops.deform_conv2d(
|
||||
input=x,
|
||||
offset=offset,
|
||||
weight=self.regular_conv.weight,
|
||||
bias=self.regular_conv.bias,
|
||||
padding=self.padding,
|
||||
mask=mask,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def get_conv(
|
||||
inplanes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_type="conv",
|
||||
mask=False,
|
||||
):
|
||||
if conv_type == "conv":
|
||||
conv = nn.Conv2d(
|
||||
inplanes,
|
||||
planes,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
elif conv_type == "dcn":
|
||||
conv = DeformableConv2d(
|
||||
inplanes,
|
||||
planes,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=_pair(padding),
|
||||
bias=bias,
|
||||
mask=mask,
|
||||
)
|
||||
else:
|
||||
raise TypeError
|
||||
return conv
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
gate: Optional[Callable[..., nn.Module]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
conv_type: str = "conv",
|
||||
mask: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if gate is None:
|
||||
self.gate = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.gate = gate
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self.conv1 = get_conv(
|
||||
in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
|
||||
)
|
||||
self.bn1 = norm_layer(out_channels)
|
||||
self.conv2 = get_conv(
|
||||
out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
|
||||
)
|
||||
self.bn2 = norm_layer(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W
|
||||
x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W
|
||||
return x
|
||||
|
||||
|
||||
# modified based on torchvision\models\resnet.py#27->BasicBlock
|
||||
class ResBlock(nn.Module):
|
||||
expansion: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
gate: Optional[Callable[..., nn.Module]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
conv_type: str = "conv",
|
||||
mask: bool = False,
|
||||
) -> None:
|
||||
super(ResBlock, self).__init__()
|
||||
if gate is None:
|
||||
self.gate = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.gate = gate
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError("ResBlock only supports groups=1 and base_width=64")
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in ResBlock")
|
||||
# Both self.conv1 and self.downsample layers
|
||||
# downsample the input when stride != 1
|
||||
self.conv1 = get_conv(
|
||||
inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask
|
||||
)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.conv2 = get_conv(
|
||||
planes, planes, kernel_size=3, conv_type=conv_type, mask=mask
|
||||
)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.gate(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.gate(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SDDH(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
kernel_size: int = 3,
|
||||
n_pos: int = 8,
|
||||
gate=nn.ReLU(),
|
||||
conv2D=False,
|
||||
mask=False,
|
||||
):
|
||||
super(SDDH, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.n_pos = n_pos
|
||||
self.conv2D = conv2D
|
||||
self.mask = mask
|
||||
|
||||
self.get_patches_func = get_patches
|
||||
|
||||
# estimate offsets
|
||||
self.channel_num = 3 * n_pos if mask else 2 * n_pos
|
||||
self.offset_conv = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
dims,
|
||||
self.channel_num,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True,
|
||||
),
|
||||
gate,
|
||||
nn.Conv2d(
|
||||
self.channel_num,
|
||||
self.channel_num,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True,
|
||||
),
|
||||
)
|
||||
|
||||
# sampled feature conv
|
||||
self.sf_conv = nn.Conv2d(
|
||||
dims, dims, kernel_size=1, stride=1, padding=0, bias=False
|
||||
)
|
||||
|
||||
# convM
|
||||
if not conv2D:
|
||||
# deformable desc weights
|
||||
agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims))
|
||||
self.register_parameter("agg_weights", agg_weights)
|
||||
else:
|
||||
self.convM = nn.Conv2d(
|
||||
dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x, keypoints):
|
||||
# x: [B,C,H,W]
|
||||
# keypoints: list, [[N_kpts,2], ...] (w,h)
|
||||
b, c, h, w = x.shape
|
||||
wh = torch.tensor([[w - 1, h - 1]], device=x.device)
|
||||
max_offset = max(h, w) / 4.0
|
||||
|
||||
offsets = []
|
||||
descriptors = []
|
||||
# get offsets for each keypoint
|
||||
for ib in range(b):
|
||||
xi, kptsi = x[ib], keypoints[ib]
|
||||
kptsi_wh = (kptsi / 2 + 0.5) * wh
|
||||
N_kpts = len(kptsi)
|
||||
|
||||
if self.kernel_size > 1:
|
||||
patch = self.get_patches_func(
|
||||
xi, kptsi_wh.long(), self.kernel_size
|
||||
) # [N_kpts, C, K, K]
|
||||
else:
|
||||
kptsi_wh_long = kptsi_wh.long()
|
||||
patch = (
|
||||
xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]]
|
||||
.permute(1, 0)
|
||||
.reshape(N_kpts, c, 1, 1)
|
||||
)
|
||||
|
||||
offset = self.offset_conv(patch).clamp(
|
||||
-max_offset, max_offset
|
||||
) # [N_kpts, 2*n_pos, 1, 1]
|
||||
if self.mask:
|
||||
offset = (
|
||||
offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1)
|
||||
) # [N_kpts, n_pos, 3]
|
||||
offset = offset[:, :, :-1] # [N_kpts, n_pos, 2]
|
||||
mask_weight = torch.sigmoid(offset[:, :, -1]) # [N_kpts, n_pos]
|
||||
else:
|
||||
offset = (
|
||||
offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1)
|
||||
) # [N_kpts, n_pos, 2]
|
||||
offsets.append(offset) # for visualization
|
||||
|
||||
# get sample positions
|
||||
pos = kptsi_wh.unsqueeze(1) + offset # [N_kpts, n_pos, 2]
|
||||
pos = 2.0 * pos / wh[None] - 1
|
||||
pos = pos.reshape(1, N_kpts * self.n_pos, 1, 2)
|
||||
|
||||
# sample features
|
||||
features = F.grid_sample(
|
||||
xi.unsqueeze(0), pos, mode="bilinear", align_corners=True
|
||||
) # [1,C,(N_kpts*n_pos),1]
|
||||
features = features.reshape(c, N_kpts, self.n_pos, 1).permute(
|
||||
1, 0, 2, 3
|
||||
) # [N_kpts, C, n_pos, 1]
|
||||
if self.mask:
|
||||
features = torch.einsum("ncpo,np->ncpo", features, mask_weight)
|
||||
|
||||
features = torch.selu_(self.sf_conv(features)).squeeze(
|
||||
-1
|
||||
) # [N_kpts, C, n_pos]
|
||||
# convM
|
||||
if not self.conv2D:
|
||||
descs = torch.einsum(
|
||||
"ncp,pcd->nd", features, self.agg_weights
|
||||
) # [N_kpts, C]
|
||||
else:
|
||||
features = features.reshape(N_kpts, -1)[
|
||||
:, :, None, None
|
||||
] # [N_kpts, C*n_pos, 1, 1]
|
||||
descs = self.convM(features).squeeze() # [N_kpts, C]
|
||||
|
||||
# normalize
|
||||
descs = F.normalize(descs, p=2.0, dim=1)
|
||||
descriptors.append(descs)
|
||||
|
||||
return descriptors, offsets
|
||||
|
||||
|
||||
class ALIKED(BaseModel):
|
||||
default_conf = {
|
||||
"model_name": "aliked-n16",
|
||||
"max_num_keypoints": -1,
|
||||
"detection_threshold": 0.2,
|
||||
"force_num_keypoints": False,
|
||||
"pretrained": True,
|
||||
"nms_radius": 2,
|
||||
}
|
||||
|
||||
checkpoint_url = "https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth"
|
||||
|
||||
n_limit_max = 20000
|
||||
|
||||
cfgs = {
|
||||
"aliked-t16": {
|
||||
"c1": 8,
|
||||
"c2": 16,
|
||||
"c3": 32,
|
||||
"c4": 64,
|
||||
"dim": 64,
|
||||
"K": 3,
|
||||
"M": 16,
|
||||
},
|
||||
"aliked-n16": {
|
||||
"c1": 16,
|
||||
"c2": 32,
|
||||
"c3": 64,
|
||||
"c4": 128,
|
||||
"dim": 128,
|
||||
"K": 3,
|
||||
"M": 16,
|
||||
},
|
||||
"aliked-n16rot": {
|
||||
"c1": 16,
|
||||
"c2": 32,
|
||||
"c3": 64,
|
||||
"c4": 128,
|
||||
"dim": 128,
|
||||
"K": 3,
|
||||
"M": 16,
|
||||
},
|
||||
"aliked-n32": {
|
||||
"c1": 16,
|
||||
"c2": 32,
|
||||
"c3": 64,
|
||||
"c4": 128,
|
||||
"dim": 128,
|
||||
"K": 3,
|
||||
"M": 32,
|
||||
},
|
||||
}
|
||||
|
||||
required_data_keys = ["image"]
|
||||
|
||||
def _init(self, conf):
|
||||
if conf.force_num_keypoints:
|
||||
assert conf.detection_threshold <= 0 and conf.max_num_keypoints > 0
|
||||
# get configurations
|
||||
c1, c2, c3, c4, dim, K, M = [v for _, v in self.cfgs[conf.model_name].items()]
|
||||
conv_types = ["conv", "conv", "dcn", "dcn"]
|
||||
conv2D = False
|
||||
mask = False
|
||||
|
||||
# build model
|
||||
self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4)
|
||||
self.norm = nn.BatchNorm2d
|
||||
self.gate = nn.SELU(inplace=True)
|
||||
self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0])
|
||||
self.block2 = ResBlock(
|
||||
c1,
|
||||
c2,
|
||||
1,
|
||||
nn.Conv2d(c1, c2, 1),
|
||||
gate=self.gate,
|
||||
norm_layer=self.norm,
|
||||
conv_type=conv_types[1],
|
||||
)
|
||||
self.block3 = ResBlock(
|
||||
c2,
|
||||
c3,
|
||||
1,
|
||||
nn.Conv2d(c2, c3, 1),
|
||||
gate=self.gate,
|
||||
norm_layer=self.norm,
|
||||
conv_type=conv_types[2],
|
||||
mask=mask,
|
||||
)
|
||||
self.block4 = ResBlock(
|
||||
c3,
|
||||
c4,
|
||||
1,
|
||||
nn.Conv2d(c3, c4, 1),
|
||||
gate=self.gate,
|
||||
norm_layer=self.norm,
|
||||
conv_type=conv_types[3],
|
||||
mask=mask,
|
||||
)
|
||||
self.conv1 = resnet.conv1x1(c1, dim // 4)
|
||||
self.conv2 = resnet.conv1x1(c2, dim // 4)
|
||||
self.conv3 = resnet.conv1x1(c3, dim // 4)
|
||||
self.conv4 = resnet.conv1x1(dim, dim // 4)
|
||||
self.upsample2 = nn.Upsample(
|
||||
scale_factor=2, mode="bilinear", align_corners=True
|
||||
)
|
||||
self.upsample4 = nn.Upsample(
|
||||
scale_factor=4, mode="bilinear", align_corners=True
|
||||
)
|
||||
self.upsample8 = nn.Upsample(
|
||||
scale_factor=8, mode="bilinear", align_corners=True
|
||||
)
|
||||
self.upsample32 = nn.Upsample(
|
||||
scale_factor=32, mode="bilinear", align_corners=True
|
||||
)
|
||||
self.score_head = nn.Sequential(
|
||||
resnet.conv1x1(dim, 8),
|
||||
self.gate,
|
||||
resnet.conv3x3(8, 4),
|
||||
self.gate,
|
||||
resnet.conv3x3(4, 4),
|
||||
self.gate,
|
||||
resnet.conv3x3(4, 1),
|
||||
)
|
||||
self.desc_head = SDDH(dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask)
|
||||
self.dkd = DKD(
|
||||
radius=conf.nms_radius,
|
||||
top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
|
||||
scores_th=conf.detection_threshold,
|
||||
n_limit=conf.max_num_keypoints
|
||||
if conf.max_num_keypoints > 0
|
||||
else self.n_limit_max,
|
||||
)
|
||||
|
||||
# load pretrained
|
||||
if conf.pretrained:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
self.checkpoint_url.format(conf.model_name), map_location="cpu"
|
||||
)
|
||||
self.load_state_dict(state_dict, strict=True)
|
||||
|
||||
def extract_dense_map(self, image):
|
||||
# Pads images such that dimensions are divisible by
|
||||
div_by = 2**5
|
||||
padder = InputPadder(image.shape[-2], image.shape[-1], div_by)
|
||||
image = padder.pad(image)
|
||||
|
||||
# ================================== feature encoder
|
||||
x1 = self.block1(image) # B x c1 x H x W
|
||||
x2 = self.pool2(x1)
|
||||
x2 = self.block2(x2) # B x c2 x H/2 x W/2
|
||||
x3 = self.pool4(x2)
|
||||
x3 = self.block3(x3) # B x c3 x H/8 x W/8
|
||||
x4 = self.pool4(x3)
|
||||
x4 = self.block4(x4) # B x dim x H/32 x W/32
|
||||
# ================================== feature aggregation
|
||||
x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W
|
||||
x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2
|
||||
x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8
|
||||
x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32
|
||||
x2_up = self.upsample2(x2) # B x dim//4 x H x W
|
||||
x3_up = self.upsample8(x3) # B x dim//4 x H x W
|
||||
x4_up = self.upsample32(x4) # B x dim//4 x H x W
|
||||
x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1)
|
||||
# ================================== score head
|
||||
score_map = torch.sigmoid(self.score_head(x1234))
|
||||
feature_map = torch.nn.functional.normalize(x1234, p=2, dim=1)
|
||||
|
||||
# Unpads images
|
||||
feature_map = padder.unpad(feature_map)
|
||||
score_map = padder.unpad(score_map)
|
||||
|
||||
return feature_map, score_map
|
||||
|
||||
def _forward(self, data):
|
||||
image = data["image"]
|
||||
feature_map, score_map = self.extract_dense_map(image)
|
||||
keypoints, kptscores, scoredispersitys = self.dkd(
|
||||
score_map, image_size=data.get("image_size")
|
||||
)
|
||||
descriptors, offsets = self.desc_head(feature_map, keypoints)
|
||||
|
||||
_, _, h, w = image.shape
|
||||
wh = torch.tensor([w, h], device=image.device)
|
||||
# no padding required,
|
||||
# we can set detection_threshold=-1 and conf.max_num_keypoints
|
||||
return {
|
||||
"keypoints": wh * (torch.stack(keypoints) + 1) / 2.0, # B N 2
|
||||
"descriptors": torch.stack(descriptors), # B N D
|
||||
"keypoint_scores": torch.stack(kptscores), # B N
|
||||
"score_dispersity": torch.stack(scoredispersitys),
|
||||
"score_map": score_map, # Bx1xHxW
|
||||
}
|
||||
|
||||
def loss(self, pred, data):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,107 @@
|
|||
import torch
|
||||
import kornia
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..utils.misc import pad_and_stack
|
||||
|
||||
|
||||
class DISK(BaseModel):
|
||||
default_conf = {
|
||||
"weights": "depth",
|
||||
"dense_outputs": False,
|
||||
"max_num_keypoints": None,
|
||||
"desc_dim": 128,
|
||||
"nms_window_size": 5,
|
||||
"detection_threshold": 0.0,
|
||||
"force_num_keypoints": False,
|
||||
"pad_if_not_divisible": True,
|
||||
"chunk": 4, # for reduced VRAM in training
|
||||
}
|
||||
required_data_keys = ["image"]
|
||||
|
||||
def _init(self, conf):
|
||||
self.model = kornia.feature.DISK.from_pretrained(conf.weights)
|
||||
|
||||
def _get_dense_outputs(self, images):
|
||||
B = images.shape[0]
|
||||
if self.conf.pad_if_not_divisible:
|
||||
h, w = images.shape[2:]
|
||||
pd_h = 16 - h % 16 if h % 16 > 0 else 0
|
||||
pd_w = 16 - w % 16 if w % 16 > 0 else 0
|
||||
images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0)
|
||||
|
||||
heatmaps, descriptors = self.model.heatmap_and_dense_descriptors(images)
|
||||
if self.conf.pad_if_not_divisible:
|
||||
heatmaps = heatmaps[..., :h, :w]
|
||||
descriptors = descriptors[..., :h, :w]
|
||||
|
||||
keypoints = kornia.feature.disk.detector.heatmap_to_keypoints(
|
||||
heatmaps,
|
||||
n=self.conf.max_num_keypoints,
|
||||
window_size=self.conf.nms_window_size,
|
||||
score_threshold=self.conf.detection_threshold,
|
||||
)
|
||||
|
||||
features = []
|
||||
for i in range(B):
|
||||
features.append(keypoints[i].merge_with_descriptors(descriptors[i]))
|
||||
|
||||
return features, descriptors
|
||||
|
||||
def _forward(self, data):
|
||||
image = data["image"]
|
||||
|
||||
keypoints, scores, descriptors = [], [], []
|
||||
if self.conf.dense_outputs:
|
||||
dense_descriptors = []
|
||||
chunk = self.conf.chunk
|
||||
for i in range(0, image.shape[0], chunk):
|
||||
if self.conf.dense_outputs:
|
||||
features, d_descriptors = self._get_dense_outputs(
|
||||
image[: min(image.shape[0], i + chunk)]
|
||||
)
|
||||
dense_descriptors.append(d_descriptors)
|
||||
else:
|
||||
features = self.model(
|
||||
image[: min(image.shape[0], i + chunk)],
|
||||
n=self.conf.max_num_keypoints,
|
||||
window_size=self.conf.nms_window_size,
|
||||
score_threshold=self.conf.detection_threshold,
|
||||
pad_if_not_divisible=self.conf.pad_if_not_divisible,
|
||||
)
|
||||
keypoints += [f.keypoints for f in features]
|
||||
scores += [f.detection_scores for f in features]
|
||||
descriptors += [f.descriptors for f in features]
|
||||
del features
|
||||
|
||||
if self.conf.force_num_keypoints:
|
||||
# pad to target_length
|
||||
target_length = self.conf.max_num_keypoints
|
||||
keypoints = pad_and_stack(
|
||||
keypoints,
|
||||
target_length,
|
||||
-2,
|
||||
mode="random_c",
|
||||
bounds=(
|
||||
0,
|
||||
data.get("image_size", torch.tensor(image.shape[-2:])).min().item(),
|
||||
),
|
||||
)
|
||||
scores = pad_and_stack(scores, target_length, -1, mode="zeros")
|
||||
descriptors = pad_and_stack(descriptors, target_length, -2, mode="zeros")
|
||||
else:
|
||||
keypoints = torch.stack(keypoints, 0)
|
||||
scores = torch.stack(scores, 0)
|
||||
descriptors = torch.stack(descriptors, 0)
|
||||
|
||||
pred = {
|
||||
"keypoints": keypoints.to(image) + 0.5,
|
||||
"keypoint_scores": scores.to(image),
|
||||
"descriptors": descriptors.to(image),
|
||||
}
|
||||
if self.conf.dense_outputs:
|
||||
pred["dense_descriptors"] = torch.cat(dense_descriptors, 0)
|
||||
return pred
|
||||
|
||||
def loss(self, pred, data):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,59 @@
|
|||
import torch
|
||||
import math
|
||||
|
||||
from ..base_model import BaseModel
|
||||
|
||||
|
||||
def to_sequence(map):
|
||||
return map.flatten(-2).transpose(-1, -2)
|
||||
|
||||
|
||||
def to_map(sequence):
|
||||
n = sequence.shape[-2]
|
||||
e = math.isqrt(n)
|
||||
assert e * e == n
|
||||
assert e * e == n
|
||||
sequence.transpose(-1, -2).unflatten(-1, [e, e])
|
||||
|
||||
|
||||
class GridExtractor(BaseModel):
|
||||
default_conf = {"cell_size": 14}
|
||||
required_data_keys = ["image"]
|
||||
|
||||
def _init(self, conf):
|
||||
pass
|
||||
|
||||
def _forward(self, data):
|
||||
b, c, h, w = data["image"].shape
|
||||
|
||||
cgrid = (
|
||||
torch.stack(
|
||||
torch.meshgrid(
|
||||
torch.arange(
|
||||
h // self.conf.cell_size,
|
||||
dtype=torch.float32,
|
||||
device=data["image"].device,
|
||||
),
|
||||
torch.arange(
|
||||
w // self.conf.cell_size,
|
||||
dtype=torch.float32,
|
||||
device=data["image"].device,
|
||||
),
|
||||
indexing="ij",
|
||||
)[::-1],
|
||||
dim=0,
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.repeat([b, 1, 1, 1])
|
||||
* self.conf.cell_size
|
||||
+ self.conf.cell_size / 2
|
||||
)
|
||||
pred = {
|
||||
"grid": cgrid + 0.5,
|
||||
"keypoints": to_sequence(cgrid) + 0.5,
|
||||
}
|
||||
|
||||
return pred
|
||||
|
||||
def loss(self, pred, data):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,73 @@
|
|||
import torch
|
||||
import kornia
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..utils.misc import pad_to_length
|
||||
|
||||
|
||||
class KeyNetAffNetHardNet(BaseModel):
|
||||
default_conf = {
|
||||
"max_num_keypoints": None,
|
||||
"desc_dim": 128,
|
||||
"upright": False,
|
||||
"scale_laf": 1.0,
|
||||
"chunk": 4, # for reduced VRAM in training
|
||||
}
|
||||
required_data_keys = ["image"]
|
||||
|
||||
def _init(self, conf):
|
||||
self.model = kornia.feature.KeyNetHardNet(
|
||||
num_features=conf.max_num_keypoints,
|
||||
upright=conf.upright,
|
||||
scale_laf=conf.scale_laf,
|
||||
)
|
||||
|
||||
def _forward(self, data):
|
||||
image = data["image"]
|
||||
if image.shape[1] == 3: # RGB
|
||||
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
|
||||
image = (image * scale).sum(1, keepdim=True)
|
||||
lafs, scores, descs = [], [], []
|
||||
im_size = data.get("image_size")
|
||||
for i in range(image.shape[0]):
|
||||
img_i = image[i : i + 1, :1]
|
||||
if im_size is not None:
|
||||
img_i = img_i[:, :, : im_size[i, 1], : im_size[i, 0]]
|
||||
laf, score, desc = self.model(img_i)
|
||||
xn = pad_to_length(
|
||||
kornia.feature.get_laf_center(laf),
|
||||
self.conf.max_num_keypoints,
|
||||
pad_dim=-2,
|
||||
mode="random_c",
|
||||
bounds=(0, min(img_i.shape[-2:])),
|
||||
)
|
||||
laf = torch.cat(
|
||||
[
|
||||
laf,
|
||||
kornia.feature.laf_from_center_scale_ori(xn[:, score.shape[-1] :]),
|
||||
],
|
||||
-3,
|
||||
)
|
||||
lafs.append(laf)
|
||||
scores.append(pad_to_length(score, self.conf.max_num_keypoints, -1))
|
||||
descs.append(pad_to_length(desc, self.conf.max_num_keypoints, -2))
|
||||
|
||||
lafs = torch.cat(lafs, 0)
|
||||
scores = torch.cat(scores, 0)
|
||||
descs = torch.cat(descs, 0)
|
||||
keypoints = kornia.feature.get_laf_center(lafs)
|
||||
scales = kornia.feature.get_laf_scale(lafs)[..., 0]
|
||||
oris = kornia.feature.get_laf_orientation(lafs)
|
||||
pred = {
|
||||
"keypoints": keypoints,
|
||||
"scales": scales.squeeze(-1),
|
||||
"oris": oris.squeeze(-1),
|
||||
"lafs": lafs,
|
||||
"keypoint_scores": scores,
|
||||
"descriptors": descs,
|
||||
}
|
||||
|
||||
return pred
|
||||
|
||||
def loss(self, pred, data):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,78 @@
|
|||
from omegaconf import OmegaConf
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from .. import get_model
|
||||
|
||||
# from ...geometry.depth import sample_fmap
|
||||
|
||||
to_ctr = OmegaConf.to_container # convert DictConfig to dict
|
||||
|
||||
|
||||
class MixedExtractor(BaseModel):
|
||||
default_conf = {
|
||||
"detector": {"name": None},
|
||||
"descriptor": {"name": None},
|
||||
"interpolate_descriptors_from": None, # field name
|
||||
}
|
||||
|
||||
required_data_keys = ["image"]
|
||||
required_cache_keys = []
|
||||
|
||||
def _init(self, conf):
|
||||
if conf.detector.name:
|
||||
self.detector = get_model(conf.detector.name)(to_ctr(conf.detector))
|
||||
else:
|
||||
self.required_data_keys += ["cache"]
|
||||
self.required_cache_keys += ["keypoints"]
|
||||
|
||||
if conf.descriptor.name:
|
||||
self.descriptor = get_model(conf.descriptor.name)(to_ctr(conf.descriptor))
|
||||
else:
|
||||
self.required_data_keys += ["cache"]
|
||||
self.required_cache_keys += ["descriptors"]
|
||||
|
||||
def _forward(self, data):
|
||||
if self.conf.detector.name:
|
||||
pred = self.detector(data)
|
||||
else:
|
||||
pred = data["cache"]
|
||||
if self.conf.detector.name:
|
||||
pred = {**pred, **self.descriptor({**pred, **data})}
|
||||
|
||||
if self.conf.interpolate_descriptors_from:
|
||||
h, w = data["image"].shape[-2:]
|
||||
kpts = pred["keypoints"]
|
||||
pts = (kpts / kpts.new_tensor([[w, h]]) * 2 - 1)[:, None]
|
||||
pred["descriptors"] = (
|
||||
F.grid_sample(
|
||||
pred[self.conf.interpolate_descriptors_from],
|
||||
pts,
|
||||
align_corners=False,
|
||||
mode="bilinear",
|
||||
)
|
||||
.squeeze(-2)
|
||||
.transpose(-2, -1)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
return pred
|
||||
|
||||
def loss(self, pred, data):
|
||||
losses = {}
|
||||
metrics = {}
|
||||
total = 0
|
||||
|
||||
for k in ["detector", "descriptor"]:
|
||||
apply = True
|
||||
if "apply_loss" in self.conf[k].keys():
|
||||
apply = self.conf[k].apply_loss
|
||||
if self.conf[k].name and apply:
|
||||
try:
|
||||
losses_, metrics_ = getattr(self, k).loss(pred, {**pred, **data})
|
||||
except NotImplementedError:
|
||||
continue
|
||||
losses = {**losses, **losses_}
|
||||
metrics = {**metrics, **metrics_}
|
||||
total = losses_["total"] + total
|
||||
return {**losses, "total": total}, metrics
|
|
@ -0,0 +1,240 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import pycolmap
|
||||
from scipy.spatial import KDTree
|
||||
from omegaconf import OmegaConf
|
||||
import cv2
|
||||
|
||||
from ..base_model import BaseModel
|
||||
|
||||
from ..utils.misc import pad_to_length
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
def sift_to_rootsift(x):
|
||||
x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS)
|
||||
x = np.sqrt(x.clip(min=EPS))
|
||||
x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS)
|
||||
return x
|
||||
|
||||
|
||||
# from OpenGlue
|
||||
def nms_keypoints(kpts: np.ndarray, responses: np.ndarray, radius: float) -> np.ndarray:
|
||||
# TODO: add approximate tree
|
||||
kd_tree = KDTree(kpts)
|
||||
|
||||
sorted_idx = np.argsort(-responses)
|
||||
kpts_to_keep_idx = []
|
||||
removed_idx = set()
|
||||
|
||||
for idx in sorted_idx:
|
||||
# skip point if it was already removed
|
||||
if idx in removed_idx:
|
||||
continue
|
||||
|
||||
kpts_to_keep_idx.append(idx)
|
||||
point = kpts[idx]
|
||||
neighbors = kd_tree.query_ball_point(point, r=radius)
|
||||
# Variable `neighbors` contains the `point` itself
|
||||
removed_idx.update(neighbors)
|
||||
|
||||
mask = np.zeros((kpts.shape[0],), dtype=bool)
|
||||
mask[kpts_to_keep_idx] = True
|
||||
return mask
|
||||
|
||||
|
||||
def detect_kpts_opencv(
|
||||
features: cv2.Feature2D, image: np.ndarray, describe: bool = True
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Detect keypoints using OpenCV Detector.
|
||||
Optionally, perform NMS and filter top-response keypoints.
|
||||
Optionally, perform description.
|
||||
Args:
|
||||
features: OpenCV based keypoints detector and descriptor
|
||||
image: Grayscale image of uint8 data type
|
||||
describe: flag indicating whether to simultaneously compute descriptors
|
||||
Returns:
|
||||
kpts: 1D array of detected cv2.KeyPoint
|
||||
"""
|
||||
if describe:
|
||||
kpts, descriptors = features.detectAndCompute(image, None)
|
||||
else:
|
||||
kpts = features.detect(image, None)
|
||||
kpts = np.array(kpts)
|
||||
|
||||
responses = np.array([k.response for k in kpts], dtype=np.float32)
|
||||
|
||||
# select all
|
||||
top_score_idx = ...
|
||||
pts = np.array([k.pt for k in kpts], dtype=np.float32)
|
||||
scales = np.array([k.size for k in kpts], dtype=np.float32)
|
||||
angles = np.array([k.angle for k in kpts], dtype=np.float32)
|
||||
spts = np.concatenate([pts, scales[..., None], angles[..., None]], -1)
|
||||
|
||||
if describe:
|
||||
return spts[top_score_idx], responses[top_score_idx], descriptors[top_score_idx]
|
||||
else:
|
||||
return spts[top_score_idx], responses[top_score_idx]
|
||||
|
||||
|
||||
class SIFT(BaseModel):
|
||||
default_conf = {
|
||||
"has_detector": True,
|
||||
"has_descriptor": True,
|
||||
"descriptor_dim": 128,
|
||||
"pycolmap_options": {
|
||||
"first_octave": 0,
|
||||
"peak_threshold": 0.005,
|
||||
"edge_threshold": 10,
|
||||
},
|
||||
"rootsift": True,
|
||||
"nms_radius": None,
|
||||
"max_num_keypoints": -1,
|
||||
"max_num_keypoints_val": None,
|
||||
"force_num_keypoints": False,
|
||||
"randomize_keypoints_training": False,
|
||||
"detector": "pycolmap", # ['pycolmap', 'pycolmap_cpu', 'pycolmap_cuda', 'cv2']
|
||||
"detection_threshold": None,
|
||||
}
|
||||
|
||||
required_data_keys = ["image"]
|
||||
|
||||
def _init(self, conf):
|
||||
self.sift = None # lazy loading
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_features(self, image):
|
||||
image_np = image.cpu().numpy()[0]
|
||||
assert image.shape[0] == 1
|
||||
assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS
|
||||
|
||||
detector = str(self.conf.detector)
|
||||
|
||||
if self.sift is None and detector.startswith("pycolmap"):
|
||||
options = OmegaConf.to_container(self.conf.pycolmap_options)
|
||||
device = (
|
||||
"auto" if detector == "pycolmap" else detector.replace("pycolmap_", "")
|
||||
)
|
||||
if self.conf.rootsift == "rootsift":
|
||||
options["normalization"] = pycolmap.Normalization.L1_ROOT
|
||||
else:
|
||||
options["normalization"] = pycolmap.Normalization.L2
|
||||
if self.conf.detection_threshold is not None:
|
||||
options["peak_threshold"] = self.conf.detection_threshold
|
||||
options["max_num_features"] = self.conf.max_num_keypoints
|
||||
self.sift = pycolmap.Sift(options=options, device=device)
|
||||
elif self.sift is None and self.conf.detector == "cv2":
|
||||
self.sift = cv2.SIFT_create(contrastThreshold=self.conf.detection_threshold)
|
||||
|
||||
if detector.startswith("pycolmap"):
|
||||
keypoints, scores, descriptors = self.sift.extract(image_np)
|
||||
elif detector == "cv2":
|
||||
# TODO: Check if opencv keypoints are already in corner convention
|
||||
keypoints, scores, descriptors = detect_kpts_opencv(
|
||||
self.sift, (image_np * 255.0).astype(np.uint8)
|
||||
)
|
||||
|
||||
if self.conf.nms_radius is not None:
|
||||
mask = nms_keypoints(keypoints[:, :2], scores, self.conf.nms_radius)
|
||||
keypoints = keypoints[mask]
|
||||
scores = scores[mask]
|
||||
descriptors = descriptors[mask]
|
||||
|
||||
scales = keypoints[:, 2]
|
||||
oris = np.rad2deg(keypoints[:, 3])
|
||||
|
||||
if self.conf.has_descriptor:
|
||||
# We still renormalize because COLMAP does not normalize well,
|
||||
# maybe due to numerical errors
|
||||
if self.conf.rootsift:
|
||||
descriptors = sift_to_rootsift(descriptors)
|
||||
descriptors = torch.from_numpy(descriptors)
|
||||
keypoints = torch.from_numpy(keypoints[:, :2]) # keep only x, y
|
||||
scales = torch.from_numpy(scales)
|
||||
oris = torch.from_numpy(oris)
|
||||
scores = torch.from_numpy(scores)
|
||||
|
||||
# Keep the k keypoints with highest score
|
||||
max_kps = self.conf.max_num_keypoints
|
||||
|
||||
# for val we allow different
|
||||
if not self.training and self.conf.max_num_keypoints_val is not None:
|
||||
max_kps = self.conf.max_num_keypoints_val
|
||||
|
||||
if max_kps is not None and max_kps > 0:
|
||||
if self.conf.randomize_keypoints_training and self.training:
|
||||
# instead of selecting top-k, sample k by score weights
|
||||
raise NotImplementedError
|
||||
elif max_kps < scores.shape[0]:
|
||||
# TODO: check that the scores from PyCOLMAP are 100% correct,
|
||||
# follow https://github.com/mihaidusmanu/pycolmap/issues/8
|
||||
indices = torch.topk(scores, max_kps).indices
|
||||
keypoints = keypoints[indices]
|
||||
scales = scales[indices]
|
||||
oris = oris[indices]
|
||||
scores = scores[indices]
|
||||
if self.conf.has_descriptor:
|
||||
descriptors = descriptors[indices]
|
||||
|
||||
if self.conf.force_num_keypoints:
|
||||
keypoints = pad_to_length(
|
||||
keypoints,
|
||||
max_kps,
|
||||
-2,
|
||||
mode="random_c",
|
||||
bounds=(0, min(image.shape[1:])),
|
||||
)
|
||||
scores = pad_to_length(scores, max_kps, -1, mode="zeros")
|
||||
scales = pad_to_length(scales, max_kps, -1, mode="zeros")
|
||||
oris = pad_to_length(oris, max_kps, -1, mode="zeros")
|
||||
if self.conf.has_descriptor:
|
||||
descriptors = pad_to_length(descriptors, max_kps, -2, mode="zeros")
|
||||
|
||||
pred = {
|
||||
"keypoints": keypoints,
|
||||
"scales": scales,
|
||||
"oris": oris,
|
||||
"keypoint_scores": scores,
|
||||
}
|
||||
|
||||
if self.conf.has_descriptor:
|
||||
pred["descriptors"] = descriptors
|
||||
return pred
|
||||
|
||||
@torch.no_grad()
|
||||
def _forward(self, data):
|
||||
pred = {
|
||||
"keypoints": [],
|
||||
"scales": [],
|
||||
"oris": [],
|
||||
"keypoint_scores": [],
|
||||
"descriptors": [],
|
||||
}
|
||||
|
||||
image = data["image"]
|
||||
if image.shape[1] == 3: # RGB
|
||||
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
|
||||
image = (image * scale).sum(1, keepdim=True).cpu()
|
||||
|
||||
for k in range(image.shape[0]):
|
||||
img = image[k]
|
||||
if "image_size" in data.keys():
|
||||
# avoid extracting points in padded areas
|
||||
w, h = data["image_size"][k]
|
||||
img = img[:, :h, :w]
|
||||
p = self.extract_features(img)
|
||||
for k, v in p.items():
|
||||
pred[k].append(v)
|
||||
|
||||
if (image.shape[0] == 1) or self.conf.force_num_keypoints:
|
||||
pred = {k: torch.stack(pred[k], 0) for k in pred.keys()}
|
||||
|
||||
pred = {k: pred[k].to(device=data["image"].device) for k in pred.keys()}
|
||||
|
||||
pred["oris"] = torch.deg2rad(pred["oris"])
|
||||
return pred
|
||||
|
||||
def loss(self, pred, data):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,45 @@
|
|||
import kornia
|
||||
import torch
|
||||
|
||||
from ..base_model import BaseModel
|
||||
|
||||
|
||||
class KorniaSIFT(BaseModel):
|
||||
default_conf = {
|
||||
"has_detector": True,
|
||||
"has_descriptor": True,
|
||||
"max_num_keypoints": -1,
|
||||
"detection_threshold": None,
|
||||
"rootsift": True,
|
||||
}
|
||||
|
||||
required_data_keys = ["image"]
|
||||
|
||||
def _init(self, conf):
|
||||
self.sift = kornia.feature.SIFTFeature(
|
||||
num_features=self.conf.max_num_keypoints, rootsift=self.conf.rootsift
|
||||
)
|
||||
|
||||
def _forward(self, data):
|
||||
lafs, scores, descriptors = self.sift(data["image"])
|
||||
keypoints = kornia.feature.get_laf_center(lafs)
|
||||
scales = kornia.feature.get_laf_scale(lafs)
|
||||
oris = kornia.feature.get_laf_orientation(lafs)
|
||||
pred = {
|
||||
"keypoints": keypoints, # @TODO: confirm keypoints are in corner convention
|
||||
"scales": scales,
|
||||
"oris": oris,
|
||||
"keypoint_scores": scores,
|
||||
}
|
||||
|
||||
if self.conf.has_descriptor:
|
||||
pred["descriptors"] = descriptors
|
||||
|
||||
pred = {k: pred[k].to(device=data["image"].device) for k in pred.keys()}
|
||||
|
||||
pred["scales"] = pred["scales"]
|
||||
pred["oris"] = torch.deg2rad(pred["oris"])
|
||||
return pred
|
||||
|
||||
def loss(self, pred, data):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,209 @@
|
|||
"""PyTorch implementation of the SuperPoint model,
|
||||
derived from the TensorFlow re-implementation (2018).
|
||||
Authors: Rémi Pautrat, Paul-Edouard Sarlin
|
||||
https://github.com/rpautrat/SuperPoint
|
||||
The implementation of this model and its trained weights are made
|
||||
available under the MIT license.
|
||||
"""
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from types import SimpleNamespace
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..utils.misc import pad_and_stack
|
||||
|
||||
|
||||
def sample_descriptors(keypoints, descriptors, s: int = 8):
|
||||
"""Interpolate descriptors at keypoint locations"""
|
||||
b, c, h, w = descriptors.shape
|
||||
keypoints = (keypoints + 0.5) / (keypoints.new_tensor([w, h]) * s)
|
||||
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
||||
descriptors = torch.nn.functional.grid_sample(
|
||||
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False
|
||||
)
|
||||
descriptors = torch.nn.functional.normalize(
|
||||
descriptors.reshape(b, c, -1), p=2, dim=1
|
||||
)
|
||||
return descriptors
|
||||
|
||||
|
||||
def batched_nms(scores, nms_radius: int):
|
||||
assert nms_radius >= 0
|
||||
|
||||
def max_pool(x):
|
||||
return torch.nn.functional.max_pool2d(
|
||||
x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
||||
)
|
||||
|
||||
zeros = torch.zeros_like(scores)
|
||||
max_mask = scores == max_pool(scores)
|
||||
for _ in range(2):
|
||||
supp_mask = max_pool(max_mask.float()) > 0
|
||||
supp_scores = torch.where(supp_mask, zeros, scores)
|
||||
new_max_mask = supp_scores == max_pool(supp_scores)
|
||||
max_mask = max_mask | (new_max_mask & (~supp_mask))
|
||||
return torch.where(max_mask, scores, zeros)
|
||||
|
||||
|
||||
def select_top_k_keypoints(keypoints, scores, k):
|
||||
if k >= len(keypoints):
|
||||
return keypoints, scores
|
||||
scores, indices = torch.topk(scores, k, dim=0, sorted=True)
|
||||
return keypoints[indices], scores
|
||||
|
||||
|
||||
class VGGBlock(nn.Sequential):
|
||||
def __init__(self, c_in, c_out, kernel_size, relu=True):
|
||||
padding = (kernel_size - 1) // 2
|
||||
conv = nn.Conv2d(
|
||||
c_in, c_out, kernel_size=kernel_size, stride=1, padding=padding
|
||||
)
|
||||
activation = nn.ReLU(inplace=True) if relu else nn.Identity()
|
||||
bn = nn.BatchNorm2d(c_out, eps=0.001)
|
||||
super().__init__(
|
||||
OrderedDict(
|
||||
[
|
||||
("conv", conv),
|
||||
("activation", activation),
|
||||
("bn", bn),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SuperPoint(BaseModel):
|
||||
default_conf = {
|
||||
"descriptor_dim": 256,
|
||||
"nms_radius": 4,
|
||||
"max_num_keypoints": None,
|
||||
"force_num_keypoints": False,
|
||||
"detection_threshold": 0.005,
|
||||
"remove_borders": 4,
|
||||
"descriptor_dim": 256,
|
||||
"channels": [64, 64, 128, 128, 256],
|
||||
"dense_outputs": None,
|
||||
}
|
||||
|
||||
checkpoint_url = "https://github.com/rpautrat/SuperPoint/raw/master/weights/superpoint_v6_from_tf.pth" # noqa: E501
|
||||
|
||||
def _init(self, conf):
|
||||
self.conf = SimpleNamespace(**conf)
|
||||
self.stride = 2 ** (len(self.conf.channels) - 2)
|
||||
channels = [1, *self.conf.channels[:-1]]
|
||||
|
||||
backbone = []
|
||||
for i, c in enumerate(channels[1:], 1):
|
||||
layers = [VGGBlock(channels[i - 1], c, 3), VGGBlock(c, c, 3)]
|
||||
if i < len(channels) - 1:
|
||||
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
|
||||
backbone.append(nn.Sequential(*layers))
|
||||
self.backbone = nn.Sequential(*backbone)
|
||||
|
||||
c = self.conf.channels[-1]
|
||||
self.detector = nn.Sequential(
|
||||
VGGBlock(channels[-1], c, 3),
|
||||
VGGBlock(c, self.stride**2 + 1, 1, relu=False),
|
||||
)
|
||||
self.descriptor = nn.Sequential(
|
||||
VGGBlock(channels[-1], c, 3),
|
||||
VGGBlock(c, self.conf.descriptor_dim, 1, relu=False),
|
||||
)
|
||||
|
||||
state_dict = torch.hub.load_state_dict_from_url(self.checkpoint_url)
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
def _forward(self, data):
|
||||
image = data["image"]
|
||||
if image.shape[1] == 3: # RGB
|
||||
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
|
||||
image = (image * scale).sum(1, keepdim=True)
|
||||
features = self.backbone(image)
|
||||
descriptors_dense = torch.nn.functional.normalize(
|
||||
self.descriptor(features), p=2, dim=1
|
||||
)
|
||||
|
||||
# Decode the detection scores
|
||||
scores = self.detector(features)
|
||||
scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
|
||||
b, _, h, w = scores.shape
|
||||
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, self.stride, self.stride)
|
||||
scores = scores.permute(0, 1, 3, 2, 4).reshape(
|
||||
b, h * self.stride, w * self.stride
|
||||
)
|
||||
scores = batched_nms(scores, self.conf.nms_radius)
|
||||
|
||||
# Discard keypoints near the image borders
|
||||
if self.conf.remove_borders:
|
||||
pad = self.conf.remove_borders
|
||||
scores[:, :pad] = -1
|
||||
scores[:, :, :pad] = -1
|
||||
scores[:, -pad:] = -1
|
||||
scores[:, :, -pad:] = -1
|
||||
|
||||
# Extract keypoints
|
||||
if b > 1:
|
||||
idxs = torch.where(scores > self.conf.detection_threshold)
|
||||
mask = idxs[0] == torch.arange(b, device=scores.device)[:, None]
|
||||
else: # Faster shortcut
|
||||
scores = scores.squeeze(0)
|
||||
idxs = torch.where(scores > self.conf.detection_threshold)
|
||||
|
||||
# Convert (i, j) to (x, y)
|
||||
keypoints_all = torch.stack(idxs[-2:], dim=-1).flip(1).float()
|
||||
scores_all = scores[idxs]
|
||||
|
||||
keypoints = []
|
||||
scores = []
|
||||
for i in range(b):
|
||||
if b > 1:
|
||||
k = keypoints_all[mask[i]]
|
||||
s = scores_all[mask[i]]
|
||||
else:
|
||||
k = keypoints_all
|
||||
s = scores_all
|
||||
if self.conf.max_num_keypoints is not None:
|
||||
k, s = select_top_k_keypoints(k, s, self.conf.max_num_keypoints)
|
||||
|
||||
keypoints.append(k)
|
||||
scores.append(s)
|
||||
|
||||
if self.conf.force_num_keypoints:
|
||||
keypoints = pad_and_stack(
|
||||
keypoints,
|
||||
self.conf.max_num_keypoints,
|
||||
-2,
|
||||
mode="random_c",
|
||||
bounds=(
|
||||
0,
|
||||
data.get("image_size", torch.tensor(image.shape[-2:])).min().item(),
|
||||
),
|
||||
)
|
||||
scores = pad_and_stack(
|
||||
scores, self.conf.max_num_keypoints, -1, mode="zeros"
|
||||
)
|
||||
else:
|
||||
keypoints = torch.stack(keypoints, 0)
|
||||
scores = torch.stack(scores, 0)
|
||||
|
||||
if len(keypoints) == 1 or self.conf.force_num_keypoints:
|
||||
# Batch sampling of the descriptors
|
||||
desc = sample_descriptors(keypoints, descriptors_dense, self.stride)
|
||||
else:
|
||||
desc = [
|
||||
sample_descriptors(k[None], d[None], self.stride)[0]
|
||||
for k, d in zip(keypoints, descriptors_dense)
|
||||
]
|
||||
|
||||
pred = {
|
||||
"keypoints": keypoints + 0.5,
|
||||
"keypoint_scores": scores,
|
||||
"descriptors": desc.transpose(-1, -2),
|
||||
}
|
||||
if self.conf.dense_outputs:
|
||||
pred["dense_descriptors"] = descriptors_dense
|
||||
|
||||
return pred
|
||||
|
||||
def loss(self, pred, data):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,105 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import deeplsd.models.deeplsd_inference as deeplsd_inference
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ...settings import DATA_PATH
|
||||
|
||||
|
||||
class DeepLSD(BaseModel):
|
||||
default_conf = {
|
||||
"min_length": 15,
|
||||
"max_num_lines": None,
|
||||
"force_num_lines": False,
|
||||
"model_conf": {
|
||||
"detect_lines": True,
|
||||
"line_detection_params": {
|
||||
"merge": False,
|
||||
"grad_nfa": True,
|
||||
"filtering": "normal",
|
||||
"grad_thresh": 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
required_data_keys = ["image"]
|
||||
|
||||
def _init(self, conf):
|
||||
if self.conf.force_num_lines:
|
||||
assert (
|
||||
self.conf.max_num_lines is not None
|
||||
), "Missing max_num_lines parameter"
|
||||
ckpt = DATA_PATH / "weights/deeplsd_md.tar"
|
||||
if not ckpt.is_file():
|
||||
self.download_model(ckpt)
|
||||
ckpt = torch.load(ckpt, map_location="cpu")
|
||||
self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval()
|
||||
self.net.load_state_dict(ckpt["model"])
|
||||
|
||||
def download_model(self, path):
|
||||
import subprocess
|
||||
|
||||
if not path.parent.is_dir():
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
link = "https://www.polybox.ethz.ch/index.php/s/XVb30sUyuJttFys/download"
|
||||
cmd = ["wget", link, "-O", path]
|
||||
print("Downloading DeepLSD model...")
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
def _forward(self, data):
|
||||
image = data["image"]
|
||||
lines, line_scores, valid_lines = [], [], []
|
||||
if image.shape[1] == 3:
|
||||
# Convert to grayscale
|
||||
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
|
||||
image = (image * scale).sum(1, keepdim=True)
|
||||
|
||||
# Forward pass
|
||||
with torch.no_grad():
|
||||
segs = self.net({"image": image})["lines"]
|
||||
|
||||
# Line scores are the sqrt of the length
|
||||
for seg in segs:
|
||||
lengths = np.linalg.norm(seg[:, 0] - seg[:, 1], axis=1)
|
||||
segs = seg[lengths >= self.conf.min_length]
|
||||
scores = np.sqrt(lengths[lengths >= self.conf.min_length])
|
||||
|
||||
# Keep the best lines
|
||||
indices = np.argsort(-scores)
|
||||
if self.conf.max_num_lines is not None:
|
||||
indices = indices[: self.conf.max_num_lines]
|
||||
segs = segs[indices]
|
||||
scores = scores[indices]
|
||||
|
||||
# Pad if necessary
|
||||
n = len(segs)
|
||||
valid_mask = np.ones(n, dtype=bool)
|
||||
if self.conf.force_num_lines:
|
||||
pad = self.conf.max_num_lines - n
|
||||
segs = np.concatenate(
|
||||
[segs, np.zeros((pad, 2, 2), dtype=np.float32)], axis=0
|
||||
)
|
||||
scores = np.concatenate(
|
||||
[scores, np.zeros(pad, dtype=np.float32)], axis=0
|
||||
)
|
||||
valid_mask = np.concatenate(
|
||||
[valid_mask, np.zeros(pad, dtype=bool)], axis=0
|
||||
)
|
||||
|
||||
lines.append(segs)
|
||||
line_scores.append(scores)
|
||||
valid_lines.append(valid_mask)
|
||||
|
||||
# Batch if possible
|
||||
if len(image) == 1 or self.conf.force_num_lines:
|
||||
lines = torch.tensor(lines, dtype=torch.float, device=image.device)
|
||||
line_scores = torch.tensor(
|
||||
line_scores, dtype=torch.float, device=image.device
|
||||
)
|
||||
valid_lines = torch.tensor(
|
||||
valid_lines, dtype=torch.bool, device=image.device
|
||||
)
|
||||
|
||||
return {"lines": lines, "line_scores": line_scores, "valid_lines": valid_lines}
|
||||
|
||||
def loss(self, pred, data):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,88 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from joblib import Parallel, delayed
|
||||
from pytlsd import lsd
|
||||
|
||||
from ..base_model import BaseModel
|
||||
|
||||
|
||||
class LSD(BaseModel):
|
||||
default_conf = {
|
||||
"min_length": 15,
|
||||
"max_num_lines": None,
|
||||
"force_num_lines": False,
|
||||
"n_jobs": 4,
|
||||
}
|
||||
required_data_keys = ["image"]
|
||||
|
||||
def _init(self, conf):
|
||||
if self.conf.force_num_lines:
|
||||
assert (
|
||||
self.conf.max_num_lines is not None
|
||||
), "Missing max_num_lines parameter"
|
||||
|
||||
def detect_lines(self, img):
|
||||
# Run LSD
|
||||
segs = lsd(img)
|
||||
|
||||
# Filter out keylines that do not meet the minimum length criteria
|
||||
lengths = np.linalg.norm(segs[:, 2:4] - segs[:, 0:2], axis=1)
|
||||
to_keep = lengths >= self.conf.min_length
|
||||
segs, lengths = segs[to_keep], lengths[to_keep]
|
||||
|
||||
# Keep the best lines
|
||||
scores = segs[:, -1] * np.sqrt(lengths)
|
||||
segs = segs[:, :4].reshape(-1, 2, 2)
|
||||
indices = np.argsort(-scores)
|
||||
if self.conf.max_num_lines is not None:
|
||||
indices = indices[: self.conf.max_num_lines]
|
||||
segs = segs[indices]
|
||||
scores = scores[indices]
|
||||
|
||||
# Pad if necessary
|
||||
n = len(segs)
|
||||
valid_mask = np.ones(n, dtype=bool)
|
||||
if self.conf.force_num_lines:
|
||||
pad = self.conf.max_num_lines - n
|
||||
segs = np.concatenate(
|
||||
[segs, np.zeros((pad, 2, 2), dtype=np.float32)], axis=0
|
||||
)
|
||||
scores = np.concatenate([scores, np.zeros(pad, dtype=np.float32)], axis=0)
|
||||
valid_mask = np.concatenate([valid_mask, np.zeros(pad, dtype=bool)], axis=0)
|
||||
|
||||
return segs, scores, valid_mask
|
||||
|
||||
def _forward(self, data):
|
||||
# Convert to the right data format
|
||||
image = data["image"]
|
||||
if image.shape[1] == 3:
|
||||
# Convert to grayscale
|
||||
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
|
||||
image = (image * scale).sum(1, keepdim=True)
|
||||
device = image.device
|
||||
b_size = len(image)
|
||||
image = np.uint8(image.squeeze(1).cpu().numpy() * 255)
|
||||
|
||||
# LSD detection in parallel
|
||||
if b_size == 1:
|
||||
lines, line_scores, valid_lines = self.detect_lines(image[0])
|
||||
lines = [lines]
|
||||
line_scores = [line_scores]
|
||||
valid_lines = [valid_lines]
|
||||
else:
|
||||
lines, line_scores, valid_lines = zip(
|
||||
*Parallel(n_jobs=self.conf.n_jobs)(
|
||||
delayed(self.detect_lines)(img) for img in image
|
||||
)
|
||||
)
|
||||
|
||||
# Batch if possible
|
||||
if b_size == 1 or self.conf.force_num_lines:
|
||||
lines = torch.tensor(lines, dtype=torch.float, device=device)
|
||||
line_scores = torch.tensor(line_scores, dtype=torch.float, device=device)
|
||||
valid_lines = torch.tensor(valid_lines, dtype=torch.bool, device=device)
|
||||
|
||||
return {"lines": lines, "line_scores": line_scores, "valid_lines": valid_lines}
|
||||
|
||||
def loss(self, pred, data):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,312 @@
|
|||
import torch
|
||||
from sklearn.cluster import DBSCAN
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from .. import get_model
|
||||
|
||||
|
||||
def sample_descriptors_corner_conv(keypoints, descriptors, s: int = 8):
|
||||
"""Interpolate descriptors at keypoint locations"""
|
||||
b, c, h, w = descriptors.shape
|
||||
keypoints = keypoints / (keypoints.new_tensor([w, h]) * s)
|
||||
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
||||
descriptors = torch.nn.functional.grid_sample(
|
||||
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False
|
||||
)
|
||||
descriptors = torch.nn.functional.normalize(
|
||||
descriptors.reshape(b, c, -1), p=2, dim=1
|
||||
)
|
||||
return descriptors
|
||||
|
||||
|
||||
def lines_to_wireframe(
|
||||
lines, line_scores, all_descs, s, nms_radius, force_num_lines, max_num_lines
|
||||
):
|
||||
"""Given a set of lines, their score and dense descriptors,
|
||||
merge close-by endpoints and compute a wireframe defined by
|
||||
its junctions and connectivity.
|
||||
Returns:
|
||||
junctions: list of [num_junc, 2] tensors listing all wireframe junctions
|
||||
junc_scores: list of [num_junc] tensors with the junction score
|
||||
junc_descs: list of [dim, num_junc] tensors with the junction descriptors
|
||||
connectivity: list of [num_junc, num_junc] bool arrays with True when 2
|
||||
junctions are connected
|
||||
new_lines: the new set of [b_size, num_lines, 2, 2] lines
|
||||
lines_junc_idx: a [b_size, num_lines, 2] tensor with the indices of the
|
||||
junctions of each endpoint
|
||||
num_true_junctions: a list of the number of valid junctions for each image
|
||||
in the batch, i.e. before filling with random ones
|
||||
"""
|
||||
b_size, _, h, w = all_descs.shape
|
||||
device = lines.device
|
||||
h, w = h * s, w * s
|
||||
endpoints = lines.reshape(b_size, -1, 2)
|
||||
|
||||
(
|
||||
junctions,
|
||||
junc_scores,
|
||||
connectivity,
|
||||
new_lines,
|
||||
lines_junc_idx,
|
||||
num_true_junctions,
|
||||
) = ([], [], [], [], [], [])
|
||||
for bs in range(b_size):
|
||||
# Cluster the junctions that are close-by
|
||||
db = DBSCAN(eps=nms_radius, min_samples=1).fit(endpoints[bs].cpu().numpy())
|
||||
clusters = db.labels_
|
||||
n_clusters = len(set(clusters))
|
||||
num_true_junctions.append(n_clusters)
|
||||
|
||||
# Compute the average junction and score for each cluster
|
||||
clusters = torch.tensor(clusters, dtype=torch.long, device=device)
|
||||
new_junc = torch.zeros(n_clusters, 2, dtype=torch.float, device=device)
|
||||
new_junc.scatter_reduce_(
|
||||
0,
|
||||
clusters[:, None].repeat(1, 2),
|
||||
endpoints[bs],
|
||||
reduce="mean",
|
||||
include_self=False,
|
||||
)
|
||||
junctions.append(new_junc)
|
||||
new_scores = torch.zeros(n_clusters, dtype=torch.float, device=device)
|
||||
new_scores.scatter_reduce_(
|
||||
0,
|
||||
clusters,
|
||||
torch.repeat_interleave(line_scores[bs], 2),
|
||||
reduce="mean",
|
||||
include_self=False,
|
||||
)
|
||||
junc_scores.append(new_scores)
|
||||
|
||||
# Compute the new lines
|
||||
new_lines.append(junctions[-1][clusters].reshape(-1, 2, 2))
|
||||
lines_junc_idx.append(clusters.reshape(-1, 2))
|
||||
|
||||
if force_num_lines:
|
||||
# Add random junctions (with no connectivity)
|
||||
missing = max_num_lines * 2 - len(junctions[-1])
|
||||
junctions[-1] = torch.cat(
|
||||
[
|
||||
junctions[-1],
|
||||
torch.rand(missing, 2).to(lines)
|
||||
* lines.new_tensor([[w - 1, h - 1]]),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
junc_scores[-1] = torch.cat(
|
||||
[junc_scores[-1], torch.zeros(missing).to(lines)], dim=0
|
||||
)
|
||||
|
||||
junc_connect = torch.eye(max_num_lines * 2, dtype=torch.bool, device=device)
|
||||
pairs = clusters.reshape(-1, 2) # these pairs are connected by a line
|
||||
junc_connect[pairs[:, 0], pairs[:, 1]] = True
|
||||
junc_connect[pairs[:, 1], pairs[:, 0]] = True
|
||||
connectivity.append(junc_connect)
|
||||
else:
|
||||
# Compute the junction connectivity
|
||||
junc_connect = torch.eye(n_clusters, dtype=torch.bool, device=device)
|
||||
pairs = clusters.reshape(-1, 2) # these pairs are connected by a line
|
||||
junc_connect[pairs[:, 0], pairs[:, 1]] = True
|
||||
junc_connect[pairs[:, 1], pairs[:, 0]] = True
|
||||
connectivity.append(junc_connect)
|
||||
|
||||
junctions = torch.stack(junctions, dim=0)
|
||||
new_lines = torch.stack(new_lines, dim=0)
|
||||
lines_junc_idx = torch.stack(lines_junc_idx, dim=0)
|
||||
|
||||
# Interpolate the new junction descriptors
|
||||
junc_descs = sample_descriptors_corner_conv(junctions, all_descs, s).mT
|
||||
|
||||
return (
|
||||
junctions,
|
||||
junc_scores,
|
||||
junc_descs,
|
||||
connectivity,
|
||||
new_lines,
|
||||
lines_junc_idx,
|
||||
num_true_junctions,
|
||||
)
|
||||
|
||||
|
||||
class WireframeExtractor(BaseModel):
|
||||
default_conf = {
|
||||
"point_extractor": {
|
||||
"name": None,
|
||||
"trainable": False,
|
||||
"dense_outputs": True,
|
||||
"max_num_keypoints": None,
|
||||
"force_num_keypoints": False,
|
||||
},
|
||||
"line_extractor": {
|
||||
"name": None,
|
||||
"trainable": False,
|
||||
"max_num_lines": None,
|
||||
"force_num_lines": False,
|
||||
"min_length": 15,
|
||||
},
|
||||
"wireframe_params": {
|
||||
"merge_points": True,
|
||||
"merge_line_endpoints": True,
|
||||
"nms_radius": 3,
|
||||
},
|
||||
}
|
||||
required_data_keys = ["image"]
|
||||
|
||||
def _init(self, conf):
|
||||
self.point_extractor = get_model(self.conf.point_extractor.name)(
|
||||
self.conf.point_extractor
|
||||
)
|
||||
self.line_extractor = get_model(self.conf.line_extractor.name)(
|
||||
self.conf.line_extractor
|
||||
)
|
||||
|
||||
def _forward(self, data):
|
||||
b_size, _, h, w = data["image"].shape
|
||||
device = data["image"].device
|
||||
|
||||
if (
|
||||
not self.conf.point_extractor.force_num_keypoints
|
||||
or not self.conf.line_extractor.force_num_lines
|
||||
):
|
||||
assert b_size == 1, "Only batch size of 1 accepted for non padded inputs"
|
||||
|
||||
# Line detection
|
||||
pred = self.line_extractor(data)
|
||||
if pred["line_scores"].shape[-1] != 0:
|
||||
pred["line_scores"] /= pred["line_scores"].max(dim=1)[0][:, None] + 1e-8
|
||||
|
||||
# Keypoint prediction
|
||||
pred = {**pred, **self.point_extractor(data)}
|
||||
assert (
|
||||
"dense_descriptors" in pred
|
||||
), "The KP extractor should return dense descriptors"
|
||||
s_desc = data["image"].shape[2] // pred["dense_descriptors"].shape[2]
|
||||
|
||||
# Remove keypoints that are too close to line endpoints
|
||||
if self.conf.wireframe_params.merge_points:
|
||||
line_endpts = pred["lines"].reshape(b_size, -1, 2)
|
||||
dist_pt_lines = torch.norm(
|
||||
pred["keypoints"][:, :, None] - line_endpts[:, None], dim=-1
|
||||
)
|
||||
# For each keypoint, mark it as valid or to remove
|
||||
pts_to_remove = torch.any(
|
||||
dist_pt_lines < self.conf.wireframe_params.nms_radius, dim=2
|
||||
)
|
||||
if self.conf.point_extractor.force_num_keypoints:
|
||||
# Replace the points with random ones
|
||||
num_to_remove = pts_to_remove.int().sum().item()
|
||||
pred["keypoints"][pts_to_remove] = torch.rand(
|
||||
num_to_remove, 2, device=device
|
||||
) * pred["keypoints"].new_tensor([[w - 1, h - 1]])
|
||||
pred["keypoint_scores"][pts_to_remove] = 0
|
||||
for bs in range(b_size):
|
||||
descrs = sample_descriptors_corner_conv(
|
||||
pred["keypoints"][bs][pts_to_remove[bs]][None],
|
||||
pred["dense_descriptors"][bs][None],
|
||||
s_desc,
|
||||
)
|
||||
pred["descriptors"][bs][pts_to_remove[bs]] = descrs[0].T
|
||||
else:
|
||||
# Simply remove them (we assume batch_size = 1 here)
|
||||
assert len(pred["keypoints"]) == 1
|
||||
pred["keypoints"] = pred["keypoints"][0][~pts_to_remove[0]][None]
|
||||
pred["keypoint_scores"] = pred["keypoint_scores"][0][~pts_to_remove[0]][
|
||||
None
|
||||
]
|
||||
pred["descriptors"] = pred["descriptors"][0][~pts_to_remove[0]][None]
|
||||
|
||||
# Connect the lines together to form a wireframe
|
||||
orig_lines = pred["lines"].clone()
|
||||
if (
|
||||
self.conf.wireframe_params.merge_line_endpoints
|
||||
and len(pred["lines"][0]) > 0
|
||||
):
|
||||
# Merge first close-by endpoints to connect lines
|
||||
(
|
||||
line_points,
|
||||
line_pts_scores,
|
||||
line_descs,
|
||||
line_association,
|
||||
pred["lines"],
|
||||
lines_junc_idx,
|
||||
n_true_junctions,
|
||||
) = lines_to_wireframe(
|
||||
pred["lines"],
|
||||
pred["line_scores"],
|
||||
pred["dense_descriptors"],
|
||||
s=s_desc,
|
||||
nms_radius=self.conf.wireframe_params.nms_radius,
|
||||
force_num_lines=self.conf.line_extractor.force_num_lines,
|
||||
max_num_lines=self.conf.line_extractor.max_num_lines,
|
||||
)
|
||||
|
||||
# Add the keypoints to the junctions and fill the rest with random keypoints
|
||||
(all_points, all_scores, all_descs, pl_associativity) = [], [], [], []
|
||||
for bs in range(b_size):
|
||||
all_points.append(
|
||||
torch.cat([line_points[bs], pred["keypoints"][bs]], dim=0)
|
||||
)
|
||||
all_scores.append(
|
||||
torch.cat([line_pts_scores[bs], pred["keypoint_scores"][bs]], dim=0)
|
||||
)
|
||||
all_descs.append(
|
||||
torch.cat([line_descs[bs], pred["descriptors"][bs]], dim=0)
|
||||
)
|
||||
|
||||
associativity = torch.eye(
|
||||
len(all_points[-1]), dtype=torch.bool, device=device
|
||||
)
|
||||
associativity[
|
||||
: n_true_junctions[bs], : n_true_junctions[bs]
|
||||
] = line_association[bs][: n_true_junctions[bs], : n_true_junctions[bs]]
|
||||
pl_associativity.append(associativity)
|
||||
|
||||
all_points = torch.stack(all_points, dim=0)
|
||||
all_scores = torch.stack(all_scores, dim=0)
|
||||
all_descs = torch.stack(all_descs, dim=0)
|
||||
pl_associativity = torch.stack(pl_associativity, dim=0)
|
||||
else:
|
||||
# Lines are independent
|
||||
all_points = torch.cat(
|
||||
[pred["lines"].reshape(b_size, -1, 2), pred["keypoints"]], dim=1
|
||||
)
|
||||
n_pts = all_points.shape[1]
|
||||
num_lines = pred["lines"].shape[1]
|
||||
n_true_junctions = [num_lines * 2] * b_size
|
||||
all_scores = torch.cat(
|
||||
[
|
||||
torch.repeat_interleave(pred["line_scores"], 2, dim=1),
|
||||
pred["keypoint_scores"],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
line_descs = sample_descriptors_corner_conv(
|
||||
pred["lines"].reshape(b_size, -1, 2), pred["dense_descriptors"], s_desc
|
||||
).mT # [B, n_lines * 2, desc_dim]
|
||||
all_descs = torch.cat([line_descs, pred["descriptors"]], dim=1)
|
||||
pl_associativity = torch.eye(n_pts, dtype=torch.bool, device=device)[
|
||||
None
|
||||
].repeat(b_size, 1, 1)
|
||||
lines_junc_idx = (
|
||||
torch.arange(num_lines * 2, device=device)
|
||||
.reshape(1, -1, 2)
|
||||
.repeat(b_size, 1, 1)
|
||||
)
|
||||
|
||||
del pred["dense_descriptors"] # Remove dense descriptors to save memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
pred["keypoints"] = all_points
|
||||
pred["keypoint_scores"] = all_scores
|
||||
pred["descriptors"] = all_descs
|
||||
pred["pl_associativity"] = pl_associativity
|
||||
pred["num_junctions"] = torch.tensor(n_true_junctions)
|
||||
pred["orig_lines"] = orig_lines
|
||||
pred["lines_junc_idx"] = lines_junc_idx
|
||||
return pred
|
||||
|
||||
def loss(self, pred, data):
|
||||
raise NotImplementedError
|
||||
|
||||
def metrics(self, _pred, _data):
|
||||
return {}
|
|
@ -0,0 +1,81 @@
|
|||
from ..base_model import BaseModel
|
||||
from ...geometry.gt_generation import (
|
||||
gt_matches_from_pose_depth,
|
||||
gt_line_matches_from_pose_depth,
|
||||
)
|
||||
import torch
|
||||
|
||||
|
||||
class DepthMatcher(BaseModel):
|
||||
default_conf = {
|
||||
# GT parameters for points
|
||||
"use_points": True,
|
||||
"th_positive": 3.0,
|
||||
"th_negative": 5.0,
|
||||
"th_epi": None, # add some more epi outliers
|
||||
"th_consistency": None, # check for projection consistency in px
|
||||
# GT parameters for lines
|
||||
"use_lines": False,
|
||||
"n_line_sampled_pts": 50,
|
||||
"line_perp_dist_th": 5,
|
||||
"overlap_th": 0.2,
|
||||
"min_visibility_th": 0.5,
|
||||
}
|
||||
|
||||
required_data_keys = ["view0", "view1", "T_0to1", "T_1to0"]
|
||||
|
||||
def _init(self, conf):
|
||||
# TODO (iago): Is this just boilerplate code?
|
||||
if self.conf.use_points:
|
||||
self.required_data_keys += ["keypoints0", "keypoints1"]
|
||||
if self.conf.use_lines:
|
||||
self.required_data_keys += [
|
||||
"lines0",
|
||||
"lines1",
|
||||
"valid_lines0",
|
||||
"valid_lines1",
|
||||
]
|
||||
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
def _forward(self, data):
|
||||
result = {}
|
||||
if self.conf.use_points:
|
||||
if "depth_keypoints0" in data:
|
||||
keys = [
|
||||
"depth_keypoints0",
|
||||
"valid_depth_keypoints0",
|
||||
"depth_keypoints1",
|
||||
"valid_depth_keypoints1",
|
||||
]
|
||||
kw = {k: data[k] for k in keys}
|
||||
else:
|
||||
kw = {}
|
||||
result = gt_matches_from_pose_depth(
|
||||
data["keypoints0"],
|
||||
data["keypoints1"],
|
||||
data,
|
||||
pos_th=self.conf.th_positive,
|
||||
neg_th=self.conf.th_negative,
|
||||
epi_th=self.conf.th_epi,
|
||||
cc_th=self.conf.th_consistency,
|
||||
**kw,
|
||||
)
|
||||
if self.conf.use_lines:
|
||||
line_assignment, line_m0, line_m1 = gt_line_matches_from_pose_depth(
|
||||
data["lines0"],
|
||||
data["lines1"],
|
||||
data["valid_lines0"],
|
||||
data["valid_lines1"],
|
||||
data,
|
||||
self.conf.n_line_sampled_pts,
|
||||
self.conf.line_perp_dist_th,
|
||||
self.conf.overlap_th,
|
||||
self.conf.min_visibility_th,
|
||||
)
|
||||
result["line_matches0"] = line_m0
|
||||
result["line_matches1"] = line_m1
|
||||
result["line_assignment"] = line_assignment
|
||||
return result
|
||||
|
||||
def loss(self, pred, data):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,778 @@
|
|||
import logging
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..utils.metrics import matcher_metrics
|
||||
from ...settings import DATA_PATH
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
ETH_EPS = 1e-8
|
||||
|
||||
|
||||
class GlueStick(BaseModel):
|
||||
default_conf = {
|
||||
"input_dim": 256,
|
||||
"descriptor_dim": 256,
|
||||
"weights": None,
|
||||
"version": "v0.1_arxiv",
|
||||
"keypoint_encoder": [32, 64, 128, 256],
|
||||
"GNN_layers": ["self", "cross"] * 9,
|
||||
"num_line_iterations": 1,
|
||||
"line_attention": False,
|
||||
"filter_threshold": 0.2,
|
||||
"checkpointed": False,
|
||||
"skip_init": False,
|
||||
"inter_supervision": None,
|
||||
"loss": {
|
||||
"nll_weight": 1.0,
|
||||
"nll_balancing": 0.5,
|
||||
"inter_supervision": [0.3, 0.6],
|
||||
},
|
||||
}
|
||||
required_data_keys = [
|
||||
"view0",
|
||||
"view1",
|
||||
"keypoints0",
|
||||
"keypoints1",
|
||||
"descriptors0",
|
||||
"descriptors1",
|
||||
"keypoint_scores0",
|
||||
"keypoint_scores1",
|
||||
"lines0",
|
||||
"lines1",
|
||||
"lines_junc_idx0",
|
||||
"lines_junc_idx1",
|
||||
"line_scores0",
|
||||
"line_scores1",
|
||||
]
|
||||
|
||||
DEFAULT_LOSS_CONF = {"nll_weight": 1.0, "nll_balancing": 0.5}
|
||||
|
||||
url = (
|
||||
"https://github.com/cvg/GlueStick/releases/download/{}/"
|
||||
"checkpoint_GlueStick_MD.tar"
|
||||
)
|
||||
|
||||
def _init(self, conf):
|
||||
if conf.input_dim != conf.descriptor_dim:
|
||||
self.input_proj = nn.Conv1d(
|
||||
conf.input_dim, conf.descriptor_dim, kernel_size=1
|
||||
)
|
||||
nn.init.constant_(self.input_proj.bias, 0.0)
|
||||
|
||||
self.kenc = KeypointEncoder(conf.descriptor_dim, conf.keypoint_encoder)
|
||||
self.lenc = EndPtEncoder(conf.descriptor_dim, conf.keypoint_encoder)
|
||||
self.gnn = AttentionalGNN(
|
||||
conf.descriptor_dim,
|
||||
conf.GNN_layers,
|
||||
checkpointed=conf.checkpointed,
|
||||
inter_supervision=conf.inter_supervision,
|
||||
num_line_iterations=conf.num_line_iterations,
|
||||
line_attention=conf.line_attention,
|
||||
)
|
||||
self.final_proj = nn.Conv1d(
|
||||
conf.descriptor_dim, conf.descriptor_dim, kernel_size=1
|
||||
)
|
||||
nn.init.constant_(self.final_proj.bias, 0.0)
|
||||
nn.init.orthogonal_(self.final_proj.weight, gain=1)
|
||||
self.final_line_proj = nn.Conv1d(
|
||||
conf.descriptor_dim, conf.descriptor_dim, kernel_size=1
|
||||
)
|
||||
nn.init.constant_(self.final_line_proj.bias, 0.0)
|
||||
nn.init.orthogonal_(self.final_line_proj.weight, gain=1)
|
||||
if conf.inter_supervision is not None:
|
||||
self.inter_line_proj = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(conf.descriptor_dim, conf.descriptor_dim, kernel_size=1)
|
||||
for _ in conf.inter_supervision
|
||||
]
|
||||
)
|
||||
self.layer2idx = {}
|
||||
for i, l in enumerate(conf.inter_supervision):
|
||||
nn.init.constant_(self.inter_line_proj[i].bias, 0.0)
|
||||
nn.init.orthogonal_(self.inter_line_proj[i].weight, gain=1)
|
||||
self.layer2idx[l] = i
|
||||
|
||||
bin_score = torch.nn.Parameter(torch.tensor(1.0))
|
||||
self.register_parameter("bin_score", bin_score)
|
||||
line_bin_score = torch.nn.Parameter(torch.tensor(1.0))
|
||||
self.register_parameter("line_bin_score", line_bin_score)
|
||||
|
||||
if conf.weights:
|
||||
assert isinstance(conf.weights, (Path, str))
|
||||
fname = DATA_PATH / "weights" / f"{conf.weights}_{conf.version}.tar"
|
||||
fname.parent.mkdir(exist_ok=True, parents=True)
|
||||
if Path(conf.weights).exists():
|
||||
logging.info(f'Loading GlueStick model from "{conf.weights}"')
|
||||
state_dict = torch.load(conf.weights, map_location="cpu")
|
||||
elif fname.exists():
|
||||
logging.info(f'Loading GlueStick model from "{fname}"')
|
||||
state_dict = torch.load(fname, map_location="cpu")
|
||||
else:
|
||||
logging.info(
|
||||
"Loading GlueStick model from " f'"{self.url.format(conf.version)}"'
|
||||
)
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
self.url.format(conf.version), file_name=fname
|
||||
)
|
||||
|
||||
if "model" in state_dict:
|
||||
state_dict = {
|
||||
k.replace("matcher.", ""): v
|
||||
for k, v in state_dict["model"].items()
|
||||
if "matcher." in k
|
||||
}
|
||||
state_dict = {
|
||||
k.replace("module.", ""): v for k, v in state_dict.items()
|
||||
}
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
def _forward(self, data):
|
||||
device = data["keypoints0"].device
|
||||
b_size = len(data["keypoints0"])
|
||||
image_size0 = (
|
||||
data["view0"]["image_size"]
|
||||
if "image_size" in data["view0"]
|
||||
else data["view0"]["image"].shape
|
||||
)
|
||||
image_size1 = (
|
||||
data["view1"]["image_size"]
|
||||
if "image_size" in data["view1"]
|
||||
else data["view1"]["image"].shape
|
||||
)
|
||||
|
||||
pred = {}
|
||||
desc0, desc1 = data["descriptors0"].mT, data["descriptors1"].mT
|
||||
kpts0, kpts1 = data["keypoints0"], data["keypoints1"]
|
||||
|
||||
n_kpts0, n_kpts1 = kpts0.shape[1], kpts1.shape[1]
|
||||
n_lines0, n_lines1 = data["lines0"].shape[1], data["lines1"].shape[1]
|
||||
if n_kpts0 == 0 or n_kpts1 == 0:
|
||||
# No detected keypoints nor lines
|
||||
pred["log_assignment"] = torch.zeros(
|
||||
b_size, n_kpts0, n_kpts1, dtype=torch.float, device=device
|
||||
)
|
||||
pred["matches0"] = torch.full(
|
||||
(b_size, n_kpts0), -1, device=device, dtype=torch.int64
|
||||
)
|
||||
pred["matches1"] = torch.full(
|
||||
(b_size, n_kpts1), -1, device=device, dtype=torch.int64
|
||||
)
|
||||
pred["matching_scores0"] = torch.zeros(
|
||||
(b_size, n_kpts0), device=device, dtype=torch.float32
|
||||
)
|
||||
pred["matching_scores1"] = torch.zeros(
|
||||
(b_size, n_kpts1), device=device, dtype=torch.float32
|
||||
)
|
||||
pred["line_log_assignment"] = torch.zeros(
|
||||
b_size, n_lines0, n_lines1, dtype=torch.float, device=device
|
||||
)
|
||||
pred["line_matches0"] = torch.full(
|
||||
(b_size, n_lines0), -1, device=device, dtype=torch.int64
|
||||
)
|
||||
pred["line_matches1"] = torch.full(
|
||||
(b_size, n_lines1), -1, device=device, dtype=torch.int64
|
||||
)
|
||||
pred["line_matching_scores0"] = torch.zeros(
|
||||
(b_size, n_lines0), device=device, dtype=torch.float32
|
||||
)
|
||||
pred["line_matching_scores1"] = torch.zeros(
|
||||
(b_size, n_kpts1), device=device, dtype=torch.float32
|
||||
)
|
||||
return pred
|
||||
|
||||
lines0 = data["lines0"].flatten(1, 2)
|
||||
lines1 = data["lines1"].flatten(1, 2)
|
||||
# [b_size, num_lines * 2]
|
||||
lines_junc_idx0 = data["lines_junc_idx0"].flatten(1, 2)
|
||||
lines_junc_idx1 = data["lines_junc_idx1"].flatten(1, 2)
|
||||
|
||||
if self.conf.input_dim != self.conf.descriptor_dim:
|
||||
desc0 = self.input_proj(desc0)
|
||||
desc1 = self.input_proj(desc1)
|
||||
|
||||
kpts0 = normalize_keypoints(kpts0, image_size0)
|
||||
kpts1 = normalize_keypoints(kpts1, image_size1)
|
||||
|
||||
assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
|
||||
assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)
|
||||
desc0 = desc0 + self.kenc(kpts0, data["keypoint_scores0"])
|
||||
desc1 = desc1 + self.kenc(kpts1, data["keypoint_scores1"])
|
||||
|
||||
if n_lines0 != 0 and n_lines1 != 0:
|
||||
# Pre-compute the line encodings
|
||||
lines0 = normalize_keypoints(lines0, image_size0).reshape(
|
||||
b_size, n_lines0, 2, 2
|
||||
)
|
||||
lines1 = normalize_keypoints(lines1, image_size1).reshape(
|
||||
b_size, n_lines1, 2, 2
|
||||
)
|
||||
line_enc0 = self.lenc(lines0, data["line_scores0"])
|
||||
line_enc1 = self.lenc(lines1, data["line_scores1"])
|
||||
else:
|
||||
line_enc0 = torch.zeros(
|
||||
b_size,
|
||||
self.conf.descriptor_dim,
|
||||
n_lines0 * 2,
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
)
|
||||
line_enc1 = torch.zeros(
|
||||
b_size,
|
||||
self.conf.descriptor_dim,
|
||||
n_lines1 * 2,
|
||||
dtype=torch.float,
|
||||
device=device,
|
||||
)
|
||||
|
||||
desc0, desc1 = self.gnn(
|
||||
desc0, desc1, line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1
|
||||
)
|
||||
|
||||
# Match all points (KP and line junctions)
|
||||
mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
|
||||
|
||||
kp_scores = torch.einsum("bdn,bdm->bnm", mdesc0, mdesc1)
|
||||
kp_scores = kp_scores / self.conf.descriptor_dim**0.5
|
||||
kp_scores = log_double_softmax(kp_scores, self.bin_score)
|
||||
m0, m1, mscores0, mscores1 = self._get_matches(kp_scores)
|
||||
pred["log_assignment"] = kp_scores
|
||||
pred["matches0"] = m0
|
||||
pred["matches1"] = m1
|
||||
pred["matching_scores0"] = mscores0
|
||||
pred["matching_scores1"] = mscores1
|
||||
|
||||
# Match the lines
|
||||
if n_lines0 > 0 and n_lines1 > 0:
|
||||
(
|
||||
line_scores,
|
||||
m0_lines,
|
||||
m1_lines,
|
||||
mscores0_lines,
|
||||
mscores1_lines,
|
||||
raw_line_scores,
|
||||
) = self._get_line_matches(
|
||||
desc0[:, :, : 2 * n_lines0],
|
||||
desc1[:, :, : 2 * n_lines1],
|
||||
lines_junc_idx0,
|
||||
lines_junc_idx1,
|
||||
self.final_line_proj,
|
||||
)
|
||||
if self.conf.inter_supervision:
|
||||
for layer in self.conf.inter_supervision:
|
||||
(
|
||||
line_scores_i,
|
||||
m0_lines_i,
|
||||
m1_lines_i,
|
||||
mscores0_lines_i,
|
||||
mscores1_lines_i,
|
||||
_,
|
||||
) = self._get_line_matches(
|
||||
self.gnn.inter_layers[layer][0][:, :, : 2 * n_lines0],
|
||||
self.gnn.inter_layers[layer][1][:, :, : 2 * n_lines1],
|
||||
lines_junc_idx0,
|
||||
lines_junc_idx1,
|
||||
self.inter_line_proj[self.layer2idx[layer]],
|
||||
)
|
||||
pred[f"line_{layer}_log_assignment"] = line_scores_i
|
||||
pred[f"line_{layer}_matches0"] = m0_lines_i
|
||||
pred[f"line_{layer}_matches1"] = m1_lines_i
|
||||
pred[f"line_{layer}_matching_scores0"] = mscores0_lines_i
|
||||
pred[f"line_{layer}_matching_scores1"] = mscores1_lines_i
|
||||
else:
|
||||
line_scores = torch.zeros(
|
||||
b_size, n_lines0, n_lines1, dtype=torch.float, device=device
|
||||
)
|
||||
m0_lines = torch.full(
|
||||
(b_size, n_lines0), -1, device=device, dtype=torch.int64
|
||||
)
|
||||
m1_lines = torch.full(
|
||||
(b_size, n_lines1), -1, device=device, dtype=torch.int64
|
||||
)
|
||||
mscores0_lines = torch.zeros(
|
||||
(b_size, n_lines0), device=device, dtype=torch.float32
|
||||
)
|
||||
mscores1_lines = torch.zeros(
|
||||
(b_size, n_lines1), device=device, dtype=torch.float32
|
||||
)
|
||||
raw_line_scores = torch.zeros(
|
||||
b_size, n_lines0, n_lines1, dtype=torch.float, device=device
|
||||
)
|
||||
pred["line_log_assignment"] = line_scores
|
||||
pred["line_matches0"] = m0_lines
|
||||
pred["line_matches1"] = m1_lines
|
||||
pred["line_matching_scores0"] = mscores0_lines
|
||||
pred["line_matching_scores1"] = mscores1_lines
|
||||
pred["raw_line_scores"] = raw_line_scores
|
||||
|
||||
return pred
|
||||
|
||||
def _get_matches(self, scores_mat):
|
||||
max0 = scores_mat[:, :-1, :-1].max(2)
|
||||
max1 = scores_mat[:, :-1, :-1].max(1)
|
||||
m0, m1 = max0.indices, max1.indices
|
||||
mutual0 = arange_like(m0, 1)[None] == m1.gather(1, m0)
|
||||
mutual1 = arange_like(m1, 1)[None] == m0.gather(1, m1)
|
||||
zero = scores_mat.new_tensor(0)
|
||||
mscores0 = torch.where(mutual0, max0.values.exp(), zero)
|
||||
mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
|
||||
valid0 = mutual0 & (mscores0 > self.conf.filter_threshold)
|
||||
valid1 = mutual1 & valid0.gather(1, m1)
|
||||
m0 = torch.where(valid0, m0, m0.new_tensor(-1))
|
||||
m1 = torch.where(valid1, m1, m1.new_tensor(-1))
|
||||
return m0, m1, mscores0, mscores1
|
||||
|
||||
def _get_line_matches(
|
||||
self, ldesc0, ldesc1, lines_junc_idx0, lines_junc_idx1, final_proj
|
||||
):
|
||||
mldesc0 = final_proj(ldesc0)
|
||||
mldesc1 = final_proj(ldesc1)
|
||||
|
||||
line_scores = torch.einsum("bdn,bdm->bnm", mldesc0, mldesc1)
|
||||
line_scores = line_scores / self.conf.descriptor_dim**0.5
|
||||
|
||||
# Get the line representation from the junction descriptors
|
||||
n2_lines0 = lines_junc_idx0.shape[1]
|
||||
n2_lines1 = lines_junc_idx1.shape[1]
|
||||
line_scores = torch.gather(
|
||||
line_scores,
|
||||
dim=2,
|
||||
index=lines_junc_idx1[:, None, :].repeat(1, line_scores.shape[1], 1),
|
||||
)
|
||||
line_scores = torch.gather(
|
||||
line_scores,
|
||||
dim=1,
|
||||
index=lines_junc_idx0[:, :, None].repeat(1, 1, n2_lines1),
|
||||
)
|
||||
line_scores = line_scores.reshape((-1, n2_lines0 // 2, 2, n2_lines1 // 2, 2))
|
||||
|
||||
# Match either in one direction or the other
|
||||
raw_line_scores = 0.5 * torch.maximum(
|
||||
line_scores[:, :, 0, :, 0] + line_scores[:, :, 1, :, 1],
|
||||
line_scores[:, :, 0, :, 1] + line_scores[:, :, 1, :, 0],
|
||||
)
|
||||
line_scores = log_double_softmax(raw_line_scores, self.line_bin_score)
|
||||
m0_lines, m1_lines, mscores0_lines, mscores1_lines = self._get_matches(
|
||||
line_scores
|
||||
)
|
||||
return (
|
||||
line_scores,
|
||||
m0_lines,
|
||||
m1_lines,
|
||||
mscores0_lines,
|
||||
mscores1_lines,
|
||||
raw_line_scores,
|
||||
)
|
||||
|
||||
def sub_loss(self, pred, data, losses, bin_score, prefix="", layer=-1):
|
||||
line_suffix = "" if layer == -1 else f"{layer}_"
|
||||
layer_weight = (
|
||||
1.0
|
||||
if layer == -1
|
||||
else self.conf.loss.inter_supervision[self.layer2idx[layer]]
|
||||
)
|
||||
|
||||
positive = data["gt_" + prefix + "assignment"].float()
|
||||
num_pos = torch.max(positive.sum((1, 2)), positive.new_tensor(1))
|
||||
neg0 = (data["gt_" + prefix + "matches0"] == -1).float()
|
||||
neg1 = (data["gt_" + prefix + "matches1"] == -1).float()
|
||||
num_neg = torch.max(neg0.sum(1) + neg1.sum(1), neg0.new_tensor(1))
|
||||
|
||||
log_assignment = pred[prefix + line_suffix + "log_assignment"]
|
||||
nll_pos = -(log_assignment[:, :-1, :-1] * positive).sum((1, 2))
|
||||
nll_pos /= num_pos
|
||||
nll_neg0 = -(log_assignment[:, :-1, -1] * neg0).sum(1)
|
||||
nll_neg1 = -(log_assignment[:, -1, :-1] * neg1).sum(1)
|
||||
nll_neg = (nll_neg0 + nll_neg1) / num_neg
|
||||
nll = (
|
||||
self.conf.loss.nll_balancing * nll_pos
|
||||
+ (1 - self.conf.loss.nll_balancing) * nll_neg
|
||||
)
|
||||
losses[prefix + line_suffix + "assignment_nll"] = nll
|
||||
if self.conf.loss.nll_weight > 0:
|
||||
losses["total"] += nll * self.conf.loss.nll_weight * layer_weight
|
||||
|
||||
# Some statistics
|
||||
if line_suffix == "":
|
||||
losses[prefix + "num_matchable"] = num_pos
|
||||
losses[prefix + "num_unmatchable"] = num_neg
|
||||
losses[prefix + "sinkhorn_norm"] = (
|
||||
log_assignment.exp()[:, :-1].sum(2).mean(1)
|
||||
)
|
||||
losses[prefix + "bin_score"] = bin_score[None]
|
||||
|
||||
return losses
|
||||
|
||||
def loss(self, pred, data):
|
||||
losses = {"total": 0}
|
||||
# If there are keypoints add their loss terms
|
||||
if not (data["keypoints0"].shape[1] == 0 or data["keypoints1"].shape[1] == 0):
|
||||
losses = self.sub_loss(pred, data, losses, self.bin_score, prefix="")
|
||||
|
||||
# If there are lines add their loss terms
|
||||
if (
|
||||
"lines0" in data
|
||||
and "lines1" in data
|
||||
and data["lines0"].shape[1] > 0
|
||||
and data["lines1"].shape[1] > 0
|
||||
):
|
||||
losses = self.sub_loss(
|
||||
pred, data, losses, self.line_bin_score, prefix="line_"
|
||||
)
|
||||
|
||||
if self.conf.inter_supervision:
|
||||
for layer in self.conf.inter_supervision:
|
||||
losses = self.sub_loss(
|
||||
pred, data, losses, self.line_bin_score, prefix="line_", layer=layer
|
||||
)
|
||||
|
||||
# Compute the metrics
|
||||
metrics = {}
|
||||
if not self.training:
|
||||
if (
|
||||
"matches0" in pred
|
||||
and pred["matches0"].shape[1] > 0
|
||||
and pred["matches1"].shape[1] > 0
|
||||
):
|
||||
metrics = {**metrics, **matcher_metrics(pred, data, prefix="")}
|
||||
if (
|
||||
"line_matches0" in pred
|
||||
and data["lines0"].shape[1] > 0
|
||||
and data["lines1"].shape[1] > 0
|
||||
):
|
||||
metrics = {**metrics, **matcher_metrics(pred, data, prefix="line_")}
|
||||
if self.conf.inter_supervision:
|
||||
for layer in self.conf.inter_supervision:
|
||||
inter_metrics = matcher_metrics(
|
||||
pred, data, prefix=f"line_{layer}_", prefix_gt="line_"
|
||||
)
|
||||
metrics = {**metrics, **inter_metrics}
|
||||
|
||||
return losses, metrics
|
||||
|
||||
|
||||
def MLP(channels, do_bn=True):
|
||||
n = len(channels)
|
||||
layers = []
|
||||
for i in range(1, n):
|
||||
layers.append(nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
|
||||
if i < (n - 1):
|
||||
if do_bn:
|
||||
layers.append(nn.BatchNorm1d(channels[i]))
|
||||
layers.append(nn.ReLU())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def normalize_keypoints(kpts, shape_or_size):
|
||||
if isinstance(shape_or_size, (tuple, list)):
|
||||
# it"s a shape
|
||||
h, w = shape_or_size[-2:]
|
||||
size = kpts.new_tensor([[w, h]])
|
||||
else:
|
||||
# it"s a size
|
||||
assert isinstance(shape_or_size, torch.Tensor)
|
||||
size = shape_or_size.to(kpts)
|
||||
c = size / 2
|
||||
f = size.max(1, keepdim=True).values * 0.7 # somehow we used 0.7 for SG
|
||||
return (kpts - c[:, None, :]) / f[:, None, :]
|
||||
|
||||
|
||||
class KeypointEncoder(nn.Module):
|
||||
def __init__(self, feature_dim, layers):
|
||||
super().__init__()
|
||||
self.encoder = MLP([3] + list(layers) + [feature_dim], do_bn=True)
|
||||
nn.init.constant_(self.encoder[-1].bias, 0.0)
|
||||
|
||||
def forward(self, kpts, scores):
|
||||
inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)]
|
||||
return self.encoder(torch.cat(inputs, dim=1))
|
||||
|
||||
|
||||
class EndPtEncoder(nn.Module):
|
||||
def __init__(self, feature_dim, layers):
|
||||
super().__init__()
|
||||
self.encoder = MLP([5] + list(layers) + [feature_dim], do_bn=True)
|
||||
nn.init.constant_(self.encoder[-1].bias, 0.0)
|
||||
|
||||
def forward(self, endpoints, scores):
|
||||
# endpoints should be [B, N, 2, 2]
|
||||
# output is [B, feature_dim, N * 2]
|
||||
b_size, n_pts, _, _ = endpoints.shape
|
||||
assert tuple(endpoints.shape[-2:]) == (2, 2)
|
||||
endpt_offset = (endpoints[:, :, 1] - endpoints[:, :, 0]).unsqueeze(2)
|
||||
endpt_offset = torch.cat([endpt_offset, -endpt_offset], dim=2)
|
||||
endpt_offset = endpt_offset.reshape(b_size, 2 * n_pts, 2).transpose(1, 2)
|
||||
inputs = [
|
||||
endpoints.flatten(1, 2).transpose(1, 2),
|
||||
endpt_offset,
|
||||
scores.repeat(1, 2).unsqueeze(1),
|
||||
]
|
||||
return self.encoder(torch.cat(inputs, dim=1))
|
||||
|
||||
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
def attention(query, key, value):
|
||||
dim = query.shape[1]
|
||||
scores = torch.einsum("bdhn,bdhm->bhnm", query, key) / dim**0.5
|
||||
prob = torch.nn.functional.softmax(scores, dim=-1)
|
||||
return torch.einsum("bhnm,bdhm->bdhn", prob, value), prob
|
||||
|
||||
|
||||
class MultiHeadedAttention(nn.Module):
|
||||
def __init__(self, h, d_model):
|
||||
super().__init__()
|
||||
assert d_model % h == 0
|
||||
self.dim = d_model // h
|
||||
self.h = h
|
||||
self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
|
||||
self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
|
||||
# self.prob = []
|
||||
|
||||
def forward(self, query, key, value):
|
||||
b = query.size(0)
|
||||
query, key, value = [
|
||||
layer(x).view(b, self.dim, self.h, -1)
|
||||
for layer, x in zip(self.proj, (query, key, value))
|
||||
]
|
||||
x, prob = attention(query, key, value)
|
||||
# self.prob.append(prob.mean(dim=1))
|
||||
return self.merge(x.contiguous().view(b, self.dim * self.h, -1))
|
||||
|
||||
|
||||
class AttentionalPropagation(nn.Module):
|
||||
def __init__(self, num_dim, num_heads, skip_init=False):
|
||||
super().__init__()
|
||||
self.attn = MultiHeadedAttention(num_heads, num_dim)
|
||||
self.mlp = MLP([num_dim * 2, num_dim * 2, num_dim], do_bn=True)
|
||||
nn.init.constant_(self.mlp[-1].bias, 0.0)
|
||||
if skip_init:
|
||||
self.register_parameter("scaling", nn.Parameter(torch.tensor(0.0)))
|
||||
else:
|
||||
self.scaling = 1.0
|
||||
|
||||
def forward(self, x, source):
|
||||
message = self.attn(x, source, source)
|
||||
return self.mlp(torch.cat([x, message], dim=1)) * self.scaling
|
||||
|
||||
|
||||
class GNNLayer(nn.Module):
|
||||
def __init__(self, feature_dim, layer_type, skip_init):
|
||||
super().__init__()
|
||||
assert layer_type in ["cross", "self"]
|
||||
self.type = layer_type
|
||||
self.update = AttentionalPropagation(feature_dim, 4, skip_init)
|
||||
|
||||
def forward(self, desc0, desc1):
|
||||
if self.type == "cross":
|
||||
src0, src1 = desc1, desc0
|
||||
elif self.type == "self":
|
||||
src0, src1 = desc0, desc1
|
||||
else:
|
||||
raise ValueError("Unknown layer type: " + self.type)
|
||||
# self.update.attn.prob = []
|
||||
delta0, delta1 = self.update(desc0, src0), self.update(desc1, src1)
|
||||
desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
|
||||
return desc0, desc1
|
||||
|
||||
|
||||
class LineLayer(nn.Module):
|
||||
def __init__(self, feature_dim, line_attention=False):
|
||||
super().__init__()
|
||||
self.dim = feature_dim
|
||||
self.mlp = MLP([self.dim * 3, self.dim * 2, self.dim], do_bn=True)
|
||||
self.line_attention = line_attention
|
||||
if line_attention:
|
||||
self.proj_node = nn.Conv1d(self.dim, self.dim, kernel_size=1)
|
||||
self.proj_neigh = nn.Conv1d(2 * self.dim, self.dim, kernel_size=1)
|
||||
|
||||
def get_endpoint_update(self, ldesc, line_enc, lines_junc_idx):
|
||||
# ldesc is [bs, D, n_junc], line_enc [bs, D, n_lines * 2]
|
||||
# and lines_junc_idx [bs, n_lines * 2]
|
||||
# Create one message per line endpoint
|
||||
b_size = lines_junc_idx.shape[0]
|
||||
line_desc = torch.gather(
|
||||
ldesc, 2, lines_junc_idx[:, None].repeat(1, self.dim, 1)
|
||||
)
|
||||
line_desc2 = line_desc.reshape(b_size, self.dim, -1, 2).flip([-1])
|
||||
message = torch.cat(
|
||||
[line_desc, line_desc2.flatten(2, 3).clone(), line_enc], dim=1
|
||||
)
|
||||
return self.mlp(message) # [b_size, D, n_lines * 2]
|
||||
|
||||
def get_endpoint_attention(self, ldesc, line_enc, lines_junc_idx):
|
||||
# ldesc is [bs, D, n_junc], line_enc [bs, D, n_lines * 2]
|
||||
# and lines_junc_idx [bs, n_lines * 2]
|
||||
b_size = lines_junc_idx.shape[0]
|
||||
expanded_lines_junc_idx = lines_junc_idx[:, None].repeat(1, self.dim, 1)
|
||||
|
||||
# Query: desc of the current node
|
||||
query = self.proj_node(ldesc) # [b_size, D, n_junc]
|
||||
query = torch.gather(query, 2, expanded_lines_junc_idx)
|
||||
# query is [b_size, D, n_lines * 2]
|
||||
|
||||
# Key: combination of neighboring desc and line encodings
|
||||
line_desc = torch.gather(ldesc, 2, expanded_lines_junc_idx)
|
||||
line_desc2 = line_desc.reshape(b_size, self.dim, -1, 2).flip([-1])
|
||||
key = self.proj_neigh(
|
||||
torch.cat([line_desc2.flatten(2, 3).clone(), line_enc], dim=1)
|
||||
) # [b_size, D, n_lines * 2]
|
||||
|
||||
# Compute the attention weights with a custom softmax per junction
|
||||
prob = (query * key).sum(dim=1) / self.dim**0.5 # [b_size, n_lines * 2]
|
||||
prob = torch.exp(prob - prob.max())
|
||||
denom = torch.zeros_like(ldesc[:, 0]).scatter_reduce_(
|
||||
dim=1, index=lines_junc_idx, src=prob, reduce="sum", include_self=False
|
||||
) # [b_size, n_junc]
|
||||
denom = torch.gather(denom, 1, lines_junc_idx) # [b_size, n_lines * 2]
|
||||
prob = prob / (denom + ETH_EPS)
|
||||
return prob # [b_size, n_lines * 2]
|
||||
|
||||
def forward(
|
||||
self, ldesc0, ldesc1, line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1
|
||||
):
|
||||
# Gather the endpoint updates
|
||||
lupdate0 = self.get_endpoint_update(ldesc0, line_enc0, lines_junc_idx0)
|
||||
lupdate1 = self.get_endpoint_update(ldesc1, line_enc1, lines_junc_idx1)
|
||||
|
||||
update0, update1 = torch.zeros_like(ldesc0), torch.zeros_like(ldesc1)
|
||||
dim = ldesc0.shape[1]
|
||||
if self.line_attention:
|
||||
# Compute an attention for each neighbor and do a weighted average
|
||||
prob0 = self.get_endpoint_attention(ldesc0, line_enc0, lines_junc_idx0)
|
||||
lupdate0 = lupdate0 * prob0[:, None]
|
||||
update0 = update0.scatter_reduce_(
|
||||
dim=2,
|
||||
index=lines_junc_idx0[:, None].repeat(1, dim, 1),
|
||||
src=lupdate0,
|
||||
reduce="sum",
|
||||
include_self=False,
|
||||
)
|
||||
prob1 = self.get_endpoint_attention(ldesc1, line_enc1, lines_junc_idx1)
|
||||
lupdate1 = lupdate1 * prob1[:, None]
|
||||
update1 = update1.scatter_reduce_(
|
||||
dim=2,
|
||||
index=lines_junc_idx1[:, None].repeat(1, dim, 1),
|
||||
src=lupdate1,
|
||||
reduce="sum",
|
||||
include_self=False,
|
||||
)
|
||||
else:
|
||||
# Average the updates for each junction (requires torch > 1.12)
|
||||
update0 = update0.scatter_reduce_(
|
||||
dim=2,
|
||||
index=lines_junc_idx0[:, None].repeat(1, dim, 1),
|
||||
src=lupdate0,
|
||||
reduce="mean",
|
||||
include_self=False,
|
||||
)
|
||||
update1 = update1.scatter_reduce_(
|
||||
dim=2,
|
||||
index=lines_junc_idx1[:, None].repeat(1, dim, 1),
|
||||
src=lupdate1,
|
||||
reduce="mean",
|
||||
include_self=False,
|
||||
)
|
||||
|
||||
# Update
|
||||
ldesc0 = ldesc0 + update0
|
||||
ldesc1 = ldesc1 + update1
|
||||
|
||||
return ldesc0, ldesc1
|
||||
|
||||
|
||||
class AttentionalGNN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
feature_dim,
|
||||
layer_types,
|
||||
checkpointed=False,
|
||||
skip=False,
|
||||
inter_supervision=None,
|
||||
num_line_iterations=1,
|
||||
line_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpointed = checkpointed
|
||||
self.inter_supervision = inter_supervision
|
||||
self.num_line_iterations = num_line_iterations
|
||||
self.inter_layers = {}
|
||||
self.layers = nn.ModuleList(
|
||||
[GNNLayer(feature_dim, layer_type, skip) for layer_type in layer_types]
|
||||
)
|
||||
self.line_layers = nn.ModuleList(
|
||||
[
|
||||
LineLayer(feature_dim, line_attention)
|
||||
for _ in range(len(layer_types) // 2)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, desc0, desc1, line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1
|
||||
):
|
||||
for i, layer in enumerate(self.layers):
|
||||
if self.checkpointed:
|
||||
desc0, desc1 = torch.utils.checkpoint.checkpoint(
|
||||
layer, desc0, desc1, preserve_rng_state=False
|
||||
)
|
||||
else:
|
||||
desc0, desc1 = layer(desc0, desc1)
|
||||
if (
|
||||
layer.type == "self"
|
||||
and lines_junc_idx0.shape[1] > 0
|
||||
and lines_junc_idx1.shape[1] > 0
|
||||
):
|
||||
# Add line self attention layers after every self layer
|
||||
for _ in range(self.num_line_iterations):
|
||||
if self.checkpointed:
|
||||
desc0, desc1 = torch.utils.checkpoint.checkpoint(
|
||||
self.line_layers[i // 2],
|
||||
desc0,
|
||||
desc1,
|
||||
line_enc0,
|
||||
line_enc1,
|
||||
lines_junc_idx0,
|
||||
lines_junc_idx1,
|
||||
preserve_rng_state=False,
|
||||
)
|
||||
else:
|
||||
desc0, desc1 = self.line_layers[i // 2](
|
||||
desc0,
|
||||
desc1,
|
||||
line_enc0,
|
||||
line_enc1,
|
||||
lines_junc_idx0,
|
||||
lines_junc_idx1,
|
||||
)
|
||||
|
||||
# Optionally store the line descriptor at intermediate layers
|
||||
if (
|
||||
self.inter_supervision is not None
|
||||
and (i // 2) in self.inter_supervision
|
||||
and layer.type == "cross"
|
||||
):
|
||||
self.inter_layers[i // 2] = (desc0.clone(), desc1.clone())
|
||||
return desc0, desc1
|
||||
|
||||
|
||||
def log_double_softmax(scores, bin_score):
|
||||
b, m, n = scores.shape
|
||||
bin_ = bin_score[None, None, None]
|
||||
scores0 = torch.cat([scores, bin_.expand(b, m, 1)], 2)
|
||||
scores1 = torch.cat([scores, bin_.expand(b, 1, n)], 1)
|
||||
scores0 = torch.nn.functional.log_softmax(scores0, 2)
|
||||
scores1 = torch.nn.functional.log_softmax(scores1, 1)
|
||||
scores = scores.new_full((b, m + 1, n + 1), 0)
|
||||
scores[:, :m, :n] = (scores0[:, :, :n] + scores1[:, :m, :]) / 2
|
||||
scores[:, :-1, -1] = scores0[:, :, -1]
|
||||
scores[:, -1, :-1] = scores1[:, -1, :]
|
||||
return scores
|
||||
|
||||
|
||||
def arange_like(x, dim):
|
||||
return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1
|
|
@ -0,0 +1,66 @@
|
|||
from ..base_model import BaseModel
|
||||
from ...geometry.gt_generation import (
|
||||
gt_matches_from_homography,
|
||||
gt_line_matches_from_homography,
|
||||
)
|
||||
|
||||
|
||||
class HomographyMatcher(BaseModel):
|
||||
default_conf = {
|
||||
# GT parameters for points
|
||||
"use_points": True,
|
||||
"th_positive": 3.0,
|
||||
"th_negative": 3.0,
|
||||
# GT parameters for lines
|
||||
"use_lines": False,
|
||||
"n_line_sampled_pts": 50,
|
||||
"line_perp_dist_th": 5,
|
||||
"overlap_th": 0.2,
|
||||
"min_visibility_th": 0.5,
|
||||
}
|
||||
|
||||
required_data_keys = ["H_0to1"]
|
||||
|
||||
def _init(self, conf):
|
||||
# TODO (iago): Is this just boilerplate code?
|
||||
if self.conf.use_points:
|
||||
self.required_data_keys += ["keypoints0", "keypoints1"]
|
||||
if self.conf.use_lines:
|
||||
self.required_data_keys += [
|
||||
"lines0",
|
||||
"lines1",
|
||||
"valid_lines0",
|
||||
"valid_lines1",
|
||||
]
|
||||
|
||||
def _forward(self, data):
|
||||
result = {}
|
||||
if self.conf.use_points:
|
||||
result = gt_matches_from_homography(
|
||||
data["keypoints0"],
|
||||
data["keypoints1"],
|
||||
data["H_0to1"],
|
||||
pos_th=self.conf.th_positive,
|
||||
neg_th=self.conf.th_negative,
|
||||
)
|
||||
if self.conf.use_lines:
|
||||
line_assignment, line_m0, line_m1 = gt_line_matches_from_homography(
|
||||
data["lines0"],
|
||||
data["lines1"],
|
||||
data["valid_lines0"],
|
||||
data["valid_lines1"],
|
||||
data["view0"]["image"].shape,
|
||||
data["view1"]["image"].shape,
|
||||
data["H_0to1"],
|
||||
self.conf.n_line_sampled_pts,
|
||||
self.conf.line_perp_dist_th,
|
||||
self.conf.overlap_th,
|
||||
self.conf.min_visibility_th,
|
||||
)
|
||||
result["line_matches0"] = line_m0
|
||||
result["line_matches1"] = line_m1
|
||||
result["line_assignment"] = line_assignment
|
||||
return result
|
||||
|
||||
def loss(self, pred, data):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,65 @@
|
|||
import kornia
|
||||
import torch
|
||||
|
||||
from ...models import BaseModel
|
||||
|
||||
|
||||
class LoFTRModule(BaseModel):
|
||||
default_conf = {
|
||||
"topk": None,
|
||||
"zero_pad": False,
|
||||
}
|
||||
required_data_keys = ["view0", "view1"]
|
||||
|
||||
def _init(self, conf):
|
||||
self.net = kornia.feature.LoFTR(pretrained="outdoor")
|
||||
|
||||
def _forward(self, data):
|
||||
image0 = data["view0"]["image"]
|
||||
image1 = data["view1"]["image"]
|
||||
if self.conf.zero_pad:
|
||||
image0, mask0 = self.zero_pad(image0)
|
||||
image1, mask1 = self.zero_pad(image1)
|
||||
res = self.net(
|
||||
{"image0": image0, "image1": image1, "mask0": mask0, "mask1": mask1}
|
||||
)
|
||||
res = self.net({"image0": image0, "image1": image1})
|
||||
else:
|
||||
res = self.net({"image0": image0, "image1": image1})
|
||||
topk = self.conf.topk
|
||||
if topk is not None and res["confidence"].shape[-1] > topk:
|
||||
_, top = torch.topk(res["confidence"], topk, -1)
|
||||
m_kpts0 = res["keypoints0"][None][:, top]
|
||||
m_kpts1 = res["keypoints1"][None][:, top]
|
||||
scores = res["confidence"][None][:, top]
|
||||
else:
|
||||
m_kpts0 = res["keypoints0"][None]
|
||||
m_kpts1 = res["keypoints1"][None]
|
||||
scores = res["confidence"][None]
|
||||
|
||||
m0 = torch.arange(0, scores.shape[-1]).to(scores.device)[None]
|
||||
m1 = torch.arange(0, scores.shape[-1]).to(scores.device)[None]
|
||||
return {
|
||||
"matches0": m0,
|
||||
"matches1": m1,
|
||||
"matching_scores0": scores,
|
||||
"keypoints0": m_kpts0,
|
||||
"keypoints1": m_kpts1,
|
||||
"keypoint_scores0": scores,
|
||||
"keypoint_scores1": scores,
|
||||
"matching_scores1": scores,
|
||||
}
|
||||
|
||||
def zero_pad(self, img):
|
||||
b, c, h, w = img.shape
|
||||
if h == w:
|
||||
return img
|
||||
s = max(h, w)
|
||||
image = torch.zeros((b, c, s, s)).to(img)
|
||||
image[:, :, :h, :w] = img
|
||||
mask = torch.zeros_like(image)
|
||||
mask[:, :, :h, :w] = 1.0
|
||||
return image, mask.squeeze(0).float()
|
||||
|
||||
def loss(self, pred, data):
|
||||
return NotImplementedError
|
|
@ -0,0 +1,610 @@
|
|||
import warnings
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, List, Callable
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from omegaconf import OmegaConf
|
||||
from ...settings import DATA_PATH
|
||||
from ..utils.losses import NLLLoss
|
||||
from ..utils.metrics import matcher_metrics
|
||||
from pathlib import Path
|
||||
|
||||
FLASH_AVAILABLE = hasattr(F, "scaled_dot_product_attention")
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
def normalize_keypoints(
|
||||
kpts: torch.Tensor, size: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
if size is None:
|
||||
size = 1 + kpts.max(-2).values - kpts.min(-2).values
|
||||
elif not isinstance(size, torch.Tensor):
|
||||
size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
|
||||
size = size.to(kpts)
|
||||
shift = size / 2
|
||||
scale = size.max(-1).values / 2
|
||||
kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
|
||||
return kpts
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.unflatten(-1, (-1, 2))
|
||||
x1, x2 = x.unbind(dim=-1)
|
||||
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
|
||||
|
||||
|
||||
def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
return (t * freqs[0]) + (rotate_half(t) * freqs[1])
|
||||
|
||||
|
||||
class LearnableFourierPositionalEncoding(nn.Module):
|
||||
def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
|
||||
super().__init__()
|
||||
F_dim = F_dim if F_dim is not None else dim
|
||||
self.gamma = gamma
|
||||
self.Wr = nn.Linear(M, F_dim // 2, bias=False)
|
||||
nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""encode position vector"""
|
||||
projected = self.Wr(x)
|
||||
cosines, sines = torch.cos(projected), torch.sin(projected)
|
||||
emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
|
||||
return emb.repeat_interleave(2, dim=-1)
|
||||
|
||||
|
||||
class TokenConfidence(nn.Module):
|
||||
def __init__(self, dim: int) -> None:
|
||||
super().__init__()
|
||||
self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
|
||||
self.loss_fn = nn.BCEWithLogitsLoss(reduction="none")
|
||||
|
||||
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
|
||||
"""get confidence tokens"""
|
||||
return (
|
||||
self.token(desc0.detach()).squeeze(-1),
|
||||
self.token(desc1.detach()).squeeze(-1),
|
||||
)
|
||||
|
||||
def loss(self, desc0, desc1, la_now, la_final):
|
||||
logit0 = self.token[0](desc0.detach()).squeeze(-1)
|
||||
logit1 = self.token[0](desc1.detach()).squeeze(-1)
|
||||
la_now, la_final = la_now.detach(), la_final.detach()
|
||||
correct0 = (
|
||||
la_final[:, :-1, :].max(-1).indices == la_now[:, :-1, :].max(-1).indices
|
||||
)
|
||||
correct1 = (
|
||||
la_final[:, :, :-1].max(-2).indices == la_now[:, :, :-1].max(-2).indices
|
||||
)
|
||||
return (
|
||||
self.loss_fn(logit0, correct0.float()).mean(-1)
|
||||
+ self.loss_fn(logit1, correct1.float()).mean(-1)
|
||||
) / 2.0
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, allow_flash: bool) -> None:
|
||||
super().__init__()
|
||||
if allow_flash and not FLASH_AVAILABLE:
|
||||
warnings.warn(
|
||||
"FlashAttention is not available. For optimal speed, "
|
||||
"consider installing torch >= 2.0 or flash-attn.",
|
||||
stacklevel=2,
|
||||
)
|
||||
self.enable_flash = allow_flash and FLASH_AVAILABLE
|
||||
|
||||
if FLASH_AVAILABLE:
|
||||
torch.backends.cuda.enable_flash_sdp(allow_flash)
|
||||
|
||||
def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.enable_flash and q.device.type == "cuda":
|
||||
# use torch 2.0 scaled_dot_product_attention with flash
|
||||
if FLASH_AVAILABLE:
|
||||
args = [x.half().contiguous() for x in [q, k, v]]
|
||||
v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
|
||||
return v if mask is None else v.nan_to_num()
|
||||
elif FLASH_AVAILABLE:
|
||||
args = [x.contiguous() for x in [q, k, v]]
|
||||
v = F.scaled_dot_product_attention(*args, attn_mask=mask)
|
||||
return v if mask is None else v.nan_to_num()
|
||||
else:
|
||||
s = q.shape[-1] ** -0.5
|
||||
sim = torch.einsum("...id,...jd->...ij", q, k) * s
|
||||
if mask is not None:
|
||||
sim.masked_fill(~mask, -float("inf"))
|
||||
attn = F.softmax(sim, -1)
|
||||
return torch.einsum("...ij,...jd->...id", attn, v)
|
||||
|
||||
|
||||
class SelfBlock(nn.Module):
|
||||
def __init__(
|
||||
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
assert self.embed_dim % num_heads == 0
|
||||
self.head_dim = self.embed_dim // num_heads
|
||||
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
|
||||
self.inner_attn = Attention(flash)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(2 * embed_dim, 2 * embed_dim),
|
||||
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
|
||||
nn.GELU(),
|
||||
nn.Linear(2 * embed_dim, embed_dim),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
encoding: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv = self.Wqkv(x)
|
||||
qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
|
||||
q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
|
||||
q = apply_cached_rotary_emb(encoding, q)
|
||||
k = apply_cached_rotary_emb(encoding, k)
|
||||
context = self.inner_attn(q, k, v, mask=mask)
|
||||
message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
|
||||
return x + self.ffn(torch.cat([x, message], -1))
|
||||
|
||||
|
||||
class CrossBlock(nn.Module):
|
||||
def __init__(
|
||||
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.heads = num_heads
|
||||
dim_head = embed_dim // num_heads
|
||||
self.scale = dim_head**-0.5
|
||||
inner_dim = dim_head * num_heads
|
||||
self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
|
||||
self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(2 * embed_dim, 2 * embed_dim),
|
||||
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
|
||||
nn.GELU(),
|
||||
nn.Linear(2 * embed_dim, embed_dim),
|
||||
)
|
||||
if flash and FLASH_AVAILABLE:
|
||||
self.flash = Attention(True)
|
||||
else:
|
||||
self.flash = None
|
||||
|
||||
def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
|
||||
return func(x0), func(x1)
|
||||
|
||||
def forward(
|
||||
self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
|
||||
) -> List[torch.Tensor]:
|
||||
qk0, qk1 = self.map_(self.to_qk, x0, x1)
|
||||
v0, v1 = self.map_(self.to_v, x0, x1)
|
||||
qk0, qk1, v0, v1 = map(
|
||||
lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
|
||||
(qk0, qk1, v0, v1),
|
||||
)
|
||||
if self.flash is not None and qk0.device.type == "cuda":
|
||||
m0 = self.flash(qk0, qk1, v1, mask)
|
||||
m1 = self.flash(
|
||||
qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
|
||||
)
|
||||
else:
|
||||
qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
|
||||
sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
|
||||
if mask is not None:
|
||||
sim = sim.masked_fill(~mask, -float("inf"))
|
||||
attn01 = F.softmax(sim, dim=-1)
|
||||
attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
|
||||
m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
|
||||
m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
|
||||
if mask is not None:
|
||||
m0, m1 = m0.nan_to_num(), m1.nan_to_num()
|
||||
m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
|
||||
m0, m1 = self.map_(self.to_out, m0, m1)
|
||||
x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
|
||||
x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
|
||||
return x0, x1
|
||||
|
||||
|
||||
class TransformerLayer(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.self_attn = SelfBlock(*args, **kwargs)
|
||||
self.cross_attn = CrossBlock(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
desc0,
|
||||
desc1,
|
||||
encoding0,
|
||||
encoding1,
|
||||
mask0: Optional[torch.Tensor] = None,
|
||||
mask1: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if mask0 is not None and mask1 is not None:
|
||||
return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
|
||||
else:
|
||||
desc0 = self.self_attn(desc0, encoding0)
|
||||
desc1 = self.self_attn(desc1, encoding1)
|
||||
return self.cross_attn(desc0, desc1)
|
||||
|
||||
# This part is compiled and allows padding inputs
|
||||
def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
|
||||
mask = mask0 & mask1.transpose(-1, -2)
|
||||
mask0 = mask0 & mask0.transpose(-1, -2)
|
||||
mask1 = mask1 & mask1.transpose(-1, -2)
|
||||
desc0 = self.self_attn(desc0, encoding0, mask0)
|
||||
desc1 = self.self_attn(desc1, encoding1, mask1)
|
||||
return self.cross_attn(desc0, desc1, mask)
|
||||
|
||||
|
||||
def sigmoid_log_double_softmax(
|
||||
sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""create the log assignment matrix from logits and similarity"""
|
||||
b, m, n = sim.shape
|
||||
certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
|
||||
scores0 = F.log_softmax(sim, 2)
|
||||
scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
|
||||
scores = sim.new_full((b, m + 1, n + 1), 0)
|
||||
scores[:, :m, :n] = scores0 + scores1 + certainties
|
||||
scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
|
||||
scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
|
||||
return scores
|
||||
|
||||
|
||||
class MatchAssignment(nn.Module):
|
||||
def __init__(self, dim: int) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.matchability = nn.Linear(dim, 1, bias=True)
|
||||
self.final_proj = nn.Linear(dim, dim, bias=True)
|
||||
|
||||
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
|
||||
"""build assignment matrix from descriptors"""
|
||||
mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
|
||||
_, _, d = mdesc0.shape
|
||||
mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
|
||||
sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
|
||||
z0 = self.matchability(desc0)
|
||||
z1 = self.matchability(desc1)
|
||||
scores = sigmoid_log_double_softmax(sim, z0, z1)
|
||||
return scores, sim
|
||||
|
||||
def get_matchability(self, desc: torch.Tensor):
|
||||
return torch.sigmoid(self.matchability(desc)).squeeze(-1)
|
||||
|
||||
|
||||
def filter_matches(scores: torch.Tensor, th: float):
|
||||
"""obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
|
||||
max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
|
||||
m0, m1 = max0.indices, max1.indices
|
||||
indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
|
||||
indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
|
||||
mutual0 = indices0 == m1.gather(1, m0)
|
||||
mutual1 = indices1 == m0.gather(1, m1)
|
||||
max0_exp = max0.values.exp()
|
||||
zero = max0_exp.new_tensor(0)
|
||||
mscores0 = torch.where(mutual0, max0_exp, zero)
|
||||
mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
|
||||
valid0 = mutual0 & (mscores0 > th)
|
||||
valid1 = mutual1 & valid0.gather(1, m1)
|
||||
m0 = torch.where(valid0, m0, -1)
|
||||
m1 = torch.where(valid1, m1, -1)
|
||||
return m0, m1, mscores0, mscores1
|
||||
|
||||
|
||||
class LightGlue(nn.Module):
|
||||
default_conf = {
|
||||
"name": "lightglue", # just for interfacing
|
||||
"input_dim": 256, # input descriptor dimension (autoselected from weights)
|
||||
"add_scale_ori": False,
|
||||
"descriptor_dim": 256,
|
||||
"n_layers": 9,
|
||||
"num_heads": 4,
|
||||
"flash": False, # enable FlashAttention if available.
|
||||
"mp": False, # enable mixed precision
|
||||
"depth_confidence": -1, # early stopping, disable with -1
|
||||
"width_confidence": -1, # point pruning, disable with -1
|
||||
"filter_threshold": 0.0, # match threshold
|
||||
"checkpointed": False,
|
||||
"weights": None, # either a path or the name of pretrained weights (disk, ...)
|
||||
"weights_from_version": "v0.1_arxiv",
|
||||
"loss": {
|
||||
"gamma": 1.0,
|
||||
"fn": "nll",
|
||||
"nll_balancing": 0.5,
|
||||
},
|
||||
}
|
||||
|
||||
required_data_keys = ["keypoints0", "keypoints1", "descriptors0", "descriptors1"]
|
||||
|
||||
url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
|
||||
|
||||
def __init__(self, conf) -> None:
|
||||
super().__init__()
|
||||
self.conf = conf = OmegaConf.merge(self.default_conf, conf)
|
||||
if conf.input_dim != conf.descriptor_dim:
|
||||
self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
|
||||
else:
|
||||
self.input_proj = nn.Identity()
|
||||
|
||||
head_dim = conf.descriptor_dim // conf.num_heads
|
||||
self.posenc = LearnableFourierPositionalEncoding(
|
||||
2 + 2 * conf.add_scale_ori, head_dim, head_dim
|
||||
)
|
||||
|
||||
h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
|
||||
|
||||
self.transformers = nn.ModuleList(
|
||||
[TransformerLayer(d, h, conf.flash) for _ in range(n)]
|
||||
)
|
||||
|
||||
self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
|
||||
self.token_confidence = nn.ModuleList(
|
||||
[TokenConfidence(d) for _ in range(n - 1)]
|
||||
)
|
||||
|
||||
self.loss_fn = NLLLoss(conf.loss)
|
||||
|
||||
state_dict = None
|
||||
if conf.weights is not None:
|
||||
# weights can be either a path or an existing file from official LG
|
||||
if Path(conf.weights).exists():
|
||||
state_dict = torch.load(conf.weights, map_location="cpu")
|
||||
elif (Path(DATA_PATH) / conf.weights).exists():
|
||||
state_dict = torch.load(
|
||||
str(DATA_PATH / conf.weights), map_location="cpu"
|
||||
)
|
||||
else:
|
||||
fname = (
|
||||
f"{conf.weights}_{conf.weights_from_version}".replace(".", "-")
|
||||
+ ".pth"
|
||||
)
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
self.url.format(conf.weights_from_version, conf.weights),
|
||||
file_name=fname,
|
||||
)
|
||||
|
||||
if state_dict:
|
||||
# rename old state dict entries
|
||||
for i in range(self.conf.n_layers):
|
||||
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
|
||||
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
|
||||
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
|
||||
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
|
||||
self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
def compile(self, mode="reduce-overhead"):
|
||||
if self.conf.width_confidence != -1:
|
||||
warnings.warn(
|
||||
"Point pruning is partially disabled for compiled forward.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
for i in range(self.conf.n_layers):
|
||||
self.transformers[i] = torch.compile(
|
||||
self.transformers[i], mode=mode, fullgraph=True
|
||||
)
|
||||
|
||||
def forward(self, data: dict) -> dict:
|
||||
for key in self.required_data_keys:
|
||||
assert key in data, f"Missing key {key} in data"
|
||||
|
||||
kpts0, kpts1 = data["keypoints0"], data["keypoints1"]
|
||||
b, m, _ = kpts0.shape
|
||||
b, n, _ = kpts1.shape
|
||||
device = kpts0.device
|
||||
if "view0" in data.keys() and "view1" in data.keys():
|
||||
size0 = data["view0"].get("image_size")
|
||||
size1 = data["view1"].get("image_size")
|
||||
kpts0 = normalize_keypoints(kpts0, size0).clone()
|
||||
kpts1 = normalize_keypoints(kpts1, size1).clone()
|
||||
|
||||
if self.conf.add_scale_ori:
|
||||
sc0, o0 = data["scales0"], data["oris0"]
|
||||
sc1, o1 = data["scales1"], data["oris1"]
|
||||
kpts0 = torch.cat(
|
||||
[
|
||||
kpts0,
|
||||
sc0 if sc0.dim() == 3 else sc0[..., None],
|
||||
o0 if o0.dim() == 3 else o0[..., None],
|
||||
],
|
||||
-1,
|
||||
)
|
||||
kpts1 = torch.cat(
|
||||
[
|
||||
kpts1,
|
||||
sc1 if sc1.dim() == 3 else sc1[..., None],
|
||||
o1 if o1.dim() == 3 else o1[..., None],
|
||||
],
|
||||
-1,
|
||||
)
|
||||
|
||||
desc0 = data["descriptors0"].contiguous()
|
||||
desc1 = data["descriptors1"].contiguous()
|
||||
|
||||
assert desc0.shape[-1] == self.conf.input_dim
|
||||
assert desc1.shape[-1] == self.conf.input_dim
|
||||
if torch.is_autocast_enabled():
|
||||
desc0 = desc0.half()
|
||||
desc1 = desc1.half()
|
||||
desc0 = self.input_proj(desc0)
|
||||
desc1 = self.input_proj(desc1)
|
||||
# cache positional embeddings
|
||||
encoding0 = self.posenc(kpts0)
|
||||
encoding1 = self.posenc(kpts1)
|
||||
|
||||
# GNN + final_proj + assignment
|
||||
do_early_stop = self.conf.depth_confidence > 0 and not self.training
|
||||
do_point_pruning = self.conf.width_confidence > 0 and not self.training
|
||||
|
||||
all_desc0, all_desc1 = [], []
|
||||
|
||||
if do_point_pruning:
|
||||
ind0 = torch.arange(0, m, device=device)[None]
|
||||
ind1 = torch.arange(0, n, device=device)[None]
|
||||
# We store the index of the layer at which pruning is detected.
|
||||
prune0 = torch.ones_like(ind0)
|
||||
prune1 = torch.ones_like(ind1)
|
||||
token0, token1 = None, None
|
||||
for i in range(self.conf.n_layers):
|
||||
if self.conf.checkpointed and self.training:
|
||||
desc0, desc1 = checkpoint(
|
||||
self.transformers[i], desc0, desc1, encoding0, encoding1
|
||||
)
|
||||
else:
|
||||
desc0, desc1 = self.transformers[i](desc0, desc1, encoding0, encoding1)
|
||||
if self.training or i == self.conf.n_layers - 1:
|
||||
all_desc0.append(desc0)
|
||||
all_desc1.append(desc1)
|
||||
continue # no early stopping or adaptive width at last layer
|
||||
|
||||
# only for eval
|
||||
if do_early_stop:
|
||||
assert b == 1
|
||||
token0, token1 = self.token_confidence[i](desc0, desc1)
|
||||
if self.check_if_stop(token0[..., :m, :], token1[..., :n, :], i, m + n):
|
||||
break
|
||||
if do_point_pruning:
|
||||
assert b == 1
|
||||
scores0 = self.log_assignment[i].get_matchability(desc0)
|
||||
prunemask0 = self.get_pruning_mask(token0, scores0, i)
|
||||
keep0 = torch.where(prunemask0)[1]
|
||||
ind0 = ind0.index_select(1, keep0)
|
||||
desc0 = desc0.index_select(1, keep0)
|
||||
encoding0 = encoding0.index_select(-2, keep0)
|
||||
prune0[:, ind0] += 1
|
||||
scores1 = self.log_assignment[i].get_matchability(desc1)
|
||||
prunemask1 = self.get_pruning_mask(token1, scores1, i)
|
||||
keep1 = torch.where(prunemask1)[1]
|
||||
ind1 = ind1.index_select(1, keep1)
|
||||
desc1 = desc1.index_select(1, keep1)
|
||||
encoding1 = encoding1.index_select(-2, keep1)
|
||||
prune1[:, ind1] += 1
|
||||
|
||||
desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :]
|
||||
scores, _ = self.log_assignment[i](desc0, desc1)
|
||||
m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
|
||||
|
||||
if do_point_pruning:
|
||||
m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
|
||||
m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
|
||||
m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
|
||||
m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
|
||||
mscores0_ = torch.zeros((b, m), device=mscores0.device)
|
||||
mscores1_ = torch.zeros((b, n), device=mscores1.device)
|
||||
mscores0_[:, ind0] = mscores0
|
||||
mscores1_[:, ind1] = mscores1
|
||||
m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
|
||||
else:
|
||||
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
|
||||
prune1 = torch.ones_like(mscores1) * self.conf.n_layers
|
||||
|
||||
pred = {
|
||||
"matches0": m0,
|
||||
"matches1": m1,
|
||||
"matching_scores0": mscores0,
|
||||
"matching_scores1": mscores1,
|
||||
"ref_descriptors0": torch.stack(all_desc0, 1),
|
||||
"ref_descriptors1": torch.stack(all_desc1, 1),
|
||||
"log_assignment": scores,
|
||||
"prune0": prune0,
|
||||
"prune1": prune1,
|
||||
}
|
||||
|
||||
return pred
|
||||
|
||||
def confidence_threshold(self, layer_index: int) -> float:
|
||||
"""scaled confidence threshold"""
|
||||
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
|
||||
return np.clip(threshold, 0, 1)
|
||||
|
||||
def get_pruning_mask(
|
||||
self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
|
||||
) -> torch.Tensor:
|
||||
"""mask points which should be removed"""
|
||||
keep = scores > (1 - self.conf.width_confidence)
|
||||
if confidences is not None: # Low-confidence points are never pruned.
|
||||
keep |= confidences <= self.confidence_thresholds[layer_index]
|
||||
return keep
|
||||
|
||||
def check_if_stop(
|
||||
self,
|
||||
confidences0: torch.Tensor,
|
||||
confidences1: torch.Tensor,
|
||||
layer_index: int,
|
||||
num_points: int,
|
||||
) -> torch.Tensor:
|
||||
"""evaluate stopping condition"""
|
||||
confidences = torch.cat([confidences0, confidences1], -1)
|
||||
threshold = self.confidence_thresholds[layer_index]
|
||||
ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
|
||||
return ratio_confident > self.conf.depth_confidence
|
||||
|
||||
def pruning_min_kpts(self, device: torch.device):
|
||||
if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
|
||||
return self.pruning_keypoint_thresholds["flash"]
|
||||
else:
|
||||
return self.pruning_keypoint_thresholds[device.type]
|
||||
|
||||
def loss(self, pred, data):
|
||||
def loss_params(pred, i):
|
||||
la, _ = self.log_assignment[i](
|
||||
pred["ref_descriptors0"][:, i], pred["ref_descriptors1"][:, i]
|
||||
)
|
||||
return {
|
||||
"log_assignment": la,
|
||||
}
|
||||
|
||||
sum_weights = 1.0
|
||||
nll, gt_weights, loss_metrics = self.loss_fn(loss_params(pred, -1), data)
|
||||
N = pred["ref_descriptors0"].shape[1]
|
||||
losses = {"total": nll, "last": nll.clone().detach(), **loss_metrics}
|
||||
|
||||
if self.training:
|
||||
losses["confidence"] = 0.0
|
||||
|
||||
# B = pred['log_assignment'].shape[0]
|
||||
losses["row_norm"] = pred["log_assignment"].exp()[:, :-1].sum(2).mean(1)
|
||||
for i in range(N - 1):
|
||||
params_i = loss_params(pred, i)
|
||||
nll, _, _ = self.loss_fn(params_i, data, weights=gt_weights)
|
||||
|
||||
if self.conf.loss.gamma > 0.0:
|
||||
weight = self.conf.loss.gamma ** (N - i - 1)
|
||||
else:
|
||||
weight = i + 1
|
||||
sum_weights += weight
|
||||
losses["total"] = losses["total"] + nll * weight
|
||||
|
||||
losses["confidence"] += self.token_confidence[i].loss(
|
||||
pred["ref_descriptors0"][:, i],
|
||||
pred["ref_descriptors1"][:, i],
|
||||
params_i["log_assignment"],
|
||||
pred["log_assignment"],
|
||||
) / (N - 1)
|
||||
|
||||
del params_i
|
||||
losses["total"] /= sum_weights
|
||||
|
||||
# confidences
|
||||
if self.training:
|
||||
losses["total"] = losses["total"] + losses["confidence"]
|
||||
|
||||
if not self.training:
|
||||
# add metrics
|
||||
metrics = matcher_metrics(pred, data)
|
||||
else:
|
||||
metrics = {}
|
||||
return losses, metrics
|
||||
|
||||
|
||||
__main_model__ = LightGlue
|
|
@ -0,0 +1,34 @@
|
|||
from ..base_model import BaseModel
|
||||
from lightglue import LightGlue as LightGlue_
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
class LightGlue(BaseModel):
|
||||
default_conf = {"features": "superpoint", **LightGlue_.default_conf}
|
||||
required_data_keys = [
|
||||
"view0",
|
||||
"keypoints0",
|
||||
"descriptors0",
|
||||
"view1",
|
||||
"keypoints1",
|
||||
"descriptors1",
|
||||
]
|
||||
|
||||
def _init(self, conf):
|
||||
dconf = OmegaConf.to_container(conf)
|
||||
self.net = LightGlue_(dconf.pop("features"), **dconf).cuda()
|
||||
# self.net.compile()
|
||||
|
||||
def _forward(self, data):
|
||||
view0 = {
|
||||
**{k: data[k + "0"] for k in ["keypoints", "descriptors"]},
|
||||
**data["view0"],
|
||||
}
|
||||
view1 = {
|
||||
**{k: data[k + "1"] for k in ["keypoints", "descriptors"]},
|
||||
**data["view1"],
|
||||
}
|
||||
return self.net({"image0": view0, "image1": view1})
|
||||
|
||||
def loss(pred, data):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,96 @@
|
|||
"""
|
||||
Nearest neighbor matcher for normalized descriptors.
|
||||
Optionally apply the mutual check and threshold the distance or ratio.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..utils.metrics import matcher_metrics
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def find_nn(sim, ratio_thresh, distance_thresh):
|
||||
sim_nn, ind_nn = sim.topk(2 if ratio_thresh else 1, dim=-1, largest=True)
|
||||
dist_nn = 2 * (1 - sim_nn)
|
||||
mask = torch.ones(ind_nn.shape[:-1], dtype=torch.bool, device=sim.device)
|
||||
if ratio_thresh:
|
||||
mask = mask & (dist_nn[..., 0] <= (ratio_thresh**2) * dist_nn[..., 1])
|
||||
if distance_thresh:
|
||||
mask = mask & (dist_nn[..., 0] <= distance_thresh**2)
|
||||
matches = torch.where(mask, ind_nn[..., 0], ind_nn.new_tensor(-1))
|
||||
return matches
|
||||
|
||||
|
||||
def mutual_check(m0, m1):
|
||||
inds0 = torch.arange(m0.shape[-1], device=m0.device)
|
||||
inds1 = torch.arange(m1.shape[-1], device=m1.device)
|
||||
loop0 = torch.gather(m1, -1, torch.where(m0 > -1, m0, m0.new_tensor(0)))
|
||||
loop1 = torch.gather(m0, -1, torch.where(m1 > -1, m1, m1.new_tensor(0)))
|
||||
m0_new = torch.where((m0 > -1) & (inds0 == loop0), m0, m0.new_tensor(-1))
|
||||
m1_new = torch.where((m1 > -1) & (inds1 == loop1), m1, m1.new_tensor(-1))
|
||||
return m0_new, m1_new
|
||||
|
||||
|
||||
class NearestNeighborMatcher(BaseModel):
|
||||
default_conf = {
|
||||
"ratio_thresh": None,
|
||||
"distance_thresh": None,
|
||||
"mutual_check": True,
|
||||
"loss": None,
|
||||
}
|
||||
required_data_keys = ["descriptors0", "descriptors1"]
|
||||
|
||||
def _init(self, conf):
|
||||
if conf.loss == "N_pair":
|
||||
temperature = torch.nn.Parameter(torch.tensor(1.0))
|
||||
self.register_parameter("temperature", temperature)
|
||||
|
||||
def _forward(self, data):
|
||||
sim = torch.einsum("bnd,bmd->bnm", data["descriptors0"], data["descriptors1"])
|
||||
matches0 = find_nn(sim, self.conf.ratio_thresh, self.conf.distance_thresh)
|
||||
matches1 = find_nn(
|
||||
sim.transpose(1, 2), self.conf.ratio_thresh, self.conf.distance_thresh
|
||||
)
|
||||
if self.conf.mutual_check:
|
||||
matches0, matches1 = mutual_check(matches0, matches1)
|
||||
b, m, n = sim.shape
|
||||
la = sim.new_zeros(b, m + 1, n + 1)
|
||||
la[:, :-1, :-1] = F.log_softmax(sim, -1) + F.log_softmax(sim, -2)
|
||||
mscores0 = (matches0 > -1).float()
|
||||
mscores1 = (matches1 > -1).float()
|
||||
return {
|
||||
"matches0": matches0,
|
||||
"matches1": matches1,
|
||||
"matching_scores0": mscores0,
|
||||
"matching_scores1": mscores1,
|
||||
"similarity": sim,
|
||||
"log_assignment": la,
|
||||
}
|
||||
|
||||
def loss(self, pred, data):
|
||||
losses = {}
|
||||
if self.conf.loss == "N_pair":
|
||||
sim = pred["similarity"]
|
||||
if torch.any(sim > (1.0 + 1e-6)):
|
||||
logging.warning(f"Similarity larger than 1, max={sim.max()}")
|
||||
scores = torch.sqrt(torch.clamp(2 * (1 - sim), min=1e-6))
|
||||
scores = self.temperature * (2 - scores)
|
||||
assert not torch.any(torch.isnan(scores)), torch.any(torch.isnan(sim))
|
||||
prob0 = torch.nn.functional.log_softmax(scores, 2)
|
||||
prob1 = torch.nn.functional.log_softmax(scores, 1)
|
||||
|
||||
assignment = data["gt_assignment"].float()
|
||||
num = torch.max(assignment.sum((1, 2)), assignment.new_tensor(1))
|
||||
nll0 = (prob0 * assignment).sum((1, 2)) / num
|
||||
nll1 = (prob1 * assignment).sum((1, 2)) / num
|
||||
nll = -(nll0 + nll1) / 2
|
||||
losses["n_pair_nll"] = losses["total"] = nll
|
||||
losses["num_matchable"] = num
|
||||
losses["n_pair_temperature"] = self.temperature[None]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
metrics = {} if self.training else matcher_metrics(pred, data)
|
||||
return losses, metrics
|
|
@ -0,0 +1,98 @@
|
|||
"""
|
||||
A two-view sparse feature matching pipeline on triplets.
|
||||
|
||||
If a triplet is found, runs the extractor on three images and
|
||||
then runs matcher/filter/solver for all three pairs.
|
||||
|
||||
Losses and metrics get accumulated accordingly.
|
||||
|
||||
If no triplet is found, this falls back to two_view_pipeline.py
|
||||
"""
|
||||
|
||||
from .two_view_pipeline import TwoViewPipeline
|
||||
import torch
|
||||
from ..utils.misc import get_twoview, stack_twoviews, unstack_twoviews
|
||||
|
||||
|
||||
def has_triplet(data):
|
||||
# we already check for image0 and image1 in required_keys
|
||||
return "view2" in data.keys()
|
||||
|
||||
|
||||
class TripletPipeline(TwoViewPipeline):
|
||||
default_conf = {"batch_triplets": True, **TwoViewPipeline.default_conf}
|
||||
|
||||
def _forward(self, data):
|
||||
if not has_triplet(data):
|
||||
return super()._forward(data)
|
||||
# the two-view outputs are stored in
|
||||
# pred['0to1'],pred['0to2'], pred['1to2']
|
||||
|
||||
assert not self.conf.run_gt_in_forward
|
||||
pred0 = self.extract_view(data, "0")
|
||||
pred1 = self.extract_view(data, "1")
|
||||
pred2 = self.extract_view(data, "2")
|
||||
|
||||
pred = {}
|
||||
pred = {
|
||||
**{k + "0": v for k, v in pred0.items()},
|
||||
**{k + "1": v for k, v in pred1.items()},
|
||||
**{k + "2": v for k, v in pred2.items()},
|
||||
}
|
||||
|
||||
def predict_twoview(pred, data):
|
||||
# forward pass
|
||||
if self.conf.matcher.name:
|
||||
pred = {**pred, **self.matcher({**data, **pred})}
|
||||
|
||||
if self.conf.filter.name:
|
||||
pred = {**pred, **self.filter({**m_data, **pred})}
|
||||
|
||||
if self.conf.solver.name:
|
||||
pred = {**pred, **self.solver({**m_data, **pred})}
|
||||
return pred
|
||||
|
||||
if self.conf.batch_triplets:
|
||||
B = data["image1"].shape[0]
|
||||
# stack on batch dimension
|
||||
m_data = stack_twoviews(data)
|
||||
m_pred = stack_twoviews(pred)
|
||||
|
||||
# forward pass
|
||||
m_pred = predict_twoview(m_pred, m_data)
|
||||
|
||||
# unstack
|
||||
pred = {**pred, **unstack_twoviews(m_pred, B)}
|
||||
else:
|
||||
for idx in ["0to1", "0to2", "1to2"]:
|
||||
m_data = get_twoview(data, idx)
|
||||
m_pred = get_twoview(pred, idx)
|
||||
pred[idx] = predict_twoview(m_pred, m_data)
|
||||
return pred
|
||||
|
||||
def loss(self, pred, data):
|
||||
if not has_triplet(data):
|
||||
return super().loss(pred, data)
|
||||
if self.conf.batch_triplets:
|
||||
m_data = stack_twoviews(data)
|
||||
m_pred = stack_twoviews(pred)
|
||||
losses, metrics = super().loss(m_pred, m_data)
|
||||
else:
|
||||
losses = {}
|
||||
metrics = {}
|
||||
for idx in ["0to1", "0to2", "1to2"]:
|
||||
data_i = get_twoview(data, idx)
|
||||
pred_i = pred[idx]
|
||||
losses_i, metrics_i = super().loss(pred_i, data_i)
|
||||
for k, v in losses_i.items():
|
||||
if k in losses.keys():
|
||||
losses[k] = losses[k] + v
|
||||
else:
|
||||
losses[k] = v
|
||||
for k, v in metrics_i.items():
|
||||
if k in metrics.keys():
|
||||
metrics[k] = torch.cat([metrics[k], v], 0)
|
||||
else:
|
||||
metrics[k] = v
|
||||
|
||||
return losses, metrics
|
|
@ -0,0 +1,114 @@
|
|||
"""
|
||||
A two-view sparse feature matching pipeline.
|
||||
|
||||
This model contains sub-models for each step:
|
||||
feature extraction, feature matching, outlier filtering, pose estimation.
|
||||
Each step is optional, and the features or matches can be provided as input.
|
||||
Default: SuperPoint with nearest neighbor matching.
|
||||
|
||||
Convention for the matches: m0[i] is the index of the keypoint in image 1
|
||||
that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched.
|
||||
"""
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from .base_model import BaseModel
|
||||
from . import get_model
|
||||
|
||||
|
||||
to_ctr = OmegaConf.to_container # convert DictConfig to dict
|
||||
|
||||
|
||||
class TwoViewPipeline(BaseModel):
|
||||
default_conf = {
|
||||
"extractor": {
|
||||
"name": None,
|
||||
"trainable": False,
|
||||
},
|
||||
"matcher": {"name": None},
|
||||
"filter": {"name": None},
|
||||
"solver": {"name": None},
|
||||
"ground_truth": {"name": None},
|
||||
"allow_no_extract": False,
|
||||
"run_gt_in_forward": False,
|
||||
}
|
||||
required_data_keys = ["view0", "view1"]
|
||||
strict_conf = False # need to pass new confs to children models
|
||||
components = [
|
||||
"extractor",
|
||||
"matcher",
|
||||
"filter",
|
||||
"solver",
|
||||
"ground_truth",
|
||||
]
|
||||
|
||||
def _init(self, conf):
|
||||
if conf.extractor.name:
|
||||
self.extractor = get_model(conf.extractor.name)(to_ctr(conf.extractor))
|
||||
|
||||
if conf.matcher.name:
|
||||
self.matcher = get_model(conf.matcher.name)(to_ctr(conf.matcher))
|
||||
|
||||
if conf.filter.name:
|
||||
self.filter = get_model(conf.filter.name)(to_ctr(conf.filter))
|
||||
|
||||
if conf.solver.name:
|
||||
self.solver = get_model(conf.solver.name)(to_ctr(conf.solver))
|
||||
|
||||
if conf.ground_truth.name:
|
||||
self.ground_truth = get_model(conf.ground_truth.name)(
|
||||
to_ctr(conf.ground_truth)
|
||||
)
|
||||
|
||||
def extract_view(self, data, i):
|
||||
data_i = data[f"view{i}"]
|
||||
pred_i = data_i.get("cache", {})
|
||||
skip_extract = len(pred_i) > 0 and self.conf.allow_no_extract
|
||||
if self.conf.extractor.name and not skip_extract:
|
||||
pred_i = {**pred_i, **self.extractor(data_i)}
|
||||
elif self.conf.extractor.name and not self.conf.allow_no_extract:
|
||||
pred_i = {**pred_i, **self.extractor({**data_i, **pred_i})}
|
||||
return pred_i
|
||||
|
||||
def _forward(self, data):
|
||||
pred0 = self.extract_view(data, "0")
|
||||
pred1 = self.extract_view(data, "1")
|
||||
pred = {
|
||||
**{k + "0": v for k, v in pred0.items()},
|
||||
**{k + "1": v for k, v in pred1.items()},
|
||||
}
|
||||
|
||||
if self.conf.matcher.name:
|
||||
pred = {**pred, **self.matcher({**data, **pred})}
|
||||
if self.conf.filter.name:
|
||||
pred = {**pred, **self.filter({**data, **pred})}
|
||||
if self.conf.solver.name:
|
||||
pred = {**pred, **self.solver({**data, **pred})}
|
||||
|
||||
if self.conf.ground_truth.name and self.conf.run_gt_in_forward:
|
||||
gt_pred = self.ground_truth({**data, **pred})
|
||||
pred.update({f"gt_{k}": v for k, v in gt_pred.items()})
|
||||
return pred
|
||||
|
||||
def loss(self, pred, data):
|
||||
losses = {}
|
||||
metrics = {}
|
||||
total = 0
|
||||
|
||||
# get labels
|
||||
if self.conf.ground_truth.name and not self.conf.run_gt_in_forward:
|
||||
gt_pred = self.ground_truth({**data, **pred})
|
||||
pred.update({f"gt_{k}": v for k, v in gt_pred.items()})
|
||||
|
||||
for k in self.components:
|
||||
apply = True
|
||||
if "apply_loss" in self.conf[k].keys():
|
||||
apply = self.conf[k].apply_loss
|
||||
if self.conf[k].name and apply:
|
||||
try:
|
||||
losses_, metrics_ = getattr(self, k).loss(pred, {**pred, **data})
|
||||
except NotImplementedError:
|
||||
continue
|
||||
losses = {**losses, **losses_}
|
||||
metrics = {**metrics, **metrics_}
|
||||
total = losses_["total"] + total
|
||||
return {**losses, "total": total}, metrics
|
|
@ -0,0 +1,73 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def weight_loss(log_assignment, weights, gamma=0.0):
|
||||
b, m, n = log_assignment.shape
|
||||
m -= 1
|
||||
n -= 1
|
||||
|
||||
loss_sc = log_assignment * weights
|
||||
|
||||
num_neg0 = weights[:, :m, -1].sum(-1).clamp(min=1.0)
|
||||
num_neg1 = weights[:, -1, :n].sum(-1).clamp(min=1.0)
|
||||
num_pos = weights[:, :m, :n].sum((-1, -2)).clamp(min=1.0)
|
||||
|
||||
nll_pos = -loss_sc[:, :m, :n].sum((-1, -2))
|
||||
nll_pos /= num_pos.clamp(min=1.0)
|
||||
|
||||
nll_neg0 = -loss_sc[:, :m, -1].sum(-1)
|
||||
nll_neg1 = -loss_sc[:, -1, :n].sum(-1)
|
||||
|
||||
nll_neg = (nll_neg0 + nll_neg1) / (num_neg0 + num_neg1)
|
||||
|
||||
return nll_pos, nll_neg, num_pos, (num_neg0 + num_neg1) / 2.0
|
||||
|
||||
|
||||
class NLLLoss(nn.Module):
|
||||
default_conf = {
|
||||
"nll_balancing": 0.5,
|
||||
"gamma_f": 0.0, # focal loss
|
||||
}
|
||||
|
||||
def __init__(self, conf):
|
||||
super().__init__()
|
||||
self.conf = OmegaConf.merge(self.default_conf, conf)
|
||||
self.loss_fn = self.nll_loss
|
||||
|
||||
def forward(self, pred, data, weights=None):
|
||||
log_assignment = pred["log_assignment"]
|
||||
if weights is None:
|
||||
weights = self.loss_fn(log_assignment, data)
|
||||
nll_pos, nll_neg, num_pos, num_neg = weight_loss(
|
||||
log_assignment, weights, gamma=self.conf.gamma_f
|
||||
)
|
||||
nll = (
|
||||
self.conf.nll_balancing * nll_pos + (1 - self.conf.nll_balancing) * nll_neg
|
||||
)
|
||||
|
||||
return (
|
||||
nll,
|
||||
weights,
|
||||
{
|
||||
"assignment_nll": nll,
|
||||
"nll_pos": nll_pos,
|
||||
"nll_neg": nll_neg,
|
||||
"num_matchable": num_pos,
|
||||
"num_unmatchable": num_neg,
|
||||
},
|
||||
)
|
||||
|
||||
def nll_loss(self, log_assignment, data):
|
||||
m, n = data["gt_matches0"].size(-1), data["gt_matches1"].size(-1)
|
||||
positive = data["gt_assignment"].float()
|
||||
neg0 = (data["gt_matches0"] == -1).float()
|
||||
neg1 = (data["gt_matches1"] == -1).float()
|
||||
|
||||
weights = torch.zeros_like(log_assignment)
|
||||
weights[:, :m, :n] = positive
|
||||
|
||||
weights[:, :m, -1] = neg0
|
||||
weights[:, -1, :m] = neg1
|
||||
return weights
|
|
@ -0,0 +1,50 @@
|
|||
import torch
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def matcher_metrics(pred, data, prefix="", prefix_gt=None):
|
||||
def recall(m, gt_m):
|
||||
mask = (gt_m > -1).float()
|
||||
return ((m == gt_m) * mask).sum(1) / (1e-8 + mask.sum(1))
|
||||
|
||||
def accuracy(m, gt_m):
|
||||
mask = (gt_m >= -1).float()
|
||||
return ((m == gt_m) * mask).sum(1) / (1e-8 + mask.sum(1))
|
||||
|
||||
def precision(m, gt_m):
|
||||
mask = ((m > -1) & (gt_m >= -1)).float()
|
||||
return ((m == gt_m) * mask).sum(1) / (1e-8 + mask.sum(1))
|
||||
|
||||
def ranking_ap(m, gt_m, scores):
|
||||
p_mask = ((m > -1) & (gt_m >= -1)).float()
|
||||
r_mask = (gt_m > -1).float()
|
||||
sort_ind = torch.argsort(-scores)
|
||||
sorted_p_mask = torch.gather(p_mask, -1, sort_ind)
|
||||
sorted_r_mask = torch.gather(r_mask, -1, sort_ind)
|
||||
sorted_tp = torch.gather(m == gt_m, -1, sort_ind)
|
||||
p_pts = torch.cumsum(sorted_tp * sorted_p_mask, -1) / (
|
||||
1e-8 + torch.cumsum(sorted_p_mask, -1)
|
||||
)
|
||||
r_pts = torch.cumsum(sorted_tp * sorted_r_mask, -1) / (
|
||||
1e-8 + sorted_r_mask.sum(-1)[:, None]
|
||||
)
|
||||
r_pts_diff = r_pts[..., 1:] - r_pts[..., :-1]
|
||||
return torch.sum(r_pts_diff * p_pts[:, None, -1], dim=-1)
|
||||
|
||||
if prefix_gt is None:
|
||||
prefix_gt = prefix
|
||||
rec = recall(pred[f"{prefix}matches0"], data[f"gt_{prefix_gt}matches0"])
|
||||
prec = precision(pred[f"{prefix}matches0"], data[f"gt_{prefix_gt}matches0"])
|
||||
acc = accuracy(pred[f"{prefix}matches0"], data[f"gt_{prefix_gt}matches0"])
|
||||
ap = ranking_ap(
|
||||
pred[f"{prefix}matches0"],
|
||||
data[f"gt_{prefix_gt}matches0"],
|
||||
pred[f"{prefix}matching_scores0"],
|
||||
)
|
||||
metrics = {
|
||||
f"{prefix}match_recall": rec,
|
||||
f"{prefix}match_precision": prec,
|
||||
f"{prefix}accuracy": acc,
|
||||
f"{prefix}average_precision": ap,
|
||||
}
|
||||
return metrics
|
|
@ -0,0 +1,69 @@
|
|||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
|
||||
def to_sequence(map):
|
||||
return map.flatten(-2).transpose(-1, -2)
|
||||
|
||||
|
||||
def to_map(sequence):
|
||||
n = sequence.shape[-2]
|
||||
e = math.isqrt(n)
|
||||
assert e * e == n
|
||||
assert e * e == n
|
||||
sequence.transpose(-1, -2).unflatten(-1, [e, e])
|
||||
|
||||
|
||||
def pad_to_length(
|
||||
x,
|
||||
length: int,
|
||||
pad_dim: int = -2,
|
||||
mode: str = "zeros", # zeros, ones, random, random_c
|
||||
bounds: Tuple[int] = (None, None),
|
||||
):
|
||||
shape = list(x.shape)
|
||||
d = x.shape[pad_dim]
|
||||
assert d <= length
|
||||
if d == length:
|
||||
return x
|
||||
shape[pad_dim] = length - d
|
||||
|
||||
low, high = bounds
|
||||
|
||||
if mode == "zeros":
|
||||
xn = torch.zeros(*shape, device=x.device, dtype=x.dtype)
|
||||
elif mode == "ones":
|
||||
xn = torch.ones(*shape, device=x.device, dtype=x.dtype)
|
||||
elif mode == "random":
|
||||
low = low if low is not None else x.min()
|
||||
high = high if high is not None else x.max()
|
||||
xn = torch.empty(*shape, device=x.device).uniform_(low, high)
|
||||
elif mode == "random_c":
|
||||
low, high = bounds # we use the bounds as fallback for empty seq.
|
||||
xn = torch.cat(
|
||||
[
|
||||
torch.empty(*shape[:-1], 1, device=x.device).uniform_(
|
||||
x[..., i].min() if d > 0 else low,
|
||||
x[..., i].max() if d > 0 else high,
|
||||
)
|
||||
for i in range(shape[-1])
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return torch.cat([x, xn], dim=pad_dim)
|
||||
|
||||
|
||||
def pad_and_stack(
|
||||
sequences: List[torch.Tensor],
|
||||
length: Optional[int] = None,
|
||||
pad_dim: int = -2,
|
||||
**kwargs,
|
||||
):
|
||||
if length is None:
|
||||
length = max([x.shape[pad_dim] for x in sequences])
|
||||
|
||||
y = torch.stack([pad_to_length(x, length, pad_dim, **kwargs) for x in sequences], 0)
|
||||
return y
|
|
@ -0,0 +1,14 @@
|
|||
import inspect
|
||||
from .base_estimator import BaseEstimator
|
||||
|
||||
|
||||
def load_estimator(type, estimator):
|
||||
module_path = f"{__name__}.{type}.{estimator}"
|
||||
module = __import__(module_path, fromlist=[""])
|
||||
classes = inspect.getmembers(module, inspect.isclass)
|
||||
# Filter classes defined in the module
|
||||
classes = [c for c in classes if c[1].__module__ == module_path]
|
||||
# Filter classes inherited from BaseModel
|
||||
classes = [c for c in classes if issubclass(c[1], BaseEstimator)]
|
||||
assert len(classes) == 1, classes
|
||||
return classes[0][1]
|
|
@ -0,0 +1,32 @@
|
|||
from omegaconf import OmegaConf
|
||||
from copy import copy
|
||||
|
||||
|
||||
class BaseEstimator:
|
||||
base_default_conf = {
|
||||
"name": "???",
|
||||
"ransac_th": "???",
|
||||
}
|
||||
test_thresholds = [1.0]
|
||||
required_data_keys = []
|
||||
|
||||
strict_conf = False
|
||||
|
||||
def __init__(self, conf):
|
||||
"""Perform some logic and call the _init method of the child model."""
|
||||
default_conf = OmegaConf.merge(
|
||||
self.base_default_conf, OmegaConf.create(self.default_conf)
|
||||
)
|
||||
if self.strict_conf:
|
||||
OmegaConf.set_struct(default_conf, True)
|
||||
|
||||
if isinstance(conf, dict):
|
||||
conf = OmegaConf.create(conf)
|
||||
self.conf = conf = OmegaConf.merge(default_conf, conf)
|
||||
OmegaConf.set_readonly(conf, True)
|
||||
OmegaConf.set_struct(conf, True)
|
||||
self.required_data_keys = copy(self.required_data_keys)
|
||||
self._init(conf)
|
||||
|
||||
def __call__(self, data):
|
||||
return self._forward(data)
|
|
@ -0,0 +1,72 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from homography_est import (
|
||||
LineSegment,
|
||||
ransac_line_homography,
|
||||
ransac_point_homography,
|
||||
ransac_point_line_homography,
|
||||
)
|
||||
|
||||
from ..base_estimator import BaseEstimator
|
||||
|
||||
|
||||
def H_estimation_hybrid(kpts0=None, kpts1=None, lines0=None, lines1=None, tol_px=5):
|
||||
"""Estimate a homography from points and lines with hybrid RANSAC.
|
||||
All features are expected in x-y convention
|
||||
"""
|
||||
# Check that we have at least 4 features
|
||||
n_features = 0
|
||||
if kpts0 is not None:
|
||||
n_features += len(kpts0) + len(kpts1)
|
||||
if lines0 is not None:
|
||||
n_features += len(lines0) + len(lines1)
|
||||
if n_features < 4:
|
||||
return None
|
||||
|
||||
if lines0 is None:
|
||||
# Point-only RANSAC
|
||||
H = ransac_point_homography(kpts0, kpts1, tol_px, False, [])
|
||||
elif kpts0 is None:
|
||||
# Line-only RANSAC
|
||||
ls0 = [LineSegment(line[0], line[1]) for line in lines0]
|
||||
ls1 = [LineSegment(line[0], line[1]) for line in lines1]
|
||||
H = ransac_line_homography(ls0, ls1, tol_px, False, [])
|
||||
else:
|
||||
# Point-lines RANSAC
|
||||
ls0 = [LineSegment(line[0], line[1]) for line in lines0]
|
||||
ls1 = [LineSegment(line[0], line[1]) for line in lines1]
|
||||
H = ransac_point_line_homography(kpts0, kpts1, ls0, ls1, tol_px, False, [], [])
|
||||
if np.abs(H[-1, -1]) > 1e-8:
|
||||
H /= H[-1, -1]
|
||||
return H
|
||||
|
||||
|
||||
class PointLineHomographyEstimator(BaseEstimator):
|
||||
default_conf = {"ransac_th": 2.0, "options": {}}
|
||||
|
||||
required_data_keys = ["m_kpts0", "m_kpts1", "m_lines0", "m_lines1"]
|
||||
|
||||
def _init(self, conf):
|
||||
pass
|
||||
|
||||
def _forward(self, data):
|
||||
m_features = {
|
||||
"kpts0": data["m_kpts1"].numpy() if "m_kpts1" in data else None,
|
||||
"kpts1": data["m_kpts0"].numpy() if "m_kpts0" in data else None,
|
||||
"lines0": data["m_lines1"].numpy() if "m_lines1" in data else None,
|
||||
"lines1": data["m_lines0"].numpy() if "m_lines0" in data else None,
|
||||
}
|
||||
feat = data["m_kpts0"] if "m_kpts0" in data else data["m_lines0"]
|
||||
M = H_estimation_hybrid(**m_features, tol_px=self.conf.ransac_th)
|
||||
success = M is not None
|
||||
if not success:
|
||||
M = torch.eye(3, device=feat.device, dtype=feat.dtype)
|
||||
else:
|
||||
M = torch.tensor(M).to(feat)
|
||||
|
||||
estimation = {
|
||||
"success": success,
|
||||
"M_0to1": M,
|
||||
}
|
||||
|
||||
return estimation
|
|
@ -0,0 +1,53 @@
|
|||
import cv2
|
||||
import torch
|
||||
|
||||
from ..base_estimator import BaseEstimator
|
||||
|
||||
|
||||
class OpenCVHomographyEstimator(BaseEstimator):
|
||||
default_conf = {
|
||||
"ransac_th": 3.0,
|
||||
"options": {"method": "ransac", "max_iters": 3000, "confidence": 0.995},
|
||||
}
|
||||
|
||||
required_data_keys = ["m_kpts0", "m_kpts1"]
|
||||
|
||||
def _init(self, conf):
|
||||
self.solver = {
|
||||
"ransac": cv2.RANSAC,
|
||||
"lmeds": cv2.LMEDS,
|
||||
"rho": cv2.RHO,
|
||||
"usac": cv2.USAC_DEFAULT,
|
||||
"usac_fast": cv2.USAC_FAST,
|
||||
"usac_accurate": cv2.USAC_ACCURATE,
|
||||
"usac_prosac": cv2.USAC_PROSAC,
|
||||
"usac_magsac": cv2.USAC_MAGSAC,
|
||||
}[conf.options.method]
|
||||
|
||||
def _forward(self, data):
|
||||
pts0, pts1 = data["m_kpts0"], data["m_kpts1"]
|
||||
|
||||
try:
|
||||
M, mask = cv2.findHomography(
|
||||
pts0.numpy(),
|
||||
pts1.numpy(),
|
||||
self.solver,
|
||||
self.conf.ransac_th,
|
||||
maxIters=self.conf.options.max_iters,
|
||||
confidence=self.conf.options.confidence,
|
||||
)
|
||||
success = M is not None
|
||||
except cv2.error:
|
||||
success = False
|
||||
if not success:
|
||||
M = torch.eye(3, device=pts0.device, dtype=pts0.dtype)
|
||||
inl = torch.zeros_like(pts0[:, 0]).bool()
|
||||
else:
|
||||
M = torch.tensor(M).to(pts0)
|
||||
inl = torch.tensor(mask).bool().to(pts0.device)
|
||||
|
||||
return {
|
||||
"success": success,
|
||||
"M_0to1": M,
|
||||
"inliers": inl,
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
import poselib
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
|
||||
from ..base_estimator import BaseEstimator
|
||||
|
||||
|
||||
class PoseLibHomographyEstimator(BaseEstimator):
|
||||
default_conf = {"ransac_th": 2.0, "options": {}}
|
||||
|
||||
required_data_keys = ["m_kpts0", "m_kpts1"]
|
||||
|
||||
def _init(self, conf):
|
||||
pass
|
||||
|
||||
def _forward(self, data):
|
||||
pts0, pts1 = data["m_kpts0"], data["m_kpts1"]
|
||||
M, info = poselib.estimate_homography(
|
||||
pts0.numpy(),
|
||||
pts1.numpy(),
|
||||
{
|
||||
"max_reproj_error": self.conf.ransac_th,
|
||||
**OmegaConf.to_container(self.conf.options),
|
||||
},
|
||||
)
|
||||
success = M is not None
|
||||
if not success:
|
||||
M = torch.eye(3, device=pts0.device, dtype=pts0.dtype)
|
||||
inl = torch.zeros_like(pts0[:, 0]).bool()
|
||||
else:
|
||||
M = torch.tensor(M).to(pts0)
|
||||
inl = torch.tensor(info["inliers"]).bool().to(pts0.device)
|
||||
|
||||
estimation = {
|
||||
"success": success,
|
||||
"M_0to1": M,
|
||||
"inliers": inl,
|
||||
}
|
||||
|
||||
return estimation
|
|
@ -0,0 +1,64 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from ...geometry.wrappers import Pose
|
||||
from ...geometry.utils import from_homogeneous
|
||||
|
||||
from ..base_estimator import BaseEstimator
|
||||
|
||||
|
||||
class OpenCVRelativePoseEstimator(BaseEstimator):
|
||||
default_conf = {
|
||||
"ransac_th": 0.5,
|
||||
"options": {"confidence": 0.99999, "method": "ransac"},
|
||||
}
|
||||
|
||||
required_data_keys = ["m_kpts0", "m_kpts1", "camera0", "camera1"]
|
||||
|
||||
def _init(self, conf):
|
||||
self.solver = {"ransac": cv2.RANSAC, "usac_magsac": cv2.USAC_MAGSAC}[
|
||||
self.conf.options.method
|
||||
]
|
||||
|
||||
def _forward(self, data):
|
||||
kpts0, kpts1 = data["m_kpts0"], data["m_kpts1"]
|
||||
camera0 = data["camera0"]
|
||||
camera1 = data["camera1"]
|
||||
M, inl = None, torch.zeros_like(kpts0[:, 0]).bool()
|
||||
|
||||
if len(kpts0) >= 5:
|
||||
f_mean = torch.cat([camera0.f, camera1.f]).mean().item()
|
||||
norm_thresh = self.conf.ransac_th / f_mean
|
||||
|
||||
pts0 = from_homogeneous(camera0.image2cam(kpts0)).cpu().detach().numpy()
|
||||
pts1 = from_homogeneous(camera1.image2cam(kpts1)).cpu().detach().numpy()
|
||||
|
||||
E, mask = cv2.findEssentialMat(
|
||||
pts0,
|
||||
pts1,
|
||||
np.eye(3),
|
||||
threshold=norm_thresh,
|
||||
prob=self.conf.options.confidence,
|
||||
method=self.solver,
|
||||
)
|
||||
|
||||
if E is not None:
|
||||
best_num_inliers = 0
|
||||
for _E in np.split(E, len(E) / 3):
|
||||
n, R, t, _ = cv2.recoverPose(
|
||||
_E, pts0, pts1, np.eye(3), 1e9, mask=mask
|
||||
)
|
||||
if n > best_num_inliers:
|
||||
best_num_inliers = n
|
||||
inl = torch.tensor(mask.ravel() > 0)
|
||||
M = Pose.from_Rt(
|
||||
torch.tensor(R).to(kpts0), torch.tensor(t[:, 0]).to(kpts0)
|
||||
)
|
||||
|
||||
estimation = {
|
||||
"success": M is not None,
|
||||
"M_0to1": M if M is not None else Pose.from_4x4mat(torch.eye(4).to(kpts0)),
|
||||
"inliers": inl.to(device=kpts0.device),
|
||||
}
|
||||
|
||||
return estimation
|
|
@ -0,0 +1,44 @@
|
|||
import poselib
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
from ...geometry.wrappers import Pose
|
||||
|
||||
from ..base_estimator import BaseEstimator
|
||||
|
||||
|
||||
class PoseLibRelativePoseEstimator(BaseEstimator):
|
||||
default_conf = {"ransac_th": 2.0, "options": {}}
|
||||
|
||||
required_data_keys = ["m_kpts0", "m_kpts1", "camera0", "camera1"]
|
||||
|
||||
def _init(self, conf):
|
||||
pass
|
||||
|
||||
def _forward(self, data):
|
||||
pts0, pts1 = data["m_kpts0"], data["m_kpts1"]
|
||||
camera0 = data["camera0"]
|
||||
camera1 = data["camera1"]
|
||||
M, info = poselib.estimate_relative_pose(
|
||||
pts0.numpy(),
|
||||
pts1.numpy(),
|
||||
camera0.to_cameradict(),
|
||||
camera1.to_cameradict(),
|
||||
{
|
||||
"max_epipolar_error": self.conf.ransac_th,
|
||||
**OmegaConf.to_container(self.conf.options),
|
||||
},
|
||||
)
|
||||
success = M is not None
|
||||
if success:
|
||||
M = Pose.from_Rt(torch.tensor(M.R), torch.tensor(M.t)).to(pts0)
|
||||
else:
|
||||
M = Pose.from_4x4mat(torch.eye(4)).to(pts0)
|
||||
|
||||
estimation = {
|
||||
"success": success,
|
||||
"M_0to1": M,
|
||||
"inliers": torch.tensor(info.pop("inliers")).to(pts0),
|
||||
**info,
|
||||
}
|
||||
|
||||
return estimation
|
|
@ -0,0 +1,52 @@
|
|||
import pycolmap
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
from ...geometry.wrappers import Pose
|
||||
|
||||
from ..base_estimator import BaseEstimator
|
||||
|
||||
|
||||
class PycolmapTwoViewEstimator(BaseEstimator):
|
||||
default_conf = {
|
||||
"ransac_th": 4.0,
|
||||
"options": {**pycolmap.TwoViewGeometryOptions().todict()},
|
||||
}
|
||||
|
||||
required_data_keys = ["m_kpts0", "m_kpts1", "camera0", "camera1"]
|
||||
|
||||
def _init(self, conf):
|
||||
opts = OmegaConf.to_container(conf.options)
|
||||
self.options = pycolmap.TwoViewGeometryOptions(opts)
|
||||
self.options.ransac.max_error = conf.ransac_th
|
||||
|
||||
def _forward(self, data):
|
||||
pts0, pts1 = data["m_kpts0"], data["m_kpts1"]
|
||||
camera0 = data["camera0"]
|
||||
camera1 = data["camera1"]
|
||||
info = pycolmap.two_view_geometry_estimation(
|
||||
pts0.numpy(),
|
||||
pts1.numpy(),
|
||||
camera0.to_cameradict(),
|
||||
camera1.to_cameradict(),
|
||||
self.options,
|
||||
)
|
||||
success = info["success"]
|
||||
if success:
|
||||
R = pycolmap.qvec_to_rotmat(info["qvec"])
|
||||
t = info["tvec"]
|
||||
M = Pose.from_Rt(torch.tensor(R), torch.tensor(t)).to(pts0)
|
||||
inl = torch.tensor(info.pop("inliers")).to(pts0)
|
||||
else:
|
||||
M = Pose.from_4x4mat(torch.eye(4)).to(pts0)
|
||||
inl = torch.zeros_like(pts0[:, 0]).bool()
|
||||
|
||||
estimation = {
|
||||
"success": success,
|
||||
"M_0to1": M,
|
||||
"inliers": inl,
|
||||
"type": str(
|
||||
info.get("configuration_type", pycolmap.TwoViewGeometry.UNDEFINED)
|
||||
),
|
||||
}
|
||||
|
||||
return estimation
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue