127 lines
4.4 KiB
Python
127 lines
4.4 KiB
Python
"""
|
|
Base class for trainable models.
|
|
"""
|
|
|
|
from abc import ABCMeta, abstractmethod
|
|
import omegaconf
|
|
from omegaconf import OmegaConf
|
|
from torch import nn
|
|
from copy import copy
|
|
|
|
|
|
class MetaModel(ABCMeta):
|
|
def __prepare__(name, bases, **kwds):
|
|
total_conf = OmegaConf.create()
|
|
for base in bases:
|
|
for key in ("base_default_conf", "default_conf"):
|
|
update = getattr(base, key, {})
|
|
if isinstance(update, dict):
|
|
update = OmegaConf.create(update)
|
|
total_conf = OmegaConf.merge(total_conf, update)
|
|
return dict(base_default_conf=total_conf)
|
|
|
|
|
|
class BaseModel(nn.Module, metaclass=MetaModel):
|
|
"""
|
|
What the child model is expect to declare:
|
|
default_conf: dictionary of the default configuration of the model.
|
|
It recursively updates the default_conf of all parent classes, and
|
|
it is updated by the user-provided configuration passed to __init__.
|
|
Configurations can be nested.
|
|
|
|
required_data_keys: list of expected keys in the input data dictionary.
|
|
|
|
strict_conf (optional): boolean. If false, BaseModel does not raise
|
|
an error when the user provides an unknown configuration entry.
|
|
|
|
_init(self, conf): initialization method, where conf is the final
|
|
configuration object (also accessible with `self.conf`). Accessing
|
|
unknown configuration entries will raise an error.
|
|
|
|
_forward(self, data): method that returns a dictionary of batched
|
|
prediction tensors based on a dictionary of batched input data tensors.
|
|
|
|
loss(self, pred, data): method that returns a dictionary of losses,
|
|
computed from model predictions and input data. Each loss is a batch
|
|
of scalars, i.e. a torch.Tensor of shape (B,).
|
|
The total loss to be optimized has the key `'total'`.
|
|
|
|
metrics(self, pred, data): method that returns a dictionary of metrics,
|
|
each as a batch of scalars.
|
|
"""
|
|
|
|
default_conf = {
|
|
"name": None,
|
|
"trainable": True, # if false: do not optimize this model parameters
|
|
"freeze_batch_normalization": False, # use test-time statistics
|
|
"timeit": False, # time forward pass
|
|
}
|
|
required_data_keys = []
|
|
strict_conf = False
|
|
|
|
def __init__(self, conf):
|
|
"""Perform some logic and call the _init method of the child model."""
|
|
super().__init__()
|
|
default_conf = OmegaConf.merge(
|
|
self.base_default_conf, OmegaConf.create(self.default_conf)
|
|
)
|
|
if self.strict_conf:
|
|
OmegaConf.set_struct(default_conf, True)
|
|
|
|
# fixme: backward compatibility
|
|
if "pad" in conf and "pad" not in default_conf: # backward compat.
|
|
with omegaconf.read_write(conf):
|
|
with omegaconf.open_dict(conf):
|
|
conf["interpolation"] = {"pad": conf.pop("pad")}
|
|
|
|
if isinstance(conf, dict):
|
|
conf = OmegaConf.create(conf)
|
|
self.conf = conf = OmegaConf.merge(default_conf, conf)
|
|
OmegaConf.set_readonly(conf, True)
|
|
OmegaConf.set_struct(conf, True)
|
|
self.required_data_keys = copy(self.required_data_keys)
|
|
self._init(conf)
|
|
|
|
if not conf.trainable:
|
|
for p in self.parameters():
|
|
p.requires_grad = False
|
|
|
|
def train(self, mode=True):
|
|
super().train(mode)
|
|
|
|
def freeze_bn(module):
|
|
if isinstance(module, nn.modules.batchnorm._BatchNorm):
|
|
module.eval()
|
|
|
|
if self.conf.freeze_batch_normalization:
|
|
self.apply(freeze_bn)
|
|
|
|
return self
|
|
|
|
def forward(self, data):
|
|
"""Check the data and call the _forward method of the child model."""
|
|
|
|
def recursive_key_check(expected, given):
|
|
for key in expected:
|
|
assert key in given, f"Missing key {key} in data"
|
|
if isinstance(expected, dict):
|
|
recursive_key_check(expected[key], given[key])
|
|
|
|
recursive_key_check(self.required_data_keys, data)
|
|
return self._forward(data)
|
|
|
|
@abstractmethod
|
|
def _init(self, conf):
|
|
"""To be implemented by the child class."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def _forward(self, data):
|
|
"""To be implemented by the child class."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def loss(self, pred, data):
|
|
"""To be implemented by the child class."""
|
|
raise NotImplementedError
|