split_dataset.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import glob
  15. import os.path
  16. import shutil
  17. import numpy as np
  18. from .....utils.file_interface import custom_open
  19. from .....utils.logging import info
  20. def split_dataset(dataset_root, train_rate, val_rate):
  21. """
  22. 将图像数据集按照比例分成训练集、验证集和测试集,并生成对应的.txt文件。
  23. Args:
  24. dataset_root (str): 数据集根目录路径。
  25. train_rate (int): 训练集占总数据集的比例(%)。
  26. val_rate (int): 验证集占总数据集的比例(%)。
  27. Returns:
  28. str: 数据划分结果信息。
  29. """
  30. sum_rate = train_rate + val_rate
  31. assert sum_rate == 100, f"训练集、验证集比例之和需要等于100,请修改后重试"
  32. assert (
  33. train_rate > 0 and val_rate > 0
  34. ), f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!"
  35. image_dir = os.path.join(dataset_root, "images")
  36. tags = ["train.txt", "val.txt"]
  37. image_files = get_files(image_dir, ["png", "jpg", "jpeg", "PNG", "JPG", "JPEG"])
  38. label_files = get_labels_files(dataset_root, ["train.txt", "val.txt"])
  39. for tag in tags:
  40. src_file = os.path.join(dataset_root, tag)
  41. dst_file = os.path.join(dataset_root, f"{tag}.bak")
  42. info(
  43. f"The original annotation file {src_file} has been backed up to {dst_file}."
  44. )
  45. shutil.move(src_file, dst_file)
  46. image_num = len(image_files)
  47. label_num = len(label_files)
  48. assert image_num != 0, f"原始图像数量({image_num})为0, 请检查后重试"
  49. assert (
  50. image_num == label_num
  51. ), f"原始图像数量({image_num})和标注图像数量({label_num})不相等,请检查后重试"
  52. image_files = np.array(image_files)
  53. label_files = np.array(label_files)
  54. state = np.random.get_state()
  55. np.random.shuffle(image_files)
  56. np.random.set_state(state)
  57. np.random.shuffle(label_files)
  58. start = 0
  59. rate_list = [train_rate, val_rate]
  60. name_list = ["train", "val"]
  61. for i, name in enumerate(name_list):
  62. info("Creating {}.txt...".format(name))
  63. rate = rate_list[i]
  64. if rate == 0:
  65. txt_file = os.path.join(dataset_root, name + ".txt")
  66. with custom_open(txt_file, "w") as f:
  67. f.write("")
  68. continue
  69. end = start + round(image_num * rate / 100)
  70. if sum(rate_list[i + 1 :]) == 0:
  71. end = image_num
  72. txt_file = os.path.join(dataset_root, name + ".txt")
  73. with custom_open(txt_file, "w") as f:
  74. for id in range(start, end):
  75. right = label_files[id]
  76. f.write(right)
  77. start = end
  78. return dataset_root
  79. def get_files(input_dir, format=["jpg", "png"]):
  80. """
  81. 在给定目录下获取符合指定文件格式的所有文件路径
  82. Args:
  83. input_dir (str): 目标文件夹路径
  84. format (Union[str, List[str]]): 需要获取的文件格式, 可以是字符串或者字符串列表
  85. Returns:
  86. List[str]: 符合格式的所有文件路径列表,返回排序后的结果
  87. """
  88. res = []
  89. if not isinstance(format, (list, tuple)):
  90. format = [format]
  91. for item in format:
  92. pattern = os.path.join(input_dir, f"**/*.{item}")
  93. files = glob.glob(pattern, recursive=True)
  94. res.extend(files)
  95. return sorted(res)
  96. def get_labels_files(input_dir, format=["train.txt", "val.txt"]):
  97. """
  98. 在给定目录下获取符合指定文件格式的所有文件路径
  99. Args:
  100. input_dir (str): 目标文件夹路径
  101. format (Union[str, List[str]]): 需要获取的文件格式, 可以是字符串或者字符串列表
  102. Returns:
  103. List[str]: 符合格式的所有文件路径列表,返回排序后的结果
  104. """
  105. res = []
  106. if not isinstance(format, (list, tuple)):
  107. format = [format]
  108. for tag in format:
  109. file_list = os.path.join(input_dir, f"{tag}")
  110. if os.path.exists(file_list):
  111. with custom_open(file_list, "r") as f:
  112. all_lines = f.readlines()
  113. res.extend(all_lines)
  114. return sorted(res)