split_dataset.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import random
  17. import shutil
  18. from .....utils.file_interface import custom_open
  19. from .....utils import logging
  20. def split_dataset(root_dir, train_percent, val_percent):
  21. """ split dataset """
  22. assert train_percent > 0, ValueError(
  23. f"The train_percent({train_percent}) must greater than 0!")
  24. assert val_percent > 0, ValueError(
  25. f"The val_percent({val_percent}) must greater than 0!")
  26. if train_percent + val_percent != 100:
  27. raise ValueError(
  28. f"The sum of train_percent({train_percent})and val_percent({val_percent}) should be 100!"
  29. )
  30. img_dir = osp.join(root_dir, "images")
  31. assert osp.exists(img_dir), FileNotFoundError(
  32. f"The dir of images ({img_dir}) doesn't exist, please check!")
  33. ann_dir = osp.join(root_dir, "annotations")
  34. assert osp.exists(ann_dir), FileNotFoundError(
  35. f"The dir of annotations ({ann_dir}) doesn't exist, please check!")
  36. img_file_list = [
  37. osp.join("images", img_name) for img_name in os.listdir(img_dir)
  38. ]
  39. img_num = len(img_file_list)
  40. ann_file_list = [
  41. osp.join("annotations", ann_name) for ann_name in os.listdir(ann_dir)
  42. ]
  43. ann_num = len(ann_file_list)
  44. assert img_num == ann_num, ValueError(
  45. "The number of images and annotations must be equal!")
  46. split_tags = ["train", "val"]
  47. mapping_line_list = []
  48. for tag in split_tags:
  49. mapping_file = osp.join(root_dir, f"{tag}.txt")
  50. if not osp.exists(mapping_file):
  51. logging.info(
  52. f"The mapping file ({mapping_file}) doesn't exist, ignored.")
  53. continue
  54. with custom_open(mapping_file, "r") as fp:
  55. lines = filter(None, (line.strip() for line in fp.readlines()))
  56. mapping_line_list.extend(lines)
  57. sample_num = len(mapping_line_list)
  58. random.shuffle(mapping_line_list)
  59. split_percents = [train_percent, val_percent]
  60. start_idx = 0
  61. for tag, percent in zip(split_tags, split_percents):
  62. if tag == 'test' and percent == 0:
  63. continue
  64. end_idx = start_idx + round(sample_num * percent / 100)
  65. end_idx = min(end_idx, sample_num)
  66. mapping_file = osp.join(root_dir, f"{tag}.txt")
  67. if os.path.exists(mapping_file):
  68. shutil.move(mapping_file, mapping_file + ".bak")
  69. logging.info(f"The original mapping file ({mapping_file}) "
  70. f"has been backed up to ({mapping_file}.bak)")
  71. with custom_open(mapping_file, "w") as fp:
  72. fp.write("\n".join(mapping_line_list[start_idx:end_idx]))
  73. start_idx = end_idx
  74. return root_dir