106 lines
3.6 KiB
Python
106 lines
3.6 KiB
Python
|
import numpy as np
|
||
|
import torch
|
||
|
import deeplsd.models.deeplsd_inference as deeplsd_inference
|
||
|
|
||
|
from ..base_model import BaseModel
|
||
|
from ...settings import DATA_PATH
|
||
|
|
||
|
|
||
|
class DeepLSD(BaseModel):
|
||
|
default_conf = {
|
||
|
"min_length": 15,
|
||
|
"max_num_lines": None,
|
||
|
"force_num_lines": False,
|
||
|
"model_conf": {
|
||
|
"detect_lines": True,
|
||
|
"line_detection_params": {
|
||
|
"merge": False,
|
||
|
"grad_nfa": True,
|
||
|
"filtering": "normal",
|
||
|
"grad_thresh": 3,
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
required_data_keys = ["image"]
|
||
|
|
||
|
def _init(self, conf):
|
||
|
if self.conf.force_num_lines:
|
||
|
assert (
|
||
|
self.conf.max_num_lines is not None
|
||
|
), "Missing max_num_lines parameter"
|
||
|
ckpt = DATA_PATH / "weights/deeplsd_md.tar"
|
||
|
if not ckpt.is_file():
|
||
|
self.download_model(ckpt)
|
||
|
ckpt = torch.load(ckpt, map_location="cpu")
|
||
|
self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval()
|
||
|
self.net.load_state_dict(ckpt["model"])
|
||
|
|
||
|
def download_model(self, path):
|
||
|
import subprocess
|
||
|
|
||
|
if not path.parent.is_dir():
|
||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
link = "https://www.polybox.ethz.ch/index.php/s/XVb30sUyuJttFys/download"
|
||
|
cmd = ["wget", link, "-O", path]
|
||
|
print("Downloading DeepLSD model...")
|
||
|
subprocess.run(cmd, check=True)
|
||
|
|
||
|
def _forward(self, data):
|
||
|
image = data["image"]
|
||
|
lines, line_scores, valid_lines = [], [], []
|
||
|
if image.shape[1] == 3:
|
||
|
# Convert to grayscale
|
||
|
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
|
||
|
image = (image * scale).sum(1, keepdim=True)
|
||
|
|
||
|
# Forward pass
|
||
|
with torch.no_grad():
|
||
|
segs = self.net({"image": image})["lines"]
|
||
|
|
||
|
# Line scores are the sqrt of the length
|
||
|
for seg in segs:
|
||
|
lengths = np.linalg.norm(seg[:, 0] - seg[:, 1], axis=1)
|
||
|
segs = seg[lengths >= self.conf.min_length]
|
||
|
scores = np.sqrt(lengths[lengths >= self.conf.min_length])
|
||
|
|
||
|
# Keep the best lines
|
||
|
indices = np.argsort(-scores)
|
||
|
if self.conf.max_num_lines is not None:
|
||
|
indices = indices[: self.conf.max_num_lines]
|
||
|
segs = segs[indices]
|
||
|
scores = scores[indices]
|
||
|
|
||
|
# Pad if necessary
|
||
|
n = len(segs)
|
||
|
valid_mask = np.ones(n, dtype=bool)
|
||
|
if self.conf.force_num_lines:
|
||
|
pad = self.conf.max_num_lines - n
|
||
|
segs = np.concatenate(
|
||
|
[segs, np.zeros((pad, 2, 2), dtype=np.float32)], axis=0
|
||
|
)
|
||
|
scores = np.concatenate(
|
||
|
[scores, np.zeros(pad, dtype=np.float32)], axis=0
|
||
|
)
|
||
|
valid_mask = np.concatenate(
|
||
|
[valid_mask, np.zeros(pad, dtype=bool)], axis=0
|
||
|
)
|
||
|
|
||
|
lines.append(segs)
|
||
|
line_scores.append(scores)
|
||
|
valid_lines.append(valid_mask)
|
||
|
|
||
|
# Batch if possible
|
||
|
if len(image) == 1 or self.conf.force_num_lines:
|
||
|
lines = torch.tensor(lines, dtype=torch.float, device=image.device)
|
||
|
line_scores = torch.tensor(
|
||
|
line_scores, dtype=torch.float, device=image.device
|
||
|
)
|
||
|
valid_lines = torch.tensor(
|
||
|
valid_lines, dtype=torch.bool, device=image.device
|
||
|
)
|
||
|
|
||
|
return {"lines": lines, "line_scores": line_scores, "valid_lines": valid_lines}
|
||
|
|
||
|
def loss(self, pred, data):
|
||
|
raise NotImplementedError
|