check_dataset.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. import os.path as osp
  13. import random
  14. from PIL import Image, ImageOps
  15. from collections import defaultdict
  16. from .....utils.errors import DatasetFileNotFoundError, CheckFailedError
  17. from .utils.visualizer import draw_label
  18. def check(dataset_dir, output, sample_num=10):
  19. """ check dataset """
  20. dataset_dir = osp.abspath(dataset_dir)
  21. # Custom dataset
  22. if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
  23. raise DatasetFileNotFoundError(file_path=dataset_dir)
  24. tags = ['train', 'val']
  25. delim = ' '
  26. valid_num_parts = 2
  27. sample_cnts = dict()
  28. label_map_dict = dict()
  29. sample_paths = defaultdict(list)
  30. labels = []
  31. label_file = osp.join(dataset_dir, 'label.txt')
  32. if not osp.exists(label_file):
  33. raise DatasetFileNotFoundError(
  34. file_path=label_file,
  35. solution=f"Ensure that `label.txt` exist in {dataset_dir}")
  36. with open(label_file, 'r', encoding='utf-8') as f:
  37. all_lines = f.readlines()
  38. for line in all_lines:
  39. substr = line.strip("\n").split(delim, 1)
  40. try:
  41. label_idx = int(substr[0])
  42. labels.append(label_idx)
  43. label_map_dict[label_idx] = str(substr[1])
  44. except:
  45. raise CheckFailedError(
  46. f"Ensure that the first number in each line in {label_file} should be int."
  47. )
  48. if min(labels) != 0:
  49. raise CheckFailedError(
  50. f"Ensure that the index starts from 0 in `{label_file}`.")
  51. for tag in tags:
  52. file_list = osp.join(dataset_dir, f'{tag}.txt')
  53. if not osp.exists(file_list):
  54. if tag in ('train', 'val'):
  55. # train and val file lists must exist
  56. raise DatasetFileNotFoundError(
  57. file_path=file_list,
  58. solution=f"Ensure that both `train.txt` and `val.txt` exist in {dataset_dir}"
  59. )
  60. else:
  61. # tag == 'test'
  62. continue
  63. else:
  64. with open(file_list, 'r', encoding='utf-8') as f:
  65. all_lines = f.readlines()
  66. random.seed(123)
  67. random.shuffle(all_lines)
  68. sample_cnts[tag] = len(all_lines)
  69. for line in all_lines:
  70. substr = line.strip("\n").split(delim)
  71. if len(substr) != valid_num_parts:
  72. raise CheckFailedError(
  73. f"The number of delimiter-separated items in each row in {file_list} \
  74. should be {valid_num_parts} (current delimiter is '{delim}')."
  75. )
  76. file_name = substr[0]
  77. label = substr[1]
  78. img_path = osp.join(dataset_dir, file_name)
  79. if not osp.exists(img_path):
  80. raise DatasetFileNotFoundError(file_path=img_path)
  81. vis_save_dir = osp.join(output, 'demo_img')
  82. if not osp.exists(vis_save_dir):
  83. os.makedirs(vis_save_dir)
  84. if len(sample_paths[tag]) < sample_num:
  85. img = Image.open(img_path)
  86. img = ImageOps.exif_transpose(img)
  87. vis_im = draw_label(img, label, label_map_dict)
  88. vis_path = osp.join(vis_save_dir,
  89. osp.basename(file_name))
  90. vis_im.save(vis_path)
  91. sample_path = osp.join(
  92. 'check_dataset', os.path.relpath(vis_path, output))
  93. sample_paths[tag].append(sample_path)
  94. try:
  95. label = int(label)
  96. except (ValueError, TypeError) as e:
  97. raise CheckFailedError(
  98. f"Ensure that the second number in each line in {label_file} should be int."
  99. ) from e
  100. num_classes = max(labels) + 1
  101. attrs = {}
  102. attrs['label_file'] = osp.relpath(label_file, output)
  103. attrs['num_classes'] = num_classes
  104. attrs['train_samples'] = sample_cnts['train']
  105. attrs['train_sample_paths'] = sample_paths['train']
  106. attrs['val_samples'] = sample_cnts['val']
  107. attrs['val_sample_paths'] = sample_paths['val']
  108. return attrs