49 lines
1.2 KiB
Python
49 lines
1.2 KiB
Python
"""
|
|
Author: Paul-Edouard Sarlin (skydes)
|
|
"""
|
|
|
|
import collections.abc as collections
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
string_classes = (str, bytes)
|
|
|
|
|
|
def map_tensor(input_, func):
|
|
if isinstance(input_, string_classes):
|
|
return input_
|
|
elif isinstance(input_, collections.Mapping):
|
|
return {k: map_tensor(sample, func) for k, sample in input_.items()}
|
|
elif isinstance(input_, collections.Sequence):
|
|
return [map_tensor(sample, func) for sample in input_]
|
|
elif input_ is None:
|
|
return None
|
|
else:
|
|
return func(input_)
|
|
|
|
|
|
def batch_to_numpy(batch):
|
|
return map_tensor(batch, lambda tensor: tensor.cpu().numpy())
|
|
|
|
|
|
def batch_to_device(batch, device, non_blocking=True):
|
|
def _func(tensor):
|
|
return tensor.to(device=device, non_blocking=non_blocking)
|
|
|
|
return map_tensor(batch, _func)
|
|
|
|
|
|
def rbd(data: dict) -> dict:
|
|
"""Remove batch dimension from elements in data"""
|
|
return {
|
|
k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
|
|
for k, v in data.items()
|
|
}
|
|
|
|
|
|
def index_batch(tensor_dict):
|
|
batch_size = len(next(iter(tensor_dict.values())))
|
|
for i in range(batch_size):
|
|
yield map_tensor(tensor_dict, lambda t: t[i])
|