import os import cv2 class EarDataClass: def __init__(self, root_dir: str, annot_file: str, mode: str): if not os.path.isdir(root_dir): raise ValueError("root_dir must be a valid directory") if os.path.isfile(os.path.join(root_dir, annot_file)): raise ValueError("annot_file must be a valid file") if mode not in ["train", "test"]: raise ValueError("mode must be either train or test") self.root_dir = root_dir self.annot_file = annot_file self.mode = mode self._set_paths() self._set_bboxes() def _set_paths(self): paths = [] labels = [] def _convert_path_to_number(path): return int(path.split("/")[-1].split(".")[0]) with open(self.annot_file, "r") as f: lines = f.readlines() for line in lines: line = line.split(" ") path = os.path.join(self.root_dir, line[0]) p_int = _convert_path_to_number(path) if self.mode == "train": if p_int % 5 != 0: paths.append(path) labels.append(int(line[1].strip())) elif self.mode == "test": if p_int % 5 == 0: paths.append(path) labels.append(int(line[1].strip())) self.paths = paths self.labels = labels def _set_bboxes(self): bboxes = [] for path in self.paths: path = path.replace(".png", ".txt") with open(path, "r") as f: lines = f.read().split(sep=" ") bbox = [float(x) for x in lines[1:]] bboxes.append(bbox) self.bboxes = bboxes def __getitem__(self, idx): image_path = self.paths[idx] label = self.labels[idx] bbox = self.bboxes[idx] image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) return image, label, bbox def __len__(self): return len(self.paths) def main(): dat = EarDataClass(root_dir="./ears", annot_file="identites.txt", mode="test") for i in range(len(dat)): image, label, bbox = dat[i] print(image.shape, label, bbox) pass if __name__ == "__main__": main()