81 lines
2.6 KiB
Python
81 lines
2.6 KiB
Python
|
"""
|
||
|
Export the predictions of a model for a given dataloader (e.g. ImageFolder).
|
||
|
Use a standalone script with `python3 -m dsfm.scipts.export_predictions dir`
|
||
|
or call from another script.
|
||
|
"""
|
||
|
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
from pathlib import Path
|
||
|
from tqdm import tqdm
|
||
|
import h5py
|
||
|
|
||
|
from .tensor import batch_to_device
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def export_predictions(
|
||
|
loader,
|
||
|
model,
|
||
|
output_file,
|
||
|
as_half=False,
|
||
|
keys="*",
|
||
|
callback_fn=None,
|
||
|
optional_keys=[],
|
||
|
):
|
||
|
assert keys == "*" or isinstance(keys, (tuple, list))
|
||
|
Path(output_file).parent.mkdir(exist_ok=True, parents=True)
|
||
|
hfile = h5py.File(str(output_file), "w")
|
||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
model = model.to(device).eval()
|
||
|
for data_ in tqdm(loader):
|
||
|
data = batch_to_device(data_, device, non_blocking=True)
|
||
|
pred = model(data)
|
||
|
if callback_fn is not None:
|
||
|
pred = {**callback_fn(pred, data), **pred}
|
||
|
if keys != "*":
|
||
|
if len(set(keys) - set(pred.keys())) > 0:
|
||
|
raise ValueError(f"Missing key {set(keys) - set(pred.keys())}")
|
||
|
pred = {k: v for k, v in pred.items() if k in keys + optional_keys}
|
||
|
assert len(pred) > 0
|
||
|
|
||
|
# renormalization
|
||
|
for k in pred.keys():
|
||
|
if k.startswith("keypoints"):
|
||
|
idx = k.replace("keypoints", "")
|
||
|
scales = 1.0 / (
|
||
|
data["scales"] if len(idx) == 0 else data[f"view{idx}"]["scales"]
|
||
|
)
|
||
|
pred[k] = pred[k] * scales[None]
|
||
|
if k.startswith("lines"):
|
||
|
idx = k.replace("lines", "")
|
||
|
scales = 1.0 / (
|
||
|
data["scales"] if len(idx) == 0 else data[f"view{idx}"]["scales"]
|
||
|
)
|
||
|
pred[k] = pred[k] * scales[None]
|
||
|
if k.startswith("orig_lines"):
|
||
|
idx = k.replace("orig_lines", "")
|
||
|
scales = 1.0 / (
|
||
|
data["scales"] if len(idx) == 0 else data[f"view{idx}"]["scales"]
|
||
|
)
|
||
|
pred[k] = pred[k] * scales[None]
|
||
|
|
||
|
pred = {k: v[0].cpu().numpy() for k, v in pred.items()}
|
||
|
|
||
|
if as_half:
|
||
|
for k in pred:
|
||
|
dt = pred[k].dtype
|
||
|
if (dt == np.float32) and (dt != np.float16):
|
||
|
pred[k] = pred[k].astype(np.float16)
|
||
|
try:
|
||
|
name = data["name"][0]
|
||
|
grp = hfile.create_group(name)
|
||
|
for k, v in pred.items():
|
||
|
grp.create_dataset(k, data=v)
|
||
|
except RuntimeError:
|
||
|
continue
|
||
|
|
||
|
del pred
|
||
|
hfile.close()
|
||
|
return output_file
|