glue-factory-custom/gluefactory/models/lines/deeplsd.py

106 lines
3.6 KiB
Python
Raw Normal View History

import deeplsd.models.deeplsd_inference as deeplsd_inference
import numpy as np
import torch
from ...settings import DATA_PATH
from ..base_model import BaseModel
class DeepLSD(BaseModel):
default_conf = {
"min_length": 15,
"max_num_lines": None,
"force_num_lines": False,
"model_conf": {
"detect_lines": True,
"line_detection_params": {
"merge": False,
"grad_nfa": True,
"filtering": "normal",
"grad_thresh": 3,
},
},
}
required_data_keys = ["image"]
def _init(self, conf):
if self.conf.force_num_lines:
assert (
self.conf.max_num_lines is not None
), "Missing max_num_lines parameter"
ckpt = DATA_PATH / "weights/deeplsd_md.tar"
if not ckpt.is_file():
self.download_model(ckpt)
ckpt = torch.load(ckpt, map_location="cpu")
self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval()
self.net.load_state_dict(ckpt["model"])
def download_model(self, path):
import subprocess
if not path.parent.is_dir():
path.parent.mkdir(parents=True, exist_ok=True)
link = "https://www.polybox.ethz.ch/index.php/s/XVb30sUyuJttFys/download"
cmd = ["wget", link, "-O", path]
print("Downloading DeepLSD model...")
subprocess.run(cmd, check=True)
def _forward(self, data):
image = data["image"]
lines, line_scores, valid_lines = [], [], []
if image.shape[1] == 3:
# Convert to grayscale
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
image = (image * scale).sum(1, keepdim=True)
# Forward pass
with torch.no_grad():
segs = self.net({"image": image})["lines"]
# Line scores are the sqrt of the length
for seg in segs:
lengths = np.linalg.norm(seg[:, 0] - seg[:, 1], axis=1)
segs = seg[lengths >= self.conf.min_length]
scores = np.sqrt(lengths[lengths >= self.conf.min_length])
# Keep the best lines
indices = np.argsort(-scores)
if self.conf.max_num_lines is not None:
indices = indices[: self.conf.max_num_lines]
segs = segs[indices]
scores = scores[indices]
# Pad if necessary
n = len(segs)
valid_mask = np.ones(n, dtype=bool)
if self.conf.force_num_lines:
pad = self.conf.max_num_lines - n
segs = np.concatenate(
[segs, np.zeros((pad, 2, 2), dtype=np.float32)], axis=0
)
scores = np.concatenate(
[scores, np.zeros(pad, dtype=np.float32)], axis=0
)
valid_mask = np.concatenate(
[valid_mask, np.zeros(pad, dtype=bool)], axis=0
)
lines.append(segs)
line_scores.append(scores)
valid_lines.append(valid_mask)
# Batch if possible
if len(image) == 1 or self.conf.force_num_lines:
lines = torch.tensor(lines, dtype=torch.float, device=image.device)
line_scores = torch.tensor(
line_scores, dtype=torch.float, device=image.device
)
valid_lines = torch.tensor(
valid_lines, dtype=torch.bool, device=image.device
)
return {"lines": lines, "line_scores": line_scores, "valid_lines": valid_lines}
def loss(self, pred, data):
raise NotImplementedError