data_path_utils.py 1001 B

123456789101112131415161718192021222324252627
  1. import os
  2. def imagenet_val_files_and_labels(dataset_directory):
  3. classes = open(os.path.join(dataset_directory, 'imagenet_lsvrc_2015_synsets.txt')).readlines()
  4. class_to_indx = {classes[i].split('\n')[0]: i for i in range(len(classes))}
  5. images_path = os.path.join(dataset_directory, 'val')
  6. filenames = []
  7. labels = []
  8. lines = open(os.path.join(dataset_directory, 'imagenet_2012_validation_synset_labels.txt'), 'r').readlines()
  9. for i, line in enumerate(lines):
  10. class_name = line.split('\n')[0]
  11. a = 'ILSVRC2012_val_%08d.JPEG' % (i + 1)
  12. filenames.append(f'{images_path}/{a}')
  13. labels.append(class_to_indx[class_name])
  14. # print(filenames[-1], labels[-1])
  15. return filenames, labels
  16. def _find_classes(dir):
  17. # Faster and available in Python 3.5 and above
  18. classes = [d.name for d in os.scandir(dir) if d.is_dir()]
  19. classes.sort()
  20. class_to_idx = {classes[i]: i for i in range(len(classes))}
  21. return classes, class_to_idx