split_dataset.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import os
  16. import random
  17. import shutil
  18. from .....utils.deps import function_requires_deps
  19. from .....utils.file_interface import custom_open, write_json_file
  20. from .....utils.logging import info
  21. def split_dataset(root_dir, train_rate, val_rate):
  22. """split dataset"""
  23. assert (
  24. train_rate + val_rate == 100
  25. ), f"The sum of train_rate({train_rate}), val_rate({val_rate}) should equal 100!"
  26. assert (
  27. train_rate > 0 and val_rate > 0
  28. ), f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!"
  29. all_image_info_list = []
  30. all_category_dict = {}
  31. max_image_id = 0
  32. for fn in ["instance_train.json", "instance_val.json"]:
  33. anno_path = os.path.join(root_dir, "annotations", fn)
  34. if not os.path.exists(anno_path):
  35. info(f"The annotation file {anno_path} don't exists, has been ignored!")
  36. continue
  37. image_info_list, category_list, max_image_id = json2list(
  38. anno_path, max_image_id
  39. )
  40. all_image_info_list.extend(image_info_list)
  41. for category in category_list:
  42. if category["id"] not in all_category_dict:
  43. all_category_dict[category["id"]] = category
  44. total_num = len(all_image_info_list)
  45. random.shuffle(all_image_info_list)
  46. all_category_list = [all_category_dict[k] for k in all_category_dict]
  47. start = 0
  48. for fn, rate in [
  49. ("instance_train.json", train_rate),
  50. ("instance_val.json", val_rate),
  51. ]:
  52. end = start + round(total_num * rate / 100)
  53. save_path = os.path.join(root_dir, "annotations", fn)
  54. if os.path.exists(save_path):
  55. bak_path = save_path + ".bak"
  56. shutil.move(save_path, bak_path)
  57. info(f"The original annotation file {fn} has been backed up to {bak_path}.")
  58. assemble_write(all_image_info_list[start:end], all_category_list, save_path)
  59. start = end
  60. return root_dir
  61. @function_requires_deps("tqdm")
  62. def json2list(json_path, base_image_num):
  63. """load json as list"""
  64. from tqdm import tqdm
  65. assert os.path.exists(json_path), json_path
  66. with custom_open(json_path, "r") as f:
  67. data = json.load(f)
  68. image_info_dict = {}
  69. max_image_id = 0
  70. for image_info in data["images"]:
  71. # 得到全局唯一的image_id
  72. global_image_id = image_info["id"] + base_image_num
  73. max_image_id = max(global_image_id, max_image_id)
  74. image_info["id"] = global_image_id
  75. image_info_dict[global_image_id] = {"img": image_info, "anno": []}
  76. image_info_dict = {
  77. image_info["id"]: {"img": image_info, "anno": []}
  78. for image_info in data["images"]
  79. }
  80. info(f"Start loading annotation file {json_path}...")
  81. for anno in tqdm(data["annotations"]):
  82. global_image_id = anno["image_id"] + base_image_num
  83. anno["image_id"] = global_image_id
  84. image_info_dict[global_image_id]["anno"].append(anno)
  85. image_info_list = [
  86. (image_info_dict[image_info]["img"], image_info_dict[image_info]["anno"])
  87. for image_info in image_info_dict
  88. ]
  89. return image_info_list, data["categories"], max_image_id
  90. def assemble_write(image_info_list, category_list, save_path):
  91. """assemble coco format and save to file"""
  92. coco_data = {"categories": category_list}
  93. image_list = [i[0] for i in image_info_list]
  94. all_anno_list = []
  95. for i in image_info_list:
  96. all_anno_list.extend(i[1])
  97. anno_list = []
  98. for i, anno in enumerate(all_anno_list):
  99. anno["id"] = i + 1
  100. anno_list.append(anno)
  101. coco_data["images"] = image_list
  102. coco_data["annotations"] = anno_list
  103. write_json_file(coco_data, save_path)
  104. info(f"The splited annotations has been save to {save_path}.")