2023-10-05 16:53:51 +02:00
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
from ..base_model import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
class DinoV2(BaseModel):
|
|
|
|
default_conf = {"weights": "dinov2_vits14", "allow_resize": False}
|
|
|
|
required_data_keys = ["image"]
|
|
|
|
|
|
|
|
def _init(self, conf):
|
|
|
|
self.net = torch.hub.load("facebookresearch/dinov2", conf.weights)
|
2023-10-10 19:46:33 +02:00
|
|
|
self.set_initialized()
|
2023-10-05 16:53:51 +02:00
|
|
|
|
|
|
|
def _forward(self, data):
|
|
|
|
img = data["image"]
|
|
|
|
if self.conf.allow_resize:
|
|
|
|
img = F.upsample(img, [int(x // 14 * 14) for x in img.shape[-2:]])
|
|
|
|
desc, cls_token = self.net.get_intermediate_layers(
|
|
|
|
img, n=1, return_class_token=True, reshape=True
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
return {
|
|
|
|
"features": desc,
|
|
|
|
"global_descriptor": cls_token,
|
|
|
|
"descriptors": desc.flatten(-2).transpose(-2, -1),
|
|
|
|
}
|
|
|
|
|
|
|
|
def loss(self, pred, data):
|
|
|
|
raise NotImplementedError
|