From 22154a60bc0eb6a033b34956338310ee2a6fe486 Mon Sep 17 00:00:00 2001 From: Philipp Lindenberger Date: Tue, 10 Oct 2023 19:46:33 +0200 Subject: [PATCH] Check for model initialization in eval (#9) * Check for model initialization during eval * Change from warning to assert * change assertion to ValueError * fix isort and rename are_weights_initialized * cleanup set_initialized() * fix variable name bug * Make load_state_dict forward compatible Co-authored-by: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com> * Remove unused imports --------- Co-authored-by: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com> --- gluefactory/eval/io.py | 5 ++++ gluefactory/models/backbones/dinov2.py | 1 + gluefactory/models/base_model.py | 30 +++++++++++++++++++ gluefactory/models/extractors/disk_kornia.py | 1 + .../extractors/keynet_affnet_hardnet.py | 1 + gluefactory/models/extractors/sift_kornia.py | 1 + gluefactory/models/lines/deeplsd.py | 1 + gluefactory/models/matchers/kornia_loftr.py | 1 + .../models/matchers/lightglue_pretrained.py | 2 +- 9 files changed, 42 insertions(+), 1 deletion(-) diff --git a/gluefactory/eval/io.py b/gluefactory/eval/io.py index 067e845..6a55d59 100644 --- a/gluefactory/eval/io.py +++ b/gluefactory/eval/io.py @@ -89,6 +89,11 @@ def load_model(model_conf, checkpoint): model = load_experiment(checkpoint, conf=model_conf).eval() else: model = get_model("two_view_pipeline")(model_conf).eval() + if not model.is_initialized(): + raise ValueError( + "The provided model has non-initialized parameters. " + + "Try to load a checkpoint instead." + ) return model diff --git a/gluefactory/models/backbones/dinov2.py b/gluefactory/models/backbones/dinov2.py index 48a48b5..cf82852 100644 --- a/gluefactory/models/backbones/dinov2.py +++ b/gluefactory/models/backbones/dinov2.py @@ -10,6 +10,7 @@ class DinoV2(BaseModel): def _init(self, conf): self.net = torch.hub.load("facebookresearch/dinov2", conf.weights) + self.set_initialized() def _forward(self, data): img = data["image"] diff --git a/gluefactory/models/base_model.py b/gluefactory/models/base_model.py index 7313d98..b4f6628 100644 --- a/gluefactory/models/base_model.py +++ b/gluefactory/models/base_model.py @@ -60,6 +60,8 @@ class BaseModel(nn.Module, metaclass=MetaModel): required_data_keys = [] strict_conf = False + are_weights_initialized = False + def __init__(self, conf): """Perform some logic and call the _init method of the child model.""" super().__init__() @@ -125,3 +127,31 @@ class BaseModel(nn.Module, metaclass=MetaModel): def loss(self, pred, data): """To be implemented by the child class.""" raise NotImplementedError + + def load_state_dict(self, *args, **kwargs): + """Load the state dict of the model, and set the model to initialized.""" + ret = super().load_state_dict(*args, **kwargs) + self.set_initialized() + return ret + + def is_initialized(self): + """Recursively check if the model is initialized, i.e. weights are loaded""" + is_initialized = True # initialize to true and perform recursive and + for _, w in self.named_children(): + if isinstance(w, BaseModel): + # if children is BaseModel, we perform recursive check + is_initialized = is_initialized and w.is_initialized() + else: + # else, we check if self is initialized or the children has no params + n_params = len(list(w.parameters())) + is_initialized = is_initialized and ( + n_params == 0 or self.are_weights_initialized + ) + return is_initialized + + def set_initialized(self, to: bool = True): + """Recursively set the initialization state.""" + self.are_weights_initialized = to + for _, w in self.named_parameters(): + if isinstance(w, BaseModel): + w.set_initialized(to) diff --git a/gluefactory/models/extractors/disk_kornia.py b/gluefactory/models/extractors/disk_kornia.py index 4d60973..e01ab89 100644 --- a/gluefactory/models/extractors/disk_kornia.py +++ b/gluefactory/models/extractors/disk_kornia.py @@ -21,6 +21,7 @@ class DISK(BaseModel): def _init(self, conf): self.model = kornia.feature.DISK.from_pretrained(conf.weights) + self.set_initialized() def _get_dense_outputs(self, images): B = images.shape[0] diff --git a/gluefactory/models/extractors/keynet_affnet_hardnet.py b/gluefactory/models/extractors/keynet_affnet_hardnet.py index b9091ea..419ee97 100644 --- a/gluefactory/models/extractors/keynet_affnet_hardnet.py +++ b/gluefactory/models/extractors/keynet_affnet_hardnet.py @@ -21,6 +21,7 @@ class KeyNetAffNetHardNet(BaseModel): upright=conf.upright, scale_laf=conf.scale_laf, ) + self.set_initialized() def _forward(self, data): image = data["image"] diff --git a/gluefactory/models/extractors/sift_kornia.py b/gluefactory/models/extractors/sift_kornia.py index 78810e6..7a1e74d 100644 --- a/gluefactory/models/extractors/sift_kornia.py +++ b/gluefactory/models/extractors/sift_kornia.py @@ -19,6 +19,7 @@ class KorniaSIFT(BaseModel): self.sift = kornia.feature.SIFTFeature( num_features=self.conf.max_num_keypoints, rootsift=self.conf.rootsift ) + self.set_initialized() def _forward(self, data): lafs, scores, descriptors = self.sift(data["image"]) diff --git a/gluefactory/models/lines/deeplsd.py b/gluefactory/models/lines/deeplsd.py index c35aa01..122f4b4 100644 --- a/gluefactory/models/lines/deeplsd.py +++ b/gluefactory/models/lines/deeplsd.py @@ -34,6 +34,7 @@ class DeepLSD(BaseModel): ckpt = torch.load(ckpt, map_location="cpu") self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval() self.net.load_state_dict(ckpt["model"]) + self.set_initialized() def download_model(self, path): import subprocess diff --git a/gluefactory/models/matchers/kornia_loftr.py b/gluefactory/models/matchers/kornia_loftr.py index 45a20b7..6fbd47b 100644 --- a/gluefactory/models/matchers/kornia_loftr.py +++ b/gluefactory/models/matchers/kornia_loftr.py @@ -13,6 +13,7 @@ class LoFTRModule(BaseModel): def _init(self, conf): self.net = kornia.feature.LoFTR(pretrained="outdoor") + self.set_initialized() def _forward(self, data): image0 = data["view0"]["image"] diff --git a/gluefactory/models/matchers/lightglue_pretrained.py b/gluefactory/models/matchers/lightglue_pretrained.py index 2e7c71b..b23976d 100644 --- a/gluefactory/models/matchers/lightglue_pretrained.py +++ b/gluefactory/models/matchers/lightglue_pretrained.py @@ -18,7 +18,7 @@ class LightGlue(BaseModel): def _init(self, conf): dconf = OmegaConf.to_container(conf) self.net = LightGlue_(dconf.pop("features"), **dconf).cuda() - # self.net.compile() + self.set_initialized() def _forward(self, data): view0 = {