2023-11-07 09:09:37 +01:00
|
|
|
import os
|
|
|
|
import cv2
|
|
|
|
|
|
|
|
|
2023-11-08 15:30:37 +01:00
|
|
|
class EarDataClass:
|
|
|
|
def __init__(self, root_dir: str, annot_file: str, mode: str):
|
2023-11-07 09:09:37 +01:00
|
|
|
if not os.path.isdir(root_dir):
|
2023-11-08 15:30:37 +01:00
|
|
|
raise ValueError("root_dir must be a valid directory")
|
2023-11-07 09:09:37 +01:00
|
|
|
if os.path.isfile(os.path.join(root_dir, annot_file)):
|
2023-11-08 15:30:37 +01:00
|
|
|
raise ValueError("annot_file must be a valid file")
|
|
|
|
if mode not in ["train", "test"]:
|
|
|
|
raise ValueError("mode must be either train or test")
|
2023-11-07 09:09:37 +01:00
|
|
|
|
|
|
|
self.root_dir = root_dir
|
|
|
|
self.annot_file = annot_file
|
|
|
|
self.mode = mode
|
|
|
|
self._set_paths()
|
2023-11-08 15:30:37 +01:00
|
|
|
self._set_bboxes()
|
2023-11-07 09:09:37 +01:00
|
|
|
|
|
|
|
def _set_paths(self):
|
|
|
|
paths = []
|
|
|
|
labels = []
|
|
|
|
|
|
|
|
def _convert_path_to_number(path):
|
2023-11-08 15:30:37 +01:00
|
|
|
return int(path.split("/")[-1].split(".")[0])
|
2023-11-07 09:09:37 +01:00
|
|
|
|
2023-11-08 15:30:37 +01:00
|
|
|
with open(self.annot_file, "r") as f:
|
2023-11-07 09:09:37 +01:00
|
|
|
lines = f.readlines()
|
|
|
|
for line in lines:
|
2023-11-08 15:30:37 +01:00
|
|
|
line = line.split(" ")
|
2023-11-07 09:09:37 +01:00
|
|
|
path = os.path.join(self.root_dir, line[0])
|
|
|
|
p_int = _convert_path_to_number(path)
|
2023-11-08 15:30:37 +01:00
|
|
|
if self.mode == "train":
|
2023-11-07 09:09:37 +01:00
|
|
|
if p_int % 5 != 0:
|
|
|
|
paths.append(path)
|
2023-11-08 15:30:37 +01:00
|
|
|
labels.append(int(line[1].strip()))
|
|
|
|
elif self.mode == "test":
|
2023-11-07 09:09:37 +01:00
|
|
|
if p_int % 5 == 0:
|
|
|
|
paths.append(path)
|
2023-11-08 15:30:37 +01:00
|
|
|
labels.append(int(line[1].strip()))
|
2023-11-07 09:09:37 +01:00
|
|
|
|
|
|
|
self.paths = paths
|
|
|
|
self.labels = labels
|
|
|
|
|
2023-11-08 15:30:37 +01:00
|
|
|
def _set_bboxes(self):
|
|
|
|
bboxes = []
|
|
|
|
for path in self.paths:
|
|
|
|
path = path.replace(".png", ".txt")
|
|
|
|
with open(path, "r") as f:
|
|
|
|
lines = f.read().split(sep=" ")
|
|
|
|
bbox = [float(x) for x in lines[1:]]
|
|
|
|
bboxes.append(bbox)
|
|
|
|
|
|
|
|
self.bboxes = bboxes
|
|
|
|
|
2023-11-07 09:09:37 +01:00
|
|
|
def __getitem__(self, idx):
|
|
|
|
image_path = self.paths[idx]
|
|
|
|
label = self.labels[idx]
|
2023-11-08 15:30:37 +01:00
|
|
|
bbox = self.bboxes[idx]
|
2023-11-07 09:09:37 +01:00
|
|
|
image = cv2.imread(image_path)
|
2023-11-08 15:30:37 +01:00
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
|
|
return image, label, bbox
|
2023-11-07 09:09:37 +01:00
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.paths)
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
2023-11-08 15:30:37 +01:00
|
|
|
dat = EarDataClass(root_dir="./ears", annot_file="identites.txt", mode="test")
|
2023-11-07 09:09:37 +01:00
|
|
|
for i in range(len(dat)):
|
2023-11-08 15:30:37 +01:00
|
|
|
image, label, bbox = dat[i]
|
|
|
|
print(image.shape, label, bbox)
|
2023-11-07 09:09:37 +01:00
|
|
|
pass
|
|
|
|
|
2023-11-08 15:30:37 +01:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2023-11-07 09:09:37 +01:00
|
|
|
main()
|