130 lines
4.4 KiB
Python
130 lines
4.4 KiB
Python
|
import torch
|
||
|
import string
|
||
|
import h5py
|
||
|
|
||
|
from .base_model import BaseModel
|
||
|
from ..settings import DATA_PATH
|
||
|
from ..datasets.base_dataset import collate
|
||
|
from ..utils.tensor import batch_to_device
|
||
|
from .utils.misc import pad_to_length
|
||
|
|
||
|
|
||
|
def pad_local_features(pred: dict, seq_l: int):
|
||
|
pred["keypoints"] = pad_to_length(
|
||
|
pred["keypoints"],
|
||
|
seq_l,
|
||
|
-2,
|
||
|
mode="random_c",
|
||
|
)
|
||
|
if "keypoint_scores" in pred.keys():
|
||
|
pred["keypoint_scores"] = pad_to_length(
|
||
|
pred["keypoint_scores"], seq_l, -1, mode="zeros"
|
||
|
)
|
||
|
if "descriptors" in pred.keys():
|
||
|
pred["descriptors"] = pad_to_length(
|
||
|
pred["descriptors"], seq_l, -2, mode="random"
|
||
|
)
|
||
|
if "scales" in pred.keys():
|
||
|
pred["scales"] = pad_to_length(pred["scales"], seq_l, -1, mode="zeros")
|
||
|
if "oris" in pred.keys():
|
||
|
pred["oris"] = pad_to_length(pred["oris"], seq_l, -1, mode="zeros")
|
||
|
return pred
|
||
|
|
||
|
|
||
|
def pad_line_features(pred, seq_l: int = None):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
|
||
|
def recursive_load(grp, pkeys):
|
||
|
return {
|
||
|
k: torch.from_numpy(grp[k].__array__())
|
||
|
if isinstance(grp[k], h5py.Dataset)
|
||
|
else recursive_load(grp[k], list(grp.keys()))
|
||
|
for k in pkeys
|
||
|
}
|
||
|
|
||
|
|
||
|
class CacheLoader(BaseModel):
|
||
|
default_conf = {
|
||
|
"path": "???", # can be a format string like exports/{scene}/
|
||
|
"data_keys": None, # load all keys
|
||
|
"device": None, # load to same device as data
|
||
|
"trainable": False,
|
||
|
"add_data_path": True,
|
||
|
"collate": True,
|
||
|
"scale": ["keypoints", "lines", "orig_lines"],
|
||
|
"padding_fn": None,
|
||
|
"padding_length": None, # required for batching!
|
||
|
"numeric_type": "float32", # [None, "float16", "float32", "float64"]
|
||
|
}
|
||
|
|
||
|
required_data_keys = ["name"] # we need an identifier
|
||
|
|
||
|
def _init(self, conf):
|
||
|
self.hfiles = {}
|
||
|
self.padding_fn = conf.padding_fn
|
||
|
if self.padding_fn is not None:
|
||
|
self.padding_fn = eval(self.padding_fn)
|
||
|
self.numeric_dtype = {
|
||
|
None: None,
|
||
|
"float16": torch.float16,
|
||
|
"float32": torch.float32,
|
||
|
"float64": torch.float64,
|
||
|
}[conf.numeric_type]
|
||
|
|
||
|
def _forward(self, data):
|
||
|
preds = []
|
||
|
device = self.conf.device
|
||
|
if not device:
|
||
|
devices = set(
|
||
|
[v.device for v in data.values() if isinstance(v, torch.Tensor)]
|
||
|
)
|
||
|
if len(devices) == 0:
|
||
|
device = "cpu"
|
||
|
else:
|
||
|
assert len(devices) == 1
|
||
|
device = devices.pop()
|
||
|
|
||
|
var_names = [x[1] for x in string.Formatter().parse(self.conf.path) if x[1]]
|
||
|
for i, name in enumerate(data["name"]):
|
||
|
fpath = self.conf.path.format(**{k: data[k][i] for k in var_names})
|
||
|
if self.conf.add_data_path:
|
||
|
fpath = DATA_PATH / fpath
|
||
|
hfile = h5py.File(str(fpath), "r")
|
||
|
grp = hfile[name]
|
||
|
pkeys = (
|
||
|
self.conf.data_keys if self.conf.data_keys is not None else grp.keys()
|
||
|
)
|
||
|
pred = recursive_load(grp, pkeys)
|
||
|
if self.numeric_dtype is not None:
|
||
|
pred = {
|
||
|
k: v
|
||
|
if not isinstance(v, torch.Tensor) or not torch.is_floating_point(v)
|
||
|
else v.to(dtype=self.numeric_dtype)
|
||
|
for k, v in pred.items()
|
||
|
}
|
||
|
pred = batch_to_device(pred, device)
|
||
|
for k, v in pred.items():
|
||
|
for pattern in self.conf.scale:
|
||
|
if k.startswith(pattern):
|
||
|
view_idx = k.replace(pattern, "")
|
||
|
scales = (
|
||
|
data["scales"]
|
||
|
if len(view_idx) == 0
|
||
|
else data[f"view{view_idx}"]["scales"]
|
||
|
)
|
||
|
pred[k] = pred[k] * scales[i]
|
||
|
# use this function to fix number of keypoints etc.
|
||
|
if self.padding_fn is not None:
|
||
|
pred = self.padding_fn(pred, self.conf.padding_length)
|
||
|
preds.append(pred)
|
||
|
hfile.close()
|
||
|
if self.conf.collate:
|
||
|
return batch_to_device(collate(preds), device)
|
||
|
else:
|
||
|
assert len(preds) == 1
|
||
|
return batch_to_device(preds[0], device)
|
||
|
|
||
|
def loss(self, pred, data):
|
||
|
raise NotImplementedError
|