split_dataset.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import os
  16. import random
  17. import shutil
  18. from .....utils.deps import function_requires_deps, is_dep_available
  19. from .....utils.file_interface import custom_open, write_json_file
  20. from .....utils.logging import info
  21. if is_dep_available("tqdm"):
  22. from tqdm import tqdm
  23. def split_dataset(root_dir, train_rate, val_rate):
  24. """split dataset"""
  25. assert (
  26. train_rate + val_rate == 100
  27. ), f"The sum of train_rate({train_rate}), val_rate({val_rate}) should equal 100!"
  28. assert (
  29. train_rate > 0 and val_rate > 0
  30. ), f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!"
  31. all_image_info_list = []
  32. all_category_dict = {}
  33. max_image_id = 0
  34. for fn in ["instance_train.json", "instance_val.json"]:
  35. anno_path = os.path.join(root_dir, "annotations", fn)
  36. if not os.path.exists(anno_path):
  37. info(f"The annotation file {anno_path} don't exists, has been ignored!")
  38. continue
  39. image_info_list, category_list, max_image_id = json2list(
  40. anno_path, max_image_id
  41. )
  42. all_image_info_list.extend(image_info_list)
  43. for category in category_list:
  44. if category["id"] not in all_category_dict:
  45. all_category_dict[category["id"]] = category
  46. total_num = len(all_image_info_list)
  47. random.shuffle(all_image_info_list)
  48. all_category_list = [all_category_dict[k] for k in all_category_dict]
  49. start = 0
  50. for fn, rate in [
  51. ("instance_train.json", train_rate),
  52. ("instance_val.json", val_rate),
  53. ]:
  54. end = start + round(total_num * rate / 100)
  55. save_path = os.path.join(root_dir, "annotations", fn)
  56. if os.path.exists(save_path):
  57. bak_path = save_path + ".bak"
  58. shutil.move(save_path, bak_path)
  59. info(f"The original annotation file {fn} has been backed up to {bak_path}.")
  60. assemble_write(all_image_info_list[start:end], all_category_list, save_path)
  61. start = end
  62. return root_dir
  63. @function_requires_deps("tqdm")
  64. def json2list(json_path, base_image_num):
  65. """load json as list"""
  66. assert os.path.exists(json_path), json_path
  67. with custom_open(json_path, "r") as f:
  68. data = json.load(f)
  69. image_info_dict = {}
  70. max_image_id = 0
  71. for image_info in data["images"]:
  72. # 得到全局唯一的image_id
  73. global_image_id = image_info["id"] + base_image_num
  74. max_image_id = max(global_image_id, max_image_id)
  75. image_info["id"] = global_image_id
  76. image_info_dict[global_image_id] = {"img": image_info, "anno": []}
  77. image_info_dict = {
  78. image_info["id"]: {"img": image_info, "anno": []}
  79. for image_info in data["images"]
  80. }
  81. info(f"Start loading annotation file {json_path}...")
  82. for anno in tqdm(data["annotations"]):
  83. global_image_id = anno["image_id"] + base_image_num
  84. anno["image_id"] = global_image_id
  85. image_info_dict[global_image_id]["anno"].append(anno)
  86. image_info_list = [
  87. (image_info_dict[image_info]["img"], image_info_dict[image_info]["anno"])
  88. for image_info in image_info_dict
  89. ]
  90. return image_info_list, data["categories"], max_image_id
  91. def assemble_write(image_info_list, category_list, save_path):
  92. """assemble coco format and save to file"""
  93. coco_data = {"categories": category_list}
  94. image_list = [i[0] for i in image_info_list]
  95. all_anno_list = []
  96. for i in image_info_list:
  97. all_anno_list.extend(i[1])
  98. anno_list = []
  99. for i, anno in enumerate(all_anno_list):
  100. anno["id"] = i + 1
  101. anno_list.append(anno)
  102. coco_data["images"] = image_list
  103. coco_data["annotations"] = anno_list
  104. write_json_file(coco_data, save_path)
  105. info(f"The splited annotations has been save to {save_path}.")