89 lines
3.0 KiB
Python
89 lines
3.0 KiB
Python
|
import numpy as np
|
||
|
import torch
|
||
|
from joblib import Parallel, delayed
|
||
|
from pytlsd import lsd
|
||
|
|
||
|
from ..base_model import BaseModel
|
||
|
|
||
|
|
||
|
class LSD(BaseModel):
|
||
|
default_conf = {
|
||
|
"min_length": 15,
|
||
|
"max_num_lines": None,
|
||
|
"force_num_lines": False,
|
||
|
"n_jobs": 4,
|
||
|
}
|
||
|
required_data_keys = ["image"]
|
||
|
|
||
|
def _init(self, conf):
|
||
|
if self.conf.force_num_lines:
|
||
|
assert (
|
||
|
self.conf.max_num_lines is not None
|
||
|
), "Missing max_num_lines parameter"
|
||
|
|
||
|
def detect_lines(self, img):
|
||
|
# Run LSD
|
||
|
segs = lsd(img)
|
||
|
|
||
|
# Filter out keylines that do not meet the minimum length criteria
|
||
|
lengths = np.linalg.norm(segs[:, 2:4] - segs[:, 0:2], axis=1)
|
||
|
to_keep = lengths >= self.conf.min_length
|
||
|
segs, lengths = segs[to_keep], lengths[to_keep]
|
||
|
|
||
|
# Keep the best lines
|
||
|
scores = segs[:, -1] * np.sqrt(lengths)
|
||
|
segs = segs[:, :4].reshape(-1, 2, 2)
|
||
|
indices = np.argsort(-scores)
|
||
|
if self.conf.max_num_lines is not None:
|
||
|
indices = indices[: self.conf.max_num_lines]
|
||
|
segs = segs[indices]
|
||
|
scores = scores[indices]
|
||
|
|
||
|
# Pad if necessary
|
||
|
n = len(segs)
|
||
|
valid_mask = np.ones(n, dtype=bool)
|
||
|
if self.conf.force_num_lines:
|
||
|
pad = self.conf.max_num_lines - n
|
||
|
segs = np.concatenate(
|
||
|
[segs, np.zeros((pad, 2, 2), dtype=np.float32)], axis=0
|
||
|
)
|
||
|
scores = np.concatenate([scores, np.zeros(pad, dtype=np.float32)], axis=0)
|
||
|
valid_mask = np.concatenate([valid_mask, np.zeros(pad, dtype=bool)], axis=0)
|
||
|
|
||
|
return segs, scores, valid_mask
|
||
|
|
||
|
def _forward(self, data):
|
||
|
# Convert to the right data format
|
||
|
image = data["image"]
|
||
|
if image.shape[1] == 3:
|
||
|
# Convert to grayscale
|
||
|
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
|
||
|
image = (image * scale).sum(1, keepdim=True)
|
||
|
device = image.device
|
||
|
b_size = len(image)
|
||
|
image = np.uint8(image.squeeze(1).cpu().numpy() * 255)
|
||
|
|
||
|
# LSD detection in parallel
|
||
|
if b_size == 1:
|
||
|
lines, line_scores, valid_lines = self.detect_lines(image[0])
|
||
|
lines = [lines]
|
||
|
line_scores = [line_scores]
|
||
|
valid_lines = [valid_lines]
|
||
|
else:
|
||
|
lines, line_scores, valid_lines = zip(
|
||
|
*Parallel(n_jobs=self.conf.n_jobs)(
|
||
|
delayed(self.detect_lines)(img) for img in image
|
||
|
)
|
||
|
)
|
||
|
|
||
|
# Batch if possible
|
||
|
if b_size == 1 or self.conf.force_num_lines:
|
||
|
lines = torch.tensor(lines, dtype=torch.float, device=device)
|
||
|
line_scores = torch.tensor(line_scores, dtype=torch.float, device=device)
|
||
|
valid_lines = torch.tensor(valid_lines, dtype=torch.bool, device=device)
|
||
|
|
||
|
return {"lines": lines, "line_scores": line_scores, "valid_lines": valid_lines}
|
||
|
|
||
|
def loss(self, pred, data):
|
||
|
raise NotImplementedError
|