Check for model initialization in eval (#9)

* Check for model initialization during eval

* Change from warning to assert

* change assertion to ValueError

* fix isort and rename are_weights_initialized

* cleanup set_initialized()

* fix variable name bug

* Make load_state_dict forward compatible

Co-authored-by: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com>

* Remove unused imports

---------

Co-authored-by: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com>
main
Philipp Lindenberger 2023-10-10 19:46:33 +02:00 committed by GitHub
parent f7b587e881
commit 22154a60bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 42 additions and 1 deletions

View File

@ -89,6 +89,11 @@ def load_model(model_conf, checkpoint):
model = load_experiment(checkpoint, conf=model_conf).eval()
else:
model = get_model("two_view_pipeline")(model_conf).eval()
if not model.is_initialized():
raise ValueError(
"The provided model has non-initialized parameters. "
+ "Try to load a checkpoint instead."
)
return model

View File

@ -10,6 +10,7 @@ class DinoV2(BaseModel):
def _init(self, conf):
self.net = torch.hub.load("facebookresearch/dinov2", conf.weights)
self.set_initialized()
def _forward(self, data):
img = data["image"]

View File

@ -60,6 +60,8 @@ class BaseModel(nn.Module, metaclass=MetaModel):
required_data_keys = []
strict_conf = False
are_weights_initialized = False
def __init__(self, conf):
"""Perform some logic and call the _init method of the child model."""
super().__init__()
@ -125,3 +127,31 @@ class BaseModel(nn.Module, metaclass=MetaModel):
def loss(self, pred, data):
"""To be implemented by the child class."""
raise NotImplementedError
def load_state_dict(self, *args, **kwargs):
"""Load the state dict of the model, and set the model to initialized."""
ret = super().load_state_dict(*args, **kwargs)
self.set_initialized()
return ret
def is_initialized(self):
"""Recursively check if the model is initialized, i.e. weights are loaded"""
is_initialized = True # initialize to true and perform recursive and
for _, w in self.named_children():
if isinstance(w, BaseModel):
# if children is BaseModel, we perform recursive check
is_initialized = is_initialized and w.is_initialized()
else:
# else, we check if self is initialized or the children has no params
n_params = len(list(w.parameters()))
is_initialized = is_initialized and (
n_params == 0 or self.are_weights_initialized
)
return is_initialized
def set_initialized(self, to: bool = True):
"""Recursively set the initialization state."""
self.are_weights_initialized = to
for _, w in self.named_parameters():
if isinstance(w, BaseModel):
w.set_initialized(to)

View File

@ -21,6 +21,7 @@ class DISK(BaseModel):
def _init(self, conf):
self.model = kornia.feature.DISK.from_pretrained(conf.weights)
self.set_initialized()
def _get_dense_outputs(self, images):
B = images.shape[0]

View File

@ -21,6 +21,7 @@ class KeyNetAffNetHardNet(BaseModel):
upright=conf.upright,
scale_laf=conf.scale_laf,
)
self.set_initialized()
def _forward(self, data):
image = data["image"]

View File

@ -19,6 +19,7 @@ class KorniaSIFT(BaseModel):
self.sift = kornia.feature.SIFTFeature(
num_features=self.conf.max_num_keypoints, rootsift=self.conf.rootsift
)
self.set_initialized()
def _forward(self, data):
lafs, scores, descriptors = self.sift(data["image"])

View File

@ -34,6 +34,7 @@ class DeepLSD(BaseModel):
ckpt = torch.load(ckpt, map_location="cpu")
self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval()
self.net.load_state_dict(ckpt["model"])
self.set_initialized()
def download_model(self, path):
import subprocess

View File

@ -13,6 +13,7 @@ class LoFTRModule(BaseModel):
def _init(self, conf):
self.net = kornia.feature.LoFTR(pretrained="outdoor")
self.set_initialized()
def _forward(self, data):
image0 = data["view0"]["image"]

View File

@ -18,7 +18,7 @@ class LightGlue(BaseModel):
def _init(self, conf):
dconf = OmegaConf.to_container(conf)
self.net = LightGlue_(dconf.pop("features"), **dconf).cuda()
# self.net.compile()
self.set_initialized()
def _forward(self, data):
view0 = {