Fix GlueStick training config (#29)

* Update training config of GlueStick

* Remove unnecessary checks in GlueStick

* Update ETH3D download link

* Update link to undistorted ETH3D

* Update link to download DeepLSD

---------

Co-authored-by: pautratr <pautratr@student.ethz.ch>
main
Rémi Pautrat 2023-11-01 14:36:58 +01:00 committed by GitHub
parent ab6e6eef8b
commit 4a8283517f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 16 additions and 13 deletions

View File

@ -1,14 +1,14 @@
data:
name: homographies
homography:
difficulty: 0.5
max_angle: 30
difficulty: 0.7
max_angle: 45
patch_shape: [640, 480]
photometric:
p: 0.75
train_size: 900000
val_size: 1000
batch_size: 80 # 20 per 10GB of GPU mem (12 for triplet)
batch_size: 160 # 20 per 10GB of GPU mem (12 for triplet)
num_workers: 15
model:
name: gluefactory.models.two_view_pipeline
@ -70,4 +70,4 @@ train:
n_steps: 4
submodules: []
# clip_grad: 10 # Use only with mixed precision
# load_experiment:
# load_experiment:

View File

@ -1,10 +1,15 @@
data:
name: gluefactory.datasets.megadepth
train_num_per_scene: 300
val_pairs: valid_pairs.txt
views: 2
min_overlap: 0.1
max_overlap: 0.7
num_overlap_bins: 3
preprocessing:
resize: 640
square_pad: True
batch_size: 60
batch_size: 160
num_workers: 15
model:
name: gluefactory.models.two_view_pipeline
@ -53,9 +58,9 @@ model:
train:
seed: 0
epochs: 200
log_every_iter: 10
eval_every_iter: 100
save_every_iter: 500
log_every_iter: 400
eval_every_iter: 700
save_every_iter: 1400
lr: 1e-4
lr_schedule:
type: exp # exp or multi_step

View File

@ -194,7 +194,7 @@ class ETH3DDataset(BaseDataset):
if tmp_dir.exists():
shutil.rmtree(tmp_dir)
tmp_dir.mkdir(exist_ok=True, parents=True)
url_base = "https://cvg-data.inf.ethz.ch/ETH3D_undistorted/"
url_base = "https://cvg-data.inf.ethz.ch/SOLD2/SOLD2_ETH3D_undistorted/"
zip_name = "ETH3D_undistorted.zip"
zip_path = tmp_dir / zip_name
torch.hub.download_url_to_file(url_base + zip_name, zip_path)

View File

@ -41,7 +41,7 @@ class DeepLSD(BaseModel):
if not path.parent.is_dir():
path.parent.mkdir(parents=True, exist_ok=True)
link = "https://www.polybox.ethz.ch/index.php/s/XVb30sUyuJttFys/download"
link = "https://cvg-data.inf.ethz.ch/DeepLSD/deeplsd_md.tar"
cmd = ["wget", link, "-O", path]
print("Downloading DeepLSD model...")
subprocess.run(cmd, check=True)

View File

@ -131,7 +131,7 @@ class GlueStick(BaseModel):
state_dict = {
k.replace("module.", ""): v for k, v in state_dict.items()
}
self.load_state_dict(state_dict)
self.load_state_dict(state_dict, strict=False)
def _forward(self, data):
device = data["keypoints0"].device
@ -200,8 +200,6 @@ class GlueStick(BaseModel):
kpts0 = normalize_keypoints(kpts0, image_size0)
kpts1 = normalize_keypoints(kpts1, image_size1)
assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)
desc0 = desc0 + self.kenc(kpts0, data["keypoint_scores0"])
desc1 = desc1 + self.kenc(kpts1, data["keypoint_scores1"])