glue-factory-custom/gluefactory/utils/tensor.py

43 lines
1.0 KiB
Python

"""
Author: Paul-Edouard Sarlin (skydes)
"""
import collections.abc as collections
import numpy as np
import torch
string_classes = (str, bytes)
def map_tensor(input_, func):
if isinstance(input_, string_classes):
return input_
elif isinstance(input_, collections.Mapping):
return {k: map_tensor(sample, func) for k, sample in input_.items()}
elif isinstance(input_, collections.Sequence):
return [map_tensor(sample, func) for sample in input_]
elif input_ is None:
return None
else:
return func(input_)
def batch_to_numpy(batch):
return map_tensor(batch, lambda tensor: tensor.cpu().numpy())
def batch_to_device(batch, device, non_blocking=True):
def _func(tensor):
return tensor.to(device=device, non_blocking=non_blocking)
return map_tensor(batch, _func)
def rbd(data: dict) -> dict:
"""Remove batch dimension from elements in data"""
return {
k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
for k, v in data.items()
}