Support additional LR schedulers (#11)
--------- Co-authored-by: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com>main
parent
43cf81aa2f
commit
e0104fd65b
|
@ -48,11 +48,12 @@ default_train_conf = {
|
|||
"optimizer_options": {}, # optional arguments passed to the optimizer
|
||||
"lr": 0.001, # learning rate
|
||||
"lr_schedule": {
|
||||
"type": None,
|
||||
"type": None, # string in {factor, exp, member of torch.optim.lr_scheduler}
|
||||
"start": 0,
|
||||
"exp_div_10": 0,
|
||||
"on_epoch": False,
|
||||
"factor": 1.0,
|
||||
"options": {}, # add lr_scheduler arguments here
|
||||
},
|
||||
"lr_scaling": [(100, ["dampingnet.const"])],
|
||||
"eval_every_iter": 1000, # interval for evaluation on the validation set
|
||||
|
@ -141,6 +142,26 @@ def filter_parameters(params, regexp):
|
|||
return params
|
||||
|
||||
|
||||
def get_lr_scheduler(optimizer, conf):
|
||||
"""Get lr scheduler specified by conf.train.lr_schedule."""
|
||||
if conf.type not in ["factor", "exp", None]:
|
||||
return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, **conf.options)
|
||||
|
||||
# backward compatibility
|
||||
def lr_fn(it): # noqa: E306
|
||||
if conf.type is None:
|
||||
return 1
|
||||
if conf.type == "factor":
|
||||
return 1.0 if it < conf.start else conf.factor
|
||||
if conf.type == "exp":
|
||||
gam = 10 ** (-1 / conf.exp_div_10)
|
||||
return 1.0 if it < conf.start else gam
|
||||
else:
|
||||
raise ValueError(conf.type)
|
||||
|
||||
return torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn)
|
||||
|
||||
|
||||
def pack_lr_parameters(params, base_lr, lr_scaling):
|
||||
"""Pack each group of parameters with the respective scaled learning rate."""
|
||||
filters, scales = tuple(zip(*[(n, s) for s, names in lr_scaling for n in names]))
|
||||
|
@ -310,22 +331,7 @@ def training(rank, conf, output_dir, args):
|
|||
|
||||
results = None # fix bug with it saving
|
||||
|
||||
def lr_fn(it): # noqa: E306
|
||||
if conf.train.lr_schedule.type is None:
|
||||
return 1
|
||||
if conf.train.lr_schedule.type == "factor":
|
||||
return (
|
||||
1.0
|
||||
if it < conf.train.lr_schedule.start
|
||||
else conf.train.lr_schedule.factor
|
||||
)
|
||||
if conf.train.lr_schedule.type == "exp":
|
||||
gam = 10 ** (-1 / conf.train.lr_schedule.exp_div_10)
|
||||
return 1.0 if it < conf.train.lr_schedule.start else gam
|
||||
else:
|
||||
raise ValueError(conf.train.lr_schedule.type)
|
||||
|
||||
lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn)
|
||||
lr_scheduler = get_lr_scheduler(optimizer=optimizer, conf=conf.train.lr_scheduler)
|
||||
if args.restore:
|
||||
optimizer.load_state_dict(init_cp["optimizer"])
|
||||
if "lr_scheduler" in init_cp:
|
||||
|
|
Loading…
Reference in New Issue