import string import h5py import torch from ..datasets.base_dataset import collate from ..settings import DATA_PATH from ..utils.tensor import batch_to_device from .base_model import BaseModel from .utils.misc import pad_to_length def pad_local_features(pred: dict, seq_l: int): pred["keypoints"] = pad_to_length( pred["keypoints"], seq_l, -2, mode="random_c", ) if "keypoint_scores" in pred.keys(): pred["keypoint_scores"] = pad_to_length( pred["keypoint_scores"], seq_l, -1, mode="zeros" ) if "descriptors" in pred.keys(): pred["descriptors"] = pad_to_length( pred["descriptors"], seq_l, -2, mode="random" ) if "scales" in pred.keys(): pred["scales"] = pad_to_length(pred["scales"], seq_l, -1, mode="zeros") if "oris" in pred.keys(): pred["oris"] = pad_to_length(pred["oris"], seq_l, -1, mode="zeros") return pred def pad_line_features(pred, seq_l: int = None): raise NotImplementedError def recursive_load(grp, pkeys): return { k: torch.from_numpy(grp[k].__array__()) if isinstance(grp[k], h5py.Dataset) else recursive_load(grp[k], list(grp.keys())) for k in pkeys } class CacheLoader(BaseModel): default_conf = { "path": "???", # can be a format string like exports/{scene}/ "data_keys": None, # load all keys "device": None, # load to same device as data "trainable": False, "add_data_path": True, "collate": True, "scale": ["keypoints", "lines", "orig_lines"], "padding_fn": None, "padding_length": None, # required for batching! "numeric_type": "float32", # [None, "float16", "float32", "float64"] } required_data_keys = ["name"] # we need an identifier def _init(self, conf): self.hfiles = {} self.padding_fn = conf.padding_fn if self.padding_fn is not None: self.padding_fn = eval(self.padding_fn) self.numeric_dtype = { None: None, "float16": torch.float16, "float32": torch.float32, "float64": torch.float64, }[conf.numeric_type] def _forward(self, data): preds = [] device = self.conf.device if not device: devices = set( [v.device for v in data.values() if isinstance(v, torch.Tensor)] ) if len(devices) == 0: device = "cpu" else: assert len(devices) == 1 device = devices.pop() var_names = [x[1] for x in string.Formatter().parse(self.conf.path) if x[1]] for i, name in enumerate(data["name"]): fpath = self.conf.path.format(**{k: data[k][i] for k in var_names}) if self.conf.add_data_path: fpath = DATA_PATH / fpath hfile = h5py.File(str(fpath), "r") grp = hfile[name] pkeys = ( self.conf.data_keys if self.conf.data_keys is not None else grp.keys() ) pred = recursive_load(grp, pkeys) if self.numeric_dtype is not None: pred = { k: v if not isinstance(v, torch.Tensor) or not torch.is_floating_point(v) else v.to(dtype=self.numeric_dtype) for k, v in pred.items() } pred = batch_to_device(pred, device) for k, v in pred.items(): for pattern in self.conf.scale: if k.startswith(pattern): view_idx = k.replace(pattern, "") scales = ( data["scales"] if len(view_idx) == 0 else data[f"view{view_idx}"]["scales"] ) pred[k] = pred[k] * scales[i] # use this function to fix number of keypoints etc. if self.padding_fn is not None: pred = self.padding_fn(pred, self.conf.padding_length) preds.append(pred) hfile.close() if self.conf.collate: return batch_to_device(collate(preds), device) else: assert len(preds) == 1 return batch_to_device(preds[0], device) def loss(self, pred, data): raise NotImplementedError