Compare commits
10 Commits
f7b587e881
...
d40e423a42
Author | SHA1 | Date |
---|---|---|
Jeba Kolega | d40e423a42 | |
Rémi Pautrat | 4a8283517f | |
Paul-Edouard Sarlin | ab6e6eef8b | |
Alexander Veicht | e0104fd65b | |
Paul-Edouard Sarlin | 43cf81aa2f | |
Philipp Lindenberger | 0c75e76fd6 | |
Iago Suárez | aa7727675e | |
Philipp Lindenberger | 692c72f94c | |
Philipp Lindenberger | 398c4b8c21 | |
Philipp Lindenberger | 22154a60bc |
|
@ -0,0 +1,30 @@
|
|||
name: Python Tests
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [ assigned, opened, synchronize, reopened ]
|
||||
jobs:
|
||||
build:
|
||||
name: Run Python Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: 'pip'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get remove libunwind-14-dev || true
|
||||
sudo apt-get install -y libceres-dev libeigen3-dev
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pytest pytest-cov
|
||||
python -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||
python -m pip install -e .[dev]
|
||||
python -m pip install -e .[extra]
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
set -o pipefail
|
||||
pytest --junitxml=pytest.xml --cov=gluefactory tests/
|
31
README.md
31
README.md
|
@ -25,7 +25,7 @@ python3 -m pip install -e .[extra]
|
|||
All models and datasets in gluefactory have auto-downloaders, so you can get started right away!
|
||||
|
||||
## License
|
||||
The code and trained models in Glue Factory are released with an Apache-2.0 license. This includes LightGlue trained with an [open version of SuperPoint](https://github.com/rpautrat/SuperPoint). Third-party models that are not compatible with this license, such as SuperPoint (original) and SuperGlue, are provided in `gluefactory_nonfree`, where each model might follow its own, restrictive license.
|
||||
The code and trained models in Glue Factory are released with an Apache-2.0 license. This includes LightGlue and an [open version of SuperPoint](https://github.com/rpautrat/SuperPoint). Third-party models that are not compatible with this license, such as SuperPoint (original) and SuperGlue, are provided in `gluefactory_nonfree`, where each model might follow its own, restrictive license.
|
||||
|
||||
## Evaluation
|
||||
|
||||
|
@ -66,8 +66,8 @@ Here are the results as Area Under the Curve (AUC) of the homography error at 1
|
|||
|
||||
| Methods | DLT | [OpenCV](../gluefactory/robust_estimators/homography/opencv.py) | [PoseLib](../gluefactory/robust_estimators/homography/poselib.py) |
|
||||
| ------------------------------------------------------------ | ------------------ | ------------------ | ------------------ |
|
||||
| [SuperPoint + SuperGlue](../gluefactory/configs/superpoint+superglue.yaml) | 32.1 / 65.0 / 75.7 | 32.9 / 55.7 / 68.0 | 37.0 / 68.2 / 78.7 |
|
||||
| [SuperPoint + LightGlue](../gluefactory/configs/superpoint+lightglue.yaml) | 35.1 / 67.2 / 77.6 | 34.2 / 57.9 / 69.9 | 37.1 / 67.4 / 77.8 |
|
||||
| [SuperPoint + SuperGlue](gluefactory/configs/superpoint+superglue-official.yaml) | 32.1 / 65.0 / 75.7 | 32.9 / 55.7 / 68.0 | 37.0 / 68.2 / 78.7 |
|
||||
| [SuperPoint + LightGlue](gluefactory/configs/superpoint+lightglue-official.yaml) | 35.1 / 67.2 / 77.6 | 34.2 / 57.9 / 69.9 | 37.1 / 67.4 / 77.8 |
|
||||
|
||||
|
||||
</details>
|
||||
|
@ -159,9 +159,12 @@ Here are the results as Area Under the Curve (AUC) of the pose error at 5/10/20
|
|||
|
||||
| Methods | [pycolmap](../gluefactory/robust_estimators/relative_pose/pycolmap.py) | [OpenCV](../gluefactory/robust_estimators/relative_pose/opencv.py) | [PoseLib](../gluefactory/robust_estimators/relative_pose/poselib.py) |
|
||||
| ------------------------------------------------------------ | ------------------ | ------------------ | ------------------ |
|
||||
| [SuperPoint + SuperGlue](../gluefactory/configs/superpoint+superglue.yaml) | 54.4 / 70.4 / 82.4 | 48.7 / 65.6 / 79.0 | 64.8 / 77.9 / 87.0 |
|
||||
| [SuperPoint + LightGlue](../gluefactory/configs/superpoint+lightglue.yaml) | 56.7 / 72.4 / 83.7 | 51.0 / 68.1 / 80.7 | 66.8 / 79.3 / 87.9 |
|
||||
| [SuperPoint + GlueStick](../gluefactory/configs/superpoint+lsd+gluestick.yaml) | 53.2 / 69.8 / 81.9 | 46.3 / 64.2 / 78.1 | 64.4 / 77.5 / 86.5 |
|
||||
| [SuperPoint + SuperGlue](gluefactory/configs/superpoint+superglue-official.yaml) | 54.4 / 70.4 / 82.4 | 48.7 / 65.6 / 79.0 | 64.8 / 77.9 / 87.0 |
|
||||
| [SuperPoint + LightGlue](gluefactory/configs/superpoint+lightglue-official.yaml) | 56.7 / 72.4 / 83.7 | 51.0 / 68.1 / 80.7 | 66.8 / 79.3 / 87.9 |
|
||||
| [SIFT (2K) + LightGlue](gluefactory/configs/sift+lightglue-official.yaml) | ? / ? / ? | 43.5 / 61.5 / 75.9 | 60.4 / 74.3 / 84.5 |
|
||||
| [SIFT (4K) + LightGlue](gluefactory/configs/sift+lightglue-official.yaml) | ? / ? / ? | 49.9 / 67.3 / 80.3 | 65.9 / 78.6 / 87.4 |
|
||||
| [ALIKED + LightGlue](gluefactory/configs/aliked+lightglue-official.yaml) | ? / ? / ? | 51.5 / 68.1 / 80.4 | 66.3 / 78.7 / 87.5 |
|
||||
| [SuperPoint + GlueStick](gluefactory/configs/superpoint+lsd+gluestick.yaml) | 53.2 / 69.8 / 81.9 | 46.3 / 64.2 / 78.1 | 64.4 / 77.5 / 86.5 |
|
||||
|
||||
</details>
|
||||
|
||||
|
@ -223,18 +226,18 @@ All training commands automatically download the datasets.
|
|||
<details>
|
||||
<summary>[Training LightGlue]</summary>
|
||||
|
||||
We show how to train LightGlue with [SuperPoint open](https://github.com/rpautrat/SuperPoint).
|
||||
We show how to train LightGlue with [SuperPoint](https://github.com/magicleap/SuperPointPretrainedNetwork).
|
||||
We first pre-train LightGlue on the homography dataset:
|
||||
```bash
|
||||
python -m gluefactory.train sp+lg_homography \ # experiment name
|
||||
--conf gluefactory/configs/superpoint-open+lightglue_homography.yaml
|
||||
--conf gluefactory/configs/superpoint+lightglue_homography.yaml
|
||||
```
|
||||
Feel free to use any other experiment name. By default the checkpoints are written to `outputs/training/`. The default batch size of 128 corresponds to the results reported in the paper and requires 2x 3090 GPUs with 24GB of VRAM each as well as PyTorch >= 2.0 (FlashAttention).
|
||||
Configurations are managed by [OmegaConf](https://omegaconf.readthedocs.io/) so any entry can be overridden from the command line.
|
||||
If you have PyTorch < 2.0 or weaker GPUs, you may thus need to reduce the batch size via:
|
||||
```bash
|
||||
python -m gluefactory.train sp+lg_homography \
|
||||
--conf gluefactory/configs/superpoint-open+lightglue_homography.yaml \
|
||||
--conf gluefactory/configs/superpoint+lightglue_homography.yaml \
|
||||
data.batch_size=32 # for 1x 1080 GPU
|
||||
```
|
||||
Be aware that this can impact the overall performance. You might need to adjust the learning rate accordingly.
|
||||
|
@ -242,17 +245,17 @@ Be aware that this can impact the overall performance. You might need to adjust
|
|||
We then fine-tune the model on the MegaDepth dataset:
|
||||
```bash
|
||||
python -m gluefactory.train sp+lg_megadepth \
|
||||
--conf gluefactory/configs/superpoint-open+lightglue_megadepth.yaml \
|
||||
--conf gluefactory/configs/superpoint+lightglue_megadepth.yaml \
|
||||
train.load_experiment=sp+lg_homography
|
||||
```
|
||||
|
||||
Here the default batch size is 32. To speed up training on MegaDepth, we suggest to cache the local features before training (requires around 150 GB of disk space):
|
||||
```bash
|
||||
# extract features
|
||||
python -m gluefactory.scripts.export_megadepth --method sp_open --num_workers 8
|
||||
python -m gluefactory.scripts.export_megadepth --method sp --num_workers 8
|
||||
# run training with cached features
|
||||
python -m gluefactory.train sp+lg_megadepth \
|
||||
--conf gluefactory/configs/superpoint-open+lightglue_megadepth.yaml \
|
||||
--conf gluefactory/configs/superpoint+lightglue_megadepth.yaml \
|
||||
train.load_experiment=sp+lg_homography \
|
||||
data.load_features.do=True
|
||||
```
|
||||
|
@ -297,10 +300,10 @@ Using the following local feature extractors:
|
|||
| Model | LightGlue config |
|
||||
| --------- | --------- |
|
||||
| [SuperPoint (open)](https://github.com/rpautrat/SuperPoint) | `superpoint-open+lightglue_{homography,megadepth}.yaml` |
|
||||
| [SuperPoint (official)](https://github.com/magicleap/SuperPointPretrainedNetwork) | ❌ TODO |
|
||||
| [SuperPoint (official)](https://github.com/magicleap/SuperPointPretrainedNetwork) | `superpoint+lightglue_{homography,megadepth}.yaml` |
|
||||
| SIFT (via [pycolmap](https://github.com/colmap/pycolmap)) | `sift+lightglue_{homography,megadepth}.yaml` |
|
||||
| [ALIKED](https://github.com/Shiaoming/ALIKED) | `aliked+lightglue_{homography,megadepth}.yaml` |
|
||||
| [DISK](https://github.com/cvlab-epfl/disk) | ❌ TODO |
|
||||
| [DISK](https://github.com/cvlab-epfl/disk) | `disk+lightglue_{homography,megadepth}.yaml` |
|
||||
| Key.Net + HardNet | ❌ TODO |
|
||||
|
||||
## Coming soon
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 519 KiB |
Binary file not shown.
After Width: | Height: | Size: 580 KiB |
|
@ -0,0 +1,28 @@
|
|||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.aliked
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
matcher:
|
||||
name: matchers.lightglue_pretrained
|
||||
features: aliked
|
||||
depth_confidence: -1
|
||||
width_confidence: -1
|
||||
filter_threshold: 0.1
|
||||
benchmarks:
|
||||
megadepth1500:
|
||||
data:
|
||||
preprocessing:
|
||||
side: long
|
||||
resize: 1600
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
model:
|
||||
extractor:
|
||||
max_num_keypoints: 1024 # overwrite config above
|
|
@ -0,0 +1,47 @@
|
|||
data:
|
||||
name: homographies
|
||||
data_dir: revisitop1m
|
||||
train_size: 150000
|
||||
val_size: 2000
|
||||
batch_size: 128
|
||||
num_workers: 14
|
||||
homography:
|
||||
difficulty: 0.7
|
||||
max_angle: 45
|
||||
photometric:
|
||||
name: lg
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.disk_kornia
|
||||
max_num_keypoints: 512
|
||||
force_num_keypoints: True
|
||||
detection_threshold: 0.0
|
||||
trainable: False
|
||||
ground_truth:
|
||||
name: matchers.homography_matcher
|
||||
th_positive: 3
|
||||
th_negative: 3
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
filter_threshold: 0.1
|
||||
input_dim: 128
|
||||
flash: false
|
||||
checkpointed: true
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 40
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 500
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
start: 20
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
|
||||
benchmarks:
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
|
@ -0,0 +1,70 @@
|
|||
data:
|
||||
name: megadepth
|
||||
preprocessing:
|
||||
resize: 1024
|
||||
side: long
|
||||
square_pad: True
|
||||
train_split: train_scenes_clean.txt
|
||||
train_num_per_scene: 300
|
||||
val_split: valid_scenes_clean.txt
|
||||
val_pairs: valid_pairs.txt
|
||||
min_overlap: 0.1
|
||||
max_overlap: 0.7
|
||||
num_overlap_bins: 3
|
||||
read_depth: true
|
||||
read_image: true
|
||||
batch_size: 32
|
||||
num_workers: 14
|
||||
load_features:
|
||||
do: false # enable this if you have cached predictions
|
||||
path: exports/megadepth-undist-depth-r1024_DISK-k2048-nms5/{scene}.h5
|
||||
padding_length: 2048
|
||||
padding_fn: pad_local_features
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.disk_kornia
|
||||
max_num_keypoints: 512
|
||||
force_num_keypoints: True
|
||||
detection_threshold: 0.0
|
||||
trainable: False
|
||||
ground_truth:
|
||||
name: matchers.homography_matcher
|
||||
th_positive: 3
|
||||
th_negative: 3
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
filter_threshold: 0.1
|
||||
input_dim: 128
|
||||
flash: false
|
||||
checkpointed: true
|
||||
allow_no_extract: True
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 50
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 1000
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
start: 30
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
dataset_callback_fn: sample_new_items
|
||||
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
|
||||
benchmarks:
|
||||
megadepth1500:
|
||||
data:
|
||||
preprocessing:
|
||||
side: long
|
||||
resize: 1024
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
model:
|
||||
extractor:
|
||||
max_num_keypoints: 1024
|
|
@ -0,0 +1,69 @@
|
|||
data:
|
||||
name: drone2sat
|
||||
batch_size: 4
|
||||
num_workers: 24
|
||||
val_size: 500
|
||||
train_size: -1
|
||||
photometric:
|
||||
name: lg
|
||||
geo_dataset:
|
||||
uav_dataset_dir: /mnt/drive/uav_dataset
|
||||
satellite_dataset_dir: /mnt/drive/tiles
|
||||
misslabeled_images_path: /mnt/drive/misslabeled.txt
|
||||
sat_zoom_level: 17
|
||||
uav_patch_width: 400
|
||||
uav_patch_height: 400
|
||||
sat_patch_width: 400
|
||||
sat_patch_height: 400
|
||||
test_from_train_ratio: 0.0
|
||||
transform_mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
transform_std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
sat_availaible_years:
|
||||
- "2023"
|
||||
- "2021"
|
||||
- "2019"
|
||||
- "2016"
|
||||
max_rotation_angle: 0
|
||||
uav_image_scale: 2.0
|
||||
use_heatmap: false
|
||||
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: gluefactory_nonfree.superpoint
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
nms_radius: 3
|
||||
ground_truth:
|
||||
name: matchers.homography_matcher
|
||||
th_positive: 3
|
||||
th_negative: 3
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
features: superpoint
|
||||
depth_confidence: -1
|
||||
width_confidence: -1
|
||||
filter_threshold: 0.1
|
||||
flash: true
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 40
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 1000 # 350000 / 4
|
||||
lr: 1e-3
|
||||
lr_schedule:
|
||||
start: 20
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
benchmarks:
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
|
@ -0,0 +1,48 @@
|
|||
data:
|
||||
name: satellites
|
||||
data_dir: /mnt/drive/tiles
|
||||
metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt
|
||||
train_size: null
|
||||
val_size: null
|
||||
batch_size: 128
|
||||
num_workers: 24
|
||||
homography:
|
||||
difficulty: 0.9
|
||||
max_angle: 359
|
||||
photometric:
|
||||
name: lg
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: gluefactory_nonfree.superpoint
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
force_num_keypoints: True
|
||||
nms_radius: 3
|
||||
trainable: False
|
||||
ground_truth:
|
||||
name: matchers.homography_matcher
|
||||
th_positive: 3
|
||||
th_negative: 3
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
filter_threshold: 0.1
|
||||
flash: true
|
||||
checkpointed: true
|
||||
weights: superpoint
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 5
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 500
|
||||
lr: 1e-7
|
||||
lr_schedule:
|
||||
start: 20
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
benchmarks:
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
|
@ -0,0 +1,60 @@
|
|||
data:
|
||||
name: satellites
|
||||
data_dir: /mnt/drive/tiles
|
||||
metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt
|
||||
train_size: null
|
||||
val_size: null
|
||||
batch_size: 128
|
||||
num_workers: 24
|
||||
homography:
|
||||
difficulty: 0.5
|
||||
max_angle: 359
|
||||
photometric:
|
||||
name: lg
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: gluefactory_nonfree.superpoint
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
nms_radius: 3
|
||||
#extractor:
|
||||
# name: gluefactory_nonfree.superpoint
|
||||
# max_num_keypoints: 512
|
||||
# force_num_keypoints: True
|
||||
# detection_threshold: 0.0
|
||||
# nms_radius: 3
|
||||
# trainable: False
|
||||
ground_truth:
|
||||
name: matchers.homography_matcher
|
||||
th_positive: 3
|
||||
th_negative: 3
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
features: superpoint
|
||||
depth_confidence: -1
|
||||
width_confidence: -1
|
||||
filter_threshold: 0.1
|
||||
flash: true
|
||||
#matcher:
|
||||
# name: matchers.lightglue_pretrained
|
||||
# features: superpoint
|
||||
# depth_confidence: -1
|
||||
# width_confidence: -1
|
||||
# filter_threshold: 0.1
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 5
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 500
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
start: 20
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
benchmarks:
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
|
@ -0,0 +1,47 @@
|
|||
data:
|
||||
name: satellites
|
||||
data_dir: /mnt/drive/tiles
|
||||
metadata_dir: /home/ml-node/Documents/glue-factory/data/satellites/coords.txt
|
||||
train_size: null
|
||||
val_size: null
|
||||
batch_size: 128
|
||||
num_workers: 24
|
||||
homography:
|
||||
difficulty: 0.5
|
||||
max_angle: 180
|
||||
photometric:
|
||||
name: lg
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: gluefactory_nonfree.superpoint
|
||||
max_num_keypoints: 2048
|
||||
detection_threshold: 0.0
|
||||
nms_radius: 3
|
||||
ground_truth:
|
||||
name: matchers.homography_matcher
|
||||
th_positive: 3
|
||||
th_negative: 3
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
features: superpoint
|
||||
depth_confidence: -1
|
||||
width_confidence: -1
|
||||
filter_threshold: 0.1
|
||||
flash: true
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 40
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 500
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
start: 20
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
benchmarks:
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
|
@ -0,0 +1,28 @@
|
|||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.sift
|
||||
backend: pycolmap_cuda
|
||||
max_num_keypoints: 4096
|
||||
matcher:
|
||||
name: matchers.lightglue_pretrained
|
||||
features: sift
|
||||
depth_confidence: -1
|
||||
width_confidence: -1
|
||||
filter_threshold: 0.1
|
||||
benchmarks:
|
||||
megadepth1500:
|
||||
data:
|
||||
preprocessing:
|
||||
side: long
|
||||
resize: 1600
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
model:
|
||||
extractor:
|
||||
max_num_keypoints: 1024 # overwrite config above
|
|
@ -14,10 +14,10 @@ model:
|
|||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.sift
|
||||
detector: pycolmap_cuda
|
||||
backend: pycolmap_cuda
|
||||
max_num_keypoints: 1024
|
||||
force_num_keypoints: True
|
||||
detection_threshold: 0.0001
|
||||
nms_radius: 3
|
||||
trainable: False
|
||||
ground_truth:
|
||||
name: matchers.homography_matcher
|
||||
|
@ -46,3 +46,6 @@ benchmarks:
|
|||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
model:
|
||||
extractor:
|
||||
nms_radius: 0
|
||||
|
|
|
@ -25,10 +25,10 @@ model:
|
|||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: extractors.sift
|
||||
detector: pycolmap_cuda
|
||||
backend: pycolmap_cuda
|
||||
max_num_keypoints: 2048
|
||||
force_num_keypoints: True
|
||||
detection_threshold: 0.0001
|
||||
nms_radius: 3
|
||||
trainable: False
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
|
@ -62,6 +62,9 @@ benchmarks:
|
|||
preprocessing:
|
||||
side: long
|
||||
resize: 1600
|
||||
model:
|
||||
extractor:
|
||||
nms_radius: 0
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
|
@ -72,3 +75,4 @@ benchmarks:
|
|||
model:
|
||||
extractor:
|
||||
max_num_keypoints: 1024
|
||||
nms_radius: 0
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
data:
|
||||
name: homographies
|
||||
data_dir: revisitop1m
|
||||
train_size: 150000
|
||||
val_size: 2000
|
||||
batch_size: 128
|
||||
num_workers: 14
|
||||
homography:
|
||||
difficulty: 0.7
|
||||
max_angle: 45
|
||||
photometric:
|
||||
name: lg
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: gluefactory_nonfree.superpoint
|
||||
max_num_keypoints: 512
|
||||
force_num_keypoints: True
|
||||
detection_threshold: 0.0
|
||||
nms_radius: 3
|
||||
trainable: False
|
||||
ground_truth:
|
||||
name: matchers.homography_matcher
|
||||
th_positive: 3
|
||||
th_negative: 3
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
filter_threshold: 0.1
|
||||
flash: false
|
||||
checkpointed: true
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 40
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 500
|
||||
profile: true
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
start: 20
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
|
||||
benchmarks:
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
|
@ -0,0 +1,71 @@
|
|||
data:
|
||||
name: megadepth
|
||||
preprocessing:
|
||||
resize: 1024
|
||||
side: long
|
||||
square_pad: True
|
||||
train_split: train_scenes_clean.txt
|
||||
train_num_per_scene: 300
|
||||
val_split: valid_scenes_clean.txt
|
||||
val_pairs: valid_pairs.txt
|
||||
min_overlap: 0.1
|
||||
max_overlap: 0.7
|
||||
num_overlap_bins: 3
|
||||
read_depth: true
|
||||
read_image: true
|
||||
batch_size: 32
|
||||
num_workers: 14
|
||||
load_features:
|
||||
do: false # enable this if you have cached predictions
|
||||
path: exports/megadepth-undist-depth-r1024_SP-k2048-nms3/{scene}.h5
|
||||
padding_length: 2048
|
||||
padding_fn: pad_local_features
|
||||
model:
|
||||
name: two_view_pipeline
|
||||
extractor:
|
||||
name: gluefactory_nonfree.superpoint
|
||||
max_num_keypoints: 2048
|
||||
force_num_keypoints: True
|
||||
detection_threshold: 0.0
|
||||
nms_radius: 3
|
||||
trainable: False
|
||||
matcher:
|
||||
name: matchers.lightglue
|
||||
filter_threshold: 0.1
|
||||
flash: false
|
||||
checkpointed: true
|
||||
ground_truth:
|
||||
name: matchers.depth_matcher
|
||||
th_positive: 3
|
||||
th_negative: 5
|
||||
th_epi: 5
|
||||
allow_no_extract: True
|
||||
train:
|
||||
seed: 0
|
||||
epochs: 50
|
||||
log_every_iter: 100
|
||||
eval_every_iter: 1000
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
start: 30
|
||||
type: exp
|
||||
on_epoch: true
|
||||
exp_div_10: 10
|
||||
dataset_callback_fn: sample_new_items
|
||||
plot: [5, 'gluefactory.visualization.visualize_batch.make_match_figures']
|
||||
benchmarks:
|
||||
megadepth1500:
|
||||
data:
|
||||
preprocessing:
|
||||
side: long
|
||||
resize: 1600
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
hpatches:
|
||||
eval:
|
||||
estimator: opencv
|
||||
ransac_th: 0.5
|
||||
model:
|
||||
extractor:
|
||||
max_num_keypoints: 1024
|
|
@ -1,14 +1,14 @@
|
|||
data:
|
||||
name: homographies
|
||||
homography:
|
||||
difficulty: 0.5
|
||||
max_angle: 30
|
||||
difficulty: 0.7
|
||||
max_angle: 45
|
||||
patch_shape: [640, 480]
|
||||
photometric:
|
||||
p: 0.75
|
||||
train_size: 900000
|
||||
val_size: 1000
|
||||
batch_size: 80 # 20 per 10GB of GPU mem (12 for triplet)
|
||||
batch_size: 160 # 20 per 10GB of GPU mem (12 for triplet)
|
||||
num_workers: 15
|
||||
model:
|
||||
name: gluefactory.models.two_view_pipeline
|
||||
|
|
|
@ -1,10 +1,15 @@
|
|||
data:
|
||||
name: gluefactory.datasets.megadepth
|
||||
train_num_per_scene: 300
|
||||
val_pairs: valid_pairs.txt
|
||||
views: 2
|
||||
min_overlap: 0.1
|
||||
max_overlap: 0.7
|
||||
num_overlap_bins: 3
|
||||
preprocessing:
|
||||
resize: 640
|
||||
square_pad: True
|
||||
batch_size: 60
|
||||
batch_size: 160
|
||||
num_workers: 15
|
||||
model:
|
||||
name: gluefactory.models.two_view_pipeline
|
||||
|
@ -53,9 +58,9 @@ model:
|
|||
train:
|
||||
seed: 0
|
||||
epochs: 200
|
||||
log_every_iter: 10
|
||||
eval_every_iter: 100
|
||||
save_every_iter: 500
|
||||
log_every_iter: 400
|
||||
eval_every_iter: 700
|
||||
save_every_iter: 1400
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
type: exp # exp or multi_step
|
||||
|
|
|
@ -168,6 +168,11 @@ class BaseDataset(metaclass=ABCMeta):
|
|||
sampler = None
|
||||
if shuffle is None:
|
||||
shuffle = split == "train" and self.conf.shuffle_training
|
||||
shuffle = split == "val"
|
||||
|
||||
shuffle = True
|
||||
|
||||
print("Shuffle", shuffle)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
|
|
|
@ -0,0 +1,694 @@
|
|||
"""
|
||||
Simply load images from a folder or nested folders (does not have any split),
|
||||
and apply homographic adaptations to it. Yields an image pair without border
|
||||
artifacts.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
import os
|
||||
import json
|
||||
import gc
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import omegaconf
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from typing import List, Literal
|
||||
from torchvision import transforms
|
||||
import random
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
|
||||
|
||||
from ..geometry.homography import (
|
||||
compute_homography,
|
||||
sample_homography_corners,
|
||||
warp_points,
|
||||
)
|
||||
from ..models.cache_loader import CacheLoader, pad_local_features
|
||||
from ..settings import DATA_PATH
|
||||
from ..utils.image import read_image
|
||||
from ..utils.tools import fork_rng
|
||||
from ..visualization.viz2d import plot_image_grid
|
||||
from .augmentations import IdentityAugmentation, augmentations
|
||||
from .base_dataset import BaseDataset
|
||||
from .s_utils import get_random_tiff_patch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Drone2SatDataset(BaseDataset):
|
||||
default_conf = {
|
||||
# image search
|
||||
"data_dir": "revisitop1m", # the top-level directory
|
||||
"image_dir": "jpg/", # the subdirectory with the images
|
||||
"image_list": "revisitop1m.txt", # optional: list or filename of list
|
||||
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
|
||||
"metadata_dir": None,
|
||||
"uav_dataset_dir": None,
|
||||
"satellite_dataset_dir": None,
|
||||
# splits
|
||||
"train_size": 100,
|
||||
"val_size": 10,
|
||||
"shuffle_seed": 0, # or None to skip
|
||||
# image loading
|
||||
"grayscale": False,
|
||||
"triplet": False,
|
||||
"right_only": False, # image0 is orig (rescaled), image1 is right
|
||||
"reseed": False,
|
||||
"homography": {
|
||||
"difficulty": 0.8,
|
||||
"translation": 1.0,
|
||||
"max_angle": 60,
|
||||
"n_angles": 10,
|
||||
"patch_shape": [640, 480],
|
||||
"min_convexity": 0.05,
|
||||
},
|
||||
"photometric": {
|
||||
"name": "dark",
|
||||
"p": 0.75,
|
||||
# 'difficulty': 1.0, # currently unused
|
||||
},
|
||||
# feature loading
|
||||
"load_features": {
|
||||
"do": False,
|
||||
**CacheLoader.default_conf,
|
||||
"collate": False,
|
||||
"thresh": 0.0,
|
||||
"max_num_keypoints": -1,
|
||||
"force_num_keypoints": False,
|
||||
},
|
||||
# Other geolocaliztion parameters
|
||||
"geo_dataset" :{
|
||||
"uav_dataset_dir": None,
|
||||
"satellite_dataset_dir": None,
|
||||
"misslabeled_images_path": None,
|
||||
"sat_zoom_level": 17,
|
||||
"uav_patch_width": 400,
|
||||
"uav_patch_height": 400,
|
||||
"sat_patch_width": 400,
|
||||
"sat_patch_height": 400,
|
||||
"heatmap_kernel_size": 33,
|
||||
"test_from_train_ratio": 0.0,
|
||||
"transform_mean": [0.485, 0.456, 0.406],
|
||||
"transform_std": [0.229, 0.224, 0.225],
|
||||
"sat_availaible_years": ["2023", "2021", "2019", "2016"],
|
||||
"max_rotation_angle": 10,
|
||||
"uav_image_scale": 1,
|
||||
"use_heatmap": False,
|
||||
}
|
||||
}
|
||||
|
||||
def _init(self, conf):
|
||||
self.images = {"train": "train", "val": "val"}
|
||||
|
||||
def get_dataset(self, split):
|
||||
if split == "val":
|
||||
return GeoLocalizationDataset(
|
||||
uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir,
|
||||
satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir,
|
||||
misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path,
|
||||
dataset="test",
|
||||
sat_zoom_level=self.conf.geo_dataset.sat_zoom_level,
|
||||
uav_patch_width=self.conf.geo_dataset.uav_patch_width,
|
||||
uav_patch_height=self.conf.geo_dataset.uav_patch_height,
|
||||
sat_patch_width=self.conf.geo_dataset.sat_patch_width,
|
||||
sat_patch_height=self.conf.geo_dataset.sat_patch_height,
|
||||
heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size,
|
||||
test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio,
|
||||
transform_mean=self.conf.geo_dataset.transform_mean,
|
||||
transform_std=self.conf.geo_dataset.transform_std,
|
||||
sat_available_years=self.conf.geo_dataset.sat_availaible_years,
|
||||
max_rotation_angle=self.conf.geo_dataset.max_rotation_angle,
|
||||
uav_image_scale=self.conf.geo_dataset.uav_image_scale,
|
||||
use_heatmap=self.conf.geo_dataset.use_heatmap,
|
||||
subset_size=self.conf.val_size,
|
||||
)
|
||||
elif split == "train":
|
||||
return GeoLocalizationDataset(
|
||||
uav_dataset_dir=self.conf.geo_dataset.uav_dataset_dir,
|
||||
satellite_dataset_dir=self.conf.geo_dataset.satellite_dataset_dir,
|
||||
misslabeled_images_path=self.conf.geo_dataset.misslabeled_images_path,
|
||||
dataset="train",
|
||||
sat_zoom_level=self.conf.geo_dataset.sat_zoom_level,
|
||||
uav_patch_width=self.conf.geo_dataset.uav_patch_width,
|
||||
uav_patch_height=self.conf.geo_dataset.uav_patch_height,
|
||||
sat_patch_width=self.conf.geo_dataset.sat_patch_width,
|
||||
sat_patch_height=self.conf.geo_dataset.sat_patch_height,
|
||||
heatmap_kernel_size=self.conf.geo_dataset.heatmap_kernel_size,
|
||||
test_from_train_ratio=self.conf.geo_dataset.test_from_train_ratio,
|
||||
transform_mean=self.conf.geo_dataset.transform_mean,
|
||||
transform_std=self.conf.geo_dataset.transform_std,
|
||||
sat_available_years=self.conf.geo_dataset.sat_availaible_years,
|
||||
max_rotation_angle=self.conf.geo_dataset.max_rotation_angle,
|
||||
uav_image_scale=self.conf.geo_dataset.uav_image_scale,
|
||||
use_heatmap=self.conf.geo_dataset.use_heatmap,
|
||||
subset_size=self.conf.train_size,
|
||||
)
|
||||
|
||||
|
||||
class GeoLocalizationDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
uav_dataset_dir: str,
|
||||
satellite_dataset_dir: str,
|
||||
misslabeled_images_path: str,
|
||||
dataset: Literal["train", "test"],
|
||||
sat_zoom_level: int = 16,
|
||||
uav_patch_width: int = 128,
|
||||
uav_patch_height: int = 128,
|
||||
sat_patch_width: int = 400,
|
||||
sat_patch_height: int = 400,
|
||||
heatmap_kernel_size: int = 33,
|
||||
test_from_train_ratio: float = 0.0,
|
||||
transform_mean: List[float] = [0.485, 0.456, 0.406],
|
||||
transform_std: List[float] = [0.229, 0.224, 0.225],
|
||||
sat_available_years: List[str] = ["2023", "2021", "2019", "2016"],
|
||||
max_rotation_angle: int = 10,
|
||||
uav_image_scale: float = 1,
|
||||
use_heatmap: bool = True,
|
||||
subset_size: int = None,
|
||||
):
|
||||
self.uav_dataset_dir = uav_dataset_dir
|
||||
self.satellite_dataset_dir = satellite_dataset_dir
|
||||
self.dataset = dataset
|
||||
self.sat_zoom_level = sat_zoom_level
|
||||
self.uav_patch_width = uav_patch_width
|
||||
self.uav_patch_height = uav_patch_height
|
||||
self.heatmap_kernel_size = heatmap_kernel_size
|
||||
self.test_from_train_ratio = test_from_train_ratio
|
||||
self.transform_mean = transform_mean
|
||||
self.transform_std = transform_std
|
||||
self.misslabeled_images_path = misslabeled_images_path
|
||||
self.metadata_dict = {}
|
||||
self.max_rotation_angle = max_rotation_angle
|
||||
self.total_uav_samples = self.count_total_uav_samples()
|
||||
self.misslabelled_images = self.read_misslabelled_images(self.misslabeled_images_path)
|
||||
self.entry_paths = self.get_entry_paths(self.uav_dataset_dir)
|
||||
self.cleanup_misslabelled_images()
|
||||
self.transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
#transforms.Normalize(self.transform_mean, self.transform_std),
|
||||
]
|
||||
)
|
||||
self.sat_available_years = sat_available_years
|
||||
self.uav_image_scale = uav_image_scale
|
||||
self.use_heatmap = use_heatmap
|
||||
self.sat_patch_width = sat_patch_width
|
||||
self.sat_patch_height = sat_patch_height
|
||||
self.subset_size = subset_size
|
||||
|
||||
self.inverse_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Normalize(
|
||||
mean=[
|
||||
-m / s for m, s in zip(self.transform_mean, self.transform_std)
|
||||
],
|
||||
std=[1 / s for s in self.transform_std],
|
||||
),
|
||||
transforms.ToPILImage(),
|
||||
]
|
||||
)
|
||||
|
||||
if self.subset_size != -1:
|
||||
ssize = int(self.subset_size / len(self.sat_available_years))
|
||||
ssize = min(ssize, len(self.entry_paths))
|
||||
self.entry_paths = self.entry_paths[:ssize]
|
||||
|
||||
|
||||
def __len__(self) -> int:
|
||||
return (
|
||||
len(self.entry_paths)
|
||||
* len(self.sat_available_years)
|
||||
)
|
||||
|
||||
def read_misslabelled_images(
|
||||
self, path: str = "misslabels/misslabeled.txt"
|
||||
) -> List[str]:
|
||||
with open(path, "r") as f:
|
||||
lines = f.readlines()
|
||||
return [line.strip() for line in lines]
|
||||
|
||||
def cleanup_misslabelled_images(self) -> None:
|
||||
indices_to_delete = []
|
||||
|
||||
for image in self.misslabelled_images:
|
||||
for image_path in self.entry_paths:
|
||||
if image in image_path:
|
||||
index = self.entry_paths.index(image_path)
|
||||
indices_to_delete.append(index)
|
||||
break
|
||||
|
||||
sorted_tuples = sorted(indices_to_delete, reverse=True)
|
||||
|
||||
for index in sorted_tuples:
|
||||
self.entry_paths.pop(index)
|
||||
|
||||
def __getitem__(self, idx) -> dict:
|
||||
"""
|
||||
Retrieves a sample given its index, returning the preprocessed UAV
|
||||
and satellite images, along with their associated heatmap and metadata.
|
||||
"""
|
||||
|
||||
image_path_index = idx // (
|
||||
len(self.sat_available_years)
|
||||
)
|
||||
|
||||
sat_year = self.sat_available_years[idx % len(self.sat_available_years)]
|
||||
rot_angle = random.randint(-self.max_rotation_angle, self.max_rotation_angle)
|
||||
|
||||
image_path = self.entry_paths[image_path_index]
|
||||
uav_image = Image.open(image_path).convert("RGB") # Ensure 3-channel image
|
||||
|
||||
original_uav_image_width = uav_image.width
|
||||
original_uav_image_height = uav_image.height
|
||||
|
||||
lookup_str, file_number = self.extract_info_from_filename(image_path)
|
||||
img_info = self.metadata_dict[lookup_str][file_number]
|
||||
|
||||
lat, lon = (
|
||||
img_info["coordinate"]["latitude"],
|
||||
img_info["coordinate"]["longitude"],
|
||||
)
|
||||
|
||||
fov_vertical = img_info["fovVertical"]
|
||||
|
||||
try:
|
||||
agl_altitude = float(image_path.split("/")[-1].split("_")[2].split("m")[0])
|
||||
except IndexError:
|
||||
agl_altitude = 150.0
|
||||
warnings.warn(
|
||||
"Could not extract AGL altitude from filename, using default value of 150m."
|
||||
)
|
||||
|
||||
(
|
||||
satellite_patch,
|
||||
x_sat,
|
||||
y_sat,
|
||||
x_offset,
|
||||
y_offset,
|
||||
patch_transform,
|
||||
) = get_random_tiff_patch(
|
||||
lat, lon, self.sat_patch_width, self.sat_patch_height, sat_year, self.satellite_dataset_dir
|
||||
)
|
||||
|
||||
# Rotate crop center and transform image
|
||||
h = np.ceil(uav_image.height // self.uav_image_scale).astype(int)
|
||||
w = np.ceil(uav_image.width // self.uav_image_scale).astype(int)
|
||||
|
||||
uav_image = F.rotate(uav_image, rot_angle)
|
||||
uav_image = F.resize(uav_image, [h, w])
|
||||
uav_image = F.center_crop(
|
||||
uav_image, (self.uav_patch_height, self.uav_patch_width)
|
||||
)
|
||||
uav_image = self.transforms(uav_image)
|
||||
|
||||
satellite_patch = satellite_patch.transpose(1, 2, 0)
|
||||
satellite_patch_pytorch = self.transforms(satellite_patch)
|
||||
del satellite_patch
|
||||
|
||||
if self.use_heatmap:
|
||||
heatmap = self.get_heatmap_gt(
|
||||
x_sat,
|
||||
y_sat,
|
||||
satellite_patch.shape[1],
|
||||
satellite_patch.shape[2],
|
||||
self.heatmap_kernel_size,
|
||||
)
|
||||
|
||||
cropped_uav_image_width = self.calculate_cropped_uav_image_width(
|
||||
fov_vertical,
|
||||
original_uav_image_width,
|
||||
original_uav_image_height,
|
||||
self.uav_patch_width,
|
||||
self.uav_patch_height,
|
||||
agl_altitude,
|
||||
)
|
||||
|
||||
satellite_tile_width = self.calculate_cropped_sat_image_width(
|
||||
lat, self.sat_patch_width, patch_transform
|
||||
)
|
||||
|
||||
scale_factor = cropped_uav_image_width / satellite_tile_width
|
||||
scale_factor *= 10
|
||||
|
||||
homography_matrix = self.compute_homography(
|
||||
rot_angle,
|
||||
x_sat,
|
||||
y_sat,
|
||||
self.uav_patch_width,
|
||||
self.uav_patch_height,
|
||||
scale_factor,
|
||||
)
|
||||
|
||||
if not self.use_heatmap:
|
||||
# Sample four points from the UAV image
|
||||
points = self.sample_four_points(
|
||||
self.uav_patch_width, self.uav_patch_height
|
||||
)
|
||||
|
||||
# Transform the points
|
||||
warped_points = self.warp_points(points, homography_matrix)
|
||||
#img_info["warped_points_sat"] = warped_points
|
||||
#img_info["warped_points_uav"] = points
|
||||
|
||||
# img_info["cropped_uav_image_width"] = cropped_uav_image_width
|
||||
# img_info["satellite_tile_width"] = satellite_tile_width
|
||||
# img_info["scale_factor"] = scale_factor
|
||||
# img_info["filename"] = image_path
|
||||
# img_info["rot_angle"] = rot_angle
|
||||
# img_info["x_sat"] = x_sat
|
||||
# img_info["y_sat"] = y_sat
|
||||
# img_info["x_offset"] = x_offset
|
||||
# img_info["y_offset"] = y_offset
|
||||
# img_info["patch_transform"] = patch_transform
|
||||
# img_info["uav_image_scale"] = self.uav_image_scale
|
||||
# img_info["homography_matrix_uav_to_sat"] = homography_matrix
|
||||
# img_info["homography_matrix_sat_to_uav"] = np.linalg.inv(homography_matrix)
|
||||
# img_info["agl_altitude"] = agl_altitude
|
||||
# img_info["original_uav_image_width"] = original_uav_image_width
|
||||
# img_info["original_drone_image_height"] = original_uav_image_height
|
||||
# img_info["fov_vertical"] = fov_vertical
|
||||
#
|
||||
|
||||
inverse_homography_matrix = np.linalg.inv(homography_matrix)
|
||||
uav_image_data = {
|
||||
"image": uav_image,
|
||||
"H_": homography_matrix,
|
||||
"coords": points,
|
||||
"image_size": np.array([self.uav_patch_width, self.uav_patch_height]),
|
||||
}
|
||||
|
||||
satellite_patch_data = {
|
||||
"image": satellite_patch_pytorch,
|
||||
"H_": inverse_homography_matrix,
|
||||
"coords": warped_points,
|
||||
"image_size": np.array([self.sat_patch_width, self.sat_patch_height]),
|
||||
}
|
||||
|
||||
#hm = self.compute_homography_points(
|
||||
# warped_points, points, [self.uav_patch_width, self.uav_patch_height]
|
||||
#)
|
||||
|
||||
hm = self.compute_homography_points(
|
||||
points, warped_points, [1.0 , 1.0]
|
||||
)
|
||||
|
||||
if np.array_equal(hm, np.eye(3)):
|
||||
print("Singular matrix")
|
||||
return self.__getitem__(random.randint(0, len(self) - 1))
|
||||
|
||||
data = {
|
||||
"name": f"{image_path}_{rot_angle}_{sat_year}",
|
||||
"original_image_size": np.array(
|
||||
[original_uav_image_width, original_uav_image_height]
|
||||
),
|
||||
"H_0to1": hm.astype(np.float32),
|
||||
"view0": uav_image_data,
|
||||
"view1": satellite_patch_data,
|
||||
"idx": idx
|
||||
}
|
||||
del img_info
|
||||
gc.collect()
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def calculate_cropped_sat_image_width(self, latitude, patch_width, patch_transform):
|
||||
"""
|
||||
Computes the width of the satellite image in world coordinate units
|
||||
"""
|
||||
length_of_degree = 111320 * np.cos(np.radians(latitude))
|
||||
scale_x_meters = patch_transform[0] * length_of_degree
|
||||
satellite_tile_width = patch_width * scale_x_meters
|
||||
return satellite_tile_width
|
||||
|
||||
def calculate_cropped_uav_image_width(
|
||||
self,
|
||||
fov_vertical,
|
||||
orig_width,
|
||||
orig_height,
|
||||
crop_width,
|
||||
crop_height,
|
||||
altitude=150.0,
|
||||
):
|
||||
"""
|
||||
Computes the width of the UAV image in world coordinate units
|
||||
"""
|
||||
# Convert fov from degrees to radians
|
||||
fov_rad = np.radians(fov_vertical)
|
||||
|
||||
# Calculate the full width of the UAV image
|
||||
full_width = 2 * (altitude * np.tan(fov_rad / 2))
|
||||
|
||||
# Determine the cropping ratio
|
||||
crop_ratio_width = crop_width / orig_width
|
||||
crop_ratio_height = crop_height / orig_height
|
||||
|
||||
# Calculate the adjusted horizontal fov
|
||||
fov_horizontal = 2 * np.arctan(np.tan(fov_rad / 2) * (orig_width / orig_height))
|
||||
adjusted_fov_horizontal = 2 * np.arctan(
|
||||
np.tan(fov_horizontal / 2) * crop_ratio_width
|
||||
)
|
||||
|
||||
# Calculate the new full width using the adjusted horizontal fov
|
||||
full_width = 2 * (altitude * np.tan(adjusted_fov_horizontal / 2))
|
||||
|
||||
# Adjust the width according to the crop ratio
|
||||
cropped_width = full_width * crop_ratio_width
|
||||
|
||||
return cropped_width
|
||||
|
||||
def compute_homography_points(self, pts1_, pts2_, shape):
|
||||
"""Compute the homography matrix from 4 point correspondences"""
|
||||
# Rescale to actual size
|
||||
shape = np.array(shape[::-1], dtype=np.float32) # different convention [y, x]
|
||||
pts1 = pts1_ * np.expand_dims(shape, axis=0)
|
||||
pts2 = pts2_ * np.expand_dims(shape, axis=0)
|
||||
|
||||
def ax(p, q):
|
||||
return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]]
|
||||
|
||||
def ay(p, q):
|
||||
return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]]
|
||||
|
||||
def flat2mat(H):
|
||||
return np.reshape(np.concatenate([H, np.ones_like(H[:, :1])], axis=1), [3, 3])
|
||||
|
||||
a_mat = np.stack([f(pts1[i], pts2[i]) for i in range(4) for f in (ax, ay)], axis=0)
|
||||
p_mat = np.transpose(
|
||||
np.stack([[pts2[i][j] for i in range(4) for j in range(2)]], axis=0)
|
||||
)
|
||||
|
||||
try:
|
||||
homography = np.transpose(np.linalg.solve(a_mat, p_mat))
|
||||
except np.linalg.LinAlgError:
|
||||
print("Singular matrix")
|
||||
return np.eye(3)
|
||||
return flat2mat(homography)
|
||||
|
||||
|
||||
def compute_homography(
|
||||
self, rot_angle, x_sat, y_sat, uav_width, uav_height, scale_factor
|
||||
):
|
||||
# Adjust rot_angle if it's greater than 180 degrees
|
||||
if rot_angle > 180:
|
||||
rot_angle -= 360
|
||||
# Convert rotation angle to radians
|
||||
theta = np.radians(rot_angle)
|
||||
|
||||
# Rotation matrix
|
||||
R = np.array(
|
||||
[
|
||||
[np.cos(theta), -np.sin(theta), 0],
|
||||
[np.sin(theta), np.cos(theta), 0],
|
||||
[0, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Scale matrix
|
||||
S = np.array([[scale_factor, 0, 0], [0, scale_factor, 0], [0, 0, 1]])
|
||||
|
||||
# Translation matrix to center the UAV image
|
||||
T_uav = np.array([[1, 0, -uav_width / 2], [0, 1, -uav_height / 2], [0, 0, 1]])
|
||||
|
||||
# Translation matrix to move to the satellite image position
|
||||
T_sat = np.array([[1, 0, x_sat], [0, 1, y_sat], [0, 0, 1]])
|
||||
|
||||
# Compute the combined homography matrix
|
||||
H = np.dot(T_sat, np.dot(R, np.dot(S, T_uav)))
|
||||
|
||||
return H
|
||||
|
||||
def sample_four_points(self, width: int, height: int) -> np.ndarray:
|
||||
"""
|
||||
Samples four points from the UAV image.
|
||||
"""
|
||||
PADDING = 50
|
||||
CENTER_PADDING = 10
|
||||
points = np.array(
|
||||
[
|
||||
[random.randint(CENTER_PADDING, width - PADDING), random.randint(CENTER_PADDING, height - PADDING)]
|
||||
for _ in range(4)
|
||||
]
|
||||
)
|
||||
return points
|
||||
|
||||
def warp_points(
|
||||
self, points: np.ndarray, homography_matrix: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Warps the given points using the given homography matrix.
|
||||
"""
|
||||
points = np.array(points)
|
||||
points = np.concatenate([points, np.ones((4, 1))], axis=1)
|
||||
points = np.dot(homography_matrix, points.T).T
|
||||
points = points[:, :2] / points[:, 2:]
|
||||
return points
|
||||
|
||||
def count_total_uav_samples(self) -> int:
|
||||
"""
|
||||
Count the total number of uav image samples in the dataset
|
||||
(train + test)
|
||||
"""
|
||||
total_samples = 0
|
||||
|
||||
for dirpath, dirnames, filenames in os.walk(self.uav_dataset_dir):
|
||||
# Skip the test folder
|
||||
for filename in filenames:
|
||||
if filename.endswith(".jpeg"):
|
||||
total_samples += 1
|
||||
return total_samples
|
||||
|
||||
def get_number_of_city_samples(self) -> int:
|
||||
"""
|
||||
TODO: Count the total number of city samples in the dataset
|
||||
"""
|
||||
return 11
|
||||
|
||||
def get_entry_paths(self, directory: str) -> List[str]:
|
||||
"""
|
||||
Recursively retrieves paths to image and metadata files in the given directory.
|
||||
"""
|
||||
entry_paths = []
|
||||
entries = os.listdir(directory)
|
||||
|
||||
images_to_take_per_folder = int(
|
||||
self.total_uav_samples
|
||||
* self.test_from_train_ratio
|
||||
/ self.get_number_of_city_samples()
|
||||
)
|
||||
|
||||
for entry in entries:
|
||||
entry_path = os.path.join(directory, entry)
|
||||
|
||||
# If it's a directory, recurse into it
|
||||
if os.path.isdir(entry_path):
|
||||
entry_paths += self.get_entry_paths(entry_path)
|
||||
|
||||
# Handle train dataset
|
||||
elif (self.dataset == "train" and "Train" in entry_path) or (
|
||||
self.dataset == "train"
|
||||
and self.test_from_train_ratio > 0
|
||||
and "Test" in entry_path
|
||||
):
|
||||
if entry_path.endswith(".jpeg"):
|
||||
_, number = self.extract_info_from_filename(entry_path)
|
||||
else:
|
||||
number = None
|
||||
if entry_path.endswith(".json"):
|
||||
self.get_metadata(entry_path)
|
||||
if number is None:
|
||||
continue
|
||||
if (
|
||||
number >= images_to_take_per_folder
|
||||
): # Only include images beyond the ones taken for test
|
||||
if entry_path.endswith(".jpeg"):
|
||||
entry_paths.append(entry_path)
|
||||
|
||||
# Handle test dataset
|
||||
elif self.dataset == "test":
|
||||
if entry_path.endswith(".jpeg"):
|
||||
_, number = self.extract_info_from_filename(entry_path)
|
||||
else:
|
||||
number = None
|
||||
if entry_path.endswith(".json"):
|
||||
self.get_metadata(entry_path)
|
||||
|
||||
if number is None:
|
||||
continue
|
||||
if (
|
||||
("Test" in entry_path and number < images_to_take_per_folder)
|
||||
or (number < images_to_take_per_folder and "Train" in entry_path)
|
||||
or (self.test_from_train_ratio == 0.0 and "Test" in entry_path)
|
||||
):
|
||||
if entry_path.endswith(".jpeg"):
|
||||
entry_paths.append(entry_path)
|
||||
|
||||
return sorted(entry_paths, key=self.extract_info_from_filename)
|
||||
|
||||
def get_metadata(self, path: str) -> None:
|
||||
"""
|
||||
Extracts metadata from a JSON file and stores it in the metadata dictionary.
|
||||
"""
|
||||
with open(path, newline="") as jsonfile:
|
||||
json_dict = json.load(jsonfile)
|
||||
path = path.split("/")[-1]
|
||||
path = path.replace(".json", "")
|
||||
self.metadata_dict[path] = json_dict["cameraFrames"]
|
||||
|
||||
def extract_info_from_filename(self, filename: str) -> (str, int):
|
||||
"""
|
||||
Extracts information from the filename.
|
||||
"""
|
||||
filename_without_ext = filename.replace(".jpeg", "")
|
||||
segments = filename_without_ext.split("/")
|
||||
info = segments[-1]
|
||||
try:
|
||||
number = int(info.split("_")[-1])
|
||||
except ValueError:
|
||||
print("Could not extract number from filename: ", filename)
|
||||
return None, None
|
||||
|
||||
info = "_".join(info.split("_")[:-1])
|
||||
|
||||
return info, number
|
||||
|
||||
def get_heatmap_gt(
|
||||
self, x: int, y: int, height: int, width: int, square_size: int = 33
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns 2D heatmap ground truth for the given x and y coordinates,
|
||||
with the given square size.
|
||||
"""
|
||||
x_map, y_map = x, y
|
||||
|
||||
heatmap = torch.zeros((height, width))
|
||||
|
||||
half_size = square_size // 2
|
||||
|
||||
# Calculate the valid range for the square
|
||||
start_x = max(0, x_map - half_size)
|
||||
end_x = min(
|
||||
width, x_map + half_size + 1
|
||||
) # +1 to include the end_x in the square
|
||||
start_y = max(0, y_map - half_size)
|
||||
end_y = min(
|
||||
height, y_map + half_size + 1
|
||||
) # +1 to include the end_y in the square
|
||||
|
||||
heatmap[start_y:end_y, start_x:end_x] = 1
|
||||
|
||||
return heatmap
|
||||
|
||||
if __name__ == "__main__":
|
||||
from .. import logger # overwrite the logger
|
|
@ -194,7 +194,7 @@ class ETH3DDataset(BaseDataset):
|
|||
if tmp_dir.exists():
|
||||
shutil.rmtree(tmp_dir)
|
||||
tmp_dir.mkdir(exist_ok=True, parents=True)
|
||||
url_base = "https://cvg-data.inf.ethz.ch/ETH3D_undistorted/"
|
||||
url_base = "https://cvg-data.inf.ethz.ch/SOLD2/SOLD2_ETH3D_undistorted/"
|
||||
zip_name = "ETH3D_undistorted.zip"
|
||||
zip_path = tmp_dir / zip_name
|
||||
torch.hub.download_url_to_file(url_base + zip_name, zip_path)
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
#! /usr/bin/env python3
|
||||
|
||||
import os
|
||||
import mercantile
|
||||
import rasterio
|
||||
import numpy as np
|
||||
import random
|
||||
import warnings
|
||||
from rasterio.errors import NotGeoreferencedWarning
|
||||
from rasterio.io import MemoryFile
|
||||
from rasterio.transform import from_bounds
|
||||
from rasterio.merge import merge
|
||||
from PIL import Image
|
||||
import gc
|
||||
import random
|
||||
from osgeo import gdal, osr
|
||||
from affine import Affine
|
||||
|
||||
|
||||
def get_5x5_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]:
|
||||
neighbors = []
|
||||
for main_neighbour in mercantile.neighbors(tile):
|
||||
for sub_neighbour in mercantile.neighbors(main_neighbour):
|
||||
if sub_neighbour not in neighbors:
|
||||
neighbors.append(sub_neighbour)
|
||||
return neighbors
|
||||
|
||||
def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir: str) -> (np.ndarray, dict):
|
||||
"""
|
||||
Returns a TIFF map of the given tile using GDAL.
|
||||
"""
|
||||
tile_data = []
|
||||
neighbors = get_5x5_neighbors(tile)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
for neighbor in neighbors:
|
||||
west, south, east, north = mercantile.bounds(neighbor)
|
||||
tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg"
|
||||
if not os.path.exists(tile_path):
|
||||
raise FileNotFoundError(f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found.")
|
||||
|
||||
img = Image.open(tile_path)
|
||||
img_array = np.array(img)
|
||||
|
||||
# Create an in-memory GDAL dataset
|
||||
mem_driver = gdal.GetDriverByName('MEM')
|
||||
dataset = mem_driver.Create('', img_array.shape[1], img_array.shape[0], 3, gdal.GDT_Byte)
|
||||
for i in range(3):
|
||||
dataset.GetRasterBand(i + 1).WriteArray(img_array[:, :, i])
|
||||
|
||||
# Set GeoTransform and Projection
|
||||
geotransform = (west, (east - west) / img_array.shape[1], 0, north, 0, -(north - south) / img_array.shape[0])
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
srs = osr.SpatialReference()
|
||||
srs.ImportFromEPSG(3857)
|
||||
dataset.SetProjection(srs.ExportToWkt())
|
||||
|
||||
tile_data.append(dataset)
|
||||
|
||||
# Merge tiles using GDAL
|
||||
vrt_options = gdal.BuildVRTOptions()
|
||||
vrt = gdal.BuildVRT('', [td for td in tile_data], options=vrt_options)
|
||||
mosaic = vrt.ReadAsArray()
|
||||
|
||||
# Get metadata
|
||||
out_trans = vrt.GetGeoTransform()
|
||||
out_crs = vrt.GetProjection()
|
||||
out_trans = Affine.from_gdal(*out_trans)
|
||||
out_meta = {
|
||||
"driver": "GTiff",
|
||||
"height": mosaic.shape[1],
|
||||
"width": mosaic.shape[2],
|
||||
"transform": out_trans,
|
||||
"crs": out_crs,
|
||||
}
|
||||
|
||||
# Clean up
|
||||
for td in tile_data:
|
||||
td.FlushCache()
|
||||
vrt = None
|
||||
gc.collect()
|
||||
|
||||
return mosaic, out_meta
|
||||
|
||||
def get_random_tiff_patch(
|
||||
lat: float,
|
||||
lon: float,
|
||||
patch_width: int,
|
||||
patch_height: int,
|
||||
sat_year: str,
|
||||
satellite_dataset_dir: str = "/mnt/drive/satellite_dataset",
|
||||
) -> (np.ndarray, int, int, int, int, rasterio.transform.Affine):
|
||||
"""
|
||||
Returns a random patch from the satellite image.
|
||||
"""
|
||||
|
||||
tile = get_tile_from_coord(lat, lon, 17)
|
||||
|
||||
mosaic, out_meta = get_tiff_map(tile, sat_year, satellite_dataset_dir)
|
||||
|
||||
|
||||
transform = out_meta["transform"]
|
||||
del out_meta
|
||||
|
||||
x_pixel, y_pixel = geo_to_pixel_coordinates(lat, lon, transform)
|
||||
|
||||
# TODO
|
||||
# Temporal constant, replace with a better solution
|
||||
KS = 120
|
||||
|
||||
x_offset_range = [
|
||||
x_pixel - patch_width + KS + 1,
|
||||
x_pixel - KS - 1,
|
||||
]
|
||||
y_offset_range = [
|
||||
y_pixel - patch_height + KS + 1,
|
||||
y_pixel - KS - 1,
|
||||
]
|
||||
|
||||
# Randomly select an offset within the valid range
|
||||
x_offset = random.randint(*x_offset_range)
|
||||
y_offset = random.randint(*y_offset_range)
|
||||
|
||||
x_offset = np.clip(x_offset, 0, mosaic.shape[-1] - patch_width)
|
||||
y_offset = np.clip(y_offset, 0, mosaic.shape[-2] - patch_height)
|
||||
|
||||
# Update x, y to reflect the clamping of x_offset and y_offset
|
||||
x, y = x_pixel - x_offset, y_pixel - y_offset
|
||||
patch = mosaic[
|
||||
:, y_offset : y_offset + patch_height, x_offset : x_offset + patch_width
|
||||
]
|
||||
|
||||
patch_transform = rasterio.transform.Affine(
|
||||
transform.a,
|
||||
transform.b,
|
||||
transform.c + x_offset * transform.a + y_offset * transform.b,
|
||||
transform.d,
|
||||
transform.e,
|
||||
transform.f + x_offset * transform.d + y_offset * transform.e,
|
||||
)
|
||||
|
||||
gc.collect()
|
||||
|
||||
return patch, x, y, x_offset, y_offset, patch_transform
|
||||
|
||||
def get_tile_from_coord(
|
||||
lat: float, lng: float, zoom_level: int
|
||||
) -> mercantile.Tile:
|
||||
"""
|
||||
Returns the tile containing the given coordinates.
|
||||
"""
|
||||
tile = mercantile.tile(lng, lat, zoom_level)
|
||||
return tile
|
||||
|
||||
def geo_to_pixel_coordinates(
|
||||
lat: float, lon: float, transform: rasterio.transform.Affine
|
||||
) -> (int, int):
|
||||
"""
|
||||
Converts a pair of (lat, lon) coordinates to pixel coordinates.
|
||||
"""
|
||||
x_pixel, y_pixel = ~transform * (lon, lat)
|
||||
return round(x_pixel), round(y_pixel)
|
|
@ -0,0 +1,109 @@
|
|||
#! /usr/bin/env python3
|
||||
|
||||
import os
|
||||
import mercantile
|
||||
import rasterio
|
||||
import numpy as np
|
||||
import random
|
||||
import warnings
|
||||
from rasterio.errors import NotGeoreferencedWarning
|
||||
from rasterio.io import MemoryFile
|
||||
from rasterio.transform import from_bounds
|
||||
from rasterio.merge import merge
|
||||
from PIL import Image
|
||||
import gc
|
||||
|
||||
|
||||
def get_3x3_neighbors(tile: mercantile.Tile) -> list[mercantile.Tile]:
|
||||
neighbors = []
|
||||
for neighbour in mercantile.neighbors(tile):
|
||||
if neighbour not in neighbors:
|
||||
neighbors.append(neighbour)
|
||||
|
||||
neighbors.append(tile)
|
||||
return neighbors
|
||||
|
||||
def get_tiff_map(tile: mercantile.Tile, sat_year: str, satellite_dataset_dir:str) -> (np.ndarray, dict):
|
||||
"""
|
||||
Returns a TIFF map of the given tile.
|
||||
"""
|
||||
tile_data = []
|
||||
neighbors = get_3x3_neighbors(tile)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)
|
||||
for neighbor in neighbors:
|
||||
west, south, east, north = mercantile.bounds(neighbor)
|
||||
tile_path = f"{satellite_dataset_dir}/{sat_year}/{neighbor.z}_{neighbor.x}_{neighbor.y}.jpg"
|
||||
if not os.path.exists(tile_path):
|
||||
raise FileNotFoundError(
|
||||
f"Tile {neighbor.z}_{neighbor.x}_{neighbor.y} not found."
|
||||
)
|
||||
|
||||
with Image.open(tile_path) as img:
|
||||
width, height = img.size
|
||||
memfile = MemoryFile()
|
||||
with memfile.open(
|
||||
driver="GTiff",
|
||||
height=height,
|
||||
width=width,
|
||||
count=3,
|
||||
dtype="uint8",
|
||||
crs="EPSG:3857",
|
||||
transform=from_bounds(west, south, east, north, width, height),
|
||||
) as dataset:
|
||||
data = rasterio.open(tile_path).read()
|
||||
dataset.write(data)
|
||||
tile_data.append(memfile.open())
|
||||
memfile.close()
|
||||
|
||||
mosaic, out_trans = merge(tile_data)
|
||||
|
||||
out_meta = tile_data[0].meta.copy()
|
||||
out_meta.update(
|
||||
{
|
||||
"driver": "GTiff",
|
||||
"height": mosaic.shape[1],
|
||||
"width": mosaic.shape[2],
|
||||
"transform": out_trans,
|
||||
"crs": "EPSG:3857",
|
||||
}
|
||||
)
|
||||
|
||||
# Clean up MemoryFile instances to free up memory
|
||||
for td in tile_data:
|
||||
td.close()
|
||||
|
||||
del neighbors
|
||||
del tile_data
|
||||
gc.collect()
|
||||
|
||||
return mosaic, out_meta
|
||||
|
||||
def get_random_tiff_patch(
|
||||
lat: float,
|
||||
lon: float,
|
||||
satellite_dataset_dir: str,
|
||||
) -> (np.ndarray):
|
||||
"""
|
||||
Returns a random patch from the satellite image.
|
||||
"""
|
||||
|
||||
tile = get_tile_from_coord(lat, lon, 17)
|
||||
sat_years = ["2023", "2021", "2019", "2016"]
|
||||
|
||||
# Randomly select a satellite year
|
||||
sat_year = random.choice(sat_years)
|
||||
|
||||
mosaic, _ = get_tiff_map(tile, sat_year, satellite_dataset_dir)
|
||||
|
||||
return mosaic
|
||||
|
||||
def get_tile_from_coord(
|
||||
lat: float, lng: float, zoom_level: int
|
||||
) -> mercantile.Tile:
|
||||
"""
|
||||
Returns the tile containing the given coordinates.
|
||||
"""
|
||||
tile = mercantile.tile(lng, lat, zoom_level)
|
||||
return tile
|
|
@ -0,0 +1,297 @@
|
|||
"""
|
||||
Simply load images from a folder or nested folders (does not have any split),
|
||||
and apply homographic adaptations to it. Yields an image pair without border
|
||||
artifacts.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
import mercantile
|
||||
import rasterio
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import omegaconf
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from ..geometry.homography import (
|
||||
compute_homography,
|
||||
sample_homography_corners,
|
||||
warp_points,
|
||||
)
|
||||
from ..models.cache_loader import CacheLoader, pad_local_features
|
||||
from ..settings import DATA_PATH
|
||||
from ..utils.image import read_image
|
||||
from ..utils.tools import fork_rng
|
||||
from ..visualization.viz2d import plot_image_grid
|
||||
from .augmentations import IdentityAugmentation, augmentations
|
||||
from .base_dataset import BaseDataset
|
||||
from .sat_utils import get_random_tiff_patch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def sample_homography(img, conf: dict, size: list):
|
||||
data = {}
|
||||
H, _, coords, _ = sample_homography_corners(img.shape[:2][::-1], **conf)
|
||||
data["image"] = cv2.warpPerspective(img, H, tuple(size))
|
||||
data["H_"] = H.astype(np.float32)
|
||||
data["coords"] = coords.astype(np.float32)
|
||||
data["image_size"] = np.array(size, dtype=np.float32)
|
||||
return data
|
||||
|
||||
|
||||
class SatelliteDataset(BaseDataset):
|
||||
default_conf = {
|
||||
# image search
|
||||
"data_dir": "revisitop1m", # the top-level directory
|
||||
"image_dir": "jpg/", # the subdirectory with the images
|
||||
"image_list": "revisitop1m.txt", # optional: list or filename of list
|
||||
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
|
||||
"metadata_dir": None,
|
||||
# splits
|
||||
"train_size": 100,
|
||||
"val_size": 10,
|
||||
"shuffle_seed": 0, # or None to skip
|
||||
# image loading
|
||||
"grayscale": False,
|
||||
"triplet": False,
|
||||
"right_only": False, # image0 is orig (rescaled), image1 is right
|
||||
"reseed": False,
|
||||
"homography": {
|
||||
"difficulty": 0.8,
|
||||
"translation": 1.0,
|
||||
"max_angle": 60,
|
||||
"n_angles": 10,
|
||||
"patch_shape": [640, 480],
|
||||
"min_convexity": 0.05,
|
||||
},
|
||||
"photometric": {
|
||||
"name": "dark",
|
||||
"p": 0.75,
|
||||
# 'difficulty': 1.0, # currently unused
|
||||
},
|
||||
# feature loading
|
||||
"load_features": {
|
||||
"do": False,
|
||||
**CacheLoader.default_conf,
|
||||
"collate": False,
|
||||
"thresh": 0.0,
|
||||
"max_num_keypoints": -1,
|
||||
"force_num_keypoints": False,
|
||||
},
|
||||
}
|
||||
|
||||
def _init(self, conf):
|
||||
data_dir = conf.data_dir
|
||||
coordinates_file = conf.metadata_dir
|
||||
|
||||
with open(coordinates_file, "r") as cf:
|
||||
coordinates = cf.readlines()
|
||||
|
||||
parsed_coordinates = []
|
||||
|
||||
for coordinate in coordinates:
|
||||
lat_part, lon_part = coordinate.split(',')
|
||||
lat = float(lat_part.split(':')[-1].strip())
|
||||
lon = float(lon_part.split(':')[-1].strip())
|
||||
parsed_coordinates.append((lat, lon))
|
||||
|
||||
# Split into train and val 20% val
|
||||
|
||||
train_images = parsed_coordinates[:int(len(parsed_coordinates) * 0.8)]
|
||||
val_images = parsed_coordinates[int(len(parsed_coordinates) * 0.8):]
|
||||
|
||||
self.images = {"train": train_images, "val": val_images}
|
||||
|
||||
def get_dataset(self, split):
|
||||
return _Dataset(self.conf, self.images[split], split)
|
||||
|
||||
|
||||
class _Dataset(torch.utils.data.Dataset):
|
||||
def __init__(self, conf, image_names, split):
|
||||
self.conf = conf
|
||||
self.split = split
|
||||
self.image_names = np.array(image_names)
|
||||
self.image_dir = DATA_PATH / conf.data_dir / conf.image_dir
|
||||
|
||||
aug_conf = conf.photometric
|
||||
aug_name = aug_conf.name
|
||||
assert (
|
||||
aug_name in augmentations.keys()
|
||||
), f'{aug_name} not in {" ".join(augmentations.keys())}'
|
||||
#self.left_augment = (
|
||||
# IdentityAugmentation() if conf.right_only else self.photo_augment
|
||||
#)
|
||||
self.photo_augment = augmentations[aug_name](aug_conf)
|
||||
self.left_augment = augmentations[aug_name](aug_conf)
|
||||
self.img_to_tensor = IdentityAugmentation()
|
||||
|
||||
if conf.load_features.do:
|
||||
self.feature_loader = CacheLoader(conf.load_features)
|
||||
|
||||
def _transform_keypoints(self, features, data):
|
||||
"""Transform keypoints by a homography, threshold them,
|
||||
and potentially keep only the best ones."""
|
||||
# Warp points
|
||||
features["keypoints"] = warp_points(
|
||||
features["keypoints"], data["H_"], inverse=False
|
||||
)
|
||||
h, w = data["image"].shape[1:3]
|
||||
valid = (
|
||||
(features["keypoints"][:, 0] >= 0)
|
||||
& (features["keypoints"][:, 0] <= w - 1)
|
||||
& (features["keypoints"][:, 1] >= 0)
|
||||
& (features["keypoints"][:, 1] <= h - 1)
|
||||
)
|
||||
features["keypoints"] = features["keypoints"][valid]
|
||||
|
||||
# Threshold
|
||||
if self.conf.load_features.thresh > 0:
|
||||
valid = features["keypoint_scores"] >= self.conf.load_features.thresh
|
||||
features = {k: v[valid] for k, v in features.items()}
|
||||
|
||||
# Get the top keypoints and pad
|
||||
n = self.conf.load_features.max_num_keypoints
|
||||
if n > -1:
|
||||
inds = np.argsort(-features["keypoint_scores"])
|
||||
features = {k: v[inds[:n]] for k, v in features.items()}
|
||||
|
||||
if self.conf.load_features.force_num_keypoints:
|
||||
features = pad_local_features(
|
||||
features, self.conf.load_features.max_num_keypoints
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.conf.reseed:
|
||||
with fork_rng(self.conf.seed + idx, False):
|
||||
return self.getitem(idx)
|
||||
else:
|
||||
return self.getitem(idx)
|
||||
|
||||
def _read_view(self, img, H_conf, ps, left=False):
|
||||
data = sample_homography(img, H_conf, ps)
|
||||
if left:
|
||||
data["image"] = self.left_augment(data["image"], return_tensor=True)
|
||||
else:
|
||||
data["image"] = self.photo_augment(data["image"], return_tensor=True)
|
||||
|
||||
gs = data["image"].new_tensor([0.299, 0.587, 0.114]).view(3, 1, 1)
|
||||
if self.conf.grayscale:
|
||||
data["image"] = (data["image"] * gs).sum(0, keepdim=True)
|
||||
|
||||
if self.conf.load_features.do:
|
||||
features = self.feature_loader({k: [v] for k, v in data.items()})
|
||||
features = self._transform_keypoints(features, data)
|
||||
data["cache"] = features
|
||||
|
||||
return data
|
||||
|
||||
def getitem(self, idx):
|
||||
# Generate a list of coordinates, based on the coordinate do the split
|
||||
lat, lon = self.image_names[idx]
|
||||
img = get_random_tiff_patch(lat, lon, self.conf.data_dir)
|
||||
|
||||
if img is None:
|
||||
logging.warning("Image %f %f could not be read.", lat, lon)
|
||||
img = np.zeros((1024, 1024) + (() if self.conf.grayscale else (3,)))
|
||||
|
||||
img = img.transpose(1,2,0)
|
||||
img = img.astype(np.float32) / 255.0
|
||||
#write_status = cv2.imwrite(f"/mnt/drive/{str(idx)}.jpg", img)
|
||||
#if write_status == True:
|
||||
# print("Writing success")
|
||||
#else:
|
||||
# print("fuck")
|
||||
size = img.shape[:2][::-1]
|
||||
ps = self.conf.homography.patch_shape
|
||||
|
||||
left_conf = omegaconf.OmegaConf.to_container(self.conf.homography)
|
||||
if self.conf.right_only:
|
||||
left_conf["difficulty"] = 0.0
|
||||
|
||||
data0 = self._read_view(img, left_conf, ps, left=True)
|
||||
data1 = self._read_view(img, self.conf.homography, ps, left=False)
|
||||
|
||||
print("Data0 homography_matrix", data0["H_"])
|
||||
print("Data1 homography matrix", data1["H_"])
|
||||
|
||||
# image_1 = data0["image"]
|
||||
# image_1 = image_1.numpy()
|
||||
# image_1 = image_1.transpose(1,2,0)
|
||||
# image_1 = np.uint8(image_1 * 255)
|
||||
#
|
||||
# image_2 = data1["image"]
|
||||
# image_2 = image_2.numpy()
|
||||
# image_2 = image_2.transpose(1,2,0)
|
||||
# image_2 = np.uint8(image_2 * 255)
|
||||
#
|
||||
# PIL_IMAGE = Image.fromarray(image_1)
|
||||
# PIL_IMAGE.save(f"/mnt/drive/{str(idx)}.jpg")
|
||||
# PIL_IMAGE = Image.fromarray(image_2)
|
||||
# PIL_IMAGE.save(f"/mnt/drive/{str(idx)}-s.jpg")
|
||||
#
|
||||
# exit()
|
||||
#
|
||||
H = compute_homography(data0["coords"], data1["coords"], [1, 1])
|
||||
|
||||
print("COmputed homography", H)
|
||||
|
||||
name = f"lat{lat}_lon{lon}"
|
||||
|
||||
data = {
|
||||
"name": name,
|
||||
"original_image_size": np.array(size),
|
||||
"H_0to1": H.astype(np.float32),
|
||||
"idx": idx,
|
||||
"view0": data0,
|
||||
"view1": data1,
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_names)
|
||||
|
||||
|
||||
def visualize(args):
|
||||
conf = {
|
||||
"batch_size": 1,
|
||||
"num_workers": 1,
|
||||
"prefetch_factor": 1,
|
||||
}
|
||||
conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
|
||||
dataset = SatelliteDataset(conf)
|
||||
loader = dataset.get_data_loader("train")
|
||||
logger.info("The dataset has %d elements.", len(loader))
|
||||
|
||||
with fork_rng(seed=dataset.conf.seed):
|
||||
images = []
|
||||
for _, data in zip(range(args.num_items), loader):
|
||||
images.append(
|
||||
(data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2))
|
||||
)
|
||||
plot_image_grid(images, dpi=args.dpi)
|
||||
plt.tight_layout()
|
||||
plt.imsave("implot.png")
|
||||
#plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from .. import logger # overwrite the logger
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num_items", type=int, default=8)
|
||||
parser.add_argument("--dpi", type=int, default=100)
|
||||
parser.add_argument("dotlist", nargs="*")
|
||||
args = parser.parse_intermixed_args()
|
||||
visualize(args)
|
|
@ -5,6 +5,7 @@ from pprint import pprint
|
|||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -12,6 +13,7 @@ from ..datasets import get_dataset
|
|||
from ..models.cache_loader import CacheLoader
|
||||
from ..settings import EVAL_PATH
|
||||
from ..utils.export_predictions import export_predictions
|
||||
from ..utils.tensor import map_tensor
|
||||
from ..utils.tools import AUCMetric
|
||||
from ..visualization.viz2d import plot_cumulative
|
||||
from .eval_pipeline import EvalPipeline
|
||||
|
@ -105,9 +107,11 @@ class HPatchesPipeline(EvalPipeline):
|
|||
cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
|
||||
for i, data in enumerate(tqdm(loader)):
|
||||
pred = cache_loader(data)
|
||||
# Remove batch dimension
|
||||
data = map_tensor(data, lambda t: torch.squeeze(t, dim=0))
|
||||
# add custom evaluations here
|
||||
if "keypoints0" in pred:
|
||||
results_i = eval_matches_homography(data, pred, {})
|
||||
results_i = eval_matches_homography(data, pred)
|
||||
results_i = {**results_i, **eval_homography_dlt(data, pred)}
|
||||
else:
|
||||
results_i = {}
|
||||
|
|
|
@ -89,6 +89,11 @@ def load_model(model_conf, checkpoint):
|
|||
model = load_experiment(checkpoint, conf=model_conf).eval()
|
||||
else:
|
||||
model = get_model("two_view_pipeline")(model_conf).eval()
|
||||
if not model.is_initialized():
|
||||
raise ValueError(
|
||||
"The provided model has non-initialized parameters. "
|
||||
+ "Try to load a checkpoint instead."
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import kornia
|
||||
import numpy as np
|
||||
import torch
|
||||
from kornia.geometry.homography import find_homography_dlt
|
||||
|
||||
from ..geometry.epipolar import generalized_epi_dist, relative_pose_error
|
||||
from ..geometry.gt_generation import IGNORE_FEATURE
|
||||
from ..geometry.homography import homography_corner_error, sym_homography_error
|
||||
from ..robust_estimators import load_estimator
|
||||
from ..utils.tensor import index_batch
|
||||
from ..utils.tools import AUCMetric
|
||||
|
||||
|
||||
|
@ -26,6 +27,16 @@ def get_matches_scores(kpts0, kpts1, matches0, mscores0):
|
|||
return pts0, pts1, scores
|
||||
|
||||
|
||||
def eval_per_batch_item(data: dict, pred: dict, eval_f, *args, **kwargs):
|
||||
# Batched data
|
||||
results = [
|
||||
eval_f(data_i, pred_i, *args, **kwargs)
|
||||
for data_i, pred_i in zip(index_batch(data), index_batch(pred))
|
||||
]
|
||||
# Return a dictionary of lists with the evaluation of each item
|
||||
return {k: [r[k] for r in results] for k in results[0].keys()}
|
||||
|
||||
|
||||
def eval_matches_epipolar(data: dict, pred: dict) -> dict:
|
||||
check_keys_recursive(data, ["view0", "view1", "T_0to1"])
|
||||
check_keys_recursive(
|
||||
|
@ -58,23 +69,25 @@ def eval_matches_epipolar(data: dict, pred: dict) -> dict:
|
|||
return results
|
||||
|
||||
|
||||
def eval_matches_homography(data: dict, pred: dict, conf) -> dict:
|
||||
def eval_matches_homography(data: dict, pred: dict) -> dict:
|
||||
check_keys_recursive(data, ["H_0to1"])
|
||||
check_keys_recursive(
|
||||
pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"]
|
||||
)
|
||||
|
||||
H_gt = data["H_0to1"]
|
||||
if H_gt.ndim > 2:
|
||||
return eval_per_batch_item(data, pred, eval_matches_homography)
|
||||
|
||||
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
|
||||
m0, scores0 = pred["matches0"], pred["matching_scores0"]
|
||||
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
|
||||
err = sym_homography_error(pts0, pts1, H_gt[0])
|
||||
err = sym_homography_error(pts0, pts1, H_gt)
|
||||
results = {}
|
||||
results["prec@1px"] = (err < 1).float().mean().nan_to_num().item()
|
||||
results["prec@3px"] = (err < 3).float().mean().nan_to_num().item()
|
||||
results["num_matches"] = pts0.shape[0]
|
||||
results["num_keypoints"] = (kp0.shape[0] + kp1.shape[0]) / 2.0
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
@ -84,7 +97,7 @@ def eval_relative_pose_robust(data, pred, conf):
|
|||
pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"]
|
||||
)
|
||||
|
||||
T_gt = data["T_0to1"][0]
|
||||
T_gt = data["T_0to1"]
|
||||
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
|
||||
m0, scores0 = pred["matches0"], pred["matching_scores0"]
|
||||
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
|
||||
|
@ -107,9 +120,8 @@ def eval_relative_pose_robust(data, pred, conf):
|
|||
else:
|
||||
# R, t, inl = ret
|
||||
M = est["M_0to1"]
|
||||
R, t = M.numpy()
|
||||
inl = est["inliers"].numpy()
|
||||
r_error, t_error = relative_pose_error(T_gt, R, t)
|
||||
t_error, r_error = relative_pose_error(T_gt, M.R, M.t)
|
||||
results["rel_pose_error"] = max(r_error, t_error)
|
||||
results["ransac_inl"] = np.sum(inl)
|
||||
results["ransac_inl%"] = np.mean(inl)
|
||||
|
@ -119,6 +131,9 @@ def eval_relative_pose_robust(data, pred, conf):
|
|||
|
||||
def eval_homography_robust(data, pred, conf):
|
||||
H_gt = data["H_0to1"]
|
||||
if H_gt.ndim > 2:
|
||||
return eval_per_batch_item(data, pred, eval_relative_pose_robust, conf)
|
||||
|
||||
estimator = load_estimator("homography", conf["estimator"])(conf)
|
||||
|
||||
data_ = {}
|
||||
|
@ -158,24 +173,26 @@ def eval_homography_robust(data, pred, conf):
|
|||
return results
|
||||
|
||||
|
||||
def eval_homography_dlt(data, pred, *args):
|
||||
def eval_homography_dlt(data, pred):
|
||||
H_gt = data["H_0to1"]
|
||||
H_inf = torch.ones_like(H_gt) * float("inf")
|
||||
|
||||
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
|
||||
m0, scores0 = pred["matches0"], pred["matching_scores0"]
|
||||
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
|
||||
scores = scores.to(pts0)
|
||||
results = {}
|
||||
try:
|
||||
Hdlt = kornia.geometry.homography.find_homography_dlt(
|
||||
pts0[None], pts1[None], scores[None].to(pts0)
|
||||
)[0]
|
||||
if H_gt.ndim == 2:
|
||||
pts0, pts1, scores = pts0[None], pts1[None], scores[None]
|
||||
h_dlt = find_homography_dlt(pts0, pts1, scores)
|
||||
if H_gt.ndim == 2:
|
||||
h_dlt = h_dlt[0]
|
||||
except AssertionError:
|
||||
Hdlt = H_inf
|
||||
h_dlt = H_inf
|
||||
|
||||
error_dlt = homography_corner_error(Hdlt, H_gt, data["view0"]["image_size"])
|
||||
error_dlt = homography_corner_error(h_dlt, H_gt, data["view0"]["image_size"])
|
||||
results["H_error_dlt"] = error_dlt.item()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .utils import skew_symmetric, to_homogeneous
|
||||
|
@ -124,39 +123,33 @@ def decompose_essential_matrix(E):
|
|||
|
||||
|
||||
# pose errors
|
||||
# TODO: port to torch and batch
|
||||
# TODO: test for batched data
|
||||
def angle_error_mat(R1, R2):
|
||||
cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
|
||||
cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
|
||||
return np.rad2deg(np.abs(np.arccos(cos)))
|
||||
cos = (torch.trace(torch.einsum("...ij, ...jk -> ...ik", R1.T, R2)) - 1) / 2
|
||||
cos = torch.clip(cos, -1.0, 1.0) # numerical errors can make it out of bounds
|
||||
return torch.rad2deg(torch.abs(torch.arccos(cos)))
|
||||
|
||||
|
||||
def angle_error_vec(v1, v2):
|
||||
n = np.linalg.norm(v1) * np.linalg.norm(v2)
|
||||
return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
|
||||
def angle_error_vec(v1, v2, eps=1e-10):
|
||||
n = torch.clip(v1.norm(dim=-1) * v2.norm(dim=-1), min=eps)
|
||||
v1v2 = (v1 * v2).sum(dim=-1) # dot product in the last dimension
|
||||
return torch.rad2deg(torch.arccos(torch.clip(v1v2 / n, -1.0, 1.0)))
|
||||
|
||||
|
||||
def compute_pose_error(T_0to1, R, t):
|
||||
R_gt = T_0to1[:3, :3]
|
||||
t_gt = T_0to1[:3, 3]
|
||||
error_t = angle_error_vec(t, t_gt)
|
||||
error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
|
||||
error_R = angle_error_mat(R, R_gt)
|
||||
return error_t, error_R
|
||||
def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0, eps=1e-10):
|
||||
if isinstance(T_0to1, torch.Tensor):
|
||||
R_gt, t_gt = T_0to1[:3, :3], T_0to1[:3, 3]
|
||||
else:
|
||||
R_gt, t_gt = T_0to1.R, T_0to1.t
|
||||
R_gt, t_gt = torch.squeeze(R_gt), torch.squeeze(t_gt)
|
||||
|
||||
|
||||
def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
|
||||
# angle error between 2 vectors
|
||||
R_gt, t_gt = T_0to1.numpy()
|
||||
n = np.linalg.norm(t) * np.linalg.norm(t_gt)
|
||||
t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0)))
|
||||
t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity
|
||||
if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging
|
||||
t_err = angle_error_vec(t, t_gt, eps)
|
||||
t_err = torch.minimum(t_err, 180 - t_err) # handle E ambiguity
|
||||
if t_gt.norm() < ignore_gt_t_thr: # pure rotation is challenging
|
||||
t_err = 0
|
||||
|
||||
# angle error between 2 rotation matrices
|
||||
cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
|
||||
cos = np.clip(cos, -1.0, 1.0) # handle numercial errors
|
||||
R_err = np.rad2deg(np.abs(np.arccos(cos)))
|
||||
r_err = angle_error_mat(R, R_gt)
|
||||
|
||||
return t_err, R_err
|
||||
return t_err, r_err
|
||||
|
|
|
@ -53,6 +53,7 @@ def sample_homography_corners(
|
|||
min_pts1 = create_center_patch(shape, (pwidth, pheight))
|
||||
full = create_center_patch(shape)
|
||||
pts2 = create_center_patch(patch_shape)
|
||||
|
||||
scale = min_pts1 - full
|
||||
found_valid = False
|
||||
cnt = -1
|
||||
|
@ -68,7 +69,9 @@ def sample_homography_corners(
|
|||
|
||||
# Rotation
|
||||
if n_angles > 0 and difficulty > 0:
|
||||
angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles)
|
||||
#angles = np.linspace(-max_angle * difficulty, max_angle * difficulty, n_angles)
|
||||
# In our case the difficulty parameter should not affect the rotation
|
||||
angles = np.linspace(-max_angle, max_angle, n_angles)
|
||||
rng.shuffle(angles)
|
||||
rng.shuffle(angles)
|
||||
angles = np.concatenate([[0.0], angles], axis=0)
|
||||
|
@ -164,7 +167,8 @@ def warp_points_torch(points, H, inverse=True):
|
|||
The inverse is used to be coherent with tf.contrib.image.transform
|
||||
Arguments:
|
||||
points: batched list of N points, shape (B, N, 2).
|
||||
homography: batched or not (shapes (B, 3, 3) and (3, 3) respectively).
|
||||
H: batched or not (shapes (B, 3, 3) and (3, 3) respectively).
|
||||
inverse: Whether to multiply the points by H or the inverse of H
|
||||
Returns: a Tensor of shape (B, N, 2) containing the new coordinates of the warps.
|
||||
"""
|
||||
|
||||
|
@ -333,7 +337,7 @@ def sym_homography_error_all(kpts0, kpts1, H):
|
|||
|
||||
|
||||
def homography_corner_error(T, T_gt, image_size):
|
||||
W, H = image_size[:, 0], image_size[:, 1]
|
||||
W, H = image_size[..., 0], image_size[..., 1]
|
||||
corners0 = torch.Tensor([[0, 0], [W, 0], [W, H], [0, H]]).float().to(T)
|
||||
corners1_gt = from_homogeneous(to_homogeneous(corners0) @ T_gt.transpose(-1, -2))
|
||||
corners1 = from_homogeneous(to_homogeneous(corners0) @ T.transpose(-1, -2))
|
||||
|
|
|
@ -23,6 +23,7 @@ def from_homogeneous(points, eps=0.0):
|
|||
"""Remove the homogeneous dimension of N-dimensional points.
|
||||
Args:
|
||||
points: torch.Tensor or numpy.ndarray with size (..., N+1).
|
||||
eps: Epsilon value to prevent zero division.
|
||||
Returns:
|
||||
A torch.Tensor or numpy ndarray with size (..., N).
|
||||
"""
|
||||
|
|
|
@ -10,6 +10,7 @@ class DinoV2(BaseModel):
|
|||
|
||||
def _init(self, conf):
|
||||
self.net = torch.hub.load("facebookresearch/dinov2", conf.weights)
|
||||
self.set_initialized()
|
||||
|
||||
def _forward(self, data):
|
||||
img = data["image"]
|
||||
|
|
|
@ -60,6 +60,8 @@ class BaseModel(nn.Module, metaclass=MetaModel):
|
|||
required_data_keys = []
|
||||
strict_conf = False
|
||||
|
||||
are_weights_initialized = False
|
||||
|
||||
def __init__(self, conf):
|
||||
"""Perform some logic and call the _init method of the child model."""
|
||||
super().__init__()
|
||||
|
@ -125,3 +127,31 @@ class BaseModel(nn.Module, metaclass=MetaModel):
|
|||
def loss(self, pred, data):
|
||||
"""To be implemented by the child class."""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
"""Load the state dict of the model, and set the model to initialized."""
|
||||
ret = super().load_state_dict(*args, **kwargs)
|
||||
self.set_initialized()
|
||||
return ret
|
||||
|
||||
def is_initialized(self):
|
||||
"""Recursively check if the model is initialized, i.e. weights are loaded"""
|
||||
is_initialized = True # initialize to true and perform recursive and
|
||||
for _, w in self.named_children():
|
||||
if isinstance(w, BaseModel):
|
||||
# if children is BaseModel, we perform recursive check
|
||||
is_initialized = is_initialized and w.is_initialized()
|
||||
else:
|
||||
# else, we check if self is initialized or the children has no params
|
||||
n_params = len(list(w.parameters()))
|
||||
is_initialized = is_initialized and (
|
||||
n_params == 0 or self.are_weights_initialized
|
||||
)
|
||||
return is_initialized
|
||||
|
||||
def set_initialized(self, to: bool = True):
|
||||
"""Recursively set the initialization state."""
|
||||
self.are_weights_initialized = to
|
||||
for _, w in self.named_parameters():
|
||||
if isinstance(w, BaseModel):
|
||||
w.set_initialized(to)
|
||||
|
|
|
@ -29,6 +29,15 @@ def pad_local_features(pred: dict, seq_l: int):
|
|||
pred["scales"] = pad_to_length(pred["scales"], seq_l, -1, mode="zeros")
|
||||
if "oris" in pred.keys():
|
||||
pred["oris"] = pad_to_length(pred["oris"], seq_l, -1, mode="zeros")
|
||||
|
||||
if "depth_keypoints" in pred.keys():
|
||||
pred["depth_keypoints"] = pad_to_length(
|
||||
pred["depth_keypoints"], seq_l, -1, mode="zeros"
|
||||
)
|
||||
if "valid_depth_keypoints" in pred.keys():
|
||||
pred["valid_depth_keypoints"] = pad_to_length(
|
||||
pred["valid_depth_keypoints"], seq_l, -1, mode="zeros"
|
||||
)
|
||||
return pred
|
||||
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ class DISK(BaseModel):
|
|||
|
||||
def _init(self, conf):
|
||||
self.model = kornia.feature.DISK.from_pretrained(conf.weights)
|
||||
self.set_initialized()
|
||||
|
||||
def _get_dense_outputs(self, images):
|
||||
B = images.shape[0]
|
||||
|
|
|
@ -21,6 +21,7 @@ class KeyNetAffNetHardNet(BaseModel):
|
|||
upright=conf.upright,
|
||||
scale_laf=conf.scale_laf,
|
||||
)
|
||||
self.set_initialized()
|
||||
|
||||
def _forward(self, data):
|
||||
image = data["image"]
|
||||
|
|
|
@ -1,238 +1,233 @@
|
|||
import warnings
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pycolmap
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from scipy.spatial import KDTree
|
||||
from kornia.color import rgb_to_grayscale
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
import pycolmap
|
||||
except ImportError:
|
||||
pycolmap = None
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..utils.misc import pad_to_length
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None):
|
||||
h, w = image_shape
|
||||
ij = np.round(points - 0.5).astype(int).T[::-1]
|
||||
|
||||
# Remove duplicate points (identical coordinates).
|
||||
# Pick highest scale or score
|
||||
s = scales if scores is None else scores
|
||||
buffer = np.zeros((h, w))
|
||||
np.maximum.at(buffer, tuple(ij), s)
|
||||
keep = np.where(buffer[tuple(ij)] == s)[0]
|
||||
|
||||
# Pick lowest angle (arbitrary).
|
||||
ij = ij[:, keep]
|
||||
buffer[:] = np.inf
|
||||
o_abs = np.abs(angles[keep])
|
||||
np.minimum.at(buffer, tuple(ij), o_abs)
|
||||
mask = buffer[tuple(ij)] == o_abs
|
||||
ij = ij[:, mask]
|
||||
keep = keep[mask]
|
||||
|
||||
if nms_radius > 0:
|
||||
# Apply NMS on the remaining points
|
||||
buffer[:] = 0
|
||||
buffer[tuple(ij)] = s[keep] # scores or scale
|
||||
|
||||
local_max = torch.nn.functional.max_pool2d(
|
||||
torch.from_numpy(buffer).unsqueeze(0),
|
||||
kernel_size=nms_radius * 2 + 1,
|
||||
stride=1,
|
||||
padding=nms_radius,
|
||||
).squeeze(0)
|
||||
is_local_max = buffer == local_max.numpy()
|
||||
keep = keep[is_local_max[tuple(ij)]]
|
||||
return keep
|
||||
|
||||
|
||||
def sift_to_rootsift(x):
|
||||
x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS)
|
||||
x = np.sqrt(x.clip(min=EPS))
|
||||
x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS)
|
||||
return x
|
||||
def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor:
|
||||
x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps)
|
||||
x.clip_(min=eps).sqrt_()
|
||||
return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps)
|
||||
|
||||
|
||||
# from OpenGlue
|
||||
def nms_keypoints(kpts: np.ndarray, responses: np.ndarray, radius: float) -> np.ndarray:
|
||||
# TODO: add approximate tree
|
||||
kd_tree = KDTree(kpts)
|
||||
|
||||
sorted_idx = np.argsort(-responses)
|
||||
kpts_to_keep_idx = []
|
||||
removed_idx = set()
|
||||
|
||||
for idx in sorted_idx:
|
||||
# skip point if it was already removed
|
||||
if idx in removed_idx:
|
||||
continue
|
||||
|
||||
kpts_to_keep_idx.append(idx)
|
||||
point = kpts[idx]
|
||||
neighbors = kd_tree.query_ball_point(point, r=radius)
|
||||
# Variable `neighbors` contains the `point` itself
|
||||
removed_idx.update(neighbors)
|
||||
|
||||
mask = np.zeros((kpts.shape[0],), dtype=bool)
|
||||
mask[kpts_to_keep_idx] = True
|
||||
return mask
|
||||
|
||||
|
||||
def detect_kpts_opencv(
|
||||
features: cv2.Feature2D, image: np.ndarray, describe: bool = True
|
||||
) -> np.ndarray:
|
||||
def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Detect keypoints using OpenCV Detector.
|
||||
Optionally, perform NMS and filter top-response keypoints.
|
||||
Optionally, perform description.
|
||||
Args:
|
||||
features: OpenCV based keypoints detector and descriptor
|
||||
image: Grayscale image of uint8 data type
|
||||
describe: flag indicating whether to simultaneously compute descriptors
|
||||
Returns:
|
||||
kpts: 1D array of detected cv2.KeyPoint
|
||||
keypoints: 1D array of detected cv2.KeyPoint
|
||||
scores: 1D array of responses
|
||||
descriptors: 1D array of descriptors
|
||||
"""
|
||||
if describe:
|
||||
kpts, descriptors = features.detectAndCompute(image, None)
|
||||
else:
|
||||
kpts = features.detect(image, None)
|
||||
kpts = np.array(kpts)
|
||||
|
||||
responses = np.array([k.response for k in kpts], dtype=np.float32)
|
||||
|
||||
# select all
|
||||
top_score_idx = ...
|
||||
pts = np.array([k.pt for k in kpts], dtype=np.float32)
|
||||
scales = np.array([k.size for k in kpts], dtype=np.float32)
|
||||
angles = np.array([k.angle for k in kpts], dtype=np.float32)
|
||||
spts = np.concatenate([pts, scales[..., None], angles[..., None]], -1)
|
||||
|
||||
if describe:
|
||||
return spts[top_score_idx], responses[top_score_idx], descriptors[top_score_idx]
|
||||
else:
|
||||
return spts[top_score_idx], responses[top_score_idx]
|
||||
detections, descriptors = features.detectAndCompute(image, None)
|
||||
points = np.array([k.pt for k in detections], dtype=np.float32)
|
||||
scores = np.array([k.response for k in detections], dtype=np.float32)
|
||||
scales = np.array([k.size for k in detections], dtype=np.float32)
|
||||
angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32))
|
||||
return points, scores, scales, angles, descriptors
|
||||
|
||||
|
||||
class SIFT(BaseModel):
|
||||
default_conf = {
|
||||
"has_detector": True,
|
||||
"has_descriptor": True,
|
||||
"descriptor_dim": 128,
|
||||
"pycolmap_options": {
|
||||
"first_octave": 0,
|
||||
"peak_threshold": 0.005,
|
||||
"edge_threshold": 10,
|
||||
},
|
||||
"rootsift": True,
|
||||
"nms_radius": None,
|
||||
"max_num_keypoints": -1,
|
||||
"max_num_keypoints_val": None,
|
||||
"nms_radius": 0, # None to disable filtering entirely.
|
||||
"max_num_keypoints": 4096,
|
||||
"backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
|
||||
"detection_threshold": 0.0066667, # from COLMAP
|
||||
"edge_threshold": 10,
|
||||
"first_octave": -1, # only used by pycolmap, the default of COLMAP
|
||||
"num_octaves": 4,
|
||||
"force_num_keypoints": False,
|
||||
"randomize_keypoints_training": False,
|
||||
"detector": "pycolmap", # ['pycolmap', 'pycolmap_cpu', 'pycolmap_cuda', 'cv2']
|
||||
"detection_threshold": None,
|
||||
}
|
||||
|
||||
required_data_keys = ["image"]
|
||||
|
||||
def _init(self, conf):
|
||||
self.sift = None # lazy loading
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_features(self, image):
|
||||
image_np = image.cpu().numpy()[0]
|
||||
assert image.shape[0] == 1
|
||||
assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS
|
||||
|
||||
detector = str(self.conf.detector)
|
||||
|
||||
if self.sift is None and detector.startswith("pycolmap"):
|
||||
options = OmegaConf.to_container(self.conf.pycolmap_options)
|
||||
device = (
|
||||
"auto" if detector == "pycolmap" else detector.replace("pycolmap_", "")
|
||||
backend = self.conf.backend
|
||||
if backend.startswith("pycolmap"):
|
||||
if pycolmap is None:
|
||||
raise ImportError(
|
||||
"Cannot find module pycolmap: install it with pip"
|
||||
"or use backend=opencv."
|
||||
)
|
||||
options = {
|
||||
"peak_threshold": self.conf.detection_threshold,
|
||||
"edge_threshold": self.conf.edge_threshold,
|
||||
"first_octave": self.conf.first_octave,
|
||||
"num_octaves": self.conf.num_octaves,
|
||||
"normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy.
|
||||
}
|
||||
device = (
|
||||
"auto" if backend == "pycolmap" else backend.replace("pycolmap_", "")
|
||||
)
|
||||
if (
|
||||
backend == "pycolmap_cpu" or not pycolmap.has_cuda
|
||||
) and pycolmap.__version__ < "0.5.0":
|
||||
warnings.warn(
|
||||
"The pycolmap CPU SIFT is buggy in version < 0.5.0, "
|
||||
"consider upgrading pycolmap or use the CUDA version.",
|
||||
stacklevel=1,
|
||||
)
|
||||
if self.conf.rootsift == "rootsift":
|
||||
options["normalization"] = pycolmap.Normalization.L1_ROOT
|
||||
else:
|
||||
options["normalization"] = pycolmap.Normalization.L2
|
||||
if self.conf.detection_threshold is not None:
|
||||
options["peak_threshold"] = self.conf.detection_threshold
|
||||
options["max_num_features"] = self.conf.max_num_keypoints
|
||||
self.sift = pycolmap.Sift(options=options, device=device)
|
||||
elif self.sift is None and self.conf.detector == "cv2":
|
||||
self.sift = cv2.SIFT_create(contrastThreshold=self.conf.detection_threshold)
|
||||
|
||||
if detector.startswith("pycolmap"):
|
||||
keypoints, scores, descriptors = self.sift.extract(image_np)
|
||||
elif detector == "cv2":
|
||||
# TODO: Check if opencv keypoints are already in corner convention
|
||||
keypoints, scores, descriptors = detect_kpts_opencv(
|
||||
self.sift, (image_np * 255.0).astype(np.uint8)
|
||||
elif backend == "opencv":
|
||||
self.sift = cv2.SIFT_create(
|
||||
contrastThreshold=self.conf.detection_threshold,
|
||||
nfeatures=self.conf.max_num_keypoints,
|
||||
edgeThreshold=self.conf.edge_threshold,
|
||||
nOctaveLayers=self.conf.num_octaves,
|
||||
)
|
||||
else:
|
||||
backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"}
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}."
|
||||
)
|
||||
|
||||
def extract_single_image(self, image: torch.Tensor):
|
||||
image_np = image.cpu().numpy().squeeze(0)
|
||||
|
||||
if self.conf.backend.startswith("pycolmap"):
|
||||
if version.parse(pycolmap.__version__) >= version.parse("0.5.0"):
|
||||
detections, descriptors = self.sift.extract(image_np)
|
||||
scores = None # Scores are not exposed by COLMAP anymore.
|
||||
else:
|
||||
detections, scores, descriptors = self.sift.extract(image_np)
|
||||
keypoints = detections[:, :2] # Keep only (x, y).
|
||||
scales, angles = detections[:, -2:].T
|
||||
if scores is not None and (
|
||||
self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda
|
||||
):
|
||||
# Set the scores as a combination of abs. response and scale.
|
||||
scores = np.abs(scores) * scales
|
||||
elif self.conf.backend == "opencv":
|
||||
# TODO: Check if opencv keypoints are already in corner convention
|
||||
keypoints, scores, scales, angles, descriptors = run_opencv_sift(
|
||||
self.sift, (image_np * 255.0).astype(np.uint8)
|
||||
)
|
||||
pred = {
|
||||
"keypoints": keypoints,
|
||||
"scales": scales,
|
||||
"oris": angles,
|
||||
"descriptors": descriptors,
|
||||
}
|
||||
if scores is not None:
|
||||
pred["keypoint_scores"] = scores
|
||||
|
||||
# sometimes pycolmap returns points outside the image. We remove them
|
||||
if self.conf.backend.startswith("pycolmap"):
|
||||
is_inside = (
|
||||
pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]])
|
||||
).all(-1)
|
||||
pred = {k: v[is_inside] for k, v in pred.items()}
|
||||
|
||||
if self.conf.nms_radius is not None:
|
||||
mask = nms_keypoints(keypoints[:, :2], scores, self.conf.nms_radius)
|
||||
keypoints = keypoints[mask]
|
||||
scores = scores[mask]
|
||||
descriptors = descriptors[mask]
|
||||
|
||||
scales = keypoints[:, 2]
|
||||
oris = np.rad2deg(keypoints[:, 3])
|
||||
|
||||
if self.conf.has_descriptor:
|
||||
# We still renormalize because COLMAP does not normalize well,
|
||||
# maybe due to numerical errors
|
||||
if self.conf.rootsift:
|
||||
descriptors = sift_to_rootsift(descriptors)
|
||||
descriptors = torch.from_numpy(descriptors)
|
||||
keypoints = torch.from_numpy(keypoints[:, :2]) # keep only x, y
|
||||
scales = torch.from_numpy(scales)
|
||||
oris = torch.from_numpy(oris)
|
||||
scores = torch.from_numpy(scores)
|
||||
keep = filter_dog_point(
|
||||
pred["keypoints"],
|
||||
pred["scales"],
|
||||
pred["oris"],
|
||||
image_np.shape,
|
||||
self.conf.nms_radius,
|
||||
pred["keypoint_scores"],
|
||||
)
|
||||
pred = {k: v[keep] for k, v in pred.items()}
|
||||
|
||||
pred = {k: torch.from_numpy(v) for k, v in pred.items()}
|
||||
if scores is not None:
|
||||
# Keep the k keypoints with highest score
|
||||
max_kps = self.conf.max_num_keypoints
|
||||
|
||||
# for val we allow different
|
||||
if not self.training and self.conf.max_num_keypoints_val is not None:
|
||||
max_kps = self.conf.max_num_keypoints_val
|
||||
|
||||
if max_kps is not None and max_kps > 0:
|
||||
if self.conf.randomize_keypoints_training and self.training:
|
||||
# instead of selecting top-k, sample k by score weights
|
||||
raise NotImplementedError
|
||||
elif max_kps < scores.shape[0]:
|
||||
# TODO: check that the scores from PyCOLMAP are 100% correct,
|
||||
# follow https://github.com/mihaidusmanu/pycolmap/issues/8
|
||||
indices = torch.topk(scores, max_kps).indices
|
||||
keypoints = keypoints[indices]
|
||||
scales = scales[indices]
|
||||
oris = oris[indices]
|
||||
scores = scores[indices]
|
||||
if self.conf.has_descriptor:
|
||||
descriptors = descriptors[indices]
|
||||
num_points = self.conf.max_num_keypoints
|
||||
if num_points is not None and len(pred["keypoints"]) > num_points:
|
||||
indices = torch.topk(pred["keypoint_scores"], num_points).indices
|
||||
pred = {k: v[indices] for k, v in pred.items()}
|
||||
|
||||
if self.conf.force_num_keypoints:
|
||||
keypoints = pad_to_length(
|
||||
keypoints,
|
||||
max_kps,
|
||||
num_points = min(self.conf.max_num_keypoints, len(pred["keypoints"]))
|
||||
pred["keypoints"] = pad_to_length(
|
||||
pred["keypoints"],
|
||||
num_points,
|
||||
-2,
|
||||
mode="random_c",
|
||||
bounds=(0, min(image.shape[1:])),
|
||||
)
|
||||
scores = pad_to_length(scores, max_kps, -1, mode="zeros")
|
||||
scales = pad_to_length(scales, max_kps, -1, mode="zeros")
|
||||
oris = pad_to_length(oris, max_kps, -1, mode="zeros")
|
||||
if self.conf.has_descriptor:
|
||||
descriptors = pad_to_length(descriptors, max_kps, -2, mode="zeros")
|
||||
|
||||
pred = {
|
||||
"keypoints": keypoints,
|
||||
"scales": scales,
|
||||
"oris": oris,
|
||||
"keypoint_scores": scores,
|
||||
}
|
||||
|
||||
if self.conf.has_descriptor:
|
||||
pred["descriptors"] = descriptors
|
||||
pred["scales"] = pad_to_length(pred["scales"], num_points, -1, mode="zeros")
|
||||
pred["oris"] = pad_to_length(pred["oris"], num_points, -1, mode="zeros")
|
||||
pred["descriptors"] = pad_to_length(
|
||||
pred["descriptors"], num_points, -2, mode="zeros"
|
||||
)
|
||||
if pred["keypoint_scores"] is not None:
|
||||
scores = pad_to_length(
|
||||
pred["keypoint_scores"], num_points, -1, mode="zeros"
|
||||
)
|
||||
return pred
|
||||
|
||||
@torch.no_grad()
|
||||
def _forward(self, data):
|
||||
pred = {
|
||||
"keypoints": [],
|
||||
"scales": [],
|
||||
"oris": [],
|
||||
"keypoint_scores": [],
|
||||
"descriptors": [],
|
||||
}
|
||||
|
||||
def _forward(self, data: dict) -> dict:
|
||||
image = data["image"]
|
||||
if image.shape[1] == 3: # RGB
|
||||
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
|
||||
image = (image * scale).sum(1, keepdim=True).cpu()
|
||||
|
||||
for k in range(image.shape[0]):
|
||||
if image.shape[1] == 3:
|
||||
image = rgb_to_grayscale(image)
|
||||
device = image.device
|
||||
image = image.cpu()
|
||||
pred = []
|
||||
for k in range(len(image)):
|
||||
img = image[k]
|
||||
if "image_size" in data.keys():
|
||||
# avoid extracting points in padded areas
|
||||
w, h = data["image_size"][k]
|
||||
img = img[:, :h, :w]
|
||||
p = self.extract_features(img)
|
||||
for k, v in p.items():
|
||||
pred[k].append(v)
|
||||
|
||||
if (image.shape[0] == 1) or self.conf.force_num_keypoints:
|
||||
pred = {k: torch.stack(pred[k], 0) for k in pred.keys()}
|
||||
|
||||
pred = {k: pred[k].to(device=data["image"].device) for k in pred.keys()}
|
||||
|
||||
pred["oris"] = torch.deg2rad(pred["oris"])
|
||||
p = self.extract_single_image(img)
|
||||
pred.append(p)
|
||||
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
|
||||
if self.conf.rootsift:
|
||||
pred["descriptors"] = sift_to_rootsift(pred["descriptors"])
|
||||
return pred
|
||||
|
||||
def loss(self, pred, data):
|
||||
|
|
|
@ -19,12 +19,13 @@ class KorniaSIFT(BaseModel):
|
|||
self.sift = kornia.feature.SIFTFeature(
|
||||
num_features=self.conf.max_num_keypoints, rootsift=self.conf.rootsift
|
||||
)
|
||||
self.set_initialized()
|
||||
|
||||
def _forward(self, data):
|
||||
lafs, scores, descriptors = self.sift(data["image"])
|
||||
keypoints = kornia.feature.get_laf_center(lafs)
|
||||
scales = kornia.feature.get_laf_scale(lafs)
|
||||
oris = kornia.feature.get_laf_orientation(lafs)
|
||||
scales = kornia.feature.get_laf_scale(lafs).squeeze(-1).squeeze(-1)
|
||||
oris = kornia.feature.get_laf_orientation(lafs).squeeze(-1)
|
||||
pred = {
|
||||
"keypoints": keypoints, # @TODO: confirm keypoints are in corner convention
|
||||
"scales": scales,
|
||||
|
|
|
@ -34,13 +34,14 @@ class DeepLSD(BaseModel):
|
|||
ckpt = torch.load(ckpt, map_location="cpu")
|
||||
self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval()
|
||||
self.net.load_state_dict(ckpt["model"])
|
||||
self.set_initialized()
|
||||
|
||||
def download_model(self, path):
|
||||
import subprocess
|
||||
|
||||
if not path.parent.is_dir():
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
link = "https://www.polybox.ethz.ch/index.php/s/XVb30sUyuJttFys/download"
|
||||
link = "https://cvg-data.inf.ethz.ch/DeepLSD/deeplsd_md.tar"
|
||||
cmd = ["wget", link, "-O", path]
|
||||
print("Downloading DeepLSD model...")
|
||||
subprocess.run(cmd, check=True)
|
||||
|
|
|
@ -119,7 +119,7 @@ class GlueStick(BaseModel):
|
|||
"Loading GlueStick model from " f'"{self.url.format(conf.version)}"'
|
||||
)
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
self.url.format(conf.version), file_name=fname
|
||||
self.url.format(conf.version), file_name=fname, map_location="cpu"
|
||||
)
|
||||
|
||||
if "model" in state_dict:
|
||||
|
@ -131,7 +131,7 @@ class GlueStick(BaseModel):
|
|||
state_dict = {
|
||||
k.replace("module.", ""): v for k, v in state_dict.items()
|
||||
}
|
||||
self.load_state_dict(state_dict)
|
||||
self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
def _forward(self, data):
|
||||
device = data["keypoints0"].device
|
||||
|
@ -200,8 +200,6 @@ class GlueStick(BaseModel):
|
|||
kpts0 = normalize_keypoints(kpts0, image_size0)
|
||||
kpts1 = normalize_keypoints(kpts1, image_size1)
|
||||
|
||||
assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
|
||||
assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)
|
||||
desc0 = desc0 + self.kenc(kpts0, data["keypoint_scores0"])
|
||||
desc1 = desc1 + self.kenc(kpts1, data["keypoint_scores1"])
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ class LoFTRModule(BaseModel):
|
|||
|
||||
def _init(self, conf):
|
||||
self.net = kornia.feature.LoFTR(pretrained="outdoor")
|
||||
self.set_initialized()
|
||||
|
||||
def _forward(self, data):
|
||||
image0 = data["view0"]["image"]
|
||||
|
|
|
@ -17,17 +17,18 @@ class LightGlue(BaseModel):
|
|||
|
||||
def _init(self, conf):
|
||||
dconf = OmegaConf.to_container(conf)
|
||||
self.net = LightGlue_(dconf.pop("features"), **dconf).cuda()
|
||||
# self.net.compile()
|
||||
self.net = LightGlue_(dconf.pop("features"), **dconf)
|
||||
self.set_initialized()
|
||||
|
||||
def _forward(self, data):
|
||||
required_keys = ["keypoints", "descriptors", "scales", "oris"]
|
||||
view0 = {
|
||||
**{k: data[k + "0"] for k in ["keypoints", "descriptors"]},
|
||||
**data["view0"],
|
||||
**{k: data[k + "0"] for k in required_keys if (k + "0") in data},
|
||||
}
|
||||
view1 = {
|
||||
**{k: data[k + "1"] for k in ["keypoints", "descriptors"]},
|
||||
**data["view1"],
|
||||
**{k: data[k + "1"] for k in required_keys if (k + "1") in data},
|
||||
}
|
||||
return self.net({"image0": view0, "image1": view1})
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ from homography_est import (
|
|||
ransac_point_line_homography,
|
||||
)
|
||||
|
||||
from ...utils.tensor import batch_to_numpy
|
||||
from ..base_estimator import BaseEstimator
|
||||
|
||||
|
||||
|
@ -50,19 +51,20 @@ class PointLineHomographyEstimator(BaseEstimator):
|
|||
pass
|
||||
|
||||
def _forward(self, data):
|
||||
m_features = {
|
||||
"kpts0": data["m_kpts1"].numpy() if "m_kpts1" in data else None,
|
||||
"kpts1": data["m_kpts0"].numpy() if "m_kpts0" in data else None,
|
||||
"lines0": data["m_lines1"].numpy() if "m_lines1" in data else None,
|
||||
"lines1": data["m_lines0"].numpy() if "m_lines0" in data else None,
|
||||
}
|
||||
feat = data["m_kpts0"] if "m_kpts0" in data else data["m_lines0"]
|
||||
data = batch_to_numpy(data)
|
||||
m_features = {
|
||||
"kpts0": data["m_kpts1"] if "m_kpts1" in data else None,
|
||||
"kpts1": data["m_kpts0"] if "m_kpts0" in data else None,
|
||||
"lines0": data["m_lines1"] if "m_lines1" in data else None,
|
||||
"lines1": data["m_lines0"] if "m_lines0" in data else None,
|
||||
}
|
||||
M = H_estimation_hybrid(**m_features, tol_px=self.conf.ransac_th)
|
||||
success = M is not None
|
||||
if not success:
|
||||
M = torch.eye(3, device=feat.device, dtype=feat.dtype)
|
||||
else:
|
||||
M = torch.tensor(M).to(feat)
|
||||
M = torch.from_numpy(M).to(feat)
|
||||
|
||||
estimation = {
|
||||
"success": success,
|
||||
|
|
|
@ -16,8 +16,8 @@ class PoseLibHomographyEstimator(BaseEstimator):
|
|||
def _forward(self, data):
|
||||
pts0, pts1 = data["m_kpts0"], data["m_kpts1"]
|
||||
M, info = poselib.estimate_homography(
|
||||
pts0.numpy(),
|
||||
pts1.numpy(),
|
||||
pts0.detach().cpu().numpy(),
|
||||
pts1.detach().cpu().numpy(),
|
||||
{
|
||||
"max_reproj_error": self.conf.ransac_th,
|
||||
**OmegaConf.to_container(self.conf.options),
|
||||
|
|
|
@ -37,14 +37,13 @@ configs = {
|
|||
},
|
||||
},
|
||||
"cv2-sift": {
|
||||
"name": f"r{resize}_cv2-SIFT-k{n_kpts}",
|
||||
"name": f"r{resize}_opencv-SIFT-k{n_kpts}",
|
||||
"keys": ["keypoints", "descriptors", "keypoint_scores", "oris", "scales"],
|
||||
"gray": True,
|
||||
"conf": {
|
||||
"name": "extractors.sift",
|
||||
"max_num_keypoints": 4096,
|
||||
"detection_threshold": 0.001,
|
||||
"detector": "cv2",
|
||||
"backend": "opencv",
|
||||
},
|
||||
},
|
||||
"pycolmap-sift": {
|
||||
|
@ -54,11 +53,7 @@ configs = {
|
|||
"conf": {
|
||||
"name": "extractors.sift",
|
||||
"max_num_keypoints": n_kpts,
|
||||
"detection_threshold": 0.0001,
|
||||
"detector": "pycolmap",
|
||||
"pycolmap_options": {
|
||||
"first_octave": -1,
|
||||
},
|
||||
"backend": "pycolmap",
|
||||
},
|
||||
},
|
||||
"pycolmap-sift-gpu": {
|
||||
|
@ -68,11 +63,7 @@ configs = {
|
|||
"conf": {
|
||||
"name": "extractors.sift",
|
||||
"max_num_keypoints": n_kpts,
|
||||
"detection_threshold": 0.0066666,
|
||||
"detector": "pycolmap_cuda",
|
||||
"pycolmap_options": {
|
||||
"first_octave": -1,
|
||||
},
|
||||
"backend": "pycolmap_cuda",
|
||||
"nms_radius": 3,
|
||||
},
|
||||
},
|
||||
|
@ -133,15 +124,18 @@ def run_export(feature_file, scene, args):
|
|||
|
||||
conf = OmegaConf.create(conf)
|
||||
|
||||
keys = configs[args.method]["keys"] + ["depth_keypoints", "valid_depth_keypoints"]
|
||||
keys = configs[args.method]["keys"]
|
||||
dataset = get_dataset(conf.data.name)(conf.data)
|
||||
loader = dataset.get_data_loader(conf.split or "test")
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = get_model(conf.model.name)(conf.model).eval().to(device)
|
||||
|
||||
if args.export_sparse_depth:
|
||||
callback_fn = get_kp_depth # use this to store the depth of each keypoint
|
||||
keys = keys + ["depth_keypoints", "valid_depth_keypoints"]
|
||||
else:
|
||||
callback_fn = None
|
||||
# callback_fn=get_kp_depth # use this to store the depth of each keypoint
|
||||
export_predictions(
|
||||
loader, model, feature_file, as_half=True, keys=keys, callback_fn=callback_fn
|
||||
)
|
||||
|
@ -153,6 +147,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--method", type=str, default="sp")
|
||||
parser.add_argument("--scenes", type=str, default=None)
|
||||
parser.add_argument("--num_workers", type=int, default=0)
|
||||
parser.add_argument("--export_sparse_depth", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
export_name = configs[args.method]["name"]
|
||||
|
|
|
@ -12,6 +12,7 @@ import signal
|
|||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from pydoc import locate
|
||||
import gc
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -48,11 +49,12 @@ default_train_conf = {
|
|||
"optimizer_options": {}, # optional arguments passed to the optimizer
|
||||
"lr": 0.001, # learning rate
|
||||
"lr_schedule": {
|
||||
"type": None,
|
||||
"type": None, # string in {factor, exp, member of torch.optim.lr_scheduler}
|
||||
"start": 0,
|
||||
"exp_div_10": 0,
|
||||
"on_epoch": False,
|
||||
"factor": 1.0,
|
||||
"options": {}, # add lr_scheduler arguments here
|
||||
},
|
||||
"lr_scaling": [(100, ["dampingnet.const"])],
|
||||
"eval_every_iter": 1000, # interval for evaluation on the validation set
|
||||
|
@ -88,6 +90,7 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
|
|||
for i, data in enumerate(
|
||||
tqdm(loader, desc="Evaluation", ascii=True, disable=not pbar)
|
||||
):
|
||||
|
||||
data = batch_to_device(data, device, non_blocking=True)
|
||||
with torch.no_grad():
|
||||
pred = model(data)
|
||||
|
@ -101,7 +104,6 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
|
|||
pred[v["predictions"]],
|
||||
mask=pred[v["mask"]] if "mask" in v.keys() else None,
|
||||
)
|
||||
del pred, data
|
||||
numbers = {**metrics, **{"loss/" + k: v for k, v in losses.items()}}
|
||||
for k, v in numbers.items():
|
||||
if k not in results:
|
||||
|
@ -117,7 +119,9 @@ def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
|
|||
if k in conf.recall_metrics.keys():
|
||||
q = conf.recall_metrics[k]
|
||||
results[k + f"_recall{int(q)}"].update(v)
|
||||
del numbers
|
||||
|
||||
del pred, data, losses, metrics
|
||||
gc.collect()
|
||||
results = {k: results[k].compute() for k in results}
|
||||
return results, {k: v.compute() for k, v in pr_metrics.items()}, figures
|
||||
|
||||
|
@ -141,6 +145,26 @@ def filter_parameters(params, regexp):
|
|||
return params
|
||||
|
||||
|
||||
def get_lr_scheduler(optimizer, conf):
|
||||
"""Get lr scheduler specified by conf.train.lr_schedule."""
|
||||
if conf.type not in ["factor", "exp", None]:
|
||||
return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, **conf.options)
|
||||
|
||||
# backward compatibility
|
||||
def lr_fn(it): # noqa: E306
|
||||
if conf.type is None:
|
||||
return 1
|
||||
if conf.type == "factor":
|
||||
return 1.0 if it < conf.start else conf.factor
|
||||
if conf.type == "exp":
|
||||
gam = 10 ** (-1 / conf.exp_div_10)
|
||||
return 1.0 if it < conf.start else gam
|
||||
else:
|
||||
raise ValueError(conf.type)
|
||||
|
||||
return torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn)
|
||||
|
||||
|
||||
def pack_lr_parameters(params, base_lr, lr_scaling):
|
||||
"""Pack each group of parameters with the respective scaled learning rate."""
|
||||
filters, scales = tuple(zip(*[(n, s) for s, names in lr_scaling for n in names]))
|
||||
|
@ -236,6 +260,7 @@ def training(rank, conf, output_dir, args):
|
|||
|
||||
dataset = get_dataset(data_conf.name)(data_conf)
|
||||
|
||||
|
||||
# Optionally load a different validation dataset than the training one
|
||||
val_data_conf = conf.get("data_val", None)
|
||||
if val_data_conf is None:
|
||||
|
@ -310,22 +335,7 @@ def training(rank, conf, output_dir, args):
|
|||
|
||||
results = None # fix bug with it saving
|
||||
|
||||
def lr_fn(it): # noqa: E306
|
||||
if conf.train.lr_schedule.type is None:
|
||||
return 1
|
||||
if conf.train.lr_schedule.type == "factor":
|
||||
return (
|
||||
1.0
|
||||
if it < conf.train.lr_schedule.start
|
||||
else conf.train.lr_schedule.factor
|
||||
)
|
||||
if conf.train.lr_schedule.type == "exp":
|
||||
gam = 10 ** (-1 / conf.train.lr_schedule.exp_div_10)
|
||||
return 1.0 if it < conf.train.lr_schedule.start else gam
|
||||
else:
|
||||
raise ValueError(conf.train.lr_schedule.type)
|
||||
|
||||
lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn)
|
||||
lr_scheduler = get_lr_scheduler(optimizer=optimizer, conf=conf.train.lr_schedule)
|
||||
if args.restore:
|
||||
optimizer.load_state_dict(init_cp["optimizer"])
|
||||
if "lr_scheduler" in init_cp:
|
||||
|
@ -402,7 +412,9 @@ def training(rank, conf, output_dir, args):
|
|||
getattr(loader.dataset, conf.train.dataset_callback_fn)(
|
||||
conf.train.seed + epoch
|
||||
)
|
||||
|
||||
for it, data in enumerate(train_loader):
|
||||
|
||||
tot_it = (len(train_loader) * epoch + it) * (
|
||||
args.n_gpus if args.distributed else 1
|
||||
)
|
||||
|
@ -421,10 +433,17 @@ def training(rank, conf, output_dir, args):
|
|||
loss = torch.mean(losses["total"])
|
||||
if torch.isnan(loss).any():
|
||||
print(f"Detected NAN, skipping iteration {it}")
|
||||
print("name", data["name"])
|
||||
print("loss", loss)
|
||||
print("losses", losses)
|
||||
print("data", data)
|
||||
print("pred", pred)
|
||||
raise RuntimeError("Detected NAN in training.")
|
||||
del pred, data, loss, losses
|
||||
continue
|
||||
|
||||
do_backward = loss.requires_grad
|
||||
|
||||
if args.distributed:
|
||||
do_backward = torch.tensor(do_backward).float().to(device)
|
||||
torch.distributed.all_reduce(
|
||||
|
@ -463,7 +482,6 @@ def training(rank, conf, output_dir, args):
|
|||
else:
|
||||
if rank == 0:
|
||||
logger.warning(f"Skip iteration {it} due to detach.")
|
||||
|
||||
if args.profile:
|
||||
prof.step()
|
||||
|
||||
|
@ -502,8 +520,11 @@ def training(rank, conf, output_dir, args):
|
|||
norm = torch.norm(param.grad.detach(), 2)
|
||||
grad_txt += f"{name} {norm.item():.3f} \n"
|
||||
writer.add_text("grad/summary", grad_txt, tot_n_samples)
|
||||
del pred, data, loss, losses
|
||||
|
||||
pred.clear()
|
||||
data.clear()
|
||||
del pred, data, loss, losses
|
||||
gc.collect()
|
||||
# Run validation
|
||||
if (
|
||||
(
|
||||
|
@ -523,6 +544,7 @@ def training(rank, conf, output_dir, args):
|
|||
pbar=(rank == -1),
|
||||
)
|
||||
|
||||
|
||||
if rank == 0:
|
||||
str_results = [
|
||||
f"{k} {v:.3E}"
|
||||
|
@ -563,6 +585,10 @@ def training(rank, conf, output_dir, args):
|
|||
f"figures/{i}_{name}", fig, tot_n_samples
|
||||
)
|
||||
torch.cuda.empty_cache() # should be cleared at the first iter
|
||||
str_results.clear()
|
||||
pr_metrics.clear()
|
||||
del str_results, figures, pr_metrics
|
||||
|
||||
|
||||
if (tot_it % conf.train.save_every_iter == 0 and tot_it > 0) and rank == 0:
|
||||
if results is None:
|
||||
|
@ -616,7 +642,7 @@ def training(rank, conf, output_dir, args):
|
|||
writer.close()
|
||||
|
||||
|
||||
def main_worker(rank, conf, output_dir, args):
|
||||
def main_worker(rank, conf, output_dir, aprgs):
|
||||
if rank == 0:
|
||||
with capture_outputs(output_dir / "log.txt"):
|
||||
training(rank, conf, output_dir, args)
|
||||
|
|
|
@ -33,6 +33,15 @@ def batch_to_device(batch, device, non_blocking=True):
|
|||
|
||||
return map_tensor(batch, _func)
|
||||
|
||||
def detach_tensors(batch):
|
||||
"""
|
||||
Detach all tensors in a batch recursively.
|
||||
This is useful for detaching tensors from the computational graph to free up memory.
|
||||
"""
|
||||
def _detach(tensor):
|
||||
return tensor.detach()
|
||||
|
||||
return map_tensor(batch, _detach)
|
||||
|
||||
def rbd(data: dict) -> dict:
|
||||
"""Remove batch dimension from elements in data"""
|
||||
|
@ -40,3 +49,9 @@ def rbd(data: dict) -> dict:
|
|||
k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
|
||||
for k, v in data.items()
|
||||
}
|
||||
|
||||
|
||||
def index_batch(tensor_dict):
|
||||
batch_size = len(next(iter(tensor_dict.values())))
|
||||
for i in range(batch_size):
|
||||
yield map_tensor(tensor_dict, lambda t: t[i])
|
||||
|
|
|
@ -67,8 +67,12 @@ class MedianMetric:
|
|||
else:
|
||||
return np.nanmedian(self._elements)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._elements)
|
||||
|
||||
|
||||
class PRMetric:
|
||||
|
||||
def __init__(self):
|
||||
self.labels = []
|
||||
self.predictions = []
|
||||
|
|
|
@ -208,14 +208,14 @@ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axe
|
|||
kpts0[:, 1],
|
||||
c=color,
|
||||
s=ps,
|
||||
label=None if labels is None else labels[0],
|
||||
label=None if labels is None or len(labels) == 0 else labels[0],
|
||||
)
|
||||
ax1.scatter(
|
||||
kpts1[:, 0],
|
||||
kpts1[:, 1],
|
||||
c=color,
|
||||
s=ps,
|
||||
label=None if labels is None else labels[1],
|
||||
label=None if labels is None or len(labels) == 0 else labels[1],
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -338,6 +338,14 @@ class SuperPoint(BaseModel):
|
|||
for k, d in zip(keypoints, dense_desc)
|
||||
]
|
||||
|
||||
if isinstance(desc, list):
|
||||
if all(isinstance(d, torch.Tensor) for d in desc):
|
||||
# If desc is a list of tensors
|
||||
desc = torch.stack(desc)
|
||||
else:
|
||||
# If desc is a list of non-tensor elements
|
||||
desc = torch.stack([torch.tensor(d) for d in desc])
|
||||
|
||||
pred = {
|
||||
"keypoints": keypoints + 0.5,
|
||||
"keypoint_scores": scores,
|
||||
|
|
|
@ -38,12 +38,12 @@ urls = {Repository = "https://github.com/cvg/glue-factory"}
|
|||
[project.optional-dependencies]
|
||||
extra = [
|
||||
"pycolmap",
|
||||
"poselib @ git+https://github.com/PoseLib/PoseLib.git",
|
||||
"pytlsd @ git+https://github.com/iago-suarez/pytlsd.git",
|
||||
"poselib @ git+https://github.com/PoseLib/PoseLib.git@9c8f3ca1baba69e19726cc7caded574873ec1f9e",
|
||||
"pytlsd @ git+https://github.com/iago-suarez/pytlsd.git@v0.0.5",
|
||||
"deeplsd @ git+https://github.com/cvg/DeepLSD.git",
|
||||
"homography_est @ git+https://github.com/rpautrat/homography_est.git",
|
||||
"homography_est @ git+https://github.com/rpautrat/homography_est.git@17b200d528e6aa8ac61a878a29265bf5f9d36c41",
|
||||
]
|
||||
dev = ["black", "flake8", "isort"]
|
||||
dev = ["black", "flake8", "isort", "parameterized"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["gluefactory*"]
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
#python -m gluefactory.train sp+lg_homography \
|
||||
# --conf gluefactory/configs/superpoint+lightglue_homography.yaml \
|
||||
# data.batch_size=32 # for 1x 1080 GPU
|
||||
|
||||
#python -m gluefactory.train satellites_aug_both_40 \
|
||||
# --conf gluefactory/configs/satellites3.yaml \
|
||||
# data.batch_size=4 # for 1x 1080 GPU
|
||||
|
||||
|
||||
python -m gluefactory.train drone2sat_v1 \
|
||||
--conf gluefactory/configs/drone2sat_v1.yaml \
|
||||
data.batch_size=6 # for 1x 1080 GPU
|
|
@ -0,0 +1,22 @@
|
|||
#!/bin/bash
|
||||
|
||||
LOGFILE="/var/log/h/system_stats.log"
|
||||
|
||||
echo "Logging CPU, Memory, Swap, Disk, and Network usage to $LOGFILE"
|
||||
|
||||
while true; do
|
||||
echo "----- $(date) -----" >> $LOGFILE
|
||||
echo "CPU Usage:" >> $LOGFILE
|
||||
mpstat -P ALL 1 1 >> $LOGFILE
|
||||
|
||||
echo "Disk Usage:" >> $LOGFILE
|
||||
iostat >> $LOGFILE
|
||||
|
||||
echo "Memory and Swap Usage:" >> $LOGFILE
|
||||
free -m >> $LOGFILE
|
||||
|
||||
echo "Network Usage:" >> $LOGFILE
|
||||
ifstat 1 1 >> $LOGFILE
|
||||
|
||||
sleep 5
|
||||
done
|
|
@ -0,0 +1,88 @@
|
|||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from gluefactory.eval.utils import eval_matches_homography
|
||||
from gluefactory.geometry.homography import warp_points_torch
|
||||
|
||||
|
||||
class TestEvalUtils(unittest.TestCase):
|
||||
@staticmethod
|
||||
def default_pts():
|
||||
return torch.tensor(
|
||||
[
|
||||
[10.0, 10.0],
|
||||
[10.0, 20.0],
|
||||
[20.0, 20.0],
|
||||
[20.0, 10.0],
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def default_pred(kps0, kps1):
|
||||
return {
|
||||
"keypoints0": kps0,
|
||||
"keypoints1": kps1,
|
||||
"matches0": torch.arange(len(kps0)),
|
||||
"matching_scores0": torch.ones(len(kps1)),
|
||||
}
|
||||
|
||||
def test_eval_matches_homography_trivial(self):
|
||||
data = {"H_0to1": torch.eye(3)}
|
||||
kps = self.default_pts()
|
||||
pred = self.default_pred(kps, kps)
|
||||
|
||||
results = eval_matches_homography(data, pred)
|
||||
|
||||
self.assertEqual(results["prec@1px"], 1)
|
||||
self.assertEqual(results["prec@3px"], 1)
|
||||
self.assertEqual(results["num_matches"], 4)
|
||||
self.assertEqual(results["num_keypoints"], 4)
|
||||
|
||||
def test_eval_matches_homography_real(self):
|
||||
data = {"H_0to1": torch.tensor([[1.5, 0.2, 21], [-0.3, 1.6, 33], [0, 0, 1.0]])}
|
||||
kps0 = self.default_pts()
|
||||
kps1 = warp_points_torch(kps0, data["H_0to1"], inverse=False)
|
||||
pred = self.default_pred(kps0, kps1)
|
||||
|
||||
results = eval_matches_homography(data, pred)
|
||||
|
||||
self.assertEqual(results["prec@1px"], 1)
|
||||
self.assertEqual(results["prec@3px"], 1)
|
||||
|
||||
def test_eval_matches_homography_real_outliers(self):
|
||||
data = {"H_0to1": torch.tensor([[1.5, 0.2, 21], [-0.3, 1.6, 33], [0, 0, 1.0]])}
|
||||
kps0 = self.default_pts()
|
||||
kps0 = torch.cat([kps0, torch.tensor([[5.0, 5.0]])])
|
||||
kps1 = warp_points_torch(kps0, data["H_0to1"], inverse=False)
|
||||
# Move one keypoint 1.5 pixels away in x and y
|
||||
kps1[-1] += 1.5
|
||||
pred = self.default_pred(kps0, kps1)
|
||||
|
||||
results = eval_matches_homography(data, pred)
|
||||
self.assertAlmostEqual(results["prec@1px"], 0.8)
|
||||
self.assertAlmostEqual(results["prec@3px"], 1.0)
|
||||
|
||||
def test_eval_matches_homography_batched(self):
|
||||
H0 = torch.tensor([[1.5, 0.2, 21], [-0.3, 1.6, 33], [0, 0, 1.0]])
|
||||
H1 = torch.tensor([[0.7, 0.1, -5], [-0.1, 0.65, 13], [0, 0, 1.0]])
|
||||
data = {"H_0to1": torch.stack([H0, H1])}
|
||||
kps0 = torch.stack([self.default_pts(), self.default_pts().flip(0)])
|
||||
kps1 = warp_points_torch(kps0, data["H_0to1"], inverse=False)
|
||||
# In the first element of the batch there is one outlier
|
||||
kps1[0, -1] += 5
|
||||
matches0 = torch.stack([torch.arange(4), torch.arange(4)])
|
||||
# In the second element of the batch there is only 2 matches
|
||||
matches0[1, :2] = -1
|
||||
pred = {
|
||||
"keypoints0": kps0,
|
||||
"keypoints1": kps1,
|
||||
"matches0": matches0,
|
||||
"matching_scores0": torch.ones_like(matches0),
|
||||
}
|
||||
|
||||
results = eval_matches_homography(data, pred)
|
||||
self.assertAlmostEqual(results["prec@1px"][0], 0.75)
|
||||
self.assertAlmostEqual(results["prec@1px"][1], 1.0)
|
||||
self.assertAlmostEqual(results["num_matches"][0], 4)
|
||||
self.assertAlmostEqual(results["num_matches"][1], 2)
|
|
@ -0,0 +1,132 @@
|
|||
import unittest
|
||||
from collections import namedtuple
|
||||
from os.path import splitext
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import torch.cuda
|
||||
from kornia import image_to_tensor
|
||||
from omegaconf import OmegaConf
|
||||
from parameterized import parameterized
|
||||
from torch import Tensor
|
||||
|
||||
from gluefactory import logger
|
||||
from gluefactory.eval.utils import (
|
||||
eval_homography_dlt,
|
||||
eval_homography_robust,
|
||||
eval_matches_homography,
|
||||
)
|
||||
from gluefactory.models.two_view_pipeline import TwoViewPipeline
|
||||
from gluefactory.settings import root
|
||||
from gluefactory.utils.image import ImagePreprocessor
|
||||
from gluefactory.utils.tensor import map_tensor
|
||||
from gluefactory.utils.tools import set_seed
|
||||
from gluefactory.visualization.viz2d import (
|
||||
plot_color_line_matches,
|
||||
plot_images,
|
||||
plot_matches,
|
||||
)
|
||||
|
||||
|
||||
def create_input_data(cv_img0, cv_img1, device):
|
||||
img0 = image_to_tensor(cv_img0).float() / 255
|
||||
img1 = image_to_tensor(cv_img1).float() / 255
|
||||
ip = ImagePreprocessor({})
|
||||
data = {"view0": ip(img0), "view1": ip(img1)}
|
||||
data = map_tensor(
|
||||
data,
|
||||
lambda t: t[None].to(device)
|
||||
if isinstance(t, Tensor)
|
||||
else torch.from_numpy(t)[None].to(device),
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
ExpectedResults = namedtuple("ExpectedResults", ("num_matches", "prec3px", "h_error"))
|
||||
|
||||
|
||||
class TestIntegration(unittest.TestCase):
|
||||
methods_to_test = [
|
||||
("superpoint+NN.yaml", "poselib", ExpectedResults(1300, 0.8, 1.0)),
|
||||
("superpoint-open+NN.yaml", "poselib", ExpectedResults(1300, 0.8, 1.0)),
|
||||
(
|
||||
"superpoint+lsd+gluestick.yaml",
|
||||
"homography_est",
|
||||
ExpectedResults(1300, 0.8, 1.0),
|
||||
),
|
||||
(
|
||||
"superpoint+lightglue-official.yaml",
|
||||
"poselib",
|
||||
ExpectedResults(1300, 0.8, 1.0),
|
||||
),
|
||||
]
|
||||
|
||||
visualize = False
|
||||
|
||||
@parameterized.expand(methods_to_test)
|
||||
@torch.no_grad()
|
||||
def test_real_homography(self, conf_file, estimator, exp_results):
|
||||
set_seed(0)
|
||||
model_path = root / "gluefactory" / "configs" / conf_file
|
||||
img_path0 = root / "assets" / "boat1.png"
|
||||
img_path1 = root / "assets" / "boat2.png"
|
||||
h_gt = torch.tensor(
|
||||
[
|
||||
[0.85799, 0.21669, 9.4839],
|
||||
[-0.21177, 0.85855, 130.48],
|
||||
[1.5015e-06, 9.2033e-07, 1],
|
||||
]
|
||||
)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
gs = TwoViewPipeline(OmegaConf.load(model_path).model).to(device).eval()
|
||||
|
||||
cv_img0, cv_img1 = cv2.imread(str(img_path0)), cv2.imread(str(img_path1))
|
||||
data = create_input_data(cv_img0, cv_img1, device)
|
||||
pred = gs(data)
|
||||
pred = map_tensor(
|
||||
pred, lambda t: torch.squeeze(t, dim=0) if isinstance(t, Tensor) else t
|
||||
)
|
||||
data["H_0to1"] = h_gt.to(device)
|
||||
data["H_1to0"] = torch.linalg.inv(h_gt).to(device)
|
||||
|
||||
results = eval_matches_homography(data, pred)
|
||||
results = {**results, **eval_homography_dlt(data, pred)}
|
||||
|
||||
results = {
|
||||
**results,
|
||||
**eval_homography_robust(
|
||||
data,
|
||||
pred,
|
||||
{"estimator": estimator},
|
||||
),
|
||||
}
|
||||
|
||||
logger.info(results)
|
||||
self.assertGreater(results["num_matches"], exp_results.num_matches)
|
||||
self.assertGreater(results["prec@3px"], exp_results.prec3px)
|
||||
self.assertLess(results["H_error_ransac"], exp_results.h_error)
|
||||
|
||||
if self.visualize:
|
||||
pred = map_tensor(
|
||||
pred, lambda t: t.cpu().numpy() if isinstance(t, Tensor) else t
|
||||
)
|
||||
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
|
||||
m0 = pred["matches0"]
|
||||
valid0 = m0 != -1
|
||||
kpm0, kpm1 = kp0[valid0], kp1[m0[valid0]]
|
||||
|
||||
plot_images([cv_img0, cv_img1])
|
||||
plot_matches(kpm0, kpm1, a=0.0)
|
||||
plt.savefig(f"{splitext(conf_file)[0]}_point_matches.svg")
|
||||
|
||||
if "lines0" in pred and "lines1" in pred:
|
||||
lines0, lines1 = pred["lines0"], pred["lines1"]
|
||||
lm0 = pred["line_matches0"]
|
||||
lvalid0 = lm0 != -1
|
||||
linem0, linem1 = lines0[lvalid0], lines1[lm0[lvalid0]]
|
||||
|
||||
plot_images([cv_img0, cv_img1])
|
||||
plot_color_line_matches([linem0, linem1])
|
||||
plt.savefig(f"{splitext(conf_file)[0]}_line_matches.svg")
|
||||
plt.show()
|
Loading…
Reference in New Issue