Fix GlueStick training config (#29)
* Update training config of GlueStick * Remove unnecessary checks in GlueStick * Update ETH3D download link * Update link to undistorted ETH3D * Update link to download DeepLSD --------- Co-authored-by: pautratr <pautratr@student.ethz.ch>main
parent
ab6e6eef8b
commit
4a8283517f
|
@ -1,14 +1,14 @@
|
||||||
data:
|
data:
|
||||||
name: homographies
|
name: homographies
|
||||||
homography:
|
homography:
|
||||||
difficulty: 0.5
|
difficulty: 0.7
|
||||||
max_angle: 30
|
max_angle: 45
|
||||||
patch_shape: [640, 480]
|
patch_shape: [640, 480]
|
||||||
photometric:
|
photometric:
|
||||||
p: 0.75
|
p: 0.75
|
||||||
train_size: 900000
|
train_size: 900000
|
||||||
val_size: 1000
|
val_size: 1000
|
||||||
batch_size: 80 # 20 per 10GB of GPU mem (12 for triplet)
|
batch_size: 160 # 20 per 10GB of GPU mem (12 for triplet)
|
||||||
num_workers: 15
|
num_workers: 15
|
||||||
model:
|
model:
|
||||||
name: gluefactory.models.two_view_pipeline
|
name: gluefactory.models.two_view_pipeline
|
||||||
|
|
|
@ -1,10 +1,15 @@
|
||||||
data:
|
data:
|
||||||
name: gluefactory.datasets.megadepth
|
name: gluefactory.datasets.megadepth
|
||||||
|
train_num_per_scene: 300
|
||||||
|
val_pairs: valid_pairs.txt
|
||||||
views: 2
|
views: 2
|
||||||
|
min_overlap: 0.1
|
||||||
|
max_overlap: 0.7
|
||||||
|
num_overlap_bins: 3
|
||||||
preprocessing:
|
preprocessing:
|
||||||
resize: 640
|
resize: 640
|
||||||
square_pad: True
|
square_pad: True
|
||||||
batch_size: 60
|
batch_size: 160
|
||||||
num_workers: 15
|
num_workers: 15
|
||||||
model:
|
model:
|
||||||
name: gluefactory.models.two_view_pipeline
|
name: gluefactory.models.two_view_pipeline
|
||||||
|
@ -53,9 +58,9 @@ model:
|
||||||
train:
|
train:
|
||||||
seed: 0
|
seed: 0
|
||||||
epochs: 200
|
epochs: 200
|
||||||
log_every_iter: 10
|
log_every_iter: 400
|
||||||
eval_every_iter: 100
|
eval_every_iter: 700
|
||||||
save_every_iter: 500
|
save_every_iter: 1400
|
||||||
lr: 1e-4
|
lr: 1e-4
|
||||||
lr_schedule:
|
lr_schedule:
|
||||||
type: exp # exp or multi_step
|
type: exp # exp or multi_step
|
||||||
|
|
|
@ -194,7 +194,7 @@ class ETH3DDataset(BaseDataset):
|
||||||
if tmp_dir.exists():
|
if tmp_dir.exists():
|
||||||
shutil.rmtree(tmp_dir)
|
shutil.rmtree(tmp_dir)
|
||||||
tmp_dir.mkdir(exist_ok=True, parents=True)
|
tmp_dir.mkdir(exist_ok=True, parents=True)
|
||||||
url_base = "https://cvg-data.inf.ethz.ch/ETH3D_undistorted/"
|
url_base = "https://cvg-data.inf.ethz.ch/SOLD2/SOLD2_ETH3D_undistorted/"
|
||||||
zip_name = "ETH3D_undistorted.zip"
|
zip_name = "ETH3D_undistorted.zip"
|
||||||
zip_path = tmp_dir / zip_name
|
zip_path = tmp_dir / zip_name
|
||||||
torch.hub.download_url_to_file(url_base + zip_name, zip_path)
|
torch.hub.download_url_to_file(url_base + zip_name, zip_path)
|
||||||
|
|
|
@ -41,7 +41,7 @@ class DeepLSD(BaseModel):
|
||||||
|
|
||||||
if not path.parent.is_dir():
|
if not path.parent.is_dir():
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
link = "https://www.polybox.ethz.ch/index.php/s/XVb30sUyuJttFys/download"
|
link = "https://cvg-data.inf.ethz.ch/DeepLSD/deeplsd_md.tar"
|
||||||
cmd = ["wget", link, "-O", path]
|
cmd = ["wget", link, "-O", path]
|
||||||
print("Downloading DeepLSD model...")
|
print("Downloading DeepLSD model...")
|
||||||
subprocess.run(cmd, check=True)
|
subprocess.run(cmd, check=True)
|
||||||
|
|
|
@ -131,7 +131,7 @@ class GlueStick(BaseModel):
|
||||||
state_dict = {
|
state_dict = {
|
||||||
k.replace("module.", ""): v for k, v in state_dict.items()
|
k.replace("module.", ""): v for k, v in state_dict.items()
|
||||||
}
|
}
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
def _forward(self, data):
|
def _forward(self, data):
|
||||||
device = data["keypoints0"].device
|
device = data["keypoints0"].device
|
||||||
|
@ -200,8 +200,6 @@ class GlueStick(BaseModel):
|
||||||
kpts0 = normalize_keypoints(kpts0, image_size0)
|
kpts0 = normalize_keypoints(kpts0, image_size0)
|
||||||
kpts1 = normalize_keypoints(kpts1, image_size1)
|
kpts1 = normalize_keypoints(kpts1, image_size1)
|
||||||
|
|
||||||
assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
|
|
||||||
assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)
|
|
||||||
desc0 = desc0 + self.kenc(kpts0, data["keypoint_scores0"])
|
desc0 = desc0 + self.kenc(kpts0, data["keypoint_scores0"])
|
||||||
desc1 = desc1 + self.kenc(kpts1, data["keypoint_scores1"])
|
desc1 = desc1 + self.kenc(kpts1, data["keypoint_scores1"])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue