glue-factory-custom/gluefactory/utils/experiments.py

135 lines
4.1 KiB
Python

"""
A set of utilities to manage and load checkpoints of training experiments.
Author: Paul-Edouard Sarlin (skydes)
"""
import logging
import os
import re
import shutil
from pathlib import Path
import torch
from omegaconf import OmegaConf
from ..models import get_model
from ..settings import TRAINING_PATH
logger = logging.getLogger(__name__)
def list_checkpoints(dir_):
"""List all valid checkpoints in a given directory."""
checkpoints = []
for p in dir_.glob("checkpoint_*.tar"):
numbers = re.findall(r"(\d+)", p.name)
assert len(numbers) <= 2
if len(numbers) == 0:
continue
if len(numbers) == 1:
checkpoints.append((int(numbers[0]), p))
else:
checkpoints.append((int(numbers[1]), p))
return checkpoints
def get_last_checkpoint(exper, allow_interrupted=True):
"""Get the last saved checkpoint for a given experiment name."""
ckpts = list_checkpoints(Path(TRAINING_PATH, exper))
if not allow_interrupted:
ckpts = [(n, p) for (n, p) in ckpts if "_interrupted" not in p.name]
assert len(ckpts) > 0
return sorted(ckpts)[-1][1]
def get_best_checkpoint(exper):
"""Get the checkpoint with the best loss, for a given experiment name."""
p = Path(TRAINING_PATH, exper, "checkpoint_best.tar")
return p
def delete_old_checkpoints(dir_, num_keep):
"""Delete all but the num_keep last saved checkpoints."""
ckpts = list_checkpoints(dir_)
ckpts = sorted(ckpts)[::-1]
kept = 0
for ckpt in ckpts:
if ("_interrupted" in str(ckpt[1]) and kept > 0) or kept >= num_keep:
logger.info(f"Deleting checkpoint {ckpt[1].name}")
ckpt[1].unlink()
else:
kept += 1
def load_experiment(exper, conf={}, get_last=False, ckpt=None):
"""Load and return the model of a given experiment."""
exper = Path(exper)
if exper.suffix != ".tar":
if get_last:
ckpt = get_last_checkpoint(exper)
else:
ckpt = get_best_checkpoint(exper)
else:
ckpt = exper
logger.info(f"Loading checkpoint {ckpt.name}")
ckpt = torch.load(str(ckpt), map_location="cpu")
loaded_conf = OmegaConf.create(ckpt["conf"])
OmegaConf.set_struct(loaded_conf, False)
conf = OmegaConf.merge(loaded_conf.model, OmegaConf.create(conf))
model = get_model(conf.name)(conf).eval()
state_dict = ckpt["model"]
dict_params = set(state_dict.keys())
model_params = set(map(lambda n: n[0], model.named_parameters()))
diff = model_params - dict_params
if len(diff) > 0:
subs = os.path.commonprefix(list(diff)).rstrip(".")
logger.warning(f"Missing {len(diff)} parameters in {subs}")
model.load_state_dict(state_dict, strict=False)
return model
# @TODO: also copy the respective module scripts (i.e. the code)
def save_experiment(
model,
optimizer,
lr_scheduler,
conf,
losses,
results,
best_eval,
epoch,
iter_i,
output_dir,
stop=False,
distributed=False,
cp_name=None,
):
"""Save the current model to a checkpoint
and return the best result so far."""
state = (model.module if distributed else model).state_dict()
checkpoint = {
"model": state,
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"conf": OmegaConf.to_container(conf, resolve=True),
"epoch": epoch,
"losses": losses,
"eval": results,
}
if cp_name is None:
cp_name = (
f"checkpoint_{epoch}_{iter_i}" + ("_interrupted" if stop else "") + ".tar"
)
logger.info(f"Saving checkpoint {cp_name}")
cp_path = str(output_dir / cp_name)
torch.save(checkpoint, cp_path)
if cp_name != "checkpoint_best.tar" and results[conf.train.best_key] < best_eval:
best_eval = results[conf.train.best_key]
logger.info(f"New best val: {conf.train.best_key}={best_eval}")
shutil.copy(cp_path, str(output_dir / "checkpoint_best.tar"))
delete_old_checkpoints(output_dir, conf.train.keep_last_checkpoints)
return best_eval