check_dataset.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import random
  17. from collections import defaultdict
  18. from .....utils.errors import CheckFailedError, DatasetFileNotFoundError
  19. def check(dataset_dir, output, sample_num=10):
  20. """check dataset"""
  21. dataset_dir = osp.abspath(dataset_dir)
  22. # Custom dataset
  23. if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
  24. raise DatasetFileNotFoundError(file_path=dataset_dir)
  25. tags = ["train", "val"]
  26. delim = " "
  27. valid_num_parts = 2
  28. sample_cnts = dict()
  29. label_map_dict = dict()
  30. sample_paths = defaultdict(list)
  31. labels = []
  32. label_file = osp.join(dataset_dir, "label.txt")
  33. if not osp.exists(label_file):
  34. raise DatasetFileNotFoundError(
  35. file_path=label_file,
  36. solution=f"Ensure that `label.txt` exist in {dataset_dir}",
  37. )
  38. with open(label_file, "r", encoding="utf-8") as f:
  39. all_lines = f.readlines()
  40. for line in all_lines:
  41. substr = line.strip("\n").split(" ", 1)
  42. try:
  43. label_idx = int(substr[0])
  44. labels.append(label_idx)
  45. label_map_dict[label_idx] = str(substr[1])
  46. except:
  47. raise CheckFailedError(
  48. f"Ensure that the first number in each line in {label_file} should be int."
  49. )
  50. if min(labels) != 0:
  51. raise CheckFailedError(
  52. f"Ensure that the index starts from 0 in `{label_file}`."
  53. )
  54. for tag in tags:
  55. file_list = osp.join(dataset_dir, f"{tag}.txt")
  56. if not osp.exists(file_list):
  57. if tag in ("train", "val"):
  58. # train and val file lists must exist
  59. raise DatasetFileNotFoundError(
  60. file_path=file_list,
  61. solution=f"Ensure that both `train.txt` and `val.txt` exist in {dataset_dir}",
  62. )
  63. else:
  64. # tag == 'test'
  65. continue
  66. else:
  67. with open(file_list, "r", encoding="utf-8") as f:
  68. all_lines = f.readlines()
  69. random.seed(123)
  70. random.shuffle(all_lines)
  71. sample_cnts[tag] = len(all_lines)
  72. for line in all_lines:
  73. substr = line.strip("\n").split(delim)
  74. if len(substr) != valid_num_parts:
  75. raise CheckFailedError(
  76. f"The number of delimiter-separated items in each row in {file_list} \
  77. should be {valid_num_parts} (current delimiter is '{delim}')."
  78. )
  79. file_name = substr[0]
  80. label = substr[1]
  81. video_path = osp.join(dataset_dir, file_name)
  82. if not osp.exists(video_path):
  83. raise DatasetFileNotFoundError(file_path=video_path)
  84. if len(sample_paths[tag]) < sample_num:
  85. sample_path = osp.join(
  86. "check_dataset", os.path.relpath(video_path, output)
  87. )
  88. sample_paths[tag].append(sample_path)
  89. try:
  90. label = int(label)
  91. except (ValueError, TypeError) as e:
  92. raise CheckFailedError(
  93. f"Ensure that the second number in each line in {label_file} should be int."
  94. ) from e
  95. num_classes = max(labels) + 1
  96. attrs = {}
  97. attrs["label_file"] = osp.relpath(label_file, output)
  98. attrs["num_classes"] = num_classes
  99. attrs["train_samples"] = sample_cnts["train"]
  100. attrs["train_sample_paths"] = sample_paths["train"]
  101. attrs["val_samples"] = sample_cnts["val"]
  102. attrs["val_sample_paths"] = sample_paths["val"]
  103. return attrs