Support additional LR schedulers (#11)

---------

Co-authored-by: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com>
main
Alexander Veicht 2023-10-23 12:36:56 +02:00 committed by GitHub
parent 43cf81aa2f
commit e0104fd65b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 23 additions and 17 deletions

View File

@ -48,11 +48,12 @@ default_train_conf = {
"optimizer_options": {}, # optional arguments passed to the optimizer "optimizer_options": {}, # optional arguments passed to the optimizer
"lr": 0.001, # learning rate "lr": 0.001, # learning rate
"lr_schedule": { "lr_schedule": {
"type": None, "type": None, # string in {factor, exp, member of torch.optim.lr_scheduler}
"start": 0, "start": 0,
"exp_div_10": 0, "exp_div_10": 0,
"on_epoch": False, "on_epoch": False,
"factor": 1.0, "factor": 1.0,
"options": {}, # add lr_scheduler arguments here
}, },
"lr_scaling": [(100, ["dampingnet.const"])], "lr_scaling": [(100, ["dampingnet.const"])],
"eval_every_iter": 1000, # interval for evaluation on the validation set "eval_every_iter": 1000, # interval for evaluation on the validation set
@ -141,6 +142,26 @@ def filter_parameters(params, regexp):
return params return params
def get_lr_scheduler(optimizer, conf):
"""Get lr scheduler specified by conf.train.lr_schedule."""
if conf.type not in ["factor", "exp", None]:
return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, **conf.options)
# backward compatibility
def lr_fn(it): # noqa: E306
if conf.type is None:
return 1
if conf.type == "factor":
return 1.0 if it < conf.start else conf.factor
if conf.type == "exp":
gam = 10 ** (-1 / conf.exp_div_10)
return 1.0 if it < conf.start else gam
else:
raise ValueError(conf.type)
return torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn)
def pack_lr_parameters(params, base_lr, lr_scaling): def pack_lr_parameters(params, base_lr, lr_scaling):
"""Pack each group of parameters with the respective scaled learning rate.""" """Pack each group of parameters with the respective scaled learning rate."""
filters, scales = tuple(zip(*[(n, s) for s, names in lr_scaling for n in names])) filters, scales = tuple(zip(*[(n, s) for s, names in lr_scaling for n in names]))
@ -310,22 +331,7 @@ def training(rank, conf, output_dir, args):
results = None # fix bug with it saving results = None # fix bug with it saving
def lr_fn(it): # noqa: E306 lr_scheduler = get_lr_scheduler(optimizer=optimizer, conf=conf.train.lr_scheduler)
if conf.train.lr_schedule.type is None:
return 1
if conf.train.lr_schedule.type == "factor":
return (
1.0
if it < conf.train.lr_schedule.start
else conf.train.lr_schedule.factor
)
if conf.train.lr_schedule.type == "exp":
gam = 10 ** (-1 / conf.train.lr_schedule.exp_div_10)
return 1.0 if it < conf.train.lr_schedule.start else gam
else:
raise ValueError(conf.train.lr_schedule.type)
lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn)
if args.restore: if args.restore:
optimizer.load_state_dict(init_cp["optimizer"]) optimizer.load_state_dict(init_cp["optimizer"])
if "lr_scheduler" in init_cp: if "lr_scheduler" in init_cp: