Check for model initialization in eval (#9)
* Check for model initialization during eval * Change from warning to assert * change assertion to ValueError * fix isort and rename are_weights_initialized * cleanup set_initialized() * fix variable name bug * Make load_state_dict forward compatible Co-authored-by: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com> * Remove unused imports --------- Co-authored-by: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com>main
parent
f7b587e881
commit
22154a60bc
|
@ -89,6 +89,11 @@ def load_model(model_conf, checkpoint):
|
||||||
model = load_experiment(checkpoint, conf=model_conf).eval()
|
model = load_experiment(checkpoint, conf=model_conf).eval()
|
||||||
else:
|
else:
|
||||||
model = get_model("two_view_pipeline")(model_conf).eval()
|
model = get_model("two_view_pipeline")(model_conf).eval()
|
||||||
|
if not model.is_initialized():
|
||||||
|
raise ValueError(
|
||||||
|
"The provided model has non-initialized parameters. "
|
||||||
|
+ "Try to load a checkpoint instead."
|
||||||
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ class DinoV2(BaseModel):
|
||||||
|
|
||||||
def _init(self, conf):
|
def _init(self, conf):
|
||||||
self.net = torch.hub.load("facebookresearch/dinov2", conf.weights)
|
self.net = torch.hub.load("facebookresearch/dinov2", conf.weights)
|
||||||
|
self.set_initialized()
|
||||||
|
|
||||||
def _forward(self, data):
|
def _forward(self, data):
|
||||||
img = data["image"]
|
img = data["image"]
|
||||||
|
|
|
@ -60,6 +60,8 @@ class BaseModel(nn.Module, metaclass=MetaModel):
|
||||||
required_data_keys = []
|
required_data_keys = []
|
||||||
strict_conf = False
|
strict_conf = False
|
||||||
|
|
||||||
|
are_weights_initialized = False
|
||||||
|
|
||||||
def __init__(self, conf):
|
def __init__(self, conf):
|
||||||
"""Perform some logic and call the _init method of the child model."""
|
"""Perform some logic and call the _init method of the child model."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -125,3 +127,31 @@ class BaseModel(nn.Module, metaclass=MetaModel):
|
||||||
def loss(self, pred, data):
|
def loss(self, pred, data):
|
||||||
"""To be implemented by the child class."""
|
"""To be implemented by the child class."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def load_state_dict(self, *args, **kwargs):
|
||||||
|
"""Load the state dict of the model, and set the model to initialized."""
|
||||||
|
ret = super().load_state_dict(*args, **kwargs)
|
||||||
|
self.set_initialized()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def is_initialized(self):
|
||||||
|
"""Recursively check if the model is initialized, i.e. weights are loaded"""
|
||||||
|
is_initialized = True # initialize to true and perform recursive and
|
||||||
|
for _, w in self.named_children():
|
||||||
|
if isinstance(w, BaseModel):
|
||||||
|
# if children is BaseModel, we perform recursive check
|
||||||
|
is_initialized = is_initialized and w.is_initialized()
|
||||||
|
else:
|
||||||
|
# else, we check if self is initialized or the children has no params
|
||||||
|
n_params = len(list(w.parameters()))
|
||||||
|
is_initialized = is_initialized and (
|
||||||
|
n_params == 0 or self.are_weights_initialized
|
||||||
|
)
|
||||||
|
return is_initialized
|
||||||
|
|
||||||
|
def set_initialized(self, to: bool = True):
|
||||||
|
"""Recursively set the initialization state."""
|
||||||
|
self.are_weights_initialized = to
|
||||||
|
for _, w in self.named_parameters():
|
||||||
|
if isinstance(w, BaseModel):
|
||||||
|
w.set_initialized(to)
|
||||||
|
|
|
@ -21,6 +21,7 @@ class DISK(BaseModel):
|
||||||
|
|
||||||
def _init(self, conf):
|
def _init(self, conf):
|
||||||
self.model = kornia.feature.DISK.from_pretrained(conf.weights)
|
self.model = kornia.feature.DISK.from_pretrained(conf.weights)
|
||||||
|
self.set_initialized()
|
||||||
|
|
||||||
def _get_dense_outputs(self, images):
|
def _get_dense_outputs(self, images):
|
||||||
B = images.shape[0]
|
B = images.shape[0]
|
||||||
|
|
|
@ -21,6 +21,7 @@ class KeyNetAffNetHardNet(BaseModel):
|
||||||
upright=conf.upright,
|
upright=conf.upright,
|
||||||
scale_laf=conf.scale_laf,
|
scale_laf=conf.scale_laf,
|
||||||
)
|
)
|
||||||
|
self.set_initialized()
|
||||||
|
|
||||||
def _forward(self, data):
|
def _forward(self, data):
|
||||||
image = data["image"]
|
image = data["image"]
|
||||||
|
|
|
@ -19,6 +19,7 @@ class KorniaSIFT(BaseModel):
|
||||||
self.sift = kornia.feature.SIFTFeature(
|
self.sift = kornia.feature.SIFTFeature(
|
||||||
num_features=self.conf.max_num_keypoints, rootsift=self.conf.rootsift
|
num_features=self.conf.max_num_keypoints, rootsift=self.conf.rootsift
|
||||||
)
|
)
|
||||||
|
self.set_initialized()
|
||||||
|
|
||||||
def _forward(self, data):
|
def _forward(self, data):
|
||||||
lafs, scores, descriptors = self.sift(data["image"])
|
lafs, scores, descriptors = self.sift(data["image"])
|
||||||
|
|
|
@ -34,6 +34,7 @@ class DeepLSD(BaseModel):
|
||||||
ckpt = torch.load(ckpt, map_location="cpu")
|
ckpt = torch.load(ckpt, map_location="cpu")
|
||||||
self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval()
|
self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval()
|
||||||
self.net.load_state_dict(ckpt["model"])
|
self.net.load_state_dict(ckpt["model"])
|
||||||
|
self.set_initialized()
|
||||||
|
|
||||||
def download_model(self, path):
|
def download_model(self, path):
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
|
@ -13,6 +13,7 @@ class LoFTRModule(BaseModel):
|
||||||
|
|
||||||
def _init(self, conf):
|
def _init(self, conf):
|
||||||
self.net = kornia.feature.LoFTR(pretrained="outdoor")
|
self.net = kornia.feature.LoFTR(pretrained="outdoor")
|
||||||
|
self.set_initialized()
|
||||||
|
|
||||||
def _forward(self, data):
|
def _forward(self, data):
|
||||||
image0 = data["view0"]["image"]
|
image0 = data["view0"]["image"]
|
||||||
|
|
|
@ -18,7 +18,7 @@ class LightGlue(BaseModel):
|
||||||
def _init(self, conf):
|
def _init(self, conf):
|
||||||
dconf = OmegaConf.to_container(conf)
|
dconf = OmegaConf.to_container(conf)
|
||||||
self.net = LightGlue_(dconf.pop("features"), **dconf).cuda()
|
self.net = LightGlue_(dconf.pop("features"), **dconf).cuda()
|
||||||
# self.net.compile()
|
self.set_initialized()
|
||||||
|
|
||||||
def _forward(self, data):
|
def _forward(self, data):
|
||||||
view0 = {
|
view0 = {
|
||||||
|
|
Loading…
Reference in New Issue