split_dataset.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import glob
  15. import os.path
  16. import numpy as np
  17. import shutil
  18. from .....utils.file_interface import custom_open
  19. from .....utils.logging import info
  20. def split_dataset(dataset_root, train_rate, val_rate):
  21. """
  22. 将图像数据集按照比例分成训练集、验证集和测试集,并生成对应的.txt文件。
  23. Args:
  24. dataset_root (str): 数据集根目录路径。
  25. train_rate (int): 训练集占总数据集的比例(%)。
  26. val_rate (int): 验证集占总数据集的比例(%)。
  27. Returns:
  28. str: 数据划分结果信息。
  29. """
  30. sum_rate = train_rate + val_rate
  31. assert sum_rate == 100, \
  32. f"训练集、验证集比例之和需要等于100,请修改后重试"
  33. assert train_rate > 0 and val_rate > 0, \
  34. f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!"
  35. image_dir = os.path.join(dataset_root, 'images')
  36. tags = ['train.txt', 'val.txt']
  37. image_files = get_files(image_dir,
  38. ['png', 'jpg', 'jpeg', 'PNG', 'JPG', 'JPEG'])
  39. label_files = get_labels_files(dataset_root, ['train.txt', 'val.txt'])
  40. for tag in tags:
  41. src_file = os.path.join(dataset_root, tag)
  42. dst_file = os.path.join(dataset_root, f"{tag}.bak")
  43. info(
  44. f"The original annotation file {src_file} has been backed up to {dst_file}."
  45. )
  46. shutil.move(src_file, dst_file)
  47. image_num = len(image_files)
  48. label_num = len(label_files)
  49. assert image_num != 0, f"原始图像数量({image_num})为0, 请检查后重试"
  50. assert image_num == label_num, \
  51. f"原始图像数量({image_num})和标注图像数量({label_num})不相等,请检查后重试"
  52. image_files = np.array(image_files)
  53. label_files = np.array(label_files)
  54. state = np.random.get_state()
  55. np.random.shuffle(image_files)
  56. np.random.set_state(state)
  57. np.random.shuffle(label_files)
  58. start = 0
  59. rate_list = [train_rate, val_rate]
  60. name_list = ['train', 'val']
  61. separator = " "
  62. for i, name in enumerate(name_list):
  63. info("Creating {}.txt...".format(name))
  64. rate = rate_list[i]
  65. if rate == 0:
  66. txt_file = os.path.join(dataset_root, name + '.txt')
  67. with custom_open(txt_file, "w") as f:
  68. f.write("")
  69. continue
  70. end = start + round(image_num * rate / 100)
  71. if sum(rate_list[i + 1:]) == 0:
  72. end = image_num
  73. txt_file = os.path.join(dataset_root, name + '.txt')
  74. with custom_open(txt_file, "w") as f:
  75. for id in range(start, end):
  76. right = label_files[id]
  77. f.write(right)
  78. start = end
  79. return dataset_root
  80. def get_files(input_dir, format=['jpg', 'png']):
  81. """
  82. 在给定目录下获取符合指定文件格式的所有文件路径
  83. Args:
  84. input_dir (str): 目标文件夹路径
  85. format (Union[str, List[str]]): 需要获取的文件格式, 可以是字符串或者字符串列表
  86. Returns:
  87. List[str]: 符合格式的所有文件路径列表,返回排序后的结果
  88. """
  89. res = []
  90. if not isinstance(format, (list, tuple)):
  91. format = [format]
  92. for item in format:
  93. pattern = os.path.join(input_dir, f'**/*.{item}')
  94. files = glob.glob(pattern, recursive=True)
  95. res.extend(files)
  96. return sorted(res)
  97. def get_labels_files(input_dir, format=['train.txt', 'val.txt']):
  98. """
  99. 在给定目录下获取符合指定文件格式的所有文件路径
  100. Args:
  101. input_dir (str): 目标文件夹路径
  102. format (Union[str, List[str]]): 需要获取的文件格式, 可以是字符串或者字符串列表
  103. Returns:
  104. List[str]: 符合格式的所有文件路径列表,返回排序后的结果
  105. """
  106. res = []
  107. if not isinstance(format, (list, tuple)):
  108. format = [format]
  109. for tag in format:
  110. file_list = os.path.join(input_dir, f'{tag}')
  111. if os.path.exists(file_list):
  112. with custom_open(file_list, 'r') as f:
  113. all_lines = f.readlines()
  114. res.extend(all_lines)
  115. return sorted(res)