split_dataset.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import shutil
  16. import random
  17. import json
  18. from tqdm import tqdm
  19. from .....utils.file_interface import custom_open, write_json_file
  20. from .....utils.logging import info
  21. def split_dataset(root_dir, train_rate, val_rate):
  22. """split dataset"""
  23. assert (
  24. train_rate + val_rate == 100
  25. ), f"The sum of train_rate({train_rate}), val_rate({val_rate}) should equal 100!"
  26. assert (
  27. train_rate > 0 and val_rate > 0
  28. ), f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!"
  29. all_image_info_list = []
  30. all_category_dict = {}
  31. max_image_id = 0
  32. for fn in ["instance_train.json", "instance_val.json"]:
  33. anno_path = os.path.join(root_dir, "annotations", fn)
  34. if not os.path.exists(anno_path):
  35. info(f"The annotation file {anno_path} don't exists, has been ignored!")
  36. continue
  37. image_info_list, category_list, max_image_id = json2list(
  38. anno_path, max_image_id
  39. )
  40. all_image_info_list.extend(image_info_list)
  41. for category in category_list:
  42. if category["id"] not in all_category_dict:
  43. all_category_dict[category["id"]] = category
  44. total_num = len(all_image_info_list)
  45. random.shuffle(all_image_info_list)
  46. all_category_list = [all_category_dict[k] for k in all_category_dict]
  47. start = 0
  48. for fn, rate in [
  49. ("instance_train.json", train_rate),
  50. ("instance_val.json", val_rate),
  51. ]:
  52. end = start + round(total_num * rate / 100)
  53. save_path = os.path.join(root_dir, "annotations", fn)
  54. if os.path.exists(save_path):
  55. bak_path = save_path + ".bak"
  56. shutil.move(save_path, bak_path)
  57. info(f"The original annotation file {fn} has been backed up to {bak_path}.")
  58. assemble_write(all_image_info_list[start:end], all_category_list, save_path)
  59. start = end
  60. return root_dir
  61. def json2list(json_path, base_image_num):
  62. """load json as list"""
  63. assert os.path.exists(json_path), json_path
  64. with custom_open(json_path, "r") as f:
  65. data = json.load(f)
  66. image_info_dict = {}
  67. max_image_id = 0
  68. for image_info in data["images"]:
  69. # 得到全局唯一的image_id
  70. global_image_id = image_info["id"] + base_image_num
  71. max_image_id = max(global_image_id, max_image_id)
  72. image_info["id"] = global_image_id
  73. image_info_dict[global_image_id] = {"img": image_info, "anno": []}
  74. image_info_dict = {
  75. image_info["id"]: {"img": image_info, "anno": []}
  76. for image_info in data["images"]
  77. }
  78. info(f"Start loading annotation file {json_path}...")
  79. for anno in tqdm(data["annotations"]):
  80. global_image_id = anno["image_id"] + base_image_num
  81. anno["image_id"] = global_image_id
  82. image_info_dict[global_image_id]["anno"].append(anno)
  83. image_info_list = [
  84. (image_info_dict[image_info]["img"], image_info_dict[image_info]["anno"])
  85. for image_info in image_info_dict
  86. ]
  87. return image_info_list, data["categories"], max_image_id
  88. def assemble_write(image_info_list, category_list, save_path):
  89. """assemble coco format and save to file"""
  90. coco_data = {"categories": category_list}
  91. image_list = [i[0] for i in image_info_list]
  92. all_anno_list = []
  93. for i in image_info_list:
  94. all_anno_list.extend(i[1])
  95. anno_list = []
  96. for i, anno in enumerate(all_anno_list):
  97. anno["id"] = i + 1
  98. anno_list.append(anno)
  99. coco_data["images"] = image_list
  100. coco_data["annotations"] = anno_list
  101. write_json_file(coco_data, save_path)
  102. info(f"The splited annotations has been save to {save_path}.")