2023-10-05 16:53:51 +02:00
|
|
|
from lightglue import LightGlue as LightGlue_
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
|
2023-10-09 08:32:43 +02:00
|
|
|
from ..base_model import BaseModel
|
|
|
|
|
2023-10-05 16:53:51 +02:00
|
|
|
|
|
|
|
class LightGlue(BaseModel):
|
|
|
|
default_conf = {"features": "superpoint", **LightGlue_.default_conf}
|
|
|
|
required_data_keys = [
|
|
|
|
"view0",
|
|
|
|
"keypoints0",
|
|
|
|
"descriptors0",
|
|
|
|
"view1",
|
|
|
|
"keypoints1",
|
|
|
|
"descriptors1",
|
|
|
|
]
|
|
|
|
|
|
|
|
def _init(self, conf):
|
|
|
|
dconf = OmegaConf.to_container(conf)
|
2023-10-18 10:09:58 +02:00
|
|
|
self.net = LightGlue_(dconf.pop("features"), **dconf)
|
2023-10-10 19:46:33 +02:00
|
|
|
self.set_initialized()
|
2023-10-05 16:53:51 +02:00
|
|
|
|
|
|
|
def _forward(self, data):
|
|
|
|
view0 = {
|
|
|
|
**{k: data[k + "0"] for k in ["keypoints", "descriptors"]},
|
|
|
|
**data["view0"],
|
|
|
|
}
|
|
|
|
view1 = {
|
|
|
|
**{k: data[k + "1"] for k in ["keypoints", "descriptors"]},
|
|
|
|
**data["view1"],
|
|
|
|
}
|
|
|
|
return self.net({"image0": view0, "image1": view1})
|
|
|
|
|
|
|
|
def loss(pred, data):
|
|
|
|
raise NotImplementedError
|