Compare commits
10 Commits
f7b587e881
...
d40e423a42
Author | SHA1 | Date |
---|---|---|
Jeba Kolega | d40e423a42 | |
Rémi Pautrat | 4a8283517f | |
Paul-Edouard Sarlin | ab6e6eef8b | |
Alexander Veicht | e0104fd65b | |
Paul-Edouard Sarlin | 43cf81aa2f | |
Philipp Lindenberger | 0c75e76fd6 | |
Iago Suárez | aa7727675e | |
Philipp Lindenberger | 692c72f94c | |
Philipp Lindenberger | 398c4b8c21 | |
Philipp Lindenberger | 22154a60bc |
|
@ -0,0 +1,30 @@
|
||||||
|
name: Python Tests
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
types: [ assigned, opened, synchronize, reopened ]
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
name: Run Python Tests
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
cache: 'pip'
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get remove libunwind-14-dev || true
|
||||||
|
sudo apt-get install -y libceres-dev libeigen3-dev
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install pytest pytest-cov
|
||||||
|
python -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
python -m pip install -e .[dev]
|
||||||
|
python -m pip install -e .[extra]
|
||||||
|
- name: Test with pytest
|
||||||
|
run: |
|
||||||
|
set -o pipefail
|
||||||
|
pytest --junitxml=pytest.xml --cov=gluefactory tests/
|
31
README.md
31
README.md
|
@ -25,7 +25,7 @@ python3 -m pip install -e .[extra]
|
||||||
All models and datasets in gluefactory have auto-downloaders, so you can get started right away!
|
All models and datasets in gluefactory have auto-downloaders, so you can get started right away!
|
||||||
|
|
||||||
## License
|
## License
|
||||||
The code and trained models in Glue Factory are released with an Apache-2.0 license. This includes LightGlue trained with an [open version of SuperPoint](https://github.com/rpautrat/SuperPoint). Third-party models that are not compatible with this license, such as SuperPoint (original) and SuperGlue, are provided in `gluefactory_nonfree`, where each model might follow its own, restrictive license.
|
The code and trained models in Glue Factory are released with an Apache-2.0 license. This includes LightGlue and an [open version of SuperPoint](https://github.com/rpautrat/SuperPoint). Third-party models that are not compatible with this license, such as SuperPoint (original) and SuperGlue, are provided in `gluefactory_nonfree`, where each model might follow its own, restrictive license.
|
||||||
|
|
||||||
## Evaluation
|
## Evaluation
|
||||||
|
|
||||||
|
@ -66,8 +66,8 @@ Here are the results as Area Under the Curve (AUC) of the homography error at 1
|
||||||
|
|
||||||
| Methods | DLT | [OpenCV](../gluefactory/robust_estimators/homography/opencv.py) | [PoseLib](../gluefactory/robust_estimators/homography/poselib.py) |
|
| Methods | DLT | [OpenCV](../gluefactory/robust_estimators/homography/opencv.py) | [PoseLib](../gluefactory/robust_estimators/homography/poselib.py) |
|
||||||
| ------------------------------------------------------------ | ------------------ | ------------------ | ------------------ |
|
| ------------------------------------------------------------ | ------------------ | ------------------ | ------------------ |
|
||||||
| [SuperPoint + SuperGlue](../gluefactory/configs/superpoint+superglue.yaml) | 32.1 / 65.0 / 75.7 | 32.9 / 55.7 / 68.0 | 37.0 / 68.2 / 78.7 |
|
| [SuperPoint + SuperGlue](gluefactory/configs/superpoint+superglue-official.yaml) | 32.1 / 65.0 / 75.7 | 32.9 / 55.7 / 68.0 | 37.0 / 68.2 / 78.7 |
|
||||||
| [SuperPoint + LightGlue](../gluefactory/configs/superpoint+lightglue.yaml) | 35.1 / 67.2 / 77.6 | 34.2 / 57.9 / 69.9 | 37.1 / 67.4 / 77.8 |
|
| [SuperPoint + LightGlue](gluefactory/configs/superpoint+lightglue-official.yaml) | 35.1 / 67.2 / 77.6 | 34.2 / 57.9 / 69.9 | 37.1 / 67.4 / 77.8 |
|
||||||
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
@ -159,9 +159,12 @@ Here are the results as Area Under the Curve (AUC) of the pose error at 5/10/20
|
||||||
|
|
||||||
| Methods | [pycolmap](../gluefactory/robust_estimators/relative_pose/pycolmap.py) | [OpenCV](../gluefactory/robust_estimators/relative_pose/opencv.py) | [PoseLib](../gluefactory/robust_estimators/relative_pose/poselib.py) |
|
| Methods | [pycolmap](../gluefactory/robust_estimators/relative_pose/pycolmap.py) | [OpenCV](../gluefactory/robust_estimators/relative_pose/opencv.py) | [PoseLib](../gluefactory/robust_estimators/relative_pose/poselib.py) |
|
||||||
| ------------------------------------------------------------ | ------------------ | ------------------ | ------------------ |
|
| ------------------------------------------------------------ | ------------------ | ------------------ | ------------------ |
|
||||||
| [SuperPoint + SuperGlue](../gluefactory/configs/superpoint+superglue.yaml) | 54.4 / 70.4 / 82.4 | 48.7 / 65.6 / 79.0 | 64.8 / 77.9 / 87.0 |
|
| [SuperPoint + SuperGlue](gluefactory/configs/superpoint+superglue-official.yaml) | 54.4 / 70.4 / 82.4 | 48.7 / 65.6 / 79.0 | 64.8 / 77.9 / 87.0 |
|
||||||
| [SuperPoint + LightGlue](../gluefactory/configs/superpoint+lightglue.yaml) | 56.7 / 72.4 / 83.7 | 51.0 / 68.1 / 80.7 | 66.8 / 79.3 / 87.9 |
|
| [SuperPoint + LightGlue](gluefactory/configs/superpoint+lightglue-official.yaml) | 56.7 / 72.4 / 83.7 | 51.0 / 68.1 / 80.7 | 66.8 / 79.3 / 87.9 |
|
||||||
| [SuperPoint + GlueStick](../gluefactory/configs/superpoint+lsd+gluestick.yaml) | 53.2 / 69.8 / 81.9 | 46.3 / 64.2 / 78.1 | 64.4 / 77.5 / 86.5 |
|
| [SIFT (2K) + LightGlue](gluefactory/configs/sift+lightglue-official.yaml) | ? / ? / ? | 43.5 / 61.5 / 75.9 | 60.4 / 74.3 / 84.5 |
|
||||||
|
| [SIFT (4K) + LightGlue](gluefactory/configs/sift+lightglue-official.yaml) | ? / ? / ? | 49.9 / 67.3 / 80.3 | 65.9 / 78.6 / 87.4 |
|
||||||
|
| [ALIKED + LightGlue](gluefactory/configs/aliked+lightglue-official.yaml) | ? / ? / ? | 51.5 / 68.1 / 80.4 | 66.3 / 78.7 / 87.5 |
|
||||||
|
| [SuperPoint + GlueStick](gluefactory/configs/superpoint+lsd+gluestick.yaml) | 53.2 / 69.8 / 81.9 | 46.3 / 64.2 / 78.1 | 64.4 / 77.5 / 86.5 |
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
@ -223,18 +226,18 @@ All training commands automatically download the datasets.
|
||||||
<details>
|
<details>
|
||||||
<summary>[Training LightGlue]</summary>
|
<summary>[Training LightGlue]</summary>
|
||||||
|
|
||||||
We show how to train LightGlue with [SuperPoint open](https://github.com/rpautrat/SuperPoint).
|
We show how to train LightGlue with [SuperPoint](https://github.com/magicleap/SuperPointPretrainedNetwork).
|
||||||
We first pre-train LightGlue on the homography dataset:
|
We first pre-train LightGlue on the homography dataset:
|
||||||
```bash
|
```bash
|
||||||
python -m gluefactory.train sp+lg_homography \ # experiment name
|
python -m gluefactory.train sp+lg_homography \ # experiment name
|
||||||
--conf gluefactory/configs/superpoint-open+lightglue_homography.yaml
|
--conf gluefactory/configs/superpoint+lightglue_homography.yaml
|
||||||
```
|
```
|
||||||
Feel free to use any other experiment name. By default the checkpoints are written to `outputs/training/`. The default batch size of 128 corresponds to the results reported in the paper and requires 2x 3090 GPUs with 24GB of VRAM each as well as PyTorch >= 2.0 (FlashAttention).
|
Feel free to use any other experiment name. By default the checkpoints are written to `outputs/training/`. The default batch size of 128 corresponds to the results reported in the paper and requires 2x 3090 GPUs with 24GB of VRAM each as well as PyTorch >= 2.0 (FlashAttention).
|
||||||
Configurations are managed by [OmegaConf](https://omegaconf.readthedocs.io/) so any entry can be overridden from the command line.
|
Configurations are managed by [OmegaConf](https://omegaconf.readthedocs.io/) so any entry can be overridden from the command line.
|
||||||
If you have PyTorch < 2.0 or weaker GPUs, you may thus need to reduce the batch size via:
|
If you have PyTorch < 2.0 or weaker GPUs, you may thus need to reduce the batch size via:
|
||||||
```bash
|
```bash
|
||||||
python -m gluefactory.train sp+lg_homography \
|
python -m gluefactory.train sp+lg_homography \
|
||||||
--conf gluefactory/configs/superpoint-open+lightglue_homography.yaml \
|
--conf gluefactory/configs/superpoint+lightglue_homography.yaml \
|
||||||
data.batch_size=32 # for 1x 1080 GPU
|
data.batch_size=32 # for 1x 1080 GPU
|
||||||
```
|
```
|
||||||
Be aware that this can impact the overall performance. You might need to adjust the learning rate accordingly.
|
Be aware that this can impact the overall performance. You might need to adjust the learning rate accordingly.
|
||||||
|
@ -242,17 +245,17 @@ Be aware that this can impact the overall performance. You might need to adjust
|
||||||
We then fine-tune the model on the MegaDepth dataset:
|
We then fine-tune the model on the MegaDepth dataset:
|
||||||
```bash
|
```bash
|
||||||
python -m gluefactory.train sp+lg_megadepth \
|
python -m gluefactory.train sp+lg_megadepth \
|
||||||
--conf gluefactory/configs/superpoint-open+lightglue_megadepth.yaml \
|
--conf gluefactory/configs/superpoint+lightglue_megadepth.yaml \
|
||||||
train.load_experiment=sp+lg_homography
|
train.load_experiment=sp+lg_homography
|
||||||
```
|
```
|
||||||
|
|
||||||
Here the default batch size is 32. To speed up training on MegaDepth, we suggest to cache the local features before training (requires around 150 GB of disk space):
|
Here the default batch size is 32. To speed up training on MegaDepth, we suggest to cache the local features before training (requires around 150 GB of disk space):
|
||||||
```bash
|
```bash
|
||||||
# extract features
|
# extract features
|
||||||
python -m gluefactory.scripts.export_megadepth --method sp_open --num_workers 8
|
python -m gluefactory.scripts.export_megadepth --method sp --num_workers 8
|
||||||
# run training with cached features
|
# run training with cached features
|
||||||
python -m gluefactory.train sp+lg_megadepth \
|
python -m gluefactory.train sp+lg_megadepth \
|
||||||
--conf gluefactory/configs/superpoint-open+lightglue_megadepth.yaml \
|
--conf gluefactory/configs/superpoint+lightglue_megadepth.yaml \
|
||||||
train.load_experiment=sp+lg_homography \
|
train.load_experiment=sp+lg_homography \
|
||||||
data.load_features.do=True
|
data.load_features.do=True
|
||||||
```
|
```
|
||||||
|
@ -297,10 +300,10 @@ Using the following local feature extractors:
|
||||||
| Model | LightGlue config |
|
| Model | LightGlue config |
|
||||||
| --------- | --------- |
|
| --------- | --------- |
|
||||||
| [SuperPoint (open)](https://github.com/rpautrat/SuperPoint) | `superpoint-open+lightglue_{homography,megadepth}.yaml` |
|
| [SuperPoint (open)](https://github.com/rpautrat/SuperPoint) | `superpoint-open+lightglue_{homography,megadepth}.yaml` |
|
||||||
| [SuperPoint (official)](https://github.com/magicleap/SuperPointPretrainedNetwork) | ❌ TODO |
|
| [SuperPoint (official)](https://github.com/magicleap/SuperPointPretrainedNetwork) | `superpoint+lightglue_{homography,megadepth}.yaml` |
|
||||||
| SIFT (via [pycolmap](https://github.com/colmap/pycolmap)) | `sift+lightglue_{homography,megadepth}.yaml` |
|
| SIFT (via [pycolmap](https://github.com/colmap/pycolmap)) | `sift+lightglue_{homography,megadepth}.yaml` |
|
||||||
| [ALIKED](https://github.com/Shiaoming/ALIKED) | `aliked+lightglue_{homography,megadepth}.yaml` |
|
| [ALIKED](https://github.com/Shiaoming/ALIKED) | `aliked+lightglue_{homography,megadepth}.yaml` |
|
||||||
| [DISK](https://github.com/cvlab-epfl/disk) | ❌ TODO |
|
| [DISK](https://github.com/cvlab-epfl/disk) | `disk+lightglue_{homography,megadepth}.yaml` |
|
||||||
| Key.Net + HardNet | ❌ TODO |
|
| Key.Net + HardNet | ❌ TODO |
|
||||||
|
|
||||||
## Coming soon
|
## Coming soon
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 519 KiB |
Binary file not shown.
After Width: | Height: | Size: 580 KiB |
|
@ -0,0 +1,28 @@
|
||||||
|
model:
|
||||||
|
name: two_view_pipeline
|
||||||
|
extractor:
|
||||||
|
name: extractors.aliked
|
||||||
|
max_num_keypoints: 2048
|
||||||
|
detection_threshold: 0.0
|
||||||
|
matcher:
|
||||||
|
name: matchers.lightglue_pretrained
|
||||||
|
features: aliked
|
||||||
|
depth_confidence: -1
|
||||||
|
width_confidence: -1
|
||||||
|
filter_threshold: 0.1
|
||||||
|
benchmarks:
|
||||||
|
megadepth1500:
|
||||||
|
data:
|
||||||
|
preprocessing:
|
||||||
|
side: long
|
||||||
|
resize: 1600
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
||||||
|
hpatches:
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
||||||
|
model:
|
||||||
|
extractor:
|
||||||
|
max_num_keypoints: 1024 # overwrite config above
|
|
@ -0,0 +1,47 @@
|
||||||
|
data:
|
||||||
|
name: homographies
|
||||||
|
data_dir: revisitop1m
|
||||||
|
train_size: 150000
|
||||||
|
val_size: 2000
|
||||||
|
batch_size: 128
|
||||||
|
num_workers: 14
|
||||||
|
homography:
|
||||||
|
difficulty: 0.7
|
||||||
|
max_angle: 45
|
||||||
|
photometric:
|
||||||
|
name: lg
|
||||||
|
model:
|
||||||
|
name: two_view_pipeline
|
||||||
|
extractor:
|
||||||
|
name: extractors.disk_kornia
|
||||||
|
max_num_keypoints: 512
|
||||||
|
force_num_keypoints: True
|
||||||
|
detection_threshold: 0.0
|
||||||
|
trainable: False
|
||||||
|
ground_truth:
|
||||||
|
name: matchers.homography_matcher
|
||||||
|
th_positive: 3
|
||||||
|
th_negative: 3
|
||||||
|
matcher:
|
||||||
|
name: matchers.lightglue
|
||||||
|
filter_threshold: 0.1
|
||||||
|
input_dim: 128
|
||||||
|
flash: false
|
||||||
|
checkpointed: true
|
||||||
|
train:
|
||||||
|
seed: 0
|
||||||
|
epochs: 40
|
||||||
|
log_every_iter: 100
|
||||||
|
eval_every_iter: 500
|
||||||
|
lr: 1e-4
|
||||||
|
lr_schedule:
|
||||||
|
start: 20
|
||||||
|
type: exp
|
||||||
|
on_epoch: true
|
||||||
|
exp_div_10: 10
|
||||||
|
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
|
||||||
|
benchmarks:
|
||||||
|
hpatches:
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
|
@ -0,0 +1,70 @@
|
||||||
|
data:
|
||||||
|
name: megadepth
|
||||||
|
preprocessing:
|
||||||
|
resize: 1024
|
||||||
|
side: long
|
||||||
|
square_pad: True
|
||||||
|
train_split: train_scenes_clean.txt
|
||||||
|
train_num_per_scene: 300
|
||||||
|
val_split: valid_scenes_clean.txt
|
||||||
|
val_pairs: valid_pairs.txt
|
||||||
|
min_overlap: 0.1
|
||||||
|
max_overlap: 0.7
|
||||||
|
num_overlap_bins: 3
|
||||||
|
read_depth: true
|
||||||
|
read_image: true
|
||||||
|
batch_size: 32
|
||||||
|
num_workers: 14
|
||||||
|
load_features:
|
||||||
|
do: false # enable this if you have cached predictions
|
||||||
|
path: exports/megadepth-undist-depth-r1024_DISK-k2048-nms5/{scene}.h5
|
||||||
|
padding_length: 2048
|
||||||
|
padding_fn: pad_local_features
|
||||||
|
model:
|
||||||
|
name: two_view_pipeline
|
||||||
|
extractor:
|
||||||
|
name: extractors.disk_kornia
|
||||||
|
max_num_keypoints: 512
|
||||||
|
force_num_keypoints: True
|
||||||
|
detection_threshold: 0.0
|
||||||
|
trainable: False
|
||||||
|
ground_truth:
|
||||||
|
name: matchers.homography_matcher
|
||||||
|
th_positive: 3
|
||||||
|
th_negative: 3
|
||||||
|
matcher:
|
||||||
|
name: matchers.lightglue
|
||||||
|
filter_threshold: 0.1
|
||||||
|
input_dim: 128
|
||||||
|
flash: false
|
||||||
|
checkpointed: true
|
||||||
|
allow_no_extract: True
|
||||||
|
train:
|
||||||
|
seed: 0
|
||||||
|
epochs: 50
|
||||||
|
log_every_iter: 100
|
||||||
|
eval_every_iter: 1000
|
||||||
|
lr: 1e-4
|
||||||
|
lr_schedule:
|
||||||
|
start: 30
|
||||||
|
type: exp
|
||||||
|
on_epoch: true
|
||||||
|
exp_div_10: 10
|
||||||
|
dataset_callback_fn: sample_new_items
|
||||||
|
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
|
||||||
|
benchmarks:
|
||||||
|
megadepth1500:
|
||||||
|
data:
|
||||||
|
preprocessing:
|
||||||
|
side: long
|
||||||
|
resize: 1024
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
||||||
|
hpatches:
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
||||||
|
model:
|
||||||
|
extractor:
|
||||||
|
max_num_keypoints: 1024
|
|
@ -0,0 +1,69 @@
|
||||||
|
data:
|
||||||
|
name: drone2sat
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 24
|
||||||
|
val_size: 500
|
||||||
|
train_size: -1
|
||||||
|
photometric:
|
||||||
|
name: lg
|
||||||
|
geo_dataset:
|
||||||
|
uav_dataset_dir: /mnt/drive/uav_dataset
|
||||||
|
satellite_dataset_dir: /mnt/drive/tiles
|
||||||
|
misslabeled_images_path: /mnt/drive/misslabeled.txt
|
||||||
|
sat_zoom_level: 17
|
||||||
|
uav_patch_width: 400
|
||||||
|
uav_patch_height: 400
|
||||||
|
sat_patch_width: 400
|
||||||
|
sat_patch_height: 400
|
||||||
|
test_from_train_ratio: 0.0
|
||||||
|
transform_mean:
|
||||||
|
- 0.485
|
||||||
|
- 0.456
|
||||||
|
- 0.406
|
||||||
|
transform_std:
|
||||||
|
- 0.229
|
||||||
|
- 0.224
|
||||||
|
- 0.225
|
||||||
|
sat_availaible_years:
|
||||||
|
- "2023"
|
||||||
|
- "2021"
|
||||||
|
- "2019"
|
||||||
|
- "2016"
|
||||||
|
max_rotation_angle: 0
|
||||||
|
uav_image_scale: 2.0
|
||||||
|
use_heatmap: false
|
||||||
|
|
||||||
|
model:
|
||||||
|
name: two_view_pipeline
|
||||||
|
extractor:
|
||||||
|
name: gluefactory_nonfree.superpoint
|
||||||
|
max_num_keypoints: 2048
|
||||||
|
detection_threshold: 0.0
|
||||||
|
nms_radius: 3
|
||||||
|
ground_truth:
|
||||||
|
name: matchers.homography_matcher
|
||||||
|
th_positive: 3
|
||||||
|
th_negative: 3
|
||||||
|
matcher:
|
||||||
|
name: matchers.lightglue
|
||||||
|
features: superpoint
|
||||||
|
depth_confidence: -1
|
||||||
|
width_confidence: -1
|
||||||
|
filter_threshold: 0.1
|
||||||
|
flash: true
|
||||||
|
train:
|
||||||
|
seed: 0
|
||||||
|
epochs: 40
|
||||||
|
log_every_iter: 100
|
||||||
|
eval_every_iter: 1000 # 350000 / 4
|
||||||
|
lr: 1e-3
|
||||||
|
lr_schedule:
|
||||||
|
start: 20
|
||||||
|
type: exp
|
||||||
|
on_epoch: true
|
||||||
|
exp_div_10: 10
|
||||||
|
benchmarks:
|
||||||
|
hpatches:
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
|
@ -0,0 +1,48 @@
|
||||||
|
data:
|
||||||
|
name: satellites
|
||||||
|
data_dir: /mnt/drive/tiles
|
||||||
|
metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt
|
||||||
|
train_size: null
|
||||||
|
val_size: null
|
||||||
|
batch_size: 128
|
||||||
|
num_workers: 24
|
||||||
|
homography:
|
||||||
|
difficulty: 0.9
|
||||||
|
max_angle: 359
|
||||||
|
photometric:
|
||||||
|
name: lg
|
||||||
|
model:
|
||||||
|
name: two_view_pipeline
|
||||||
|
extractor:
|
||||||
|
name: gluefactory_nonfree.superpoint
|
||||||
|
max_num_keypoints: 2048
|
||||||
|
detection_threshold: 0.0
|
||||||
|
force_num_keypoints: True
|
||||||
|
nms_radius: 3
|
||||||
|
trainable: False
|
||||||
|
ground_truth:
|
||||||
|
name: matchers.homography_matcher
|
||||||
|
th_positive: 3
|
||||||
|
th_negative: 3
|
||||||
|
matcher:
|
||||||
|
name: matchers.lightglue
|
||||||
|
filter_threshold: 0.1
|
||||||
|
flash: true
|
||||||
|
checkpointed: true
|
||||||
|
weights: superpoint
|
||||||
|
train:
|
||||||
|
seed: 0
|
||||||
|
epochs: 5
|
||||||
|
log_every_iter: 100
|
||||||
|
eval_every_iter: 500
|
||||||
|
lr: 1e-7
|
||||||
|
lr_schedule:
|
||||||
|
start: 20
|
||||||
|
type: exp
|
||||||
|
on_epoch: true
|
||||||
|
exp_div_10: 10
|
||||||
|
benchmarks:
|
||||||
|
hpatches:
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
|
@ -0,0 +1,60 @@
|
||||||
|
data:
|
||||||
|
name: satellites
|
||||||
|
data_dir: /mnt/drive/tiles
|
||||||
|
metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt
|
||||||
|
train_size: null
|
||||||
|
val_size: null
|
||||||
|
batch_size: 128
|
||||||
|
num_workers: 24
|
||||||
|
homography:
|
||||||
|
difficulty: 0.5
|
||||||
|
max_angle: 359
|
||||||
|
photometric:
|
||||||
|
name: lg
|
||||||
|
model:
|
||||||
|
name: two_view_pipeline
|
||||||
|
extractor:
|
||||||
|
name: gluefactory_nonfree.superpoint
|
||||||
|
max_num_keypoints: 2048
|
||||||
|
detection_threshold: 0.0
|
||||||
|
nms_radius: 3
|
||||||
|
#extractor:
|
||||||
|
# name: gluefactory_nonfree.superpoint
|
||||||
|
# max_num_keypoints: 512
|
||||||
|
# force_num_keypoints: True
|
||||||
|
# detection_threshold: 0.0
|
||||||
|
# nms_radius: 3
|
||||||
|
# trainable: False
|
||||||
|
ground_truth:
|
||||||
|
name: matchers.homography_matcher
|
||||||
|
th_positive: 3
|
||||||
|
th_negative: 3
|
||||||
|
matcher:
|
||||||
|
name: matchers.lightglue
|
||||||
|
features: superpoint
|
||||||
|
depth_confidence: -1
|
||||||
|
width_confidence: -1
|
||||||
|
filter_threshold: 0.1
|
||||||
|
flash: true
|
||||||
|
#matcher:
|
||||||
|
# name: matchers.lightglue_pretrained
|
||||||
|
# features: superpoint
|
||||||
|
# depth_confidence: -1
|
||||||
|
# width_confidence: -1
|
||||||
|
# filter_threshold: 0.1
|
||||||
|
train:
|
||||||
|
seed: 0
|
||||||
|
epochs: 5
|
||||||
|
log_every_iter: 100
|
||||||
|
eval_every_iter: 500
|
||||||
|
lr: 1e-4
|
||||||
|
lr_schedule:
|
||||||
|
start: 20
|
||||||
|
type: exp
|
||||||
|
on_epoch: true
|
||||||
|
exp_div_10: 10
|
||||||
|
benchmarks:
|
||||||
|
hpatches:
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
|
@ -0,0 +1,47 @@
|
||||||
|
data:
|
||||||
|
name: satellites
|
||||||
|
data_dir: /mnt/drive/tiles
|
||||||
|
metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt
|
||||||
|
train_size: null
|
||||||
|
val_size: null
|
||||||
|
batch_size: 128
|
||||||
|
num_workers: 24
|
||||||
|
homography:
|
||||||
|
difficulty: 0.5
|
||||||
|
max_angle: 180
|
||||||
|
photometric:
|
||||||
|
name: lg
|
||||||
|
model:
|
||||||
|
name: two_view_pipeline
|
||||||
|
extractor:
|
||||||
|
name: gluefactory_nonfree.superpoint
|
||||||
|
max_num_keypoints: 2048
|
||||||
|
detection_threshold: 0.0
|
||||||
|
nms_radius: 3
|
||||||
|
ground_truth:
|
||||||
|
name: matchers.homography_matcher
|
||||||
|
th_positive: 3
|
||||||
|
th_negative: 3
|
||||||
|
matcher:
|
||||||
|
name: matchers.lightglue
|
||||||
|
features: superpoint
|
||||||
|
depth_confidence: -1
|
||||||
|
width_confidence: -1
|
||||||
|
filter_threshold: 0.1
|
||||||
|
flash: true
|
||||||
|
train:
|
||||||
|
seed: 0
|
||||||
|
epochs: 40
|
||||||
|
log_every_iter: 100
|
||||||
|
eval_every_iter: 500
|
||||||
|
lr: 1e-4
|
||||||
|
lr_schedule:
|
||||||
|
start: 20
|
||||||
|
type: exp
|
||||||
|
on_epoch: true
|
||||||
|
exp_div_10: 10
|
||||||
|
benchmarks:
|
||||||
|
hpatches:
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
|
@ -0,0 +1,28 @@
|
||||||
|
model:
|
||||||
|
name: two_view_pipeline
|
||||||
|
extractor:
|
||||||
|
name: extractors.sift
|
||||||
|
backend: pycolmap_cuda
|
||||||
|
max_num_keypoints: 4096
|
||||||
|
matcher:
|
||||||
|
name: matchers.lightglue_pretrained
|
||||||
|
features: sift
|
||||||
|
depth_confidence: -1
|
||||||
|
width_confidence: -1
|
||||||
|
filter_threshold: 0.1
|
||||||
|
benchmarks:
|
||||||
|
megadepth1500:
|
||||||
|
data:
|
||||||
|
preprocessing:
|
||||||
|
side: long
|
||||||
|
resize: 1600
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
||||||
|
hpatches:
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
||||||
|
model:
|
||||||
|
extractor:
|
||||||
|
max_num_keypoints: 1024 # overwrite config above
|
|
@ -14,10 +14,10 @@ model:
|
||||||
name: two_view_pipeline
|
name: two_view_pipeline
|
||||||
extractor:
|
extractor:
|
||||||
name: extractors.sift
|
name: extractors.sift
|
||||||
detector: pycolmap_cuda
|
backend: pycolmap_cuda
|
||||||
max_num_keypoints: 1024
|
max_num_keypoints: 1024
|
||||||
force_num_keypoints: True
|
force_num_keypoints: True
|
||||||
detection_threshold: 0.0001
|
nms_radius: 3
|
||||||
trainable: False
|
trainable: False
|
||||||
ground_truth:
|
ground_truth:
|
||||||
name: matchers.homography_matcher
|
name: matchers.homography_matcher
|
||||||
|
@ -46,3 +46,6 @@ benchmarks:
|
||||||
eval:
|
eval:
|
||||||
estimator: opencv
|
estimator: opencv
|
||||||
ransac_th: 0.5
|
ransac_th: 0.5
|
||||||
|
model:
|
||||||
|
extractor:
|
||||||
|
nms_radius: 0
|
||||||
|
|
|
@ -25,10 +25,10 @@ model:
|
||||||
name: two_view_pipeline
|
name: two_view_pipeline
|
||||||
extractor:
|
extractor:
|
||||||
name: extractors.sift
|
name: extractors.sift
|
||||||
detector: pycolmap_cuda
|
backend: pycolmap_cuda
|
||||||
max_num_keypoints: 2048
|
max_num_keypoints: 2048
|
||||||
force_num_keypoints: True
|
force_num_keypoints: True
|
||||||
detection_threshold: 0.0001
|
nms_radius: 3
|
||||||
trainable: False
|
trainable: False
|
||||||
matcher:
|
matcher:
|
||||||
name: matchers.lightglue
|
name: matchers.lightglue
|
||||||
|
@ -62,6 +62,9 @@ benchmarks:
|
||||||
preprocessing:
|
preprocessing:
|
||||||
side: long
|
side: long
|
||||||
resize: 1600
|
resize: 1600
|
||||||
|
model:
|
||||||
|
extractor:
|
||||||
|
nms_radius: 0
|
||||||
eval:
|
eval:
|
||||||
estimator: opencv
|
estimator: opencv
|
||||||
ransac_th: 0.5
|
ransac_th: 0.5
|
||||||
|
@ -72,3 +75,4 @@ benchmarks:
|
||||||
model:
|
model:
|
||||||
extractor:
|
extractor:
|
||||||
max_num_keypoints: 1024
|
max_num_keypoints: 1024
|
||||||
|
nms_radius: 0
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
data:
|
||||||
|
name: homographies
|
||||||
|
data_dir: revisitop1m
|
||||||
|
train_size: 150000
|
||||||
|
val_size: 2000
|
||||||
|
batch_size: 128
|
||||||
|
num_workers: 14
|
||||||
|
homography:
|
||||||
|
difficulty: 0.7
|
||||||
|
max_angle: 45
|
||||||
|
photometric:
|
||||||
|
name: lg
|
||||||
|
model:
|
||||||
|
name: two_view_pipeline
|
||||||
|
extractor:
|
||||||
|
name: gluefactory_nonfree.superpoint
|
||||||
|
max_num_keypoints: 512
|
||||||
|
force_num_keypoints: True
|
||||||
|
detection_threshold: 0.0
|
||||||
|
nms_radius: 3
|
||||||
|
trainable: False
|
||||||
|
ground_truth:
|
||||||
|
name: matchers.homography_matcher
|
||||||
|
th_positive: 3
|
||||||
|
th_negative: 3
|
||||||
|
matcher:
|
||||||
|
name: matchers.lightglue
|
||||||
|
filter_threshold: 0.1
|
||||||
|
flash: false
|
||||||
|
checkpointed: true
|
||||||
|
train:
|
||||||
|
seed: 0
|
||||||
|
epochs: 40
|
||||||
|
log_every_iter: 100
|
||||||
|
eval_every_iter: 500
|
||||||
|
profile: true
|
||||||
|
lr: 1e-4
|
||||||
|
lr_schedule:
|
||||||
|
start: 20
|
||||||
|
type: exp
|
||||||
|
on_epoch: true
|
||||||
|
exp_div_10: 10
|
||||||
|
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
|
||||||
|
benchmarks:
|
||||||
|
hpatches:
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
|
@ -0,0 +1,71 @@
|
||||||
|
data:
|
||||||
|
name: megadepth
|
||||||
|
preprocessing:
|
||||||
|
resize: 1024
|
||||||
|
side: long
|
||||||
|
square_pad: True
|
||||||
|
train_split: train_scenes_clean.txt
|
||||||
|
train_num_per_scene: 300
|
||||||
|
val_split: valid_scenes_clean.txt
|
||||||
|
val_pairs: valid_pairs.txt
|
||||||
|
min_overlap: 0.1
|
||||||
|
max_overlap: 0.7
|
||||||
|
num_overlap_bins: 3
|
||||||
|
read_depth: true
|
||||||
|
read_image: true
|
||||||
|
batch_size: 32
|
||||||
|
num_workers: 14
|
||||||
|
load_features:
|
||||||
|
do: false # enable this if you have cached predictions
|
||||||
|
path: exports/megadepth-undist-depth-r1024_SP-k2048-nms3/{scene}.h5
|
||||||
|
padding_length: 2048
|
||||||
|
padding_fn: pad_local_features
|
||||||
|
model:
|
||||||
|
name: two_view_pipeline
|
||||||
|
extractor:
|
||||||
|
name: gluefactory_nonfree.superpoint
|
||||||
|
max_num_keypoints: 2048
|
||||||
|
force_num_keypoints: True
|
||||||
|
detection_threshold: 0.0
|
||||||
|
nms_radius: 3
|
||||||
|
trainable: False
|
||||||
|
matcher:
|
||||||
|
name: matchers.lightglue
|
||||||
|
filter_threshold: 0.1
|
||||||
|
flash: false
|
||||||
|
checkpointed: true
|
||||||
|
ground_truth:
|
||||||
|
name: matchers.depth_matcher
|
||||||
|
th_positive: 3
|
||||||
|
th_negative: 5
|
||||||
|
th_epi: 5
|
||||||
|
allow_no_extract: True
|
||||||
|
train:
|
||||||
|
seed: 0
|
||||||
|
epochs: 50
|
||||||
|
log_every_iter: 100
|
||||||
|
eval_every_iter: 1000
|
||||||
|
lr: 1e-4
|
||||||
|
lr_schedule:
|
||||||
|
start: 30
|
||||||
|
type: exp
|
||||||
|
on_epoch: true
|
||||||
|
exp_div_10: 10
|
||||||
|
dataset_callback_fn: sample_new_items
|
||||||
|
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
|
||||||
|
benchmarks:
|
||||||
|
megadepth1500:
|
||||||
|
data:
|
||||||
|
preprocessing:
|
||||||
|
side: long
|
||||||
|
resize: 1600
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
||||||
|
hpatches:
|
||||||
|
eval:
|
||||||
|
estimator: opencv
|
||||||
|
ransac_th: 0.5
|
||||||
|
model:
|
||||||
|
extractor:
|
||||||
|
max_num_keypoints: 1024
|
|
@ -1,14 +1,14 @@
|
||||||
data:
|
data:
|
||||||
name: homographies
|
name: homographies
|
||||||
homography:
|
homography:
|
||||||
difficulty: 0.5
|
difficulty: 0.7
|
||||||
max_angle: 30
|
max_angle: 45
|
||||||
patch_shape: [640, 480]
|
patch_shape: [640, 480]
|
||||||
photometric:
|
photometric:
|
||||||
p: 0.75
|
p: 0.75
|
||||||
train_size: 900000
|
train_size: 900000
|
||||||
val_size: 1000
|
val_size: 1000
|
||||||
batch_size: 80 # 20 per 10GB of GPU mem (12 for triplet)
|
batch_size: 160 # 20 per 10GB of GPU mem (12 for triplet)
|
||||||
num_workers: 15
|
num_workers: 15
|
||||||
model:
|
model:
|
||||||
name: gluefactory.models.two_view_pipeline
|
name: gluefactory.models.two_view_pipeline
|
||||||
|
@ -70,4 +70,4 @@ train:
|
||||||
n_steps: 4
|
n_steps: 4
|
||||||
submodules: []
|
submodules: []
|
||||||
# clip_grad: 10 # Use only with mixed precision
|
# clip_grad: 10 # Use only with mixed precision
|
||||||
# load_experiment:
|
# load_experiment:
|
|
@ -1,10 +1,15 @@
|
||||||
data:
|
data:
|
||||||
name: gluefactory.datasets.megadepth
|
name: gluefactory.datasets.megadepth
|
||||||
|
train_num_per_scene: 300
|
||||||
|
val_pairs: valid_pairs.txt
|
||||||
views: 2
|
views: 2
|
||||||
|
min_overlap: 0.1
|
||||||
|
max_overlap: 0.7
|
||||||
|
num_overlap_bins: 3
|
||||||
preprocessing:
|
preprocessing:
|
||||||
resize: 640
|
resize: 640
|
||||||
square_pad: True
|
square_pad: True
|
||||||
batch_size: 60
|
batch_size: 160
|
||||||
num_workers: 15
|
num_workers: 15
|
||||||
model:
|
model:
|
||||||
name: gluefactory.models.two_view_pipeline
|
name: gluefactory.models.two_view_pipeline
|
||||||
|
@ -53,9 +58,9 @@ model:
|
||||||
train:
|
train:
|
||||||
seed: 0
|
seed: 0
|
||||||
epochs: 200
|
epochs: 200
|
||||||
log_every_iter: 10
|
log_every_iter: 400
|
||||||
eval_every_iter: 100
|
eval_every_iter: 700
|
||||||
save_every_iter: 500
|
save_every_iter: 1400
|
||||||
lr: 1e-4
|
lr: 1e-4
|
||||||
lr_schedule:
|
lr_schedule:
|
||||||
type: exp # exp or multi_step
|
type: exp # exp or multi_step
|
||||||
|
|
|
@ -168,6 +168,11 @@ class BaseDataset(metaclass=ABCMeta):
|
||||||
sampler = None
|
sampler = None
|
||||||
if shuffle is None:
|
if shuffle is None:
|
||||||
shuffle = split == "train" and self.conf.shuffle_training
|
shuffle = split == "train" and self.conf.shuffle_training
|
||||||
|
shuffle = split == "val"
|
||||||
|
|
||||||
|
shuffle = True
|
||||||
|
|
||||||
|
print("Shuffle", shuffle)
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
|
|
@ -0,0 +1,694 @@
|
||||||
|
"""
|
||||||
|
Simply load images from a folder or nested folders (does not have any split),
|
||||||
|
and apply homographic adaptations to it. Yields an image pair without border
|
||||||
|
artifacts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import tarfile
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import gc
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import omegaconf
|
||||||
|
import torch
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from tqdm import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
from typing import List, Literal
|
||||||
|
from torchvision import transforms
|
||||||
|
import random
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
from ..geometry.homography import (
|
||||||
|
compute_homography,
|
||||||
|
sample_homography_corners,
|
||||||
|
warp_points,
|
||||||
|
)
|
||||||
|
from ..models.cache_loader import CacheLoader, pad_local_features
|
||||||
|
from ..settings import DATA_PATH
|
||||||
|
from ..utils.image import read_image
|
||||||
|
from ..utils.tools import fork_rng
|
||||||
|
from ..visualization.viz2d import plot_image_grid
|
||||||
|
from .augmentations import IdentityAugmentation, augmentations
|
||||||
|
from .base_dataset import BaseDataset
|
||||||
|
from .s_utils import get_random_tiff_patch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class Drone2SatDataset(BaseDataset):
|
||||||
|
default_conf = {
|
||||||
|
# image search
|
||||||
|
"data_dir": "revisitop1m", # the top-level directory
|
||||||
|
"image_dir": "jpg/", # the subdirectory with the images
|
||||||
|
"image_list": "revisitop1m.txt", # optional: list or filename of list
|
||||||
|
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
|
||||||
|
"metadata_dir": None,
|
||||||
|
"uav_dataset_dir": None,
|
||||||
|
"satellite_dataset_dir": None,
|
||||||
|
# splits
|
||||||
|
"train_size": 100,
|
||||||
|
"val_size": 10,
|
||||||
|
"shuffle_seed": 0, # or None to skip
|
||||||
|
# image loading
|
||||||
|
"grayscale": False,
|
||||||
|
"triplet": False,
|
||||||
|
"right_only": False, # image0 is orig (rescaled), image1 is right
|
||||||
|
"reseed": False,
|
||||||
|
"homography": {
|
||||||
|
"difficulty": 0.8,
|
||||||
|
"translation": 1.0,
|
||||||
|
"max_angle": 60,
|
||||||
|
"n_angles": 10,
|
||||||
|
"patch_shape": [640, 480],
|
||||||
|
"min_convexity": 0.05,
|
||||||
|
},
|
||||||
|
"photometric": {
|
||||||
|
"name": "dark",
|
||||||
|
"p": 0.75,
|
||||||
|
# 'difficulty': 1.0, # currently unused
|
||||||
|
},
|
||||||
|
# feature loading
|
||||||
|
"load_features": {
|
||||||
|
"do": False,
|
||||||
|
**CacheLoader.default_conf,
|
||||||
|
"collate": False,
|
||||||
|
"thresh": 0.0,
|
||||||
|
"max_num_keypoints": -1,
|
||||||
|
"force_num_keypoints": False,
|
||||||
|
},
|
||||||
|
# Other geolocaliztion parameters
|
||||||
|
"geo_dataset" :{
|
||||||
|
"uav_dataset_dir": None,
|
||||||
|
"satellite_dataset_dir": None,
|
||||||
|
"misslabeled_images_path": None,
|
||||||
|
"sat_zoom_level": 17,
|
||||||
|
"uav_patch_width": 400,
|
||||||
|
"uav_patch_height": 400,
|
||||||
|
"sat_patch_width": 400,
|
||||||
|
"sat_patch_height": 400,
|
||||||
|
"heatmap_kernel_size": 33,
|
||||||
|
"test_from_train_ratio": 0.0,
|
||||||
|
"transform_mean": [0.485, 0.456, 0.406],
|
||||||
|
"transform_std": [0.229, 0.224, 0.225],
|
||||||
|
"sat_availaible_years": ["2023", "2021", "2019", "2016"],
|
||||||
|
"max_rotation_angle": 10,
|
||||||
|
"uav_image_scale": 1,
|
||||||
|
"use_heatmap": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _init(self, conf):
|
||||||
|
self.images = {"train": "train", "val": "val"}
|
||||||
|
|
||||||
|
def get_dataset(self, split):
|
||||||
|
if split == "val":
|
||||||
|
return GeoLocalizationDataset(
|
||||||
|
uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir,
|
||||||
|
satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir,
|
||||||
|
misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path,
|
||||||
|
dataset="test",
|
||||||
|
sat_zoom_level=self.conf.geo_dataset.sat_zoom_level,
|
||||||
|
uav_patch_width=self.conf.geo_dataset.uav_patch_width,
|
||||||
|
uav_patch_height=self.conf.geo_dataset.uav_patch_height,
|
||||||
|
sat_patch_width=self.conf.geo_dataset.sat_patch_width,
|
||||||
|
sat_patch_height=self.conf.geo_dataset.sat_patch_height,
|
||||||
|
heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size,
|
||||||
|
test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio,
|
||||||
|
transform_mean=self.conf.geo_dataset.transform_mean,
|
||||||
|
transform_std=self.conf.geo_dataset.transform_std,
|
||||||
|
sat_available_years=self.conf.geo_dataset.sat_availaible_years,
|
||||||
|
max_rotation_angle=self.conf.geo_dataset.max_rotation_angle,
|
||||||
|
uav_image_scale=self.conf.geo_dataset.uav_image_scale,
|
||||||
|
use_heatmap=self.conf.geo_dataset.use_heatmap,
|
||||||
|
subset_size=self.conf.val_size,
|
||||||
|
)
|
||||||
|
elif split == "train":
|
||||||
|
return GeoLocalizationDataset(
|
||||||
|
uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir,
|
||||||
|
satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir,
|
||||||
|
misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path,
|
||||||
|
dataset="train",
|
||||||
|
sat_zoom_level=self.conf.geo_dataset.sat_zoom_level,
|
||||||
|
uav_patch_width=self.conf.geo_dataset.uav_patch_width,
|
||||||
|
uav_patch_height=self.conf.geo_dataset.uav_patch_height,
|
||||||
|
sat_patch_width=self.conf.geo_dataset.sat_patch_width,
|
||||||
|
sat_patch_height=self.conf.geo_dataset.sat_patch_height,
|
||||||
|
heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size,
|
||||||
|
test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio,
|
||||||
|
transform_mean=self.conf.geo_dataset.transform_mean,
|
||||||
|
transform_std=self.conf.geo_dataset.transform_std,
|
||||||
|
sat_available_years=self.conf.geo_dataset.sat_availaible_years,
|
||||||
|
max_rotation_angle=self.conf.geo_dataset.max_rotation_angle,
|
||||||
|
uav_image_scale=self.conf.geo_dataset.uav_image_scale,
|
||||||
|
use_heatmap=self.conf.geo_dataset.use_heatmap,
|
||||||
|
subset_size=self.conf.train_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GeoLocalizationDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
uav_dataset_dir: str,
|
||||||
|
satellite_dataset_dir: str,
|
||||||
|
misslabeled_images_path: str,
|
||||||
|
dataset: Literal["train", "test"],
|
||||||
|
sat_zoom_level: int = 16,
|
||||||
|
uav_patch_width: int = 128,
|
||||||
|
uav_patch_height: int = 128,
|
||||||
|
sat_patch_width: int = 400,
|
||||||
|
sat_patch_height: int = 400,
|
||||||
|
heatmap_kernel_size: int = 33,
|
||||||
|
test_from_train_ratio: float = 0.0,
|
||||||
|
transform_mean: List[float] = [0.485, 0.456, 0.406],
|
||||||
|
transform_std: List[float] = [0.229, 0.224, 0.225],
|
||||||
|
sat_available_years: List[str] = ["2023", "2021", "2019", "2016"],
|
||||||
|
max_rotation_angle: int = 10,
|
||||||
|
uav_image_scale: float = 1,
|
||||||
|
use_heatmap: bool = True,
|
||||||
|
subset_size: int = None,
|
||||||
|
):
|
||||||
|
self.uav_dataset_dir = uav_dataset_dir
|
||||||
|
self.satellite_dataset_dir = satellite_dataset_dir
|
||||||
|
self.dataset = dataset
|
||||||
|
self.sat_zoom_level = sat_zoom_level
|
||||||
|
self.uav_patch_width = uav_patch_width
|
||||||
|
self.uav_patch_height = uav_patch_height
|
||||||
|
self.heatmap_kernel_size = heatmap_kernel_size
|
||||||
|
self.test_from_train_ratio = test_from_train_ratio
|
||||||
|
self.transform_mean = transform_mean
|
||||||
|
self.transform_std = transform_std
|
||||||
|
self.misslabeled_images_path = misslabeled_images_path
|
||||||
|
self.metadata_dict = {}
|
||||||
|
self.max_rotation_angle = max_rotation_angle
|
||||||
|
self.total_uav_samples = self.count_total_uav_samples()
|
||||||
|
self.misslabelled_images = self.read_misslabelled_images(self.misslabeled_images_path)
|
||||||
|
self.entry_paths = self.get_entry_paths(self.uav_dataset_dir)
|
||||||
|
self.cleanup_misslabelled_images()
|
||||||
|
self.transforms = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
#transforms.Normalize(self.transform_mean, self.transform_std),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.sat_available_years = sat_available_years
|
||||||
|
self.uav_image_scale = uav_image_scale
|
||||||
|
self.use_heatmap = use_heatmap
|
||||||
|
self.sat_patch_width = sat_patch_width
|
||||||
|
self.sat_patch_height = sat_patch_height
|
||||||
|
self.subset_size = subset_size
|
||||||
|
|
||||||
|
self.inverse_transforms = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[
|
||||||
|
-m / s for m, s in zip(self.transform_mean, self.transform_std)
|
||||||
|
],
|
||||||
|
std=[1 / s for s in self.transform_std],
|
||||||
|
),
|
||||||
|
transforms.ToPILImage(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.subset_size != -1:
|
||||||
|
ssize = int(self.subset_size / len(self.sat_available_years))
|
||||||
|
ssize = min(ssize, len(self.entry_paths))
|
||||||
|
self.entry_paths = self.entry_paths[:ssize]
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return (
|
||||||
|
len(self.entry_paths)
|
||||||
|
* len(self.sat_available_years)
|
||||||
|
)
|
||||||
|
|
||||||
|
def read_misslabelled_images(
|
||||||
|
self, path: str = "misslabels/misslabeled.txt"
|
||||||
|
) -> List[str]:
|
||||||
|
with open(path, "r") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
return [line.strip() for line in lines]
|
||||||
|
|
||||||
|
def cleanup_misslabelled_images(self) -> None:
|
||||||
|
indices_to_delete = []
|
||||||
|
|
||||||
|
for image in self.misslabelled_images:
|
||||||
|
for image_path in self.entry_paths:
|
||||||
|
if image in image_path:
|
||||||
|
index = self.entry_paths.index(image_path)
|
||||||
|
indices_to_delete.append(index)
|
||||||
|
break
|
||||||
|
|
||||||
|
sorted_tuples = sorted(indices_to_delete, reverse=True)
|
||||||
|
|
||||||
|
for index in sorted_tuples:
|
||||||
|
self.entry_paths.pop(index)
|
||||||
|
|
||||||
|
def __getitem__(self, idx) -> dict:
|
||||||
|
"""
|
||||||
|
Retrieves a sample given its index, returning the preprocessed UAV
|
||||||
|
and satellite images, along with their associated heatmap and metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_path_index = idx // (
|
||||||
|
len(self.sat_available_years)
|
||||||
|
)
|
||||||
|
|
||||||
|
sat_year = self.sat_available_years[idx % len(self.sat_available_years)]
|
||||||
|
rot_angle = random.randint(-self.max_rotation_angle, self.max_rotation_angle)
|
||||||
|
|
||||||
|
image_path = self.entry_paths[image_path_index]
|
||||||
|
uav_image = Image.open(image_path).convert("RGB") # Ensure 3-channel image
|
||||||
|
|
||||||
|
original_uav_image_width = uav_image.width
|
||||||
|
original_uav_image_height = uav_image.height
|
||||||
|
|
||||||
|
lookup_str, file_number = self.extract_info_from_filename(image_path)
|
||||||
|
img_info = self.metadata_dict[lookup_str][file_number]
|
||||||
|
|
||||||
|
lat, lon = (
|
||||||
|
img_info["coordinate"]["latitude"],
|
||||||
|
img_info["coordinate"]["longitude"],
|
||||||
|
)
|
||||||
|
|
||||||
|
fov_vertical = img_info["fovVertical"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
agl_altitude = float(image_path.split("/")[-1].split("_")[2].split("m")[0])
|
||||||
|
except IndexError:
|
||||||
|
agl_altitude = 150.0
|
||||||
|
warnings.warn(
|
||||||
|
"Could not extract AGL altitude from filename, using default value of 150m."
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
satellite_patch,
|
||||||
|
x_sat,
|
||||||
|
y_sat,
|
||||||
|
x_offset,
|
||||||
|
y_offset,
|
||||||
|
patch_transform,
|
||||||
|
) = get_random_tiff_patch(
|
||||||
|
lat, lon, self.sat_patch_width, self.sat_patch_height, sat_year, self.satellite_dataset_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rotate crop center and transform image
|
||||||
|
h = np.ceil(uav_image.height // self.uav_image_scale).astype(int)
|
||||||
|
w = np.ceil(uav_image.width // self.uav_image_scale).astype(int)
|
||||||
|
|
||||||
|
uav_image = F.rotate(uav_image, rot_angle)
|
||||||
|
uav_image = F.resize(uav_image, [h, w])
|
||||||
|
uav_image = F.center_crop(
|
||||||
|
uav_image, (self.uav_patch_height, self.uav_patch_width)
|
||||||
|
)
|
||||||
|
uav_image = self.transforms(uav_image)
|
||||||
|
|
||||||
|
satellite_patch = satellite_patch.transpose(1, 2, 0)
|
||||||
|
satellite_patch_pytorch = self.transforms(satellite_patch)
|
||||||
|
del satellite_patch
|
||||||
|
|
||||||
|
if self.use_heatmap:
|
||||||
|
heatmap = self.get_heatmap_gt(
|
||||||
|
x_sat,
|
||||||
|
y_sat,
|
||||||
|
satellite_patch.shape[1],
|
||||||
|
satellite_patch.shape[2],
|
||||||
|
self.heatmap_kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
cropped_uav_image_width = self.calculate_cropped_uav_image_width(
|
||||||
|
fov_vertical,
|
||||||
|
original_uav_image_width,
|
||||||
|
original_uav_image_height,
|
||||||
|
self.uav_patch_width,
|
||||||
|
self.uav_patch_height,
|
||||||
|
agl_altitude,
|
||||||
|
)
|
||||||
|
|
||||||
|
satellite_tile_width = self.calculate_cropped_sat_image_width(
|
||||||
|
lat, self.sat_patch_width, patch_transform
|
||||||
|
)
|
||||||
|
|
||||||
|
scale_factor = cropped_uav_image_width / satellite_tile_width
|
||||||
|
scale_factor *= 10
|
||||||
|
|
||||||
|
homography_matrix = self.compute_homography(
|
||||||
|
rot_angle,
|
||||||
|
x_sat,
|
||||||
|
y_sat,
|
||||||
|
self.uav_patch_width,
|
||||||
|
self.uav_patch_height,
|
||||||
|
scale_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.use_heatmap:
|
||||||
|
# Sample four points from the UAV image
|
||||||
|
points = self.sample_four_points(
|
||||||
|
self.uav_patch_width, self.uav_patch_height
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transform the points
|
||||||
|
warped_points = self.warp_points(points, homography_matrix)
|
||||||
|
#img_info["warped_points_sat"] = warped_points
|
||||||
|
#img_info["warped_points_uav"] = points
|
||||||
|
|
||||||
|
# img_info["cropped_uav_image_width"] = cropped_uav_image_width
|
||||||
|
# img_info["satellite_tile_width"] = satellite_tile_width
|
||||||
|
# img_info["scale_factor"] = scale_factor
|
||||||
|
# img_info["filename"] = image_path
|
||||||
|
# img_info["rot_angle"] = rot_angle
|
||||||
|
# img_info["x_sat"] = x_sat
|
||||||
|
# img_info["y_sat"] = y_sat
|
||||||
|
# img_info["x_offset"] = x_offset
|
||||||
|
# img_info["y_offset"] = y_offset
|
||||||
|
# img_info["patch_transform"] = patch_transform
|
||||||
|
# img_info["uav_image_scale"] = self.uav_image_scale
|
||||||
|
# img_info["homography_matrix_uav_to_sat"] = homography_matrix
|
||||||
|
# img_info["homography_matrix_sat_to_uav"] = np.linalg.inv(homography_matrix)
|
||||||
|
# img_info["agl_altitude"] = agl_altitude
|
||||||
|
# img_info["original_uav_image_width"] = original_uav_image_width
|
||||||
|
# img_info["original_drone_image_height"] = original_uav_image_height
|
||||||
|
# img_info["fov_vertical"] = fov_vertical
|
||||||
|
#
|
||||||
|
|
||||||
|
inverse_homography_matrix = np.linalg.inv(homography_matrix)
|
||||||
|
uav_image_data = {
|
||||||
|
"image": uav_image,
|
||||||
|
"H_": homography_matrix,
|
||||||
|
"coords": points,
|
||||||
|
"image_size": np.array([self.uav_patch_width, self.uav_patch_height]),
|
||||||
|
}
|
||||||
|
|
||||||
|
satellite_patch_data = {
|
||||||
|
"image": satellite_patch_pytorch,
|
||||||
|
"H_": inverse_homography_matrix,
|
||||||
|
"coords": warped_points,
|
||||||
|
"image_size": np.array([self.sat_patch_width, self.sat_patch_height]),
|
||||||
|
}
|
||||||
|
|
||||||
|
#hm = self.compute_homography_points(
|
||||||
|
# warped_points, points, [self.uav_patch_width, self.uav_patch_height]
|
||||||
|
#)
|
||||||
|
|
||||||
|
hm = self.compute_homography_points(
|
||||||
|
points, warped_points, [1.0 , 1.0]
|
||||||
|
)
|
||||||
|
|
||||||
|
if np.array_equal(hm, np.eye(3)):
|
||||||
|
print("Singular matrix")
|
||||||
|
return self.__getitem__(random.randint(0, len(self) - 1))
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"name": f"{image_path}_{rot_angle}_{sat_year}",
|
||||||
|
"original_image_size": np.array(
|
||||||
|
[original_uav_image_width, original_uav_image_height]
|
||||||
|
),
|
||||||
|
"H_0to1": hm.astype(np.float32),
|
||||||
|
"view0": uav_image_data,
|
||||||
|
"view1": satellite_patch_data,
|
||||||
|
"idx": idx
|
||||||
|
}
|
||||||
|
del img_info
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_cropped_sat_image_width(self, latitude, patch_width, patch_transform):
|
||||||
|
"""
|
||||||
|
Computes the width of the satellite image in world coordinate units
|
||||||
|
"""
|
||||||
|
length_of_degree = 111320 * np.cos(np.radians(latitude))
|
||||||
|
scale_x_meters = patch_transform[0] * length_of_degree
|
||||||
|
satellite_tile_width = patch_width * scale_x_meters
|
||||||
|
return satellite_tile_width
|
||||||
|
|
||||||
|
def calculate_cropped_uav_image_width(
|
||||||
|
self,
|
||||||
|
fov_vertical,
|
||||||
|
orig_width,
|
||||||
|
orig_height,
|
||||||
|
crop_width,
|
||||||
|
crop_height,
|
||||||
|
altitude=150.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Computes the width of the UAV image in world coordinate units
|
||||||
|
"""
|
||||||
|
# Convert fov from degrees to radians
|
||||||
|
fov_rad = np.radians(fov_vertical)
|
||||||
|
|
||||||
|
# Calculate the full width of the UAV image
|
||||||
|
full_width = 2 * (altitude * np.tan(fov_rad / 2))
|
||||||
|
|
||||||
|
# Determine the cropping ratio
|
||||||
|
crop_ratio_width = crop_width / orig_width
|
||||||
|
crop_ratio_height = crop_height / orig_height
|
||||||
|
|
||||||
|
# Calculate the adjusted horizontal fov
|
||||||
|
fov_horizontal = 2 * np.arctan(np.tan(fov_rad / 2) * (orig_width / orig_height))
|
||||||
|
adjusted_fov_horizontal = 2 * np.arctan(
|
||||||
|
np.tan(fov_horizontal / 2) * crop_ratio_width
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate the new full width using the adjusted horizontal fov
|
||||||
|
full_width = 2 * (altitude * np.tan(adjusted_fov_horizontal / 2))
|
||||||
|
|
||||||
|
# Adjust the width according to the crop ratio
|
||||||
|
cropped_width = full_width * crop_ratio_width
|
||||||
|
|
||||||
|
return cropped_width
|
||||||
|
|
||||||
|
def compute_homography_points(self, pts1_, pts2_, shape):
|
||||||
|
"""Compute the homography matrix from 4 point correspondences"""
|
||||||
|
# Rescale to actual size
|
||||||
|
shape = np.array(shape[::-1], dtype=np.float32) # different convention [y, x]
|
||||||
|
pts1 = pts1_ * np.expand_dims(shape, axis=0)
|
||||||
|
pts2 = pts2_ * np.expand_dims(shape, axis=0)
|
||||||
|
|
||||||
|
def ax(p, q):
|
||||||
|
return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
|
||||||
|
|
||||||
|
def ay(p, q):
|
||||||
|
return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
|
||||||
|
|
||||||
|
def flat2mat(H):
|
||||||
|
return np.reshape(np.concatenate([H, np.ones_like(H[:, :1])], axis=1), [3, 3])
|
||||||
|
|
||||||
|
a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0)
|
||||||
|
p_mat = np.transpose(
|
||||||
|
np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
homography = np.transpose(np.linalg.solve(a_mat, p_mat))
|
||||||
|
except np.linalg.LinAlgError:
|
||||||
|
print("Singular matrix")
|
||||||
|
return np.eye(3)
|
||||||
|
return flat2mat(homography)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_homography(
|
||||||
|
self, rot_angle, x_sat, y_sat, uav_width, uav_height, scale_factor
|
||||||
|
):
|
||||||
|
# Adjust rot_angle if it's greater than 180 degrees
|
||||||
|
if rot_angle > 180:
|
||||||
|
rot_angle -= 360
|
||||||
|
# Convert rotation angle to radians
|
||||||
|
theta = np.radians(rot_angle)
|
||||||
|
|
||||||
|
# Rotation matrix
|
||||||
|
R = np.array(
|
||||||
|
[
|
||||||
|
[np.cos(theta), -np.sin(theta), 0],
|
||||||
|
[np.sin(theta), np.cos(theta), 0],
|
||||||
|
[0, 0, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scale matrix
|
||||||
|
S = np.array([[scale_factor, 0, 0], [0, scale_factor, 0], [0, 0, 1]])
|
||||||
|
|
||||||
|
# Translation matrix to center the UAV image
|
||||||
|
T_uav = np.array([[1, 0, -uav_width / 2], [0, 1, -uav_height / 2], [0, 0, 1]])
|
||||||
|
|
||||||
|
# Translation matrix to move to the satellite image position
|
||||||
|
T_sat = np.array([[1, 0, x_sat], [0, 1, y_sat], [0, 0, 1]])
|
||||||
|
|
||||||
|
# Compute the combined homography matrix
|
||||||
|
H = np.dot(T_sat, np.dot(R, np.dot(S, T_uav)))
|
||||||
|
|
||||||
|
return H
|
||||||
|
|
||||||
|
def sample_four_points(self, width: int, height: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Samples four points from the UAV image.
|
||||||
|
"""
|
||||||
|
PADDING = 50
|
||||||
|
CENTER_PADDING = 10
|
||||||
|
points = np.array(
|
||||||
|
[
|
||||||
|
[random.randint(CENTER_PADDING, width - PADDING), random.randint(CENTER_PADDING, height - PADDING)]
|
||||||
|
for _ in range(4)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return points
|
||||||
|
|
||||||
|
def warp_points(
|
||||||
|
self, points: np.ndarray, homography_matrix: np.ndarray
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Warps the given points using the given homography matrix.
|
||||||
|
"""
|
||||||
|
points = np.array(points)
|
||||||
|
points = np.concatenate([points, np.ones((4, 1))], axis=1)
|
||||||
|
points = np.dot(homography_matrix, points.T).T
|
||||||
|
points = points[:, :2] / points[:, 2:]
|
||||||
|
return points
|
||||||
|
|
||||||
|
def count_total_uav_samples(self) -> int:
|
||||||
|
"""
|
||||||
|
Count the total number of uav image samples in the dataset
|
||||||
|
(train + test)
|
||||||
|
"""
|
||||||
|
total_samples = 0
|
||||||
|
|
||||||
|
for dirpath, dirnames, filenames in os.walk(self.uav_dataset_dir):
|
||||||
|
# Skip the test folder
|
||||||
|
for filename in filenames:
|
||||||
|
if filename.endswith(".jpeg"):
|
||||||
|
total_samples += 1
|
||||||
|
return total_samples
|
||||||
|
|
||||||
|
def get_number_of_city_samples(self) -> int:
|
||||||
|
"""
|
||||||
|
TODO: Count the total number of city samples in the dataset
|
||||||
|
"""
|
||||||
|
return 11
|
||||||
|
|
||||||
|
def get_entry_paths(self, directory: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Recursively retrieves paths to image and metadata files in the given directory.
|
||||||
|
"""
|
||||||
|
entry_paths = []
|
||||||
|
entries = os.listdir(directory)
|
||||||
|
|
||||||
|
images_to_take_per_folder = int(
|
||||||
|
self.total_uav_samples
|
||||||
|
* self.test_from_train_ratio
|
||||||
|
/ self.get_number_of_city_samples()
|
||||||
|
)
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
entry_path = os.path.join(directory, entry)
|
||||||
|
|
||||||
|
# If it's a directory, recurse into it
|
||||||
|
if os.path.isdir(entry_path):
|
||||||
|
entry_paths += self.get_entry_paths(entry_path)
|
||||||
|
|
||||||
|
# Handle train dataset
|
||||||
|
elif (self.dataset == "train" and "Train" in entry_path) or (
|
||||||
|
self.dataset == "train"
|
||||||
|
and self.test_from_train_ratio > 0
|
||||||
|
and "Test" in entry_path
|
||||||
|
):
|
||||||
|
if entry_path.endswith(".jpeg"):
|
||||||
|
_, number = self.extract_info_from_filename(entry_path)
|
||||||
|
else:
|
||||||
|
number = None
|
||||||
|
if entry_path.endswith(".json"):
|
||||||
|
self.get_metadata(entry_path)
|
||||||
|
if number is None:
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
number >= images_to_take_per_folder
|
||||||
|
): # Only include images beyond the ones taken for test
|
||||||
|
if entry_path.endswith(".jpeg"):
|
||||||
|
entry_paths.append(entry_path)
|
||||||
|
|
||||||
|
# Handle test dataset
|
||||||
|
elif self.dataset == "test":
|
||||||
|
if entry_path.endswith(".jpeg"):
|
||||||
|
_, number = self.extract_info_from_filename(entry_path)
|
||||||
|
else:
|
||||||
|
number = None
|
||||||
|
if entry_path.endswith(".json"):
|
||||||
|
self.get_metadata(entry_path)
|
||||||
|
|
||||||
|
if number is None:
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
("Test" in entry_path and number < images_to_take_per_folder)
|
||||||
|
or (number < images_to_take_per_folder and "Train" in entry_path)
|
||||||
|
or (self.test_from_train_ratio == 0.0 and "Test" in entry_path)
|
||||||
|
):
|
||||||
|
if entry_path.endswith(".jpeg"):
|
||||||
|
entry_paths.append(entry_path)
|
||||||
|
|
||||||
|
return sorted(entry_paths, key=self.extract_info_from_filename)
|
||||||
|
|
||||||
|
def get_metadata(self, path: str) -> None:
|
||||||
|
"""
|
||||||
|
Extracts metadata from a JSON file and stores it in the metadata dictionary.
|
||||||
|
"""
|
||||||
|
with open(path, newline="") as jsonfile:
|
||||||
|
json_dict = json.load(jsonfile)
|
||||||
|
path = path.split("/")[-1]
|
||||||
|
path = path.replace(".json", "")
|
||||||
|
self.metadata_dict[path] = json_dict["cameraFrames"]
|
||||||
|
|
||||||
|
def extract_info_from_filename(self, filename: str) -> (str, int):
|
||||||
|
"""
|
||||||
|
Extracts information from the filename.
|
||||||
|
"""
|
||||||
|
filename_without_ext = filename.replace(".jpeg", "")
|
||||||
|
segments = filename_without_ext.split("/")
|
||||||
|
info = segments[-1]
|
||||||
|
try:
|
||||||
|
number = int(info.split("_")[-1])
|
||||||
|
except ValueError:
|
||||||
|
print("Could not extract number from filename: ", filename)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
info = "_".join(info.split("_")[:-1])
|
||||||
|
|
||||||
|
return info, number
|
||||||
|
|
||||||
|
def get_heatmap_gt(
|
||||||
|
self, x: int, y: int, height: int, width: int, square_size: int = 33
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Returns 2D heatmap ground truth for the given x and y coordinates,
|
||||||
|
with the given square size.
|
||||||
|
"""
|
||||||
|
x_map, y_map = x, y
|
||||||
|
|
||||||
|
heatmap = torch.zeros((height, width))
|
||||||
|
|
||||||
|
half_size = square_size // 2
|
||||||
|
|
||||||
|
# Calculate the valid range for the square
|
||||||
|
start_x = max(0, x_map - half_size)
|
||||||
|
end_x = min(
|
||||||
|
width, x_map + half_size + 1
|
||||||
|
) # +1 to include the end_x in the square
|
||||||
|
start_y = max(0, y_map - half_size)
|
||||||
|
end_y = min(
|
||||||
|
height, y_map + half_size + 1
|
||||||
|
) # +1 to include the end_y in the square
|
||||||
|
|
||||||
|
heatmap[start_y:end_y, start_x:end_x] = 1
|
||||||
|
|
||||||
|
return heatmap
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from .. import logger # overwrite the logger
|
|
@ -194,7 +194,7 @@ class ETH3DDataset(BaseDataset):
|
||||||
if tmp_dir.exists():
|
if tmp_dir.exists():
|
||||||
shutil.rmtree(tmp_dir)
|
shutil.rmtree(tmp_dir)
|
||||||
tmp_dir.mkdir(exist_ok=True, parents=True)
|
tmp_dir.mkdir(exist_ok=True, parents=True)
|
||||||
url_base = "https://cvg-data.inf.ethz.ch/ETH3D_undistorted/"
|
url_base = "https://cvg-data.inf.ethz.ch/SOLD2/SOLD2_ETH3D_undistorted/"
|
||||||
zip_name = "ETH3D_undistorted.zip"
|
zip_name = "ETH3D_undistorted.zip"
|
||||||
zip_path = tmp_dir / zip_name
|
zip_path = tmp_dir / zip_name
|
||||||
torch.hub.download_url_to_file(url_base + zip_name, zip_path)
|
torch.hub.download_url_to_file(url_base + zip_name, zip_path)
|
||||||
|
|
|
@ -0,0 +1,163 @@
|
||||||
|
#! /usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
import mercantile
|
||||||
|
import rasterio
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
from rasterio.errors import NotGeoreferencedWarning
|
||||||
|
from rasterio.io import MemoryFile
|
||||||
|
from rasterio.transform import from_bounds
|
||||||
|
from rasterio.merge import merge
|
||||||
|
from PIL import Image
|
||||||
|
import gc
|
||||||
|
import random
|
||||||
|
from osgeo import gdal, osr
|
||||||
|
from affine import Affine
|
||||||
|
|
||||||
|
|
||||||
|
def get_5x5_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]:
|
||||||
|
neighbors = []
|
||||||
|
for main_neighbour in mercantile.neighbors(tile):
|
||||||
|
for sub_neighbour in mercantile.neighbors(main_neighbour):
|
||||||
|
if sub_neighbour not in neighbors:
|
||||||
|
neighbors.append(sub_neighbour)
|
||||||
|
return neighbors
|
||||||
|
|
||||||
|
def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir: str) -> (np.ndarray, dict):
|
||||||
|
"""
|
||||||
|
Returns a TIFF map of the given tile using GDAL.
|
||||||
|
"""
|
||||||
|
tile_data = []
|
||||||
|
neighbors = get_5x5_neighbors(tile)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
for neighbor in neighbors:
|
||||||
|
west, south, east, north = mercantile.bounds(neighbor)
|
||||||
|
tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg"
|
||||||
|
if not os.path.exists(tile_path):
|
||||||
|
raise FileNotFoundError(f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found.")
|
||||||
|
|
||||||
|
img = Image.open(tile_path)
|
||||||
|
img_array = np.array(img)
|
||||||
|
|
||||||
|
# Create an in-memory GDAL dataset
|
||||||
|
mem_driver = gdal.GetDriverByName('MEM')
|
||||||
|
dataset = mem_driver.Create('', img_array.shape[1], img_array.shape[0], 3, gdal.GDT_Byte)
|
||||||
|
for i in range(3):
|
||||||
|
dataset.GetRasterBand(i + 1).WriteArray(img_array[:, :, i])
|
||||||
|
|
||||||
|
# Set GeoTransform and Projection
|
||||||
|
geotransform = (west, (east - west) / img_array.shape[1], 0, north, 0, -(north - south) / img_array.shape[0])
|
||||||
|
dataset.SetGeoTransform(geotransform)
|
||||||
|
srs = osr.SpatialReference()
|
||||||
|
srs.ImportFromEPSG(3857)
|
||||||
|
dataset.SetProjection(srs.ExportToWkt())
|
||||||
|
|
||||||
|
tile_data.append(dataset)
|
||||||
|
|
||||||
|
# Merge tiles using GDAL
|
||||||
|
vrt_options = gdal.BuildVRTOptions()
|
||||||
|
vrt = gdal.BuildVRT('', [td for td in tile_data], options=vrt_options)
|
||||||
|
mosaic = vrt.ReadAsArray()
|
||||||
|
|
||||||
|
# Get metadata
|
||||||
|
out_trans = vrt.GetGeoTransform()
|
||||||
|
out_crs = vrt.GetProjection()
|
||||||
|
out_trans = Affine.from_gdal(*out_trans)
|
||||||
|
out_meta = {
|
||||||
|
"driver": "GTiff",
|
||||||
|
"height": mosaic.shape[1],
|
||||||
|
"width": mosaic.shape[2],
|
||||||
|
"transform": out_trans,
|
||||||
|
"crs": out_crs,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
for td in tile_data:
|
||||||
|
td.FlushCache()
|
||||||
|
vrt = None
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return mosaic, out_meta
|
||||||
|
|
||||||
|
def get_random_tiff_patch(
|
||||||
|
lat: float,
|
||||||
|
lon: float,
|
||||||
|
patch_width: int,
|
||||||
|
patch_height: int,
|
||||||
|
sat_year: str,
|
||||||
|
satellite_dataset_dir: str = "/mnt/drive/satellite_dataset",
|
||||||
|
) -> (np.ndarray, int, int, int, int, rasterio.transform.Affine):
|
||||||
|
"""
|
||||||
|
Returns a random patch from the satellite image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tile = get_tile_from_coord(lat, lon, 17)
|
||||||
|
|
||||||
|
mosaic, out_meta = get_tiff_map(tile, sat_year, satellite_dataset_dir)
|
||||||
|
|
||||||
|
|
||||||
|
transform = out_meta["transform"]
|
||||||
|
del out_meta
|
||||||
|
|
||||||
|
x_pixel, y_pixel = geo_to_pixel_coordinates(lat, lon, transform)
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# Temporal constant, replace with a better solution
|
||||||
|
KS = 120
|
||||||
|
|
||||||
|
x_offset_range = [
|
||||||
|
x_pixel - patch_width + KS + 1,
|
||||||
|
x_pixel - KS - 1,
|
||||||
|
]
|
||||||
|
y_offset_range = [
|
||||||
|
y_pixel - patch_height + KS + 1,
|
||||||
|
y_pixel - KS - 1,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Randomly select an offset within the valid range
|
||||||
|
x_offset = random.randint(*x_offset_range)
|
||||||
|
y_offset = random.randint(*y_offset_range)
|
||||||
|
|
||||||
|
x_offset = np.clip(x_offset, 0, mosaic.shape[-1] - patch_width)
|
||||||
|
y_offset = np.clip(y_offset, 0, mosaic.shape[-2] - patch_height)
|
||||||
|
|
||||||
|
# Update x, y to reflect the clamping of x_offset and y_offset
|
||||||
|
x, y = x_pixel - x_offset, y_pixel - y_offset
|
||||||
|
patch = mosaic[
|
||||||
|
:, y_offset : y_offset + patch_height, x_offset : x_offset + patch_width
|
||||||
|
]
|
||||||
|
|
||||||
|
patch_transform = rasterio.transform.Affine(
|
||||||
|
transform.a,
|
||||||
|
transform.b,
|
||||||
|
transform.c + x_offset * transform.a + y_offset * transform.b,
|
||||||
|
transform.d,
|
||||||
|
transform.e,
|
||||||
|
transform.f + x_offset * transform.d + y_offset * transform.e,
|
||||||
|
)
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return patch, x, y, x_offset, y_offset, patch_transform
|
||||||
|
|
||||||
|
def get_tile_from_coord(
|
||||||
|
lat: float, lng: float, zoom_level: int
|
||||||
|
) -> mercantile.Tile:
|
||||||
|
"""
|
||||||
|
Returns the tile containing the given coordinates.
|
||||||
|
"""
|
||||||
|
tile = mercantile.tile(lng, lat, zoom_level)
|
||||||
|
return tile
|
||||||
|
|
||||||
|
def geo_to_pixel_coordinates(
|
||||||
|
lat: float, lon: float, transform: rasterio.transform.Affine
|
||||||
|
) -> (int, int):
|
||||||
|
"""
|
||||||
|
Converts a pair of (lat, lon) coordinates to pixel coordinates.
|
||||||
|
"""
|
||||||
|
x_pixel, y_pixel = ~transform * (lon, lat)
|
||||||
|
return round(x_pixel), round(y_pixel)
|
|
@ -0,0 +1,109 @@
|
||||||
|
#! /usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
import mercantile
|
||||||
|
import rasterio
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
from rasterio.errors import NotGeoreferencedWarning
|
||||||
|
from rasterio.io import MemoryFile
|
||||||
|
from rasterio.transform import from_bounds
|
||||||
|
from rasterio.merge import merge
|
||||||
|
from PIL import Image
|
||||||
|
import gc
|
||||||
|
|
||||||
|
|
||||||
|
def get_3x3_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]:
|
||||||
|
neighbors = []
|
||||||
|
for neighbour in mercantile.neighbors(tile):
|
||||||
|
if neighbour not in neighbors:
|
||||||
|
neighbors.append(neighbour)
|
||||||
|
|
||||||
|
neighbors.append(tile)
|
||||||
|
return neighbors
|
||||||
|
|
||||||
|
def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir:str) -> (np.ndarray, dict):
|
||||||
|
"""
|
||||||
|
Returns a TIFF map of the given tile.
|
||||||
|
"""
|
||||||
|
tile_data = []
|
||||||
|
neighbors = get_3x3_neighbors(tile)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)
|
||||||
|
for neighbor in neighbors:
|
||||||
|
west, south, east, north = mercantile.bounds(neighbor)
|
||||||
|
tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg"
|
||||||
|
if not os.path.exists(tile_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found."
|
||||||
|
)
|
||||||
|
|
||||||
|
with Image.open(tile_path) as img:
|
||||||
|
width, height = img.size
|
||||||
|
memfile = MemoryFile()
|
||||||
|
with memfile.open(
|
||||||
|
driver="GTiff",
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
count=3,
|
||||||
|
dtype="uint8",
|
||||||
|
crs="EPSG:3857",
|
||||||
|
transform=from_bounds(west, south, east, north, width, height),
|
||||||
|
) as dataset:
|
||||||
|
data = rasterio.open(tile_path).read()
|
||||||
|
dataset.write(data)
|
||||||
|
tile_data.append(memfile.open())
|
||||||
|
memfile.close()
|
||||||
|
|
||||||
|
mosaic, out_trans = merge(tile_data)
|
||||||
|
|
||||||
|
out_meta = tile_data[0].meta.copy()
|
||||||
|
out_meta.update(
|
||||||
|
{
|
||||||
|
"driver": "GTiff",
|
||||||
|
"height": mosaic.shape[1],
|
||||||
|
"width": mosaic.shape[2],
|
||||||
|
"transform": out_trans,
|
||||||
|
"crs": "EPSG:3857",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up MemoryFile instances to free up memory
|
||||||
|
for td in tile_data:
|
||||||
|
td.close()
|
||||||
|
|
||||||
|
del neighbors
|
||||||
|
del tile_data
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return mosaic, out_meta
|
||||||
|
|
||||||
|
def get_random_tiff_patch(
|
||||||
|
lat: float,
|
||||||
|
lon: float,
|
||||||
|
satellite_dataset_dir: str,
|
||||||
|
) -> (np.ndarray):
|
||||||
|
"""
|
||||||
|
Returns a random patch from the satellite image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tile = get_tile_from_coord(lat, lon, 17)
|
||||||
|
sat_years = ["2023", "2021", "2019", "2016"]
|
||||||
|
|
||||||
|
# Randomly select a satellite year
|
||||||
|
sat_year = random.choice(sat_years)
|
||||||
|
|
||||||
|
mosaic, _ = get_tiff_map(tile, sat_year, satellite_dataset_dir)
|
||||||
|
|
||||||
|
return mosaic
|
||||||
|
|
||||||
|
def get_tile_from_coord(
|
||||||
|
lat: float, lng: float, zoom_level: int
|
||||||
|
) -> mercantile.Tile:
|
||||||
|
"""
|
||||||
|
Returns the tile containing the given coordinates.
|
||||||
|
"""
|
||||||
|
tile = mercantile.tile(lng, lat, zoom_level)
|
||||||
|
return tile
|
|
@ -0,0 +1,297 @@
|
||||||
|
"""
|
||||||
|
Simply load images from a folder or nested folders (does not have any split),
|
||||||
|
and apply homographic adaptations to it. Yields an image pair without border
|
||||||
|
artifacts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import tarfile
|
||||||
|
from pathlib import Path
|
||||||
|
import mercantile
|
||||||
|
import rasterio
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import omegaconf
|
||||||
|
import torch
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from tqdm import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..geometry.homography import (
|
||||||
|
compute_homography,
|
||||||
|
sample_homography_corners,
|
||||||
|
warp_points,
|
||||||
|
)
|
||||||
|
from ..models.cache_loader import CacheLoader, pad_local_features
|
||||||
|
from ..settings import DATA_PATH
|
||||||
|
from ..utils.image import read_image
|
||||||
|
from ..utils.tools import fork_rng
|
||||||
|
from ..visualization.viz2d import plot_image_grid
|
||||||
|
from .augmentations import IdentityAugmentation, augmentations
|
||||||
|
from .base_dataset import BaseDataset
|
||||||
|
from .sat_utils import get_random_tiff_patch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_homography(img, conf: dict, size: list):
|
||||||
|
data = {}
|
||||||
|
H, _, coords, _ = sample_homography_corners(img.shape[:2][::-1], **conf)
|
||||||
|
data["image"] = cv2.warpPerspective(img, H, tuple(size))
|
||||||
|
data["H_"] = H.astype(np.float32)
|
||||||
|
data["coords"] = coords.astype(np.float32)
|
||||||
|
data["image_size"] = np.array(size, dtype=np.float32)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class SatelliteDataset(BaseDataset):
|
||||||
|
default_conf = {
|
||||||
|
# image search
|
||||||
|
"data_dir": "revisitop1m", # the top-level directory
|
||||||
|
"image_dir": "jpg/", # the subdirectory with the images
|
||||||
|
"image_list": "revisitop1m.txt", # optional: list or filename of list
|
||||||
|
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
|
||||||
|
"metadata_dir": None,
|
||||||
|
# splits
|
||||||
|
"train_size": 100,
|
||||||
|
"val_size": 10,
|
||||||
|
"shuffle_seed": 0, # or None to skip
|
||||||
|
# image loading
|
||||||
|
"grayscale": False,
|
||||||
|
"triplet": False,
|
||||||
|
"right_only": False, # image0 is orig (rescaled), image1 is right
|
||||||
|
"reseed": False,
|
||||||
|
"homography": {
|
||||||
|
"difficulty": 0.8,
|
||||||
|
"translation": 1.0,
|
||||||
|
"max_angle": 60,
|
||||||
|
"n_angles": 10,
|
||||||
|
"patch_shape": [640, 480],
|
||||||
|
"min_convexity": 0.05,
|
||||||
|
},
|
||||||
|
"photometric": {
|
||||||
|
"name": "dark",
|
||||||
|
"p": 0.75,
|
||||||
|
# 'difficulty': 1.0, # currently unused
|
||||||
|
},
|
||||||
|
# feature loading
|
||||||
|
"load_features": {
|
||||||
|
"do": False,
|
||||||
|
**CacheLoader.default_conf,
|
||||||
|
"collate": False,
|
||||||
|
"thresh": 0.0,
|
||||||
|
"max_num_keypoints": -1,
|
||||||
|
"force_num_keypoints": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _init(self, conf):
|
||||||
|
data_dir = conf.data_dir
|
||||||
|
coordinates_file = conf.metadata_dir
|
||||||
|
|
||||||
|
with open(coordinates_file, "r") as cf:
|
||||||
|
coordinates = cf.readlines()
|
||||||
|
|
||||||
|
parsed_coordinates = []
|
||||||
|
|
||||||
|
for coordinate in coordinates:
|
||||||
|
lat_part, lon_part = coordinate.split(',')
|
||||||
|
lat = float(lat_part.split(':')[-1].strip())
|
||||||
|
lon = float(lon_part.split(':')[-1].strip())
|
||||||
|
parsed_coordinates.append((lat, lon))
|
||||||
|
|
||||||
|
# Split into train and val 20% val
|
||||||
|
|
||||||
|
train_images = parsed_coordinates[:int(len(parsed_coordinates) * 0.8)]
|
||||||
|
val_images = parsed_coordinates[int(len(parsed_coordinates) * 0.8):]
|
||||||
|
|
||||||
|
self.images = {"train": train_images, "val": val_images}
|
||||||
|
|
||||||
|
def get_dataset(self, split):
|
||||||
|
return _Dataset(self.conf, self.images[split], split)
|
||||||
|
|
||||||
|
|
||||||
|
class _Dataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, conf, image_names, split):
|
||||||
|
self.conf = conf
|
||||||
|
self.split = split
|
||||||
|
self.image_names = np.array(image_names)
|
||||||
|
self.image_dir = DATA_PATH / conf.data_dir / conf.image_dir
|
||||||
|
|
||||||
|
aug_conf = conf.photometric
|
||||||
|
aug_name = aug_conf.name
|
||||||
|
assert (
|
||||||
|
aug_name in augmentations.keys()
|
||||||
|
), f'{aug_name} not in {" ".join(augmentations.keys())}'
|
||||||
|
#self.left_augment = (
|
||||||
|
# IdentityAugmentation() if conf.right_only else self.photo_augment
|
||||||
|
#)
|
||||||
|
self.photo_augment = augmentations[aug_name](aug_conf)
|
||||||
|
self.left_augment = augmentations[aug_name](aug_conf)
|
||||||
|
self.img_to_tensor = IdentityAugmentation()
|
||||||
|
|
||||||
|
if conf.load_features.do:
|
||||||
|
self.feature_loader = CacheLoader(conf.load_features)
|
||||||
|
|
||||||
|
def _transform_keypoints(self, features, data):
|
||||||
|
"""Transform keypoints by a homography, threshold them,
|
||||||
|
and potentially keep only the best ones."""
|
||||||
|
# Warp points
|
||||||
|
features["keypoints"] = warp_points(
|
||||||
|
features["keypoints"], data["H_"], inverse=False
|
||||||
|
)
|
||||||
|
h, w = data["image"].shape[1:3]
|
||||||
|
valid = (
|
||||||
|
(features["keypoints"][:, 0] >= 0)
|
||||||
|
& (features["keypoints"][:, 0] <= w - 1)
|
||||||
|
& (features["keypoints"][:, 1] >= 0)
|
||||||
|
& (features["keypoints"][:, 1] <= h - 1)
|
||||||
|
)
|
||||||
|
features["keypoints"] = features["keypoints"][valid]
|
||||||
|
|
||||||
|
# Threshold
|
||||||
|
if self.conf.load_features.thresh > 0:
|
||||||
|
valid = features["keypoint_scores"] >= self.conf.load_features.thresh
|
||||||
|
features = {k: v[valid] for k, v in features.items()}
|
||||||
|
|
||||||
|
# Get the top keypoints and pad
|
||||||
|
n = self.conf.load_features.max_num_keypoints
|
||||||
|
if n > -1:
|
||||||
|
inds = np.argsort(-features["keypoint_scores"])
|
||||||
|
features = {k: v[inds[:n]] for k, v in features.items()}
|
||||||
|
|
||||||
|
if self.conf.load_features.force_num_keypoints:
|
||||||
|
features = pad_local_features(
|
||||||
|
features, self.conf.load_features.max_num_keypoints
|
||||||
|
)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if self.conf.reseed:
|
||||||
|
with fork_rng(self.conf.seed + idx, False):
|
||||||
|
return self.getitem(idx)
|
||||||
|
else:
|
||||||
|
return self.getitem(idx)
|
||||||
|
|
||||||
|
def _read_view(self, img, H_conf, ps, left=False):
|
||||||
|
data = sample_homography(img, H_conf, ps)
|
||||||
|
if left:
|
||||||
|
data["image"] = self.left_augment(data["image"], return_tensor=True)
|
||||||
|
else:
|
||||||
|
data["image"] = self.photo_augment(data["image"], return_tensor=True)
|
||||||
|
|
||||||
|
gs = data["image"].new_tensor([0.299, 0.587, 0.114]).view(3, 1, 1)
|
||||||
|
if self.conf.grayscale:
|
||||||
|
data["image"] = (data["image"] * gs).sum(0, keepdim=True)
|
||||||
|
|
||||||
|
if self.conf.load_features.do:
|
||||||
|
features = self.feature_loader({k: [v] for k, v in data.items()})
|
||||||
|
features = self._transform_keypoints(features, data)
|
||||||
|
data["cache"] = features
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def getitem(self, idx):
|
||||||
|
# Generate a list of coordinates, based on the coordinate do the split
|
||||||
|
lat, lon = self.image_names[idx]
|
||||||
|
img = get_random_tiff_patch(lat, lon, self.conf.data_dir)
|
||||||
|
|
||||||
|
if img is None:
|
||||||
|
logging.warning("Image %f %f could not be read.", lat, lon)
|
||||||
|
img = np.zeros((1024, 1024) + (() if self.conf.grayscale else (3,)))
|
||||||
|
|
||||||
|
img = img.transpose(1,2,0)
|
||||||
|
img = img.astype(np.float32) / 255.0
|
||||||
|
#write_status = cv2.imwrite(f"/mnt/drive/{str(idx)}.jpg", img)
|
||||||
|
#if write_status == True:
|
||||||
|
# print("Writing success")
|
||||||
|
#else:
|
||||||
|
# print("fuck")
|
||||||
|
size = img.shape[:2][::-1]
|
||||||
|
ps = self.conf.homography.patch_shape
|
||||||
|
|
||||||
|
left_conf = omegaconf.OmegaConf.to_container(self.conf.homography)
|
||||||
|
if self.conf.right_only:
|
||||||
|
left_conf["difficulty"] = 0.0
|
||||||
|
|
||||||
|
data0 = self._read_view(img, left_conf, ps, left=True)
|
||||||
|
data1 = self._read_view(img, self.conf.homography, ps, left=False)
|
||||||
|
|
||||||
|
print("Data0 homography_matrix", data0["H_"])
|
||||||
|
print("Data1 homography matrix", data1["H_"])
|
||||||
|
|
||||||
|
# image_1 = data0["image"]
|
||||||
|
# image_1 = image_1.numpy()
|
||||||
|
# image_1 = image_1.transpose(1,2,0)
|
||||||
|
# image_1 = np.uint8(image_1 * 255)
|
||||||
|
#
|
||||||
|
# image_2 = data1["image"]
|
||||||
|
# image_2 = image_2.numpy()
|
||||||
|
# image_2 = image_2.transpose(1,2,0)
|
||||||
|
# image_2 = np.uint8(image_2 * 255)
|
||||||
|
#
|
||||||
|
# PIL_IMAGE = Image.fromarray(image_1)
|
||||||
|
# PIL_IMAGE.save(f"/mnt/drive/{str(idx)}.jpg")
|
||||||
|
# PIL_IMAGE = Image.fromarray(image_2)
|
||||||
|
# PIL_IMAGE.save(f"/mnt/drive/{str(idx)}-s.jpg")
|
||||||
|
#
|
||||||
|
# exit()
|
||||||
|
#
|
||||||
|
H = compute_homography(data0["coords"], data1["coords"], [1, 1])
|
||||||
|
|
||||||
|
print("COmputed homography", H)
|
||||||
|
|
||||||
|
name = f"lat{lat}_lon{lon}"
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"name": name,
|
||||||
|
"original_image_size": np.array(size),
|
||||||
|
"H_0to1": H.astype(np.float32),
|
||||||
|
"idx": idx,
|
||||||
|
"view0": data0,
|
||||||
|
"view1": data1,
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_names)
|
||||||
|
|
||||||
|
|
||||||
|
def visualize(args):
|
||||||
|
conf = {
|
||||||
|
"batch_size": 1,
|
||||||
|
"num_workers": 1,
|
||||||
|
"prefetch_factor": 1,
|
||||||
|
}
|
||||||
|
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
|
||||||
|
dataset = SatelliteDataset(conf)
|
||||||
|
loader = dataset.get_data_loader("train")
|
||||||
|
logger.info("The dataset has %d elements.", len(loader))
|
||||||
|
|
||||||
|
with fork_rng(seed=dataset.conf.seed):
|
||||||
|
images = []
|
||||||
|
for _, data in zip(range(args.num_items), loader):
|
||||||
|
images.append(
|
||||||
|
(data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2))
|
||||||
|
)
|
||||||
|
plot_image_grid(images, dpi=args.dpi)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.imsave("implot.png")
|
||||||
|
#plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from .. import logger # overwrite the logger
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--num_items", type=int, default=8)
|
||||||
|
parser.add_argument("--dpi", type=int, default=100)
|
||||||
|
parser.add_argument("dotlist", nargs="*")
|
||||||
|
args = parser.parse_intermixed_args()
|
||||||
|
visualize(args)
|
|
@ -5,6 +5,7 @@ from pprint import pprint
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
@ -12,6 +13,7 @@ from ..datasets import get_dataset
|
||||||
from ..models.cache_loader import CacheLoader
|
from ..models.cache_loader import CacheLoader
|
||||||
from ..settings import EVAL_PATH
|
from ..settings import EVAL_PATH
|
||||||
from ..utils.export_predictions import export_predictions
|
from ..utils.export_predictions import export_predictions
|
||||||
|
from ..utils.tensor import map_tensor
|
||||||
from ..utils.tools import AUCMetric
|
from ..utils.tools import AUCMetric
|
||||||
from ..visualization.viz2d import plot_cumulative
|
from ..visualization.viz2d import plot_cumulative
|
||||||
from .eval_pipeline import EvalPipeline
|
from .eval_pipeline import EvalPipeline
|
||||||
|
@ -105,9 +107,11 @@ class HPatchesPipeline(EvalPipeline):
|
||||||
cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
|
cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
|
||||||
for i, data in enumerate(tqdm(loader)):
|
for i, data in enumerate(tqdm(loader)):
|
||||||
pred = cache_loader(data)
|
pred = cache_loader(data)
|
||||||
|
# Remove batch dimension
|
||||||
|
data = map_tensor(data, lambda t: torch.squeeze(t, dim=0))
|
||||||
# add custom evaluations here
|
# add custom evaluations here
|
||||||
if "keypoints0" in pred:
|
if "keypoints0" in pred:
|
||||||
results_i = eval_matches_homography(data, pred, {})
|
results_i = eval_matches_homography(data, pred)
|
||||||
results_i = {**results_i, **eval_homography_dlt(data, pred)}
|
results_i = {**results_i, **eval_homography_dlt(data, pred)}
|
||||||
else:
|
else:
|
||||||
results_i = {}
|
results_i = {}
|
||||||
|
|
|
@ -89,6 +89,11 @@ def load_model(model_conf, checkpoint):
|
||||||
model = load_experiment(checkpoint, conf=model_conf).eval()
|
model = load_experiment(checkpoint, conf=model_conf).eval()
|
||||||
else:
|
else:
|
||||||
model = get_model("two_view_pipeline")(model_conf).eval()
|
model = get_model("two_view_pipeline")(model_conf).eval()
|
||||||
|
if not model.is_initialized():
|
||||||
|
raise ValueError(
|
||||||
|
"The provided model has non-initialized parameters. "
|
||||||
|
+ "Try to load a checkpoint instead."
|
||||||
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import kornia
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from kornia.geometry.homography import find_homography_dlt
|
||||||
|
|
||||||
from ..geometry.epipolar import generalized_epi_dist, relative_pose_error
|
from ..geometry.epipolar import generalized_epi_dist, relative_pose_error
|
||||||
from ..geometry.gt_generation import IGNORE_FEATURE
|
from ..geometry.gt_generation import IGNORE_FEATURE
|
||||||
from ..geometry.homography import homography_corner_error, sym_homography_error
|
from ..geometry.homography import homography_corner_error, sym_homography_error
|
||||||
from ..robust_estimators import load_estimator
|
from ..robust_estimators import load_estimator
|
||||||
|
from ..utils.tensor import index_batch
|
||||||
from ..utils.tools import AUCMetric
|
from ..utils.tools import AUCMetric
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,6 +27,16 @@ def get_matches_scores(kpts0, kpts1, matches0, mscores0):
|
||||||
return pts0, pts1, scores
|
return pts0, pts1, scores
|
||||||
|
|
||||||
|
|
||||||
|
def eval_per_batch_item(data: dict, pred: dict, eval_f, *args, **kwargs):
|
||||||
|
# Batched data
|
||||||
|
results = [
|
||||||
|
eval_f(data_i, pred_i, *args, **kwargs)
|
||||||
|
for data_i, pred_i in zip(index_batch(data), index_batch(pred))
|
||||||
|
]
|
||||||
|
# Return a dictionary of lists with the evaluation of each item
|
||||||
|
return {k: [r[k] for r in results] for k in results[0].keys()}
|
||||||
|
|
||||||
|
|
||||||
def eval_matches_epipolar(data: dict, pred: dict) -> dict:
|
def eval_matches_epipolar(data: dict, pred: dict) -> dict:
|
||||||
check_keys_recursive(data, ["view0", "view1", "T_0to1"])
|
check_keys_recursive(data, ["view0", "view1", "T_0to1"])
|
||||||
check_keys_recursive(
|
check_keys_recursive(
|
||||||
|
@ -58,23 +69,25 @@ def eval_matches_epipolar(data: dict, pred: dict) -> dict:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def eval_matches_homography(data: dict, pred: dict, conf) -> dict:
|
def eval_matches_homography(data: dict, pred: dict) -> dict:
|
||||||
check_keys_recursive(data, ["H_0to1"])
|
check_keys_recursive(data, ["H_0to1"])
|
||||||
check_keys_recursive(
|
check_keys_recursive(
|
||||||
pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"]
|
pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"]
|
||||||
)
|
)
|
||||||
|
|
||||||
H_gt = data["H_0to1"]
|
H_gt = data["H_0to1"]
|
||||||
|
if H_gt.ndim > 2:
|
||||||
|
return eval_per_batch_item(data, pred, eval_matches_homography)
|
||||||
|
|
||||||
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
|
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
|
||||||
m0, scores0 = pred["matches0"], pred["matching_scores0"]
|
m0, scores0 = pred["matches0"], pred["matching_scores0"]
|
||||||
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
|
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
|
||||||
err = sym_homography_error(pts0, pts1, H_gt[0])
|
err = sym_homography_error(pts0, pts1, H_gt)
|
||||||
results = {}
|
results = {}
|
||||||
results["prec@1px"] = (err < 1).float().mean().nan_to_num().item()
|
results["prec@1px"] = (err < 1).float().mean().nan_to_num().item()
|
||||||
results["prec@3px"] = (err < 3).float().mean().nan_to_num().item()
|
results["prec@3px"] = (err < 3).float().mean().nan_to_num().item()
|
||||||
results["num_matches"] = pts0.shape[0]
|
results["num_matches"] = pts0.shape[0]
|
||||||
results["num_keypoints"] = (kp0.shape[0] + kp1.shape[0]) / 2.0
|
results["num_keypoints"] = (kp0.shape[0] + kp1.shape[0]) / 2.0
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@ -84,7 +97,7 @@ def eval_relative_pose_robust(data, pred, conf):
|
||||||
pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"]
|
pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"]
|
||||||
)
|
)
|
||||||
|
|
||||||
T_gt = data["T_0to1"][0]
|
T_gt = data["T_0to1"]
|
||||||
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
|
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
|
||||||
m0, scores0 = pred["matches0"], pred["matching_scores0"]
|
m0, scores0 = pred["matches0"], pred["matching_scores0"]
|
||||||
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
|
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
|
||||||
|
@ -107,9 +120,8 @@ def eval_relative_pose_robust(data, pred, conf):
|
||||||
else:
|
else:
|
||||||
# R, t, inl = ret
|
# R, t, inl = ret
|
||||||
M = est["M_0to1"]
|
M = est["M_0to1"]
|
||||||
R, t = M.numpy()
|
|
||||||
inl = est["inliers"].numpy()
|
inl = est["inliers"].numpy()
|
||||||
r_error, t_error = relative_pose_error(T_gt, R, t)
|
t_error, r_error = relative_pose_error(T_gt, M.R, M.t)
|
||||||
results["rel_pose_error"] = max(r_error, t_error)
|
results["rel_pose_error"] = max(r_error, t_error)
|
||||||
results["ransac_inl"] = np.sum(inl)
|
results["ransac_inl"] = np.sum(inl)
|
||||||
results["ransac_inl%"] = np.mean(inl)
|
results["ransac_inl%"] = np.mean(inl)
|
||||||
|
@ -119,6 +131,9 @@ def eval_relative_pose_robust(data, pred, conf):
|
||||||
|
|
||||||
def eval_homography_robust(data, pred, conf):
|
def eval_homography_robust(data, pred, conf):
|
||||||
H_gt = data["H_0to1"]
|
H_gt = data["H_0to1"]
|
||||||
|
if H_gt.ndim > 2:
|
||||||
|
return eval_per_batch_item(data, pred, eval_relative_pose_robust, conf)
|
||||||
|
|
||||||
estimator = load_estimator("homography", conf["estimator"])(conf)
|
estimator = load_estimator("homography", conf["estimator"])(conf)
|
||||||
|
|
||||||
data_ = {}
|
data_ = {}
|
||||||
|
@ -158,24 +173,26 @@ def eval_homography_robust(data, pred, conf):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def eval_homography_dlt(data, pred, *args):
|
def eval_homography_dlt(data, pred):
|
||||||
H_gt = data["H_0to1"]
|
H_gt = data["H_0to1"]
|
||||||
H_inf = torch.ones_like(H_gt) * float("inf")
|
H_inf = torch.ones_like(H_gt) * float("inf")
|
||||||
|
|
||||||
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
|
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
|
||||||
m0, scores0 = pred["matches0"], pred["matching_scores0"]
|
m0, scores0 = pred["matches0"], pred["matching_scores0"]
|
||||||
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
|
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
|
||||||
|
scores = scores.to(pts0)
|
||||||
results = {}
|
results = {}
|
||||||
try:
|
try:
|
||||||
Hdlt = kornia.geometry.homography.find_homography_dlt(
|
if H_gt.ndim == 2:
|
||||||
pts0[None], pts1[None], scores[None].to(pts0)
|
pts0, pts1, scores = pts0[None], pts1[None], scores[None]
|
||||||
)[0]
|
h_dlt = find_homography_dlt(pts0, pts1, scores)
|
||||||
|
if H_gt.ndim == 2:
|
||||||
|
h_dlt = h_dlt[0]
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
Hdlt = H_inf
|
h_dlt = H_inf
|
||||||
|
|
||||||
error_dlt = homography_corner_error(Hdlt, H_gt, data["view0"]["image_size"])
|
error_dlt = homography_corner_error(h_dlt, H_gt, data["view0"]["image_size"])
|
||||||
results["H_error_dlt"] = error_dlt.item()
|
results["H_error_dlt"] = error_dlt.item()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .utils import skew_symmetric, to_homogeneous
|
from .utils import skew_symmetric, to_homogeneous
|
||||||
|
@ -124,39 +123,33 @@ def decompose_essential_matrix(E):
|
||||||
|
|
||||||
|
|
||||||
# pose errors
|
# pose errors
|
||||||
# TODO: port to torch and batch
|
# TODO: test for batched data
|
||||||
def angle_error_mat(R1, R2):
|
def angle_error_mat(R1, R2):
|
||||||
cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
|
cos = (torch.trace(torch.einsum("...ij, ...jk -> ...ik", R1.T, R2)) - 1) / 2
|
||||||
cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
|
cos = torch.clip(cos, -1.0, 1.0) # numerical errors can make it out of bounds
|
||||||
return np.rad2deg(np.abs(np.arccos(cos)))
|
return torch.rad2deg(torch.abs(torch.arccos(cos)))
|
||||||
|
|
||||||
|
|
||||||
def angle_error_vec(v1, v2):
|
def angle_error_vec(v1, v2, eps=1e-10):
|
||||||
n = np.linalg.norm(v1) * np.linalg.norm(v2)
|
n = torch.clip(v1.norm(dim=-1) * v2.norm(dim=-1), min=eps)
|
||||||
return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
|
v1v2 = (v1 * v2).sum(dim=-1) # dot product in the last dimension
|
||||||
|
return torch.rad2deg(torch.arccos(torch.clip(v1v2 / n, -1.0, 1.0)))
|
||||||
|
|
||||||
|
|
||||||
def compute_pose_error(T_0to1, R, t):
|
def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0, eps=1e-10):
|
||||||
R_gt = T_0to1[:3, :3]
|
if isinstance(T_0to1, torch.Tensor):
|
||||||
t_gt = T_0to1[:3, 3]
|
R_gt, t_gt = T_0to1[:3, :3], T_0to1[:3, 3]
|
||||||
error_t = angle_error_vec(t, t_gt)
|
else:
|
||||||
error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
|
R_gt, t_gt = T_0to1.R, T_0to1.t
|
||||||
error_R = angle_error_mat(R, R_gt)
|
R_gt, t_gt = torch.squeeze(R_gt), torch.squeeze(t_gt)
|
||||||
return error_t, error_R
|
|
||||||
|
|
||||||
|
|
||||||
def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
|
|
||||||
# angle error between 2 vectors
|
# angle error between 2 vectors
|
||||||
R_gt, t_gt = T_0to1.numpy()
|
t_err = angle_error_vec(t, t_gt, eps)
|
||||||
n = np.linalg.norm(t) * np.linalg.norm(t_gt)
|
t_err = torch.minimum(t_err, 180 - t_err) # handle E ambiguity
|
||||||
t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0)))
|
if t_gt.norm() < ignore_gt_t_thr: # pure rotation is challenging
|
||||||
t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity
|
|
||||||
if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging
|
|
||||||
t_err = 0
|
t_err = 0
|
||||||
|
|
||||||
# angle error between 2 rotation matrices
|
# angle error between 2 rotation matrices
|
||||||
cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
|
r_err = angle_error_mat(R, R_gt)
|
||||||
cos = np.clip(cos, -1.0, 1.0) # handle numercial errors
|
|
||||||
R_err = np.rad2deg(np.abs(np.arccos(cos)))
|
|
||||||
|
|
||||||
return t_err, R_err
|
return t_err, r_err
|
||||||
|
|
|
@ -53,6 +53,7 @@ def sample_homography_corners(
|
||||||
min_pts1 = create_center_patch(shape, (pwidth, pheight))
|
min_pts1 = create_center_patch(shape, (pwidth, pheight))
|
||||||
full = create_center_patch(shape)
|
full = create_center_patch(shape)
|
||||||
pts2 = create_center_patch(patch_shape)
|
pts2 = create_center_patch(patch_shape)
|
||||||
|
|
||||||
scale = min_pts1 - full
|
scale = min_pts1 - full
|
||||||
found_valid = False
|
found_valid = False
|
||||||
cnt = -1
|
cnt = -1
|
||||||
|
@ -68,7 +69,9 @@ def sample_homography_corners(
|
||||||
|
|
||||||
# Rotation
|
# Rotation
|
||||||
if n_angles > 0 and difficulty > 0:
|
if n_angles > 0 and difficulty > 0:
|
||||||
angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles)
|
#angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles)
|
||||||
|
# In our case the difficulty parameter should not affect the rotation
|
||||||
|
angles = np.linspace(-max_angle, max_angle, n_angles)
|
||||||
rng.shuffle(angles)
|
rng.shuffle(angles)
|
||||||
rng.shuffle(angles)
|
rng.shuffle(angles)
|
||||||
angles = np.concatenate([[0.0], angles], axis=0)
|
angles = np.concatenate([[0.0], angles], axis=0)
|
||||||
|
@ -164,7 +167,8 @@ def warp_points_torch(points, H, inverse=True):
|
||||||
The inverse is used to be coherent with tf.contrib.image.transform
|
The inverse is used to be coherent with tf.contrib.image.transform
|
||||||
Arguments:
|
Arguments:
|
||||||
points: batched list of N points, shape (B, N, 2).
|
points: batched list of N points, shape (B, N, 2).
|
||||||
homography: batched or not (shapes (B, 3, 3) and (3, 3) respectively).
|
H: batched or not (shapes (B, 3, 3) and (3, 3) respectively).
|
||||||
|
inverse: Whether to multiply the points by H or the inverse of H
|
||||||
Returns: a Tensor of shape (B, N, 2) containing the new coordinates of the warps.
|
Returns: a Tensor of shape (B, N, 2) containing the new coordinates of the warps.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -333,7 +337,7 @@ def sym_homography_error_all(kpts0, kpts1, H):
|
||||||
|
|
||||||
|
|
||||||
def homography_corner_error(T, T_gt, image_size):
|
def homography_corner_error(T, T_gt, image_size):
|
||||||
W, H = image_size[:, 0], image_size[:, 1]
|
W, H = image_size[..., 0], image_size[..., 1]
|
||||||
corners0 = torch.Tensor([[0, 0], [W, 0], [W, H], [0, H]]).float().to(T)
|
corners0 = torch.Tensor([[0, 0], [W, 0], [W, H], [0, H]]).float().to(T)
|
||||||
corners1_gt = from_homogeneous(to_homogeneous(corners0) @ T_gt.transpose(-1, -2))
|
corners1_gt = from_homogeneous(to_homogeneous(corners0) @ T_gt.transpose(-1, -2))
|
||||||
corners1 = from_homogeneous(to_homogeneous(corners0) @ T.transpose(-1, -2))
|
corners1 = from_homogeneous(to_homogeneous(corners0) @ T.transpose(-1, -2))
|
||||||
|
|
|
@ -23,6 +23,7 @@ def from_homogeneous(points, eps=0.0):
|
||||||
"""Remove the homogeneous dimension of N-dimensional points.
|
"""Remove the homogeneous dimension of N-dimensional points.
|
||||||
Args:
|
Args:
|
||||||
points: torch.Tensor or numpy.ndarray with size (..., N+1).
|
points: torch.Tensor or numpy.ndarray with size (..., N+1).
|
||||||
|
eps: Epsilon value to prevent zero division.
|
||||||
Returns:
|
Returns:
|
||||||
A torch.Tensor or numpy ndarray with size (..., N).
|
A torch.Tensor or numpy ndarray with size (..., N).
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -10,6 +10,7 @@ class DinoV2(BaseModel):
|
||||||
|
|
||||||
def _init(self, conf):
|
def _init(self, conf):
|
||||||
self.net = torch.hub.load("facebookresearch/dinov2", conf.weights)
|
self.net = torch.hub.load("facebookresearch/dinov2", conf.weights)
|
||||||
|
self.set_initialized()
|
||||||
|
|
||||||
def _forward(self, data):
|
def _forward(self, data):
|
||||||
img = data["image"]
|
img = data["image"]
|
||||||
|
|
|
@ -60,6 +60,8 @@ class BaseModel(nn.Module, metaclass=MetaModel):
|
||||||
required_data_keys = []
|
required_data_keys = []
|
||||||
strict_conf = False
|
strict_conf = False
|
||||||
|
|
||||||
|
are_weights_initialized = False
|
||||||
|
|
||||||
def __init__(self, conf):
|
def __init__(self, conf):
|
||||||
"""Perform some logic and call the _init method of the child model."""
|
"""Perform some logic and call the _init method of the child model."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -125,3 +127,31 @@ class BaseModel(nn.Module, metaclass=MetaModel):
|
||||||
def loss(self, pred, data):
|
def loss(self, pred, data):
|
||||||
"""To be implemented by the child class."""
|
"""To be implemented by the child class."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def load_state_dict(self, *args, **kwargs):
|
||||||
|
"""Load the state dict of the model, and set the model to initialized."""
|
||||||
|
ret = super().load_state_dict(*args, **kwargs)
|
||||||
|
self.set_initialized()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def is_initialized(self):
|
||||||
|
"""Recursively check if the model is initialized, i.e. weights are loaded"""
|
||||||
|
is_initialized = True # initialize to true and perform recursive and
|
||||||
|
for _, w in self.named_children():
|
||||||
|
if isinstance(w, BaseModel):
|
||||||
|
# if children is BaseModel, we perform recursive check
|
||||||
|
is_initialized = is_initialized and w.is_initialized()
|
||||||
|
else:
|
||||||
|
# else, we check if self is initialized or the children has no params
|
||||||
|
n_params = len(list(w.parameters()))
|
||||||
|
is_initialized = is_initialized and (
|
||||||
|
n_params == 0 or self.are_weights_initialized
|
||||||
|
)
|
||||||
|
return is_initialized
|
||||||
|
|
||||||
|
def set_initialized(self, to: bool = True):
|
||||||
|
"""Recursively set the initialization state."""
|
||||||
|
self.are_weights_initialized = to
|
||||||
|
for _, w in self.named_parameters():
|
||||||
|
if isinstance(w, BaseModel):
|
||||||
|
w.set_initialized(to)
|
||||||
|
|
|
@ -29,6 +29,15 @@ def pad_local_features(pred: dict, seq_l: int):
|
||||||
pred["scales"] = pad_to_length(pred["scales"], seq_l, -1, mode="zeros")
|
pred["scales"] = pad_to_length(pred["scales"], seq_l, -1, mode="zeros")
|
||||||
if "oris" in pred.keys():
|
if "oris" in pred.keys():
|
||||||
pred["oris"] = pad_to_length(pred["oris"], seq_l, -1, mode="zeros")
|
pred["oris"] = pad_to_length(pred["oris"], seq_l, -1, mode="zeros")
|
||||||
|
|
||||||
|
if "depth_keypoints" in pred.keys():
|
||||||
|
pred["depth_keypoints"] = pad_to_length(
|
||||||
|
pred["depth_keypoints"], seq_l, -1, mode="zeros"
|
||||||
|
)
|
||||||
|
if "valid_depth_keypoints" in pred.keys():
|
||||||
|
pred["valid_depth_keypoints"] = pad_to_length(
|
||||||
|
pred["valid_depth_keypoints"], seq_l, -1, mode="zeros"
|
||||||
|
)
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ class DISK(BaseModel):
|
||||||
|
|
||||||
def _init(self, conf):
|
def _init(self, conf):
|
||||||
self.model = kornia.feature.DISK.from_pretrained(conf.weights)
|
self.model = kornia.feature.DISK.from_pretrained(conf.weights)
|
||||||
|
self.set_initialized()
|
||||||
|
|
||||||
def _get_dense_outputs(self, images):
|
def _get_dense_outputs(self, images):
|
||||||
B = images.shape[0]
|
B = images.shape[0]
|
||||||
|
|
|
@ -21,6 +21,7 @@ class KeyNetAffNetHardNet(BaseModel):
|
||||||
upright=conf.upright,
|
upright=conf.upright,
|
||||||
scale_laf=conf.scale_laf,
|
scale_laf=conf.scale_laf,
|
||||||
)
|
)
|
||||||
|
self.set_initialized()
|
||||||
|
|
||||||
def _forward(self, data):
|
def _forward(self, data):
|
||||||
image = data["image"]
|
image = data["image"]
|
||||||
|
|
|
@ -1,238 +1,233 @@
|
||||||
|
import warnings
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pycolmap
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from kornia.color import rgb_to_grayscale
|
||||||
from scipy.spatial import KDTree
|
from packaging import version
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pycolmap
|
||||||
|
except ImportError:
|
||||||
|
pycolmap = None
|
||||||
|
|
||||||
from ..base_model import BaseModel
|
from ..base_model import BaseModel
|
||||||
from ..utils.misc import pad_to_length
|
from ..utils.misc import pad_to_length
|
||||||
|
|
||||||
EPS = 1e-6
|
|
||||||
|
def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None):
|
||||||
|
h, w = image_shape
|
||||||
|
ij = np.round(points - 0.5).astype(int).T[::-1]
|
||||||
|
|
||||||
|
# Remove duplicate points (identical coordinates).
|
||||||
|
# Pick highest scale or score
|
||||||
|
s = scales if scores is None else scores
|
||||||
|
buffer = np.zeros((h, w))
|
||||||
|
np.maximum.at(buffer, tuple(ij), s)
|
||||||
|
keep = np.where(buffer[tuple(ij)] == s)[0]
|
||||||
|
|
||||||
|
# Pick lowest angle (arbitrary).
|
||||||
|
ij = ij[:, keep]
|
||||||
|
buffer[:] = np.inf
|
||||||
|
o_abs = np.abs(angles[keep])
|
||||||
|
np.minimum.at(buffer, tuple(ij), o_abs)
|
||||||
|
mask = buffer[tuple(ij)] == o_abs
|
||||||
|
ij = ij[:, mask]
|
||||||
|
keep = keep[mask]
|
||||||
|
|
||||||
|
if nms_radius > 0:
|
||||||
|
# Apply NMS on the remaining points
|
||||||
|
buffer[:] = 0
|
||||||
|
buffer[tuple(ij)] = s[keep] # scores or scale
|
||||||
|
|
||||||
|
local_max = torch.nn.functional.max_pool2d(
|
||||||
|
torch.from_numpy(buffer).unsqueeze(0),
|
||||||
|
kernel_size=nms_radius * 2 + 1,
|
||||||
|
stride=1,
|
||||||
|
padding=nms_radius,
|
||||||
|
).squeeze(0)
|
||||||
|
is_local_max = buffer == local_max.numpy()
|
||||||
|
keep = keep[is_local_max[tuple(ij)]]
|
||||||
|
return keep
|
||||||
|
|
||||||
|
|
||||||
def sift_to_rootsift(x):
|
def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor:
|
||||||
x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS)
|
x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps)
|
||||||
x = np.sqrt(x.clip(min=EPS))
|
x.clip_(min=eps).sqrt_()
|
||||||
x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS)
|
return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps)
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
# from OpenGlue
|
def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray:
|
||||||
def nms_keypoints(kpts: np.ndarray, responses: np.ndarray, radius: float) -> np.ndarray:
|
|
||||||
# TODO: add approximate tree
|
|
||||||
kd_tree = KDTree(kpts)
|
|
||||||
|
|
||||||
sorted_idx = np.argsort(-responses)
|
|
||||||
kpts_to_keep_idx = []
|
|
||||||
removed_idx = set()
|
|
||||||
|
|
||||||
for idx in sorted_idx:
|
|
||||||
# skip point if it was already removed
|
|
||||||
if idx in removed_idx:
|
|
||||||
continue
|
|
||||||
|
|
||||||
kpts_to_keep_idx.append(idx)
|
|
||||||
point = kpts[idx]
|
|
||||||
neighbors = kd_tree.query_ball_point(point, r=radius)
|
|
||||||
# Variable `neighbors` contains the `point` itself
|
|
||||||
removed_idx.update(neighbors)
|
|
||||||
|
|
||||||
mask = np.zeros((kpts.shape[0],), dtype=bool)
|
|
||||||
mask[kpts_to_keep_idx] = True
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def detect_kpts_opencv(
|
|
||||||
features: cv2.Feature2D, image: np.ndarray, describe: bool = True
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
"""
|
||||||
Detect keypoints using OpenCV Detector.
|
Detect keypoints using OpenCV Detector.
|
||||||
Optionally, perform NMS and filter top-response keypoints.
|
|
||||||
Optionally, perform description.
|
Optionally, perform description.
|
||||||
Args:
|
Args:
|
||||||
features: OpenCV based keypoints detector and descriptor
|
features: OpenCV based keypoints detector and descriptor
|
||||||
image: Grayscale image of uint8 data type
|
image: Grayscale image of uint8 data type
|
||||||
describe: flag indicating whether to simultaneously compute descriptors
|
|
||||||
Returns:
|
Returns:
|
||||||
kpts: 1D array of detected cv2.KeyPoint
|
keypoints: 1D array of detected cv2.KeyPoint
|
||||||
|
scores: 1D array of responses
|
||||||
|
descriptors: 1D array of descriptors
|
||||||
"""
|
"""
|
||||||
if describe:
|
detections, descriptors = features.detectAndCompute(image, None)
|
||||||
kpts, descriptors = features.detectAndCompute(image, None)
|
points = np.array([k.pt for k in detections], dtype=np.float32)
|
||||||
else:
|
scores = np.array([k.response for k in detections], dtype=np.float32)
|
||||||
kpts = features.detect(image, None)
|
scales = np.array([k.size for k in detections], dtype=np.float32)
|
||||||
kpts = np.array(kpts)
|
angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32))
|
||||||
|
return points, scores, scales, angles, descriptors
|
||||||
responses = np.array([k.response for k in kpts], dtype=np.float32)
|
|
||||||
|
|
||||||
# select all
|
|
||||||
top_score_idx = ...
|
|
||||||
pts = np.array([k.pt for k in kpts], dtype=np.float32)
|
|
||||||
scales = np.array([k.size for k in kpts], dtype=np.float32)
|
|
||||||
angles = np.array([k.angle for k in kpts], dtype=np.float32)
|
|
||||||
spts = np.concatenate([pts, scales[..., None], angles[..., None]], -1)
|
|
||||||
|
|
||||||
if describe:
|
|
||||||
return spts[top_score_idx], responses[top_score_idx], descriptors[top_score_idx]
|
|
||||||
else:
|
|
||||||
return spts[top_score_idx], responses[top_score_idx]
|
|
||||||
|
|
||||||
|
|
||||||
class SIFT(BaseModel):
|
class SIFT(BaseModel):
|
||||||
default_conf = {
|
default_conf = {
|
||||||
"has_detector": True,
|
|
||||||
"has_descriptor": True,
|
|
||||||
"descriptor_dim": 128,
|
|
||||||
"pycolmap_options": {
|
|
||||||
"first_octave": 0,
|
|
||||||
"peak_threshold": 0.005,
|
|
||||||
"edge_threshold": 10,
|
|
||||||
},
|
|
||||||
"rootsift": True,
|
"rootsift": True,
|
||||||
"nms_radius": None,
|
"nms_radius": 0, # None to disable filtering entirely.
|
||||||
"max_num_keypoints": -1,
|
"max_num_keypoints": 4096,
|
||||||
"max_num_keypoints_val": None,
|
"backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
|
||||||
|
"detection_threshold": 0.0066667, # from COLMAP
|
||||||
|
"edge_threshold": 10,
|
||||||
|
"first_octave": -1, # only used by pycolmap, the default of COLMAP
|
||||||
|
"num_octaves": 4,
|
||||||
"force_num_keypoints": False,
|
"force_num_keypoints": False,
|
||||||
"randomize_keypoints_training": False,
|
|
||||||
"detector": "pycolmap", # ['pycolmap', 'pycolmap_cpu', 'pycolmap_cuda', 'cv2']
|
|
||||||
"detection_threshold": None,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
required_data_keys = ["image"]
|
required_data_keys = ["image"]
|
||||||
|
|
||||||
def _init(self, conf):
|
def _init(self, conf):
|
||||||
self.sift = None # lazy loading
|
backend = self.conf.backend
|
||||||
|
if backend.startswith("pycolmap"):
|
||||||
@torch.no_grad()
|
if pycolmap is None:
|
||||||
def extract_features(self, image):
|
raise ImportError(
|
||||||
image_np = image.cpu().numpy()[0]
|
"Cannot find module pycolmap: install it with pip"
|
||||||
assert image.shape[0] == 1
|
"or use backend=opencv."
|
||||||
assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS
|
)
|
||||||
|
options = {
|
||||||
detector = str(self.conf.detector)
|
"peak_threshold": self.conf.detection_threshold,
|
||||||
|
"edge_threshold": self.conf.edge_threshold,
|
||||||
if self.sift is None and detector.startswith("pycolmap"):
|
"first_octave": self.conf.first_octave,
|
||||||
options = OmegaConf.to_container(self.conf.pycolmap_options)
|
"num_octaves": self.conf.num_octaves,
|
||||||
|
"normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy.
|
||||||
|
}
|
||||||
device = (
|
device = (
|
||||||
"auto" if detector == "pycolmap" else detector.replace("pycolmap_", "")
|
"auto" if backend == "pycolmap" else backend.replace("pycolmap_", "")
|
||||||
)
|
)
|
||||||
if self.conf.rootsift == "rootsift":
|
if (
|
||||||
options["normalization"] = pycolmap.Normalization.L1_ROOT
|
backend == "pycolmap_cpu" or not pycolmap.has_cuda
|
||||||
|
) and pycolmap.__version__ < "0.5.0":
|
||||||
|
warnings.warn(
|
||||||
|
"The pycolmap CPU SIFT is buggy in version < 0.5.0, "
|
||||||
|
"consider upgrading pycolmap or use the CUDA version.",
|
||||||
|
stacklevel=1,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
options["normalization"] = pycolmap.Normalization.L2
|
options["max_num_features"] = self.conf.max_num_keypoints
|
||||||
if self.conf.detection_threshold is not None:
|
|
||||||
options["peak_threshold"] = self.conf.detection_threshold
|
|
||||||
options["max_num_features"] = self.conf.max_num_keypoints
|
|
||||||
self.sift = pycolmap.Sift(options=options, device=device)
|
self.sift = pycolmap.Sift(options=options, device=device)
|
||||||
elif self.sift is None and self.conf.detector == "cv2":
|
elif backend == "opencv":
|
||||||
self.sift = cv2.SIFT_create(contrastThreshold=self.conf.detection_threshold)
|
self.sift = cv2.SIFT_create(
|
||||||
|
contrastThreshold=self.conf.detection_threshold,
|
||||||
|
nfeatures=self.conf.max_num_keypoints,
|
||||||
|
edgeThreshold=self.conf.edge_threshold,
|
||||||
|
nOctaveLayers=self.conf.num_octaves,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"}
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}."
|
||||||
|
)
|
||||||
|
|
||||||
if detector.startswith("pycolmap"):
|
def extract_single_image(self, image: torch.Tensor):
|
||||||
keypoints, scores, descriptors = self.sift.extract(image_np)
|
image_np = image.cpu().numpy().squeeze(0)
|
||||||
elif detector == "cv2":
|
|
||||||
|
if self.conf.backend.startswith("pycolmap"):
|
||||||
|
if version.parse(pycolmap.__version__) >= version.parse("0.5.0"):
|
||||||
|
detections, descriptors = self.sift.extract(image_np)
|
||||||
|
scores = None # Scores are not exposed by COLMAP anymore.
|
||||||
|
else:
|
||||||
|
detections, scores, descriptors = self.sift.extract(image_np)
|
||||||
|
keypoints = detections[:, :2] # Keep only (x, y).
|
||||||
|
scales, angles = detections[:, -2:].T
|
||||||
|
if scores is not None and (
|
||||||
|
self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda
|
||||||
|
):
|
||||||
|
# Set the scores as a combination of abs. response and scale.
|
||||||
|
scores = np.abs(scores) * scales
|
||||||
|
elif self.conf.backend == "opencv":
|
||||||
# TODO: Check if opencv keypoints are already in corner convention
|
# TODO: Check if opencv keypoints are already in corner convention
|
||||||
keypoints, scores, descriptors = detect_kpts_opencv(
|
keypoints, scores, scales, angles, descriptors = run_opencv_sift(
|
||||||
self.sift, (image_np * 255.0).astype(np.uint8)
|
self.sift, (image_np * 255.0).astype(np.uint8)
|
||||||
)
|
)
|
||||||
|
pred = {
|
||||||
|
"keypoints": keypoints,
|
||||||
|
"scales": scales,
|
||||||
|
"oris": angles,
|
||||||
|
"descriptors": descriptors,
|
||||||
|
}
|
||||||
|
if scores is not None:
|
||||||
|
pred["keypoint_scores"] = scores
|
||||||
|
|
||||||
|
# sometimes pycolmap returns points outside the image. We remove them
|
||||||
|
if self.conf.backend.startswith("pycolmap"):
|
||||||
|
is_inside = (
|
||||||
|
pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]])
|
||||||
|
).all(-1)
|
||||||
|
pred = {k: v[is_inside] for k, v in pred.items()}
|
||||||
|
|
||||||
if self.conf.nms_radius is not None:
|
if self.conf.nms_radius is not None:
|
||||||
mask = nms_keypoints(keypoints[:, :2], scores, self.conf.nms_radius)
|
keep = filter_dog_point(
|
||||||
keypoints = keypoints[mask]
|
pred["keypoints"],
|
||||||
scores = scores[mask]
|
pred["scales"],
|
||||||
descriptors = descriptors[mask]
|
pred["oris"],
|
||||||
|
image_np.shape,
|
||||||
|
self.conf.nms_radius,
|
||||||
|
pred["keypoint_scores"],
|
||||||
|
)
|
||||||
|
pred = {k: v[keep] for k, v in pred.items()}
|
||||||
|
|
||||||
scales = keypoints[:, 2]
|
pred = {k: torch.from_numpy(v) for k, v in pred.items()}
|
||||||
oris = np.rad2deg(keypoints[:, 3])
|
if scores is not None:
|
||||||
|
# Keep the k keypoints with highest score
|
||||||
if self.conf.has_descriptor:
|
num_points = self.conf.max_num_keypoints
|
||||||
# We still renormalize because COLMAP does not normalize well,
|
if num_points is not None and len(pred["keypoints"]) > num_points:
|
||||||
# maybe due to numerical errors
|
indices = torch.topk(pred["keypoint_scores"], num_points).indices
|
||||||
if self.conf.rootsift:
|
pred = {k: v[indices] for k, v in pred.items()}
|
||||||
descriptors = sift_to_rootsift(descriptors)
|
|
||||||
descriptors = torch.from_numpy(descriptors)
|
|
||||||
keypoints = torch.from_numpy(keypoints[:, :2]) # keep only x, y
|
|
||||||
scales = torch.from_numpy(scales)
|
|
||||||
oris = torch.from_numpy(oris)
|
|
||||||
scores = torch.from_numpy(scores)
|
|
||||||
|
|
||||||
# Keep the k keypoints with highest score
|
|
||||||
max_kps = self.conf.max_num_keypoints
|
|
||||||
|
|
||||||
# for val we allow different
|
|
||||||
if not self.training and self.conf.max_num_keypoints_val is not None:
|
|
||||||
max_kps = self.conf.max_num_keypoints_val
|
|
||||||
|
|
||||||
if max_kps is not None and max_kps > 0:
|
|
||||||
if self.conf.randomize_keypoints_training and self.training:
|
|
||||||
# instead of selecting top-k, sample k by score weights
|
|
||||||
raise NotImplementedError
|
|
||||||
elif max_kps < scores.shape[0]:
|
|
||||||
# TODO: check that the scores from PyCOLMAP are 100% correct,
|
|
||||||
# follow https://github.com/mihaidusmanu/pycolmap/issues/8
|
|
||||||
indices = torch.topk(scores, max_kps).indices
|
|
||||||
keypoints = keypoints[indices]
|
|
||||||
scales = scales[indices]
|
|
||||||
oris = oris[indices]
|
|
||||||
scores = scores[indices]
|
|
||||||
if self.conf.has_descriptor:
|
|
||||||
descriptors = descriptors[indices]
|
|
||||||
|
|
||||||
if self.conf.force_num_keypoints:
|
if self.conf.force_num_keypoints:
|
||||||
keypoints = pad_to_length(
|
num_points = min(self.conf.max_num_keypoints, len(pred["keypoints"]))
|
||||||
keypoints,
|
pred["keypoints"] = pad_to_length(
|
||||||
max_kps,
|
pred["keypoints"],
|
||||||
|
num_points,
|
||||||
-2,
|
-2,
|
||||||
mode="random_c",
|
mode="random_c",
|
||||||
bounds=(0, min(image.shape[1:])),
|
bounds=(0, min(image.shape[1:])),
|
||||||
)
|
)
|
||||||
scores = pad_to_length(scores, max_kps, -1, mode="zeros")
|
pred["scales"] = pad_to_length(pred["scales"], num_points, -1, mode="zeros")
|
||||||
scales = pad_to_length(scales, max_kps, -1, mode="zeros")
|
pred["oris"] = pad_to_length(pred["oris"], num_points, -1, mode="zeros")
|
||||||
oris = pad_to_length(oris, max_kps, -1, mode="zeros")
|
pred["descriptors"] = pad_to_length(
|
||||||
if self.conf.has_descriptor:
|
pred["descriptors"], num_points, -2, mode="zeros"
|
||||||
descriptors = pad_to_length(descriptors, max_kps, -2, mode="zeros")
|
)
|
||||||
|
if pred["keypoint_scores"] is not None:
|
||||||
pred = {
|
scores = pad_to_length(
|
||||||
"keypoints": keypoints,
|
pred["keypoint_scores"], num_points, -1, mode="zeros"
|
||||||
"scales": scales,
|
)
|
||||||
"oris": oris,
|
|
||||||
"keypoint_scores": scores,
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.conf.has_descriptor:
|
|
||||||
pred["descriptors"] = descriptors
|
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
@torch.no_grad()
|
def _forward(self, data: dict) -> dict:
|
||||||
def _forward(self, data):
|
|
||||||
pred = {
|
|
||||||
"keypoints": [],
|
|
||||||
"scales": [],
|
|
||||||
"oris": [],
|
|
||||||
"keypoint_scores": [],
|
|
||||||
"descriptors": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
image = data["image"]
|
image = data["image"]
|
||||||
if image.shape[1] == 3: # RGB
|
if image.shape[1] == 3:
|
||||||
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
|
image = rgb_to_grayscale(image)
|
||||||
image = (image * scale).sum(1, keepdim=True).cpu()
|
device = image.device
|
||||||
|
image = image.cpu()
|
||||||
for k in range(image.shape[0]):
|
pred = []
|
||||||
|
for k in range(len(image)):
|
||||||
img = image[k]
|
img = image[k]
|
||||||
if "image_size" in data.keys():
|
if "image_size" in data.keys():
|
||||||
# avoid extracting points in padded areas
|
# avoid extracting points in padded areas
|
||||||
w, h = data["image_size"][k]
|
w, h = data["image_size"][k]
|
||||||
img = img[:, :h, :w]
|
img = img[:, :h, :w]
|
||||||
p = self.extract_features(img)
|
p = self.extract_single_image(img)
|
||||||
for k, v in p.items():
|
pred.append(p)
|
||||||
pred[k].append(v)
|
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
|
||||||
|
if self.conf.rootsift:
|
||||||
if (image.shape[0] == 1) or self.conf.force_num_keypoints:
|
pred["descriptors"] = sift_to_rootsift(pred["descriptors"])
|
||||||
pred = {k: torch.stack(pred[k], 0) for k in pred.keys()}
|
|
||||||
|
|
||||||
pred = {k: pred[k].to(device=data["image"].device) for k in pred.keys()}
|
|
||||||
|
|
||||||
pred["oris"] = torch.deg2rad(pred["oris"])
|
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
def loss(self, pred, data):
|
def loss(self, pred, data):
|
||||||
|
|
|
@ -19,12 +19,13 @@ class KorniaSIFT(BaseModel):
|
||||||
self.sift = kornia.feature.SIFTFeature(
|
self.sift = kornia.feature.SIFTFeature(
|
||||||
num_features=self.conf.max_num_keypoints, rootsift=self.conf.rootsift
|
num_features=self.conf.max_num_keypoints, rootsift=self.conf.rootsift
|
||||||
)
|
)
|
||||||
|
self.set_initialized()
|
||||||
|
|
||||||
def _forward(self, data):
|
def _forward(self, data):
|
||||||
lafs, scores, descriptors = self.sift(data["image"])
|
lafs, scores, descriptors = self.sift(data["image"])
|
||||||
keypoints = kornia.feature.get_laf_center(lafs)
|
keypoints = kornia.feature.get_laf_center(lafs)
|
||||||
scales = kornia.feature.get_laf_scale(lafs)
|
scales = kornia.feature.get_laf_scale(lafs).squeeze(-1).squeeze(-1)
|
||||||
oris = kornia.feature.get_laf_orientation(lafs)
|
oris = kornia.feature.get_laf_orientation(lafs).squeeze(-1)
|
||||||
pred = {
|
pred = {
|
||||||
"keypoints": keypoints, # @TODO: confirm keypoints are in corner convention
|
"keypoints": keypoints, # @TODO: confirm keypoints are in corner convention
|
||||||
"scales": scales,
|
"scales": scales,
|
||||||
|
|
|
@ -34,13 +34,14 @@ class DeepLSD(BaseModel):
|
||||||
ckpt = torch.load(ckpt, map_location="cpu")
|
ckpt = torch.load(ckpt, map_location="cpu")
|
||||||
self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval()
|
self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval()
|
||||||
self.net.load_state_dict(ckpt["model"])
|
self.net.load_state_dict(ckpt["model"])
|
||||||
|
self.set_initialized()
|
||||||
|
|
||||||
def download_model(self, path):
|
def download_model(self, path):
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
if not path.parent.is_dir():
|
if not path.parent.is_dir():
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
link = "https://www.polybox.ethz.ch/index.php/s/XVb30sUyuJttFys/download"
|
link = "https://cvg-data.inf.ethz.ch/DeepLSD/deeplsd_md.tar"
|
||||||
cmd = ["wget", link, "-O", path]
|
cmd = ["wget", link, "-O", path]
|
||||||
print("Downloading DeepLSD model...")
|
print("Downloading DeepLSD model...")
|
||||||
subprocess.run(cmd, check=True)
|
subprocess.run(cmd, check=True)
|
||||||
|
|
|
@ -119,7 +119,7 @@ class GlueStick(BaseModel):
|
||||||
"Loading GlueStick model from " f'"{self.url.format(conf.version)}"'
|
"Loading GlueStick model from " f'"{self.url.format(conf.version)}"'
|
||||||
)
|
)
|
||||||
state_dict = torch.hub.load_state_dict_from_url(
|
state_dict = torch.hub.load_state_dict_from_url(
|
||||||
self.url.format(conf.version), file_name=fname
|
self.url.format(conf.version), file_name=fname, map_location="cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
if "model" in state_dict:
|
if "model" in state_dict:
|
||||||
|
@ -131,7 +131,7 @@ class GlueStick(BaseModel):
|
||||||
state_dict = {
|
state_dict = {
|
||||||
k.replace("module.", ""): v for k, v in state_dict.items()
|
k.replace("module.", ""): v for k, v in state_dict.items()
|
||||||
}
|
}
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
def _forward(self, data):
|
def _forward(self, data):
|
||||||
device = data["keypoints0"].device
|
device = data["keypoints0"].device
|
||||||
|
@ -200,8 +200,6 @@ class GlueStick(BaseModel):
|
||||||
kpts0 = normalize_keypoints(kpts0, image_size0)
|
kpts0 = normalize_keypoints(kpts0, image_size0)
|
||||||
kpts1 = normalize_keypoints(kpts1, image_size1)
|
kpts1 = normalize_keypoints(kpts1, image_size1)
|
||||||
|
|
||||||
assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
|
|
||||||
assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)
|
|
||||||
desc0 = desc0 + self.kenc(kpts0, data["keypoint_scores0"])
|
desc0 = desc0 + self.kenc(kpts0, data["keypoint_scores0"])
|
||||||
desc1 = desc1 + self.kenc(kpts1, data["keypoint_scores1"])
|
desc1 = desc1 + self.kenc(kpts1, data["keypoint_scores1"])
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ class LoFTRModule(BaseModel):
|
||||||
|
|
||||||
def _init(self, conf):
|
def _init(self, conf):
|
||||||
self.net = kornia.feature.LoFTR(pretrained="outdoor")
|
self.net = kornia.feature.LoFTR(pretrained="outdoor")
|
||||||
|
self.set_initialized()
|
||||||
|
|
||||||
def _forward(self, data):
|
def _forward(self, data):
|
||||||
image0 = data["view0"]["image"]
|
image0 = data["view0"]["image"]
|
||||||
|
|
|
@ -17,17 +17,18 @@ class LightGlue(BaseModel):
|
||||||
|
|
||||||
def _init(self, conf):
|
def _init(self, conf):
|
||||||
dconf = OmegaConf.to_container(conf)
|
dconf = OmegaConf.to_container(conf)
|
||||||
self.net = LightGlue_(dconf.pop("features"), **dconf).cuda()
|
self.net = LightGlue_(dconf.pop("features"), **dconf)
|
||||||
# self.net.compile()
|
self.set_initialized()
|
||||||
|
|
||||||
def _forward(self, data):
|
def _forward(self, data):
|
||||||
|
required_keys = ["keypoints", "descriptors", "scales", "oris"]
|
||||||
view0 = {
|
view0 = {
|
||||||
**{k: data[k + "0"] for k in ["keypoints", "descriptors"]},
|
|
||||||
**data["view0"],
|
**data["view0"],
|
||||||
|
**{k: data[k + "0"] for k in required_keys if (k + "0") in data},
|
||||||
}
|
}
|
||||||
view1 = {
|
view1 = {
|
||||||
**{k: data[k + "1"] for k in ["keypoints", "descriptors"]},
|
|
||||||
**data["view1"],
|
**data["view1"],
|
||||||
|
**{k: data[k + "1"] for k in required_keys if (k + "1") in data},
|
||||||
}
|
}
|
||||||
return self.net({"image0": view0, "image1": view1})
|
return self.net({"image0": view0, "image1": view1})
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ from homography_est import (
|
||||||
ransac_point_line_homography,
|
ransac_point_line_homography,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ...utils.tensor import batch_to_numpy
|
||||||
from ..base_estimator import BaseEstimator
|
from ..base_estimator import BaseEstimator
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,19 +51,20 @@ class PointLineHomographyEstimator(BaseEstimator):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _forward(self, data):
|
def _forward(self, data):
|
||||||
m_features = {
|
|
||||||
"kpts0": data["m_kpts1"].numpy() if "m_kpts1" in data else None,
|
|
||||||
"kpts1": data["m_kpts0"].numpy() if "m_kpts0" in data else None,
|
|
||||||
"lines0": data["m_lines1"].numpy() if "m_lines1" in data else None,
|
|
||||||
"lines1": data["m_lines0"].numpy() if "m_lines0" in data else None,
|
|
||||||
}
|
|
||||||
feat = data["m_kpts0"] if "m_kpts0" in data else data["m_lines0"]
|
feat = data["m_kpts0"] if "m_kpts0" in data else data["m_lines0"]
|
||||||
|
data = batch_to_numpy(data)
|
||||||
|
m_features = {
|
||||||
|
"kpts0": data["m_kpts1"] if "m_kpts1" in data else None,
|
||||||
|
"kpts1": data["m_kpts0"] if "m_kpts0" in data else None,
|
||||||
|
"lines0": data["m_lines1"] if "m_lines1" in data else None,
|
||||||
|
"lines1": data["m_lines0"] if "m_lines0" in data else None,
|
||||||
|
}
|
||||||
M = H_estimation_hybrid(**m_features, tol_px=self.conf.ransac_th)
|
M = H_estimation_hybrid(**m_features, tol_px=self.conf.ransac_th)
|
||||||
success = M is not None
|
success = M is not None
|
||||||
if not success:
|
if not success:
|
||||||
M = torch.eye(3, device=feat.device, dtype=feat.dtype)
|
M = torch.eye(3, device=feat.device, dtype=feat.dtype)
|
||||||
else:
|
else:
|
||||||
M = torch.tensor(M).to(feat)
|
M = torch.from_numpy(M).to(feat)
|
||||||
|
|
||||||
estimation = {
|
estimation = {
|
||||||
"success": success,
|
"success": success,
|
||||||
|
|
|
@ -16,8 +16,8 @@ class PoseLibHomographyEstimator(BaseEstimator):
|
||||||
def _forward(self, data):
|
def _forward(self, data):
|
||||||
pts0, pts1 = data["m_kpts0"], data["m_kpts1"]
|
pts0, pts1 = data["m_kpts0"], data["m_kpts1"]
|
||||||
M, info = poselib.estimate_homography(
|
M, info = poselib.estimate_homography(
|
||||||
pts0.numpy(),
|
pts0.detach().cpu().numpy(),
|
||||||
pts1.numpy(),
|
pts1.detach().cpu().numpy(),
|
||||||
{
|
{
|
||||||
"max_reproj_error": self.conf.ransac_th,
|
"max_reproj_error": self.conf.ransac_th,
|
||||||
**OmegaConf.to_container(self.conf.options),
|
**OmegaConf.to_container(self.conf.options),
|
||||||
|
|
|
@ -37,14 +37,13 @@ configs = {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"cv2-sift": {
|
"cv2-sift": {
|
||||||
"name": f"r{resize}_cv2-SIFT-k{n_kpts}",
|
"name": f"r{resize}_opencv-SIFT-k{n_kpts}",
|
||||||
"keys": ["keypoints", "descriptors", "keypoint_scores", "oris", "scales"],
|
"keys": ["keypoints", "descriptors", "keypoint_scores", "oris", "scales"],
|
||||||
"gray": True,
|
"gray": True,
|
||||||
"conf": {
|
"conf": {
|
||||||
"name": "extractors.sift",
|
"name": "extractors.sift",
|
||||||
"max_num_keypoints": 4096,
|
"max_num_keypoints": 4096,
|
||||||
"detection_threshold": 0.001,
|
"backend": "opencv",
|
||||||
"detector": "cv2",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"pycolmap-sift": {
|
"pycolmap-sift": {
|
||||||
|
@ -54,11 +53,7 @@ configs = {
|
||||||
"conf": {
|
"conf": {
|
||||||
"name": "extractors.sift",
|
"name": "extractors.sift",
|
||||||
"max_num_keypoints": n_kpts,
|
"max_num_keypoints": n_kpts,
|
||||||
"detection_threshold": 0.0001,
|
"backend": "pycolmap",
|
||||||
"detector": "pycolmap",
|
|
||||||
"pycolmap_options": {
|
|
||||||
"first_octave": -1,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"pycolmap-sift-gpu": {
|
"pycolmap-sift-gpu": {
|
||||||
|
@ -68,11 +63,7 @@ configs = {
|
||||||
"conf": {
|
"conf": {
|
||||||
"name": "extractors.sift",
|
"name": "extractors.sift",
|
||||||
"max_num_keypoints": n_kpts,
|
"max_num_keypoints": n_kpts,
|
||||||
"detection_threshold": 0.0066666,
|
"backend": "pycolmap_cuda",
|
||||||
"detector": "pycolmap_cuda",
|
|
||||||
"pycolmap_options": {
|
|
||||||
"first_octave": -1,
|
|
||||||
},
|
|
||||||
"nms_radius": 3,
|
"nms_radius": 3,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -133,15 +124,18 @@ def run_export(feature_file, scene, args):
|
||||||
|
|
||||||
conf = OmegaConf.create(conf)
|
conf = OmegaConf.create(conf)
|
||||||
|
|
||||||
keys = configs[args.method]["keys"] + ["depth_keypoints", "valid_depth_keypoints"]
|
keys = configs[args.method]["keys"]
|
||||||
dataset = get_dataset(conf.data.name)(conf.data)
|
dataset = get_dataset(conf.data.name)(conf.data)
|
||||||
loader = dataset.get_data_loader(conf.split or "test")
|
loader = dataset.get_data_loader(conf.split or "test")
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model = get_model(conf.model.name)(conf.model).eval().to(device)
|
model = get_model(conf.model.name)(conf.model).eval().to(device)
|
||||||
|
|
||||||
callback_fn = None
|
if args.export_sparse_depth:
|
||||||
# callback_fn=get_kp_depth # use this to store the depth of each keypoint
|
callback_fn = get_kp_depth # use this to store the depth of each keypoint
|
||||||
|
keys = keys + ["depth_keypoints", "valid_depth_keypoints"]
|
||||||
|
else:
|
||||||
|
callback_fn = None
|
||||||
export_predictions(
|
export_predictions(
|
||||||
loader, model, feature_file, as_half=True, keys=keys, callback_fn=callback_fn
|
loader, model, feature_file, as_half=True, keys=keys, callback_fn=callback_fn
|
||||||
)
|
)
|
||||||
|
@ -153,6 +147,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--method", type=str, default="sp")
|
parser.add_argument("--method", type=str, default="sp")
|
||||||
parser.add_argument("--scenes", type=str, default=None)
|
parser.add_argument("--scenes", type=str, default=None)
|
||||||
parser.add_argument("--num_workers", type=int, default=0)
|
parser.add_argument("--num_workers", type=int, default=0)
|
||||||
|
parser.add_argument("--export_sparse_depth", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
export_name = configs[args.method]["name"]
|
export_name = configs[args.method]["name"]
|
||||||
|
|
|
@ -12,6 +12,7 @@ import signal
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydoc import locate
|
from pydoc import locate
|
||||||
|
import gc
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -48,11 +49,12 @@ default_train_conf = {
|
||||||
"optimizer_options": {}, # optional arguments passed to the optimizer
|
"optimizer_options": {}, # optional arguments passed to the optimizer
|
||||||
"lr": 0.001, # learning rate
|
"lr": 0.001, # learning rate
|
||||||
"lr_schedule": {
|
"lr_schedule": {
|
||||||
"type": None,
|
"type": None, # string in {factor, exp, member of torch.optim.lr_scheduler}
|
||||||
"start": 0,
|
"start": 0,
|
||||||
"exp_div_10": 0,
|
"exp_div_10": 0,
|
||||||
"on_epoch": False,
|
"on_epoch": False,
|
||||||
"factor": 1.0,
|
"factor": 1.0,
|
||||||
|
"options": {}, # add lr_scheduler arguments here
|
||||||
},
|
},
|
||||||
"lr_scaling": [(100, ["dampingnet.const"])],
|
"lr_scaling": [(100, ["dampingnet.const"])],
|
||||||
"eval_every_iter": 1000, # interval for evaluation on the validation set
|
"eval_every_iter": 1000, # interval for evaluation on the validation set
|
||||||
|
@ -88,6 +90,7 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
|
||||||
for i, data in enumerate(
|
for i, data in enumerate(
|
||||||
tqdm(loader, desc="Evaluation", ascii=True, disable=not pbar)
|
tqdm(loader, desc="Evaluation", ascii=True, disable=not pbar)
|
||||||
):
|
):
|
||||||
|
|
||||||
data = batch_to_device(data, device, non_blocking=True)
|
data = batch_to_device(data, device, non_blocking=True)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pred = model(data)
|
pred = model(data)
|
||||||
|
@ -101,7 +104,6 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
|
||||||
pred[v["predictions"]],
|
pred[v["predictions"]],
|
||||||
mask=pred[v["mask"]] if "mask" in v.keys() else None,
|
mask=pred[v["mask"]] if "mask" in v.keys() else None,
|
||||||
)
|
)
|
||||||
del pred, data
|
|
||||||
numbers = {**metrics, **{"loss/" + k: v for k, v in losses.items()}}
|
numbers = {**metrics, **{"loss/" + k: v for k, v in losses.items()}}
|
||||||
for k, v in numbers.items():
|
for k, v in numbers.items():
|
||||||
if k not in results:
|
if k not in results:
|
||||||
|
@ -117,7 +119,9 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
|
||||||
if k in conf.recall_metrics.keys():
|
if k in conf.recall_metrics.keys():
|
||||||
q = conf.recall_metrics[k]
|
q = conf.recall_metrics[k]
|
||||||
results[k + f"_recall{int(q)}"].update(v)
|
results[k + f"_recall{int(q)}"].update(v)
|
||||||
del numbers
|
|
||||||
|
del pred, data, losses, metrics
|
||||||
|
gc.collect()
|
||||||
results = {k: results[k].compute() for k in results}
|
results = {k: results[k].compute() for k in results}
|
||||||
return results, {k: v.compute() for k, v in pr_metrics.items()}, figures
|
return results, {k: v.compute() for k, v in pr_metrics.items()}, figures
|
||||||
|
|
||||||
|
@ -141,6 +145,26 @@ def filter_parameters(params, regexp):
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def get_lr_scheduler(optimizer, conf):
|
||||||
|
"""Get lr scheduler specified by conf.train.lr_schedule."""
|
||||||
|
if conf.type not in ["factor", "exp", None]:
|
||||||
|
return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, **conf.options)
|
||||||
|
|
||||||
|
# backward compatibility
|
||||||
|
def lr_fn(it): # noqa: E306
|
||||||
|
if conf.type is None:
|
||||||
|
return 1
|
||||||
|
if conf.type == "factor":
|
||||||
|
return 1.0 if it < conf.start else conf.factor
|
||||||
|
if conf.type == "exp":
|
||||||
|
gam = 10 ** (-1 / conf.exp_div_10)
|
||||||
|
return 1.0 if it < conf.start else gam
|
||||||
|
else:
|
||||||
|
raise ValueError(conf.type)
|
||||||
|
|
||||||
|
return torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn)
|
||||||
|
|
||||||
|
|
||||||
def pack_lr_parameters(params, base_lr, lr_scaling):
|
def pack_lr_parameters(params, base_lr, lr_scaling):
|
||||||
"""Pack each group of parameters with the respective scaled learning rate."""
|
"""Pack each group of parameters with the respective scaled learning rate."""
|
||||||
filters, scales = tuple(zip(*[(n, s) for s, names in lr_scaling for n in names]))
|
filters, scales = tuple(zip(*[(n, s) for s, names in lr_scaling for n in names]))
|
||||||
|
@ -236,6 +260,7 @@ def training(rank, conf, output_dir, args):
|
||||||
|
|
||||||
dataset = get_dataset(data_conf.name)(data_conf)
|
dataset = get_dataset(data_conf.name)(data_conf)
|
||||||
|
|
||||||
|
|
||||||
# Optionally load a different validation dataset than the training one
|
# Optionally load a different validation dataset than the training one
|
||||||
val_data_conf = conf.get("data_val", None)
|
val_data_conf = conf.get("data_val", None)
|
||||||
if val_data_conf is None:
|
if val_data_conf is None:
|
||||||
|
@ -310,22 +335,7 @@ def training(rank, conf, output_dir, args):
|
||||||
|
|
||||||
results = None # fix bug with it saving
|
results = None # fix bug with it saving
|
||||||
|
|
||||||
def lr_fn(it): # noqa: E306
|
lr_scheduler = get_lr_scheduler(optimizer=optimizer, conf=conf.train.lr_schedule)
|
||||||
if conf.train.lr_schedule.type is None:
|
|
||||||
return 1
|
|
||||||
if conf.train.lr_schedule.type == "factor":
|
|
||||||
return (
|
|
||||||
1.0
|
|
||||||
if it < conf.train.lr_schedule.start
|
|
||||||
else conf.train.lr_schedule.factor
|
|
||||||
)
|
|
||||||
if conf.train.lr_schedule.type == "exp":
|
|
||||||
gam = 10 ** (-1 / conf.train.lr_schedule.exp_div_10)
|
|
||||||
return 1.0 if it < conf.train.lr_schedule.start else gam
|
|
||||||
else:
|
|
||||||
raise ValueError(conf.train.lr_schedule.type)
|
|
||||||
|
|
||||||
lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn)
|
|
||||||
if args.restore:
|
if args.restore:
|
||||||
optimizer.load_state_dict(init_cp["optimizer"])
|
optimizer.load_state_dict(init_cp["optimizer"])
|
||||||
if "lr_scheduler" in init_cp:
|
if "lr_scheduler" in init_cp:
|
||||||
|
@ -402,7 +412,9 @@ def training(rank, conf, output_dir, args):
|
||||||
getattr(loader.dataset, conf.train.dataset_callback_fn)(
|
getattr(loader.dataset, conf.train.dataset_callback_fn)(
|
||||||
conf.train.seed + epoch
|
conf.train.seed + epoch
|
||||||
)
|
)
|
||||||
|
|
||||||
for it, data in enumerate(train_loader):
|
for it, data in enumerate(train_loader):
|
||||||
|
|
||||||
tot_it = (len(train_loader) * epoch + it) * (
|
tot_it = (len(train_loader) * epoch + it) * (
|
||||||
args.n_gpus if args.distributed else 1
|
args.n_gpus if args.distributed else 1
|
||||||
)
|
)
|
||||||
|
@ -421,10 +433,17 @@ def training(rank, conf, output_dir, args):
|
||||||
loss = torch.mean(losses["total"])
|
loss = torch.mean(losses["total"])
|
||||||
if torch.isnan(loss).any():
|
if torch.isnan(loss).any():
|
||||||
print(f"Detected NAN, skipping iteration {it}")
|
print(f"Detected NAN, skipping iteration {it}")
|
||||||
|
print("name", data["name"])
|
||||||
|
print("loss", loss)
|
||||||
|
print("losses", losses)
|
||||||
|
print("data", data)
|
||||||
|
print("pred", pred)
|
||||||
|
raise RuntimeError("Detected NAN in training.")
|
||||||
del pred, data, loss, losses
|
del pred, data, loss, losses
|
||||||
continue
|
continue
|
||||||
|
|
||||||
do_backward = loss.requires_grad
|
do_backward = loss.requires_grad
|
||||||
|
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
do_backward = torch.tensor(do_backward).float().to(device)
|
do_backward = torch.tensor(do_backward).float().to(device)
|
||||||
torch.distributed.all_reduce(
|
torch.distributed.all_reduce(
|
||||||
|
@ -463,7 +482,6 @@ def training(rank, conf, output_dir, args):
|
||||||
else:
|
else:
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.warning(f"Skip iteration {it} due to detach.")
|
logger.warning(f"Skip iteration {it} due to detach.")
|
||||||
|
|
||||||
if args.profile:
|
if args.profile:
|
||||||
prof.step()
|
prof.step()
|
||||||
|
|
||||||
|
@ -502,8 +520,11 @@ def training(rank, conf, output_dir, args):
|
||||||
norm = torch.norm(param.grad.detach(), 2)
|
norm = torch.norm(param.grad.detach(), 2)
|
||||||
grad_txt += f"{name} {norm.item():.3f} \n"
|
grad_txt += f"{name} {norm.item():.3f} \n"
|
||||||
writer.add_text("grad/summary", grad_txt, tot_n_samples)
|
writer.add_text("grad/summary", grad_txt, tot_n_samples)
|
||||||
del pred, data, loss, losses
|
|
||||||
|
|
||||||
|
pred.clear()
|
||||||
|
data.clear()
|
||||||
|
del pred, data, loss, losses
|
||||||
|
gc.collect()
|
||||||
# Run validation
|
# Run validation
|
||||||
if (
|
if (
|
||||||
(
|
(
|
||||||
|
@ -523,6 +544,7 @@ def training(rank, conf, output_dir, args):
|
||||||
pbar=(rank == -1),
|
pbar=(rank == -1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
str_results = [
|
str_results = [
|
||||||
f"{k} {v:.3E}"
|
f"{k} {v:.3E}"
|
||||||
|
@ -563,6 +585,10 @@ def training(rank, conf, output_dir, args):
|
||||||
f"figures/{i}_{name}", fig, tot_n_samples
|
f"figures/{i}_{name}", fig, tot_n_samples
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache() # should be cleared at the first iter
|
torch.cuda.empty_cache() # should be cleared at the first iter
|
||||||
|
str_results.clear()
|
||||||
|
pr_metrics.clear()
|
||||||
|
del str_results, figures, pr_metrics
|
||||||
|
|
||||||
|
|
||||||
if (tot_it % conf.train.save_every_iter == 0 and tot_it > 0) and rank == 0:
|
if (tot_it % conf.train.save_every_iter == 0 and tot_it > 0) and rank == 0:
|
||||||
if results is None:
|
if results is None:
|
||||||
|
@ -616,7 +642,7 @@ def training(rank, conf, output_dir, args):
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
|
|
||||||
def main_worker(rank, conf, output_dir, args):
|
def main_worker(rank, conf, output_dir, aprgs):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
with capture_outputs(output_dir / "log.txt"):
|
with capture_outputs(output_dir / "log.txt"):
|
||||||
training(rank, conf, output_dir, args)
|
training(rank, conf, output_dir, args)
|
||||||
|
|
|
@ -33,6 +33,15 @@ def batch_to_device(batch, device, non_blocking=True):
|
||||||
|
|
||||||
return map_tensor(batch, _func)
|
return map_tensor(batch, _func)
|
||||||
|
|
||||||
|
def detach_tensors(batch):
|
||||||
|
"""
|
||||||
|
Detach all tensors in a batch recursively.
|
||||||
|
This is useful for detaching tensors from the computational graph to free up memory.
|
||||||
|
"""
|
||||||
|
def _detach(tensor):
|
||||||
|
return tensor.detach()
|
||||||
|
|
||||||
|
return map_tensor(batch, _detach)
|
||||||
|
|
||||||
def rbd(data: dict) -> dict:
|
def rbd(data: dict) -> dict:
|
||||||
"""Remove batch dimension from elements in data"""
|
"""Remove batch dimension from elements in data"""
|
||||||
|
@ -40,3 +49,9 @@ def rbd(data: dict) -> dict:
|
||||||
k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
|
k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
|
||||||
for k, v in data.items()
|
for k, v in data.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def index_batch(tensor_dict):
|
||||||
|
batch_size = len(next(iter(tensor_dict.values())))
|
||||||
|
for i in range(batch_size):
|
||||||
|
yield map_tensor(tensor_dict, lambda t: t[i])
|
||||||
|
|
|
@ -67,8 +67,12 @@ class MedianMetric:
|
||||||
else:
|
else:
|
||||||
return np.nanmedian(self._elements)
|
return np.nanmedian(self._elements)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._elements)
|
||||||
|
|
||||||
|
|
||||||
class PRMetric:
|
class PRMetric:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.labels = []
|
self.labels = []
|
||||||
self.predictions = []
|
self.predictions = []
|
||||||
|
|
|
@ -208,14 +208,14 @@ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axe
|
||||||
kpts0[:, 1],
|
kpts0[:, 1],
|
||||||
c=color,
|
c=color,
|
||||||
s=ps,
|
s=ps,
|
||||||
label=None if labels is None else labels[0],
|
label=None if labels is None or len(labels) == 0 else labels[0],
|
||||||
)
|
)
|
||||||
ax1.scatter(
|
ax1.scatter(
|
||||||
kpts1[:, 0],
|
kpts1[:, 0],
|
||||||
kpts1[:, 1],
|
kpts1[:, 1],
|
||||||
c=color,
|
c=color,
|
||||||
s=ps,
|
s=ps,
|
||||||
label=None if labels is None else labels[1],
|
label=None if labels is None or len(labels) == 0 else labels[1],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -338,6 +338,14 @@ class SuperPoint(BaseModel):
|
||||||
for k, d in zip(keypoints, dense_desc)
|
for k, d in zip(keypoints, dense_desc)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if isinstance(desc, list):
|
||||||
|
if all(isinstance(d, torch.Tensor) for d in desc):
|
||||||
|
# If desc is a list of tensors
|
||||||
|
desc = torch.stack(desc)
|
||||||
|
else:
|
||||||
|
# If desc is a list of non-tensor elements
|
||||||
|
desc = torch.stack([torch.tensor(d) for d in desc])
|
||||||
|
|
||||||
pred = {
|
pred = {
|
||||||
"keypoints": keypoints + 0.5,
|
"keypoints": keypoints + 0.5,
|
||||||
"keypoint_scores": scores,
|
"keypoint_scores": scores,
|
||||||
|
|
|
@ -38,12 +38,12 @@ urls = {Repository = "https://github.com/cvg/glue-factory"}
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
extra = [
|
extra = [
|
||||||
"pycolmap",
|
"pycolmap",
|
||||||
"poselib @ git+https://github.com/PoseLib/PoseLib.git",
|
"poselib @ git+https://github.com/PoseLib/PoseLib.git@9c8f3ca1baba69e19726cc7caded574873ec1f9e",
|
||||||
"pytlsd @ git+https://github.com/iago-suarez/pytlsd.git",
|
"pytlsd @ git+https://github.com/iago-suarez/pytlsd.git@v0.0.5",
|
||||||
"deeplsd @ git+https://github.com/cvg/DeepLSD.git",
|
"deeplsd @ git+https://github.com/cvg/DeepLSD.git",
|
||||||
"homography_est @ git+https://github.com/rpautrat/homography_est.git",
|
"homography_est @ git+https://github.com/rpautrat/homography_est.git@17b200d528e6aa8ac61a878a29265bf5f9d36c41",
|
||||||
]
|
]
|
||||||
dev = ["black", "flake8", "isort"]
|
dev = ["black", "flake8", "isort", "parameterized"]
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
include = ["gluefactory*"]
|
include = ["gluefactory*"]
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
#python -m gluefactory.train sp+lg_homography \
|
||||||
|
# --conf gluefactory/configs/superpoint+lightglue_homography.yaml \
|
||||||
|
# data.batch_size=32 # for 1x 1080 GPU
|
||||||
|
|
||||||
|
#python -m gluefactory.train satellites_aug_both_40 \
|
||||||
|
# --conf gluefactory/configs/satellites3.yaml \
|
||||||
|
# data.batch_size=4 # for 1x 1080 GPU
|
||||||
|
|
||||||
|
|
||||||
|
python -m gluefactory.train drone2sat_v1 \
|
||||||
|
--conf gluefactory/configs/drone2sat_v1.yaml \
|
||||||
|
data.batch_size=6 # for 1x 1080 GPU
|
|
@ -0,0 +1,22 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
LOGFILE="/var/log/h/system_stats.log"
|
||||||
|
|
||||||
|
echo "Logging CPU, Memory, Swap, Disk, and Network usage to $LOGFILE"
|
||||||
|
|
||||||
|
while true; do
|
||||||
|
echo "----- $(date) -----" >> $LOGFILE
|
||||||
|
echo "CPU Usage:" >> $LOGFILE
|
||||||
|
mpstat -P ALL 1 1 >> $LOGFILE
|
||||||
|
|
||||||
|
echo "Disk Usage:" >> $LOGFILE
|
||||||
|
iostat >> $LOGFILE
|
||||||
|
|
||||||
|
echo "Memory and Swap Usage:" >> $LOGFILE
|
||||||
|
free -m >> $LOGFILE
|
||||||
|
|
||||||
|
echo "Network Usage:" >> $LOGFILE
|
||||||
|
ifstat 1 1 >> $LOGFILE
|
||||||
|
|
||||||
|
sleep 5
|
||||||
|
done
|
|
@ -0,0 +1,88 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from gluefactory.eval.utils import eval_matches_homography
|
||||||
|
from gluefactory.geometry.homography import warp_points_torch
|
||||||
|
|
||||||
|
|
||||||
|
class TestEvalUtils(unittest.TestCase):
|
||||||
|
@staticmethod
|
||||||
|
def default_pts():
|
||||||
|
return torch.tensor(
|
||||||
|
[
|
||||||
|
[10.0, 10.0],
|
||||||
|
[10.0, 20.0],
|
||||||
|
[20.0, 20.0],
|
||||||
|
[20.0, 10.0],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_pred(kps0, kps1):
|
||||||
|
return {
|
||||||
|
"keypoints0": kps0,
|
||||||
|
"keypoints1": kps1,
|
||||||
|
"matches0": torch.arange(len(kps0)),
|
||||||
|
"matching_scores0": torch.ones(len(kps1)),
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_eval_matches_homography_trivial(self):
|
||||||
|
data = {"H_0to1": torch.eye(3)}
|
||||||
|
kps = self.default_pts()
|
||||||
|
pred = self.default_pred(kps, kps)
|
||||||
|
|
||||||
|
results = eval_matches_homography(data, pred)
|
||||||
|
|
||||||
|
self.assertEqual(results["prec@1px"], 1)
|
||||||
|
self.assertEqual(results["prec@3px"], 1)
|
||||||
|
self.assertEqual(results["num_matches"], 4)
|
||||||
|
self.assertEqual(results["num_keypoints"], 4)
|
||||||
|
|
||||||
|
def test_eval_matches_homography_real(self):
|
||||||
|
data = {"H_0to1": torch.tensor([[1.5, 0.2, 21], [-0.3, 1.6, 33], [0, 0, 1.0]])}
|
||||||
|
kps0 = self.default_pts()
|
||||||
|
kps1 = warp_points_torch(kps0, data["H_0to1"], inverse=False)
|
||||||
|
pred = self.default_pred(kps0, kps1)
|
||||||
|
|
||||||
|
results = eval_matches_homography(data, pred)
|
||||||
|
|
||||||
|
self.assertEqual(results["prec@1px"], 1)
|
||||||
|
self.assertEqual(results["prec@3px"], 1)
|
||||||
|
|
||||||
|
def test_eval_matches_homography_real_outliers(self):
|
||||||
|
data = {"H_0to1": torch.tensor([[1.5, 0.2, 21], [-0.3, 1.6, 33], [0, 0, 1.0]])}
|
||||||
|
kps0 = self.default_pts()
|
||||||
|
kps0 = torch.cat([kps0, torch.tensor([[5.0, 5.0]])])
|
||||||
|
kps1 = warp_points_torch(kps0, data["H_0to1"], inverse=False)
|
||||||
|
# Move one keypoint 1.5 pixels away in x and y
|
||||||
|
kps1[-1] += 1.5
|
||||||
|
pred = self.default_pred(kps0, kps1)
|
||||||
|
|
||||||
|
results = eval_matches_homography(data, pred)
|
||||||
|
self.assertAlmostEqual(results["prec@1px"], 0.8)
|
||||||
|
self.assertAlmostEqual(results["prec@3px"], 1.0)
|
||||||
|
|
||||||
|
def test_eval_matches_homography_batched(self):
|
||||||
|
H0 = torch.tensor([[1.5, 0.2, 21], [-0.3, 1.6, 33], [0, 0, 1.0]])
|
||||||
|
H1 = torch.tensor([[0.7, 0.1, -5], [-0.1, 0.65, 13], [0, 0, 1.0]])
|
||||||
|
data = {"H_0to1": torch.stack([H0, H1])}
|
||||||
|
kps0 = torch.stack([self.default_pts(), self.default_pts().flip(0)])
|
||||||
|
kps1 = warp_points_torch(kps0, data["H_0to1"], inverse=False)
|
||||||
|
# In the first element of the batch there is one outlier
|
||||||
|
kps1[0, -1] += 5
|
||||||
|
matches0 = torch.stack([torch.arange(4), torch.arange(4)])
|
||||||
|
# In the second element of the batch there is only 2 matches
|
||||||
|
matches0[1, :2] = -1
|
||||||
|
pred = {
|
||||||
|
"keypoints0": kps0,
|
||||||
|
"keypoints1": kps1,
|
||||||
|
"matches0": matches0,
|
||||||
|
"matching_scores0": torch.ones_like(matches0),
|
||||||
|
}
|
||||||
|
|
||||||
|
results = eval_matches_homography(data, pred)
|
||||||
|
self.assertAlmostEqual(results["prec@1px"][0], 0.75)
|
||||||
|
self.assertAlmostEqual(results["prec@1px"][1], 1.0)
|
||||||
|
self.assertAlmostEqual(results["num_matches"][0], 4)
|
||||||
|
self.assertAlmostEqual(results["num_matches"][1], 2)
|
|
@ -0,0 +1,132 @@
|
||||||
|
import unittest
|
||||||
|
from collections import namedtuple
|
||||||
|
from os.path import splitext
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch.cuda
|
||||||
|
from kornia import image_to_tensor
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from parameterized import parameterized
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from gluefactory import logger
|
||||||
|
from gluefactory.eval.utils import (
|
||||||
|
eval_homography_dlt,
|
||||||
|
eval_homography_robust,
|
||||||
|
eval_matches_homography,
|
||||||
|
)
|
||||||
|
from gluefactory.models.two_view_pipeline import TwoViewPipeline
|
||||||
|
from gluefactory.settings import root
|
||||||
|
from gluefactory.utils.image import ImagePreprocessor
|
||||||
|
from gluefactory.utils.tensor import map_tensor
|
||||||
|
from gluefactory.utils.tools import set_seed
|
||||||
|
from gluefactory.visualization.viz2d import (
|
||||||
|
plot_color_line_matches,
|
||||||
|
plot_images,
|
||||||
|
plot_matches,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_input_data(cv_img0, cv_img1, device):
|
||||||
|
img0 = image_to_tensor(cv_img0).float() / 255
|
||||||
|
img1 = image_to_tensor(cv_img1).float() / 255
|
||||||
|
ip = ImagePreprocessor({})
|
||||||
|
data = {"view0": ip(img0), "view1": ip(img1)}
|
||||||
|
data = map_tensor(
|
||||||
|
data,
|
||||||
|
lambda t: t[None].to(device)
|
||||||
|
if isinstance(t, Tensor)
|
||||||
|
else torch.from_numpy(t)[None].to(device),
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
ExpectedResults = namedtuple("ExpectedResults", ("num_matches", "prec3px", "h_error"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestIntegration(unittest.TestCase):
|
||||||
|
methods_to_test = [
|
||||||
|
("superpoint+NN.yaml", "poselib", ExpectedResults(1300, 0.8, 1.0)),
|
||||||
|
("superpoint-open+NN.yaml", "poselib", ExpectedResults(1300, 0.8, 1.0)),
|
||||||
|
(
|
||||||
|
"superpoint+lsd+gluestick.yaml",
|
||||||
|
"homography_est",
|
||||||
|
ExpectedResults(1300, 0.8, 1.0),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"superpoint+lightglue-official.yaml",
|
||||||
|
"poselib",
|
||||||
|
ExpectedResults(1300, 0.8, 1.0),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
visualize = False
|
||||||
|
|
||||||
|
@parameterized.expand(methods_to_test)
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_real_homography(self, conf_file, estimator, exp_results):
|
||||||
|
set_seed(0)
|
||||||
|
model_path = root / "gluefactory" / "configs" / conf_file
|
||||||
|
img_path0 = root / "assets" / "boat1.png"
|
||||||
|
img_path1 = root / "assets" / "boat2.png"
|
||||||
|
h_gt = torch.tensor(
|
||||||
|
[
|
||||||
|
[0.85799, 0.21669, 9.4839],
|
||||||
|
[-0.21177, 0.85855, 130.48],
|
||||||
|
[1.5015e-06, 9.2033e-07, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
gs = TwoViewPipeline(OmegaConf.load(model_path).model).to(device).eval()
|
||||||
|
|
||||||
|
cv_img0, cv_img1 = cv2.imread(str(img_path0)), cv2.imread(str(img_path1))
|
||||||
|
data = create_input_data(cv_img0, cv_img1, device)
|
||||||
|
pred = gs(data)
|
||||||
|
pred = map_tensor(
|
||||||
|
pred, lambda t: torch.squeeze(t, dim=0) if isinstance(t, Tensor) else t
|
||||||
|
)
|
||||||
|
data["H_0to1"] = h_gt.to(device)
|
||||||
|
data["H_1to0"] = torch.linalg.inv(h_gt).to(device)
|
||||||
|
|
||||||
|
results = eval_matches_homography(data, pred)
|
||||||
|
results = {**results, **eval_homography_dlt(data, pred)}
|
||||||
|
|
||||||
|
results = {
|
||||||
|
**results,
|
||||||
|
**eval_homography_robust(
|
||||||
|
data,
|
||||||
|
pred,
|
||||||
|
{"estimator": estimator},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(results)
|
||||||
|
self.assertGreater(results["num_matches"], exp_results.num_matches)
|
||||||
|
self.assertGreater(results["prec@3px"], exp_results.prec3px)
|
||||||
|
self.assertLess(results["H_error_ransac"], exp_results.h_error)
|
||||||
|
|
||||||
|
if self.visualize:
|
||||||
|
pred = map_tensor(
|
||||||
|
pred, lambda t: t.cpu().numpy() if isinstance(t, Tensor) else t
|
||||||
|
)
|
||||||
|
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
|
||||||
|
m0 = pred["matches0"]
|
||||||
|
valid0 = m0 != -1
|
||||||
|
kpm0, kpm1 = kp0[valid0], kp1[m0[valid0]]
|
||||||
|
|
||||||
|
plot_images([cv_img0, cv_img1])
|
||||||
|
plot_matches(kpm0, kpm1, a=0.0)
|
||||||
|
plt.savefig(f"{splitext(conf_file)[0]}_point_matches.svg")
|
||||||
|
|
||||||
|
if "lines0" in pred and "lines1" in pred:
|
||||||
|
lines0, lines1 = pred["lines0"], pred["lines1"]
|
||||||
|
lm0 = pred["line_matches0"]
|
||||||
|
lvalid0 = lm0 != -1
|
||||||
|
linem0, linem1 = lines0[lvalid0], lines1[lm0[lvalid0]]
|
||||||
|
|
||||||
|
plot_images([cv_img0, cv_img1])
|
||||||
|
plot_color_line_matches([linem0, linem1])
|
||||||
|
plt.savefig(f"{splitext(conf_file)[0]}_line_matches.svg")
|
||||||
|
plt.show()
|
Loading…
Reference in New Issue