Fix GlueStick training config (#29)

* Update training config of GlueStick

* Remove unnecessary checks in GlueStick

* Update ETH3D download link

* Update link to undistorted ETH3D

* Update link to download DeepLSD

---------

Co-authored-by: pautratr <pautratr@student.ethz.ch>
main
Rémi Pautrat 2023-11-01 14:36:58 +01:00 committed by GitHub
parent ab6e6eef8b
commit 4a8283517f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 16 additions and 13 deletions

View File

@ -1,14 +1,14 @@
data: data:
name: homographies name: homographies
homography: homography:
difficulty: 0.5 difficulty: 0.7
max_angle: 30 max_angle: 45
patch_shape: [640, 480] patch_shape: [640, 480]
photometric: photometric:
p: 0.75 p: 0.75
train_size: 900000 train_size: 900000
val_size: 1000 val_size: 1000
batch_size: 80 # 20 per 10GB of GPU mem (12 for triplet) batch_size: 160 # 20 per 10GB of GPU mem (12 for triplet)
num_workers: 15 num_workers: 15
model: model:
name: gluefactory.models.two_view_pipeline name: gluefactory.models.two_view_pipeline

View File

@ -1,10 +1,15 @@
data: data:
name: gluefactory.datasets.megadepth name: gluefactory.datasets.megadepth
train_num_per_scene: 300
val_pairs: valid_pairs.txt
views: 2 views: 2
min_overlap: 0.1
max_overlap: 0.7
num_overlap_bins: 3
preprocessing: preprocessing:
resize: 640 resize: 640
square_pad: True square_pad: True
batch_size: 60 batch_size: 160
num_workers: 15 num_workers: 15
model: model:
name: gluefactory.models.two_view_pipeline name: gluefactory.models.two_view_pipeline
@ -53,9 +58,9 @@ model:
train: train:
seed: 0 seed: 0
epochs: 200 epochs: 200
log_every_iter: 10 log_every_iter: 400
eval_every_iter: 100 eval_every_iter: 700
save_every_iter: 500 save_every_iter: 1400
lr: 1e-4 lr: 1e-4
lr_schedule: lr_schedule:
type: exp # exp or multi_step type: exp # exp or multi_step

View File

@ -194,7 +194,7 @@ class ETH3DDataset(BaseDataset):
if tmp_dir.exists(): if tmp_dir.exists():
shutil.rmtree(tmp_dir) shutil.rmtree(tmp_dir)
tmp_dir.mkdir(exist_ok=True, parents=True) tmp_dir.mkdir(exist_ok=True, parents=True)
url_base = "https://cvg-data.inf.ethz.ch/ETH3D_undistorted/" url_base = "https://cvg-data.inf.ethz.ch/SOLD2/SOLD2_ETH3D_undistorted/"
zip_name = "ETH3D_undistorted.zip" zip_name = "ETH3D_undistorted.zip"
zip_path = tmp_dir / zip_name zip_path = tmp_dir / zip_name
torch.hub.download_url_to_file(url_base + zip_name, zip_path) torch.hub.download_url_to_file(url_base + zip_name, zip_path)

View File

@ -41,7 +41,7 @@ class DeepLSD(BaseModel):
if not path.parent.is_dir(): if not path.parent.is_dir():
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
link = "https://www.polybox.ethz.ch/index.php/s/XVb30sUyuJttFys/download" link = "https://cvg-data.inf.ethz.ch/DeepLSD/deeplsd_md.tar"
cmd = ["wget", link, "-O", path] cmd = ["wget", link, "-O", path]
print("Downloading DeepLSD model...") print("Downloading DeepLSD model...")
subprocess.run(cmd, check=True) subprocess.run(cmd, check=True)

View File

@ -131,7 +131,7 @@ class GlueStick(BaseModel):
state_dict = { state_dict = {
k.replace("module.", ""): v for k, v in state_dict.items() k.replace("module.", ""): v for k, v in state_dict.items()
} }
self.load_state_dict(state_dict) self.load_state_dict(state_dict, strict=False)
def _forward(self, data): def _forward(self, data):
device = data["keypoints0"].device device = data["keypoints0"].device
@ -200,8 +200,6 @@ class GlueStick(BaseModel):
kpts0 = normalize_keypoints(kpts0, image_size0) kpts0 = normalize_keypoints(kpts0, image_size0)
kpts1 = normalize_keypoints(kpts1, image_size1) kpts1 = normalize_keypoints(kpts1, image_size1)
assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)
desc0 = desc0 + self.kenc(kpts0, data["keypoint_scores0"]) desc0 = desc0 + self.kenc(kpts0, data["keypoint_scores0"])
desc1 = desc1 + self.kenc(kpts1, data["keypoint_scores1"]) desc1 = desc1 + self.kenc(kpts1, data["keypoint_scores1"])