70 lines
1.8 KiB
Python
70 lines
1.8 KiB
Python
import math
|
|
from typing import List, Optional, Tuple
|
|
import torch
|
|
|
|
|
|
def to_sequence(map):
|
|
return map.flatten(-2).transpose(-1, -2)
|
|
|
|
|
|
def to_map(sequence):
|
|
n = sequence.shape[-2]
|
|
e = math.isqrt(n)
|
|
assert e * e == n
|
|
assert e * e == n
|
|
sequence.transpose(-1, -2).unflatten(-1, [e, e])
|
|
|
|
|
|
def pad_to_length(
|
|
x,
|
|
length: int,
|
|
pad_dim: int = -2,
|
|
mode: str = "zeros", # zeros, ones, random, random_c
|
|
bounds: Tuple[int] = (None, None),
|
|
):
|
|
shape = list(x.shape)
|
|
d = x.shape[pad_dim]
|
|
assert d <= length
|
|
if d == length:
|
|
return x
|
|
shape[pad_dim] = length - d
|
|
|
|
low, high = bounds
|
|
|
|
if mode == "zeros":
|
|
xn = torch.zeros(*shape, device=x.device, dtype=x.dtype)
|
|
elif mode == "ones":
|
|
xn = torch.ones(*shape, device=x.device, dtype=x.dtype)
|
|
elif mode == "random":
|
|
low = low if low is not None else x.min()
|
|
high = high if high is not None else x.max()
|
|
xn = torch.empty(*shape, device=x.device).uniform_(low, high)
|
|
elif mode == "random_c":
|
|
low, high = bounds # we use the bounds as fallback for empty seq.
|
|
xn = torch.cat(
|
|
[
|
|
torch.empty(*shape[:-1], 1, device=x.device).uniform_(
|
|
x[..., i].min() if d > 0 else low,
|
|
x[..., i].max() if d > 0 else high,
|
|
)
|
|
for i in range(shape[-1])
|
|
],
|
|
dim=-1,
|
|
)
|
|
else:
|
|
raise ValueError(mode)
|
|
return torch.cat([x, xn], dim=pad_dim)
|
|
|
|
|
|
def pad_and_stack(
|
|
sequences: List[torch.Tensor],
|
|
length: Optional[int] = None,
|
|
pad_dim: int = -2,
|
|
**kwargs,
|
|
):
|
|
if length is None:
|
|
length = max([x.shape[pad_dim] for x in sequences])
|
|
|
|
y = torch.stack([pad_to_length(x, length, pad_dim, **kwargs) for x in sequences], 0)
|
|
return y
|