Fix GlueStick training config (#29)
* Update training config of GlueStick * Remove unnecessary checks in GlueStick * Update ETH3D download link * Update link to undistorted ETH3D * Update link to download DeepLSD --------- Co-authored-by: pautratr <pautratr@student.ethz.ch>main
parent
ab6e6eef8b
commit
4a8283517f
|
@ -1,14 +1,14 @@
|
|||
data:
|
||||
name: homographies
|
||||
homography:
|
||||
difficulty: 0.5
|
||||
max_angle: 30
|
||||
difficulty: 0.7
|
||||
max_angle: 45
|
||||
patch_shape: [640, 480]
|
||||
photometric:
|
||||
p: 0.75
|
||||
train_size: 900000
|
||||
val_size: 1000
|
||||
batch_size: 80 # 20 per 10GB of GPU mem (12 for triplet)
|
||||
batch_size: 160 # 20 per 10GB of GPU mem (12 for triplet)
|
||||
num_workers: 15
|
||||
model:
|
||||
name: gluefactory.models.two_view_pipeline
|
||||
|
@ -70,4 +70,4 @@ train:
|
|||
n_steps: 4
|
||||
submodules: []
|
||||
# clip_grad: 10 # Use only with mixed precision
|
||||
# load_experiment:
|
||||
# load_experiment:
|
|
@ -1,10 +1,15 @@
|
|||
data:
|
||||
name: gluefactory.datasets.megadepth
|
||||
train_num_per_scene: 300
|
||||
val_pairs: valid_pairs.txt
|
||||
views: 2
|
||||
min_overlap: 0.1
|
||||
max_overlap: 0.7
|
||||
num_overlap_bins: 3
|
||||
preprocessing:
|
||||
resize: 640
|
||||
square_pad: True
|
||||
batch_size: 60
|
||||
batch_size: 160
|
||||
num_workers: 15
|
||||
model:
|
||||
name: gluefactory.models.two_view_pipeline
|
||||
|
@ -53,9 +58,9 @@ model:
|
|||
train:
|
||||
seed: 0
|
||||
epochs: 200
|
||||
log_every_iter: 10
|
||||
eval_every_iter: 100
|
||||
save_every_iter: 500
|
||||
log_every_iter: 400
|
||||
eval_every_iter: 700
|
||||
save_every_iter: 1400
|
||||
lr: 1e-4
|
||||
lr_schedule:
|
||||
type: exp # exp or multi_step
|
||||
|
|
|
@ -194,7 +194,7 @@ class ETH3DDataset(BaseDataset):
|
|||
if tmp_dir.exists():
|
||||
shutil.rmtree(tmp_dir)
|
||||
tmp_dir.mkdir(exist_ok=True, parents=True)
|
||||
url_base = "https://cvg-data.inf.ethz.ch/ETH3D_undistorted/"
|
||||
url_base = "https://cvg-data.inf.ethz.ch/SOLD2/SOLD2_ETH3D_undistorted/"
|
||||
zip_name = "ETH3D_undistorted.zip"
|
||||
zip_path = tmp_dir / zip_name
|
||||
torch.hub.download_url_to_file(url_base + zip_name, zip_path)
|
||||
|
|
|
@ -41,7 +41,7 @@ class DeepLSD(BaseModel):
|
|||
|
||||
if not path.parent.is_dir():
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
link = "https://www.polybox.ethz.ch/index.php/s/XVb30sUyuJttFys/download"
|
||||
link = "https://cvg-data.inf.ethz.ch/DeepLSD/deeplsd_md.tar"
|
||||
cmd = ["wget", link, "-O", path]
|
||||
print("Downloading DeepLSD model...")
|
||||
subprocess.run(cmd, check=True)
|
||||
|
|
|
@ -131,7 +131,7 @@ class GlueStick(BaseModel):
|
|||
state_dict = {
|
||||
k.replace("module.", ""): v for k, v in state_dict.items()
|
||||
}
|
||||
self.load_state_dict(state_dict)
|
||||
self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
def _forward(self, data):
|
||||
device = data["keypoints0"].device
|
||||
|
@ -200,8 +200,6 @@ class GlueStick(BaseModel):
|
|||
kpts0 = normalize_keypoints(kpts0, image_size0)
|
||||
kpts1 = normalize_keypoints(kpts1, image_size1)
|
||||
|
||||
assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
|
||||
assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)
|
||||
desc0 = desc0 + self.kenc(kpts0, data["keypoint_scores0"])
|
||||
desc1 = desc1 + self.kenc(kpts1, data["keypoint_scores1"])
|
||||
|
||||
|
|
Loading…
Reference in New Issue