check_dataset.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import random
  17. from collections import defaultdict
  18. from .....utils.errors import CheckFailedError, DatasetFileNotFoundError
  19. from .....utils.file_interface import custom_open
  20. def check(dataset_dir, output, sample_num=10):
  21. """check dataset"""
  22. dataset_dir = osp.abspath(dataset_dir)
  23. # Custom dataset
  24. if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
  25. raise DatasetFileNotFoundError(file_path=dataset_dir)
  26. tags = ["train", "val"]
  27. valid_num_parts = 5
  28. sample_cnts = dict()
  29. label_map_dict = dict()
  30. sample_paths = defaultdict(list)
  31. labels = []
  32. image_dir = osp.join(dataset_dir, "rgb-images")
  33. label_dir = osp.join(dataset_dir, "labels")
  34. if not osp.exists(image_dir):
  35. raise DatasetFileNotFoundError(file_path=image_dir)
  36. if not osp.exists(label_dir):
  37. raise DatasetFileNotFoundError(file_path=label_dir)
  38. label_map_file = osp.join(dataset_dir, "label_map.txt")
  39. if not osp.exists(label_map_file):
  40. raise DatasetFileNotFoundError(
  41. file_path=label_map_file,
  42. solution=f"Ensure that `label_map.txt` exist in {dataset_dir}",
  43. )
  44. with open(label_map_file, "r", encoding="utf-8") as f:
  45. all_lines = f.readlines()
  46. for line in all_lines:
  47. substr = line.strip("\n").split(" ", 1)
  48. try:
  49. label_idx = int(substr[1])
  50. labels.append(label_idx)
  51. label_map_dict[label_idx] = str(substr[0])
  52. except:
  53. raise CheckFailedError(
  54. f"Ensure that the second number in each line in {label_map_file} should be int."
  55. )
  56. if min(labels) != 1:
  57. raise CheckFailedError(
  58. f"Ensure that the index starts from 1 in `{label_map_file}`."
  59. )
  60. for tag in tags:
  61. file_list = osp.join(dataset_dir, f"{tag}.txt")
  62. if not osp.exists(file_list):
  63. if tag in ("train", "val"):
  64. # train and val file lists must exist
  65. raise DatasetFileNotFoundError(
  66. file_path=file_list,
  67. solution=f"Ensure that both `train.txt` and `val.txt` exist in {dataset_dir}",
  68. )
  69. else:
  70. # tag == 'test'
  71. continue
  72. else:
  73. with open(file_list, "r", encoding="utf-8") as f:
  74. all_lines = f.readlines()
  75. random.seed(123)
  76. random.shuffle(all_lines)
  77. sample_cnts[tag] = len(all_lines)
  78. for line in all_lines:
  79. substr = line.strip("\n")
  80. label_path = osp.join(dataset_dir, substr)
  81. img_path = (
  82. osp.join(dataset_dir, substr)
  83. .replace("labels", "rgb-images")
  84. .replace("txt", "jpg")
  85. )
  86. if not osp.exists(img_path):
  87. raise DatasetFileNotFoundError(file_path=img_path)
  88. if not osp.exists(label_path):
  89. raise DatasetFileNotFoundError(file_path=label_path)
  90. with custom_open(label_path, "r") as f:
  91. label_lines = f.readlines()
  92. for label_line in label_lines:
  93. label_info = label_line.strip().split(" ")
  94. try:
  95. int(label_info[0])
  96. except (ValueError, TypeError) as e:
  97. raise CheckFailedError(
  98. f"Ensure that the first number in each line in {label_info} should be int."
  99. ) from e
  100. if len(label_info) != valid_num_parts:
  101. raise CheckFailedError(
  102. f"Ensure that each line in {label_line} has exactly two numbers."
  103. )
  104. if len(sample_paths[tag]) < sample_num:
  105. sample_path = osp.join(
  106. "check_dataset", os.path.relpath(img_path, output)
  107. )
  108. sample_paths[tag].append(sample_path)
  109. num_classes = max(labels)
  110. attrs = {}
  111. attrs["label_file"] = osp.relpath(label_map_file, output)
  112. attrs["num_classes"] = num_classes
  113. attrs["train_samples"] = sample_cnts["train"]
  114. attrs["train_sample_paths"] = sample_paths["train"]
  115. attrs["val_samples"] = sample_cnts["val"]
  116. attrs["val_sample_paths"] = sample_paths["val"]
  117. return attrs