sb-ass/a1/dataloader.py

65 lines
1.9 KiB
Python

import os
import cv2
class EarDataClass():
def __init__(self, root_dir:str , annot_file: str, mode: str):
if not os.path.isdir(root_dir):
raise ValueError('root_dir must be a valid directory')
if os.path.isfile(os.path.join(root_dir, annot_file)):
raise ValueError('annot_file must be a valid file')
if mode not in ['train', 'test']:
raise ValueError('mode must be either train or test')
self.root_dir = root_dir
self.annot_file = annot_file
self.mode = mode
self._set_paths()
def _set_paths(self):
paths = []
labels = []
def _convert_path_to_number(path):
return int(path.split('/')[-1].split('.')[0])
with open(self.annot_file, 'r') as f:
lines = f.readlines()
for line in lines:
line = line.split(' ')
path = os.path.join(self.root_dir, line[0])
p_int = _convert_path_to_number(path)
if self.mode == 'train':
if p_int % 5 != 0:
paths.append(path)
labels.append(line[1])
elif self.mode == 'test':
if p_int % 5 == 0:
paths.append(path)
labels.append(line[1])
self.paths = paths
self.labels = labels
def __getitem__(self, idx):
image_path = self.paths[idx]
label = self.labels[idx]
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image, label
def __len__(self):
return len(self.paths)
def main():
dat = EarDataClass(root_dir='./ears', annot_file='identites.txt', mode='train')
for i in range(len(dat)):
image, label = dat[i]
print(image.shape, label)
pass
if __name__ == '__main__':
main()