135 lines
4.1 KiB
Python
135 lines
4.1 KiB
Python
"""
|
|
A set of utilities to manage and load checkpoints of training experiments.
|
|
|
|
Author: Paul-Edouard Sarlin (skydes)
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import re
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from omegaconf import OmegaConf
|
|
|
|
from ..models import get_model
|
|
from ..settings import TRAINING_PATH
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def list_checkpoints(dir_):
|
|
"""List all valid checkpoints in a given directory."""
|
|
checkpoints = []
|
|
for p in dir_.glob("checkpoint_*.tar"):
|
|
numbers = re.findall(r"(\d+)", p.name)
|
|
assert len(numbers) <= 2
|
|
if len(numbers) == 0:
|
|
continue
|
|
if len(numbers) == 1:
|
|
checkpoints.append((int(numbers[0]), p))
|
|
else:
|
|
checkpoints.append((int(numbers[1]), p))
|
|
return checkpoints
|
|
|
|
|
|
def get_last_checkpoint(exper, allow_interrupted=True):
|
|
"""Get the last saved checkpoint for a given experiment name."""
|
|
ckpts = list_checkpoints(Path(TRAINING_PATH, exper))
|
|
if not allow_interrupted:
|
|
ckpts = [(n, p) for (n, p) in ckpts if "_interrupted" not in p.name]
|
|
assert len(ckpts) > 0
|
|
return sorted(ckpts)[-1][1]
|
|
|
|
|
|
def get_best_checkpoint(exper):
|
|
"""Get the checkpoint with the best loss, for a given experiment name."""
|
|
p = Path(TRAINING_PATH, exper, "checkpoint_best.tar")
|
|
return p
|
|
|
|
|
|
def delete_old_checkpoints(dir_, num_keep):
|
|
"""Delete all but the num_keep last saved checkpoints."""
|
|
ckpts = list_checkpoints(dir_)
|
|
ckpts = sorted(ckpts)[::-1]
|
|
kept = 0
|
|
for ckpt in ckpts:
|
|
if ("_interrupted" in str(ckpt[1]) and kept > 0) or kept >= num_keep:
|
|
logger.info(f"Deleting checkpoint {ckpt[1].name}")
|
|
ckpt[1].unlink()
|
|
else:
|
|
kept += 1
|
|
|
|
|
|
def load_experiment(exper, conf={}, get_last=False, ckpt=None):
|
|
"""Load and return the model of a given experiment."""
|
|
exper = Path(exper)
|
|
if exper.suffix != ".tar":
|
|
if get_last:
|
|
ckpt = get_last_checkpoint(exper)
|
|
else:
|
|
ckpt = get_best_checkpoint(exper)
|
|
else:
|
|
ckpt = exper
|
|
logger.info(f"Loading checkpoint {ckpt.name}")
|
|
ckpt = torch.load(str(ckpt), map_location="cpu")
|
|
|
|
loaded_conf = OmegaConf.create(ckpt["conf"])
|
|
OmegaConf.set_struct(loaded_conf, False)
|
|
conf = OmegaConf.merge(loaded_conf.model, OmegaConf.create(conf))
|
|
model = get_model(conf.name)(conf).eval()
|
|
|
|
state_dict = ckpt["model"]
|
|
dict_params = set(state_dict.keys())
|
|
model_params = set(map(lambda n: n[0], model.named_parameters()))
|
|
diff = model_params - dict_params
|
|
if len(diff) > 0:
|
|
subs = os.path.commonprefix(list(diff)).rstrip(".")
|
|
logger.warning(f"Missing {len(diff)} parameters in {subs}")
|
|
model.load_state_dict(state_dict, strict=False)
|
|
return model
|
|
|
|
|
|
# @TODO: also copy the respective module scripts (i.e. the code)
|
|
def save_experiment(
|
|
model,
|
|
optimizer,
|
|
lr_scheduler,
|
|
conf,
|
|
losses,
|
|
results,
|
|
best_eval,
|
|
epoch,
|
|
iter_i,
|
|
output_dir,
|
|
stop=False,
|
|
distributed=False,
|
|
cp_name=None,
|
|
):
|
|
"""Save the current model to a checkpoint
|
|
and return the best result so far."""
|
|
state = (model.module if distributed else model).state_dict()
|
|
checkpoint = {
|
|
"model": state,
|
|
"optimizer": optimizer.state_dict(),
|
|
"lr_scheduler": lr_scheduler.state_dict(),
|
|
"conf": OmegaConf.to_container(conf, resolve=True),
|
|
"epoch": epoch,
|
|
"losses": losses,
|
|
"eval": results,
|
|
}
|
|
if cp_name is None:
|
|
cp_name = (
|
|
f"checkpoint_{epoch}_{iter_i}" + ("_interrupted" if stop else "") + ".tar"
|
|
)
|
|
logger.info(f"Saving checkpoint {cp_name}")
|
|
cp_path = str(output_dir / cp_name)
|
|
torch.save(checkpoint, cp_path)
|
|
if cp_name != "checkpoint_best.tar" and results[conf.train.best_key] < best_eval:
|
|
best_eval = results[conf.train.best_key]
|
|
logger.info(f"New best val: {conf.train.best_key}={best_eval}")
|
|
shutil.copy(cp_path, str(output_dir / "checkpoint_best.tar"))
|
|
delete_old_checkpoints(output_dir, conf.train.keep_last_checkpoints)
|
|
return best_eval
|