51 lines
1.7 KiB
Python
51 lines
1.7 KiB
Python
|
import torch
|
||
|
|
||
|
|
||
|
def extract_patches(
|
||
|
tensor: torch.Tensor,
|
||
|
required_corners: torch.Tensor,
|
||
|
ps: int,
|
||
|
) -> torch.Tensor:
|
||
|
c, h, w = tensor.shape
|
||
|
corner = required_corners.long()
|
||
|
corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps)
|
||
|
corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps)
|
||
|
offset = torch.arange(0, ps)
|
||
|
|
||
|
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
|
||
|
x, y = torch.meshgrid(offset, offset, **kw)
|
||
|
patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2)
|
||
|
patches = patches.to(corner) + corner[None, None]
|
||
|
pts = patches.reshape(-1, 2)
|
||
|
sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]]
|
||
|
sampled = sampled.reshape(ps, ps, -1, c)
|
||
|
assert sampled.shape[:3] == patches.shape[:3]
|
||
|
return sampled.permute(2, 3, 0, 1), corner.float()
|
||
|
|
||
|
|
||
|
def batch_extract_patches(tensor: torch.Tensor, kpts: torch.Tensor, ps: int):
|
||
|
b, c, h, w = tensor.shape
|
||
|
b, n, _ = kpts.shape
|
||
|
out = torch.zeros((b, n, c, ps, ps), dtype=tensor.dtype, device=tensor.device)
|
||
|
corners = torch.zeros((b, n, 2), dtype=tensor.dtype, device=tensor.device)
|
||
|
for i in range(b):
|
||
|
out[i], corners[i] = extract_patches(tensor[i], kpts[i] - ps / 2 - 1, ps)
|
||
|
return out, corners
|
||
|
|
||
|
|
||
|
def draw_image_patches(img, patches, corners):
|
||
|
b, c, h, w = img.shape
|
||
|
b, n, c, p, p = patches.shape
|
||
|
b, n, _ = corners.shape
|
||
|
for i in range(b):
|
||
|
for k in range(n):
|
||
|
y, x = corners[i, k]
|
||
|
img[i, :, x : x + p, y : y + p] = patches[i, k]
|
||
|
|
||
|
|
||
|
def build_heatmap(img, patches, corners):
|
||
|
hmap = torch.zeros_like(img)
|
||
|
draw_image_patches(hmap, patches, corners.long())
|
||
|
hmap = hmap.squeeze(1)
|
||
|
return hmap, (hmap > 0.0).float() # bxhxw
|