Support additional LR schedulers (#11)
--------- Co-authored-by: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com>main
parent
43cf81aa2f
commit
e0104fd65b
|
@ -48,11 +48,12 @@ default_train_conf = {
|
||||||
"optimizer_options": {}, # optional arguments passed to the optimizer
|
"optimizer_options": {}, # optional arguments passed to the optimizer
|
||||||
"lr": 0.001, # learning rate
|
"lr": 0.001, # learning rate
|
||||||
"lr_schedule": {
|
"lr_schedule": {
|
||||||
"type": None,
|
"type": None, # string in {factor, exp, member of torch.optim.lr_scheduler}
|
||||||
"start": 0,
|
"start": 0,
|
||||||
"exp_div_10": 0,
|
"exp_div_10": 0,
|
||||||
"on_epoch": False,
|
"on_epoch": False,
|
||||||
"factor": 1.0,
|
"factor": 1.0,
|
||||||
|
"options": {}, # add lr_scheduler arguments here
|
||||||
},
|
},
|
||||||
"lr_scaling": [(100, ["dampingnet.const"])],
|
"lr_scaling": [(100, ["dampingnet.const"])],
|
||||||
"eval_every_iter": 1000, # interval for evaluation on the validation set
|
"eval_every_iter": 1000, # interval for evaluation on the validation set
|
||||||
|
@ -141,6 +142,26 @@ def filter_parameters(params, regexp):
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def get_lr_scheduler(optimizer, conf):
|
||||||
|
"""Get lr scheduler specified by conf.train.lr_schedule."""
|
||||||
|
if conf.type not in ["factor", "exp", None]:
|
||||||
|
return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, **conf.options)
|
||||||
|
|
||||||
|
# backward compatibility
|
||||||
|
def lr_fn(it): # noqa: E306
|
||||||
|
if conf.type is None:
|
||||||
|
return 1
|
||||||
|
if conf.type == "factor":
|
||||||
|
return 1.0 if it < conf.start else conf.factor
|
||||||
|
if conf.type == "exp":
|
||||||
|
gam = 10 ** (-1 / conf.exp_div_10)
|
||||||
|
return 1.0 if it < conf.start else gam
|
||||||
|
else:
|
||||||
|
raise ValueError(conf.type)
|
||||||
|
|
||||||
|
return torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn)
|
||||||
|
|
||||||
|
|
||||||
def pack_lr_parameters(params, base_lr, lr_scaling):
|
def pack_lr_parameters(params, base_lr, lr_scaling):
|
||||||
"""Pack each group of parameters with the respective scaled learning rate."""
|
"""Pack each group of parameters with the respective scaled learning rate."""
|
||||||
filters, scales = tuple(zip(*[(n, s) for s, names in lr_scaling for n in names]))
|
filters, scales = tuple(zip(*[(n, s) for s, names in lr_scaling for n in names]))
|
||||||
|
@ -310,22 +331,7 @@ def training(rank, conf, output_dir, args):
|
||||||
|
|
||||||
results = None # fix bug with it saving
|
results = None # fix bug with it saving
|
||||||
|
|
||||||
def lr_fn(it): # noqa: E306
|
lr_scheduler = get_lr_scheduler(optimizer=optimizer, conf=conf.train.lr_scheduler)
|
||||||
if conf.train.lr_schedule.type is None:
|
|
||||||
return 1
|
|
||||||
if conf.train.lr_schedule.type == "factor":
|
|
||||||
return (
|
|
||||||
1.0
|
|
||||||
if it < conf.train.lr_schedule.start
|
|
||||||
else conf.train.lr_schedule.factor
|
|
||||||
)
|
|
||||||
if conf.train.lr_schedule.type == "exp":
|
|
||||||
gam = 10 ** (-1 / conf.train.lr_schedule.exp_div_10)
|
|
||||||
return 1.0 if it < conf.train.lr_schedule.start else gam
|
|
||||||
else:
|
|
||||||
raise ValueError(conf.train.lr_schedule.type)
|
|
||||||
|
|
||||||
lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn)
|
|
||||||
if args.restore:
|
if args.restore:
|
||||||
optimizer.load_state_dict(init_cp["optimizer"])
|
optimizer.load_state_dict(init_cp["optimizer"])
|
||||||
if "lr_scheduler" in init_cp:
|
if "lr_scheduler" in init_cp:
|
||||||
|
|
Loading…
Reference in New Issue