split_dataset.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import shutil
  16. import random
  17. import json
  18. from tqdm import tqdm
  19. from .....utils.file_interface import custom_open, write_json_file
  20. from .....utils.logging import info
  21. def split_dataset(root_dir, train_rate, val_rate):
  22. """ split dataset """
  23. assert train_rate + val_rate == 100, \
  24. f"The sum of train_rate({train_rate}), val_rate({val_rate}) should equal 100!"
  25. assert train_rate > 0 and val_rate > 0, \
  26. f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!"
  27. all_image_info_list = []
  28. all_category_dict = {}
  29. max_image_id = 0
  30. for fn in ["instance_train.json", "instance_val.json"]:
  31. anno_path = os.path.join(root_dir, "annotations", fn)
  32. if not os.path.exists(anno_path):
  33. info(
  34. f"The annotation file {anno_path} don't exists, has been ignored!"
  35. )
  36. continue
  37. image_info_list, category_list, max_image_id = json2list(anno_path,
  38. max_image_id)
  39. all_image_info_list.extend(image_info_list)
  40. for category in category_list:
  41. if category['id'] not in all_category_dict:
  42. all_category_dict[category['id']] = category
  43. total_num = len(all_image_info_list)
  44. random.shuffle(all_image_info_list)
  45. all_category_list = [all_category_dict[k] for k in all_category_dict]
  46. start = 0
  47. for fn, rate in [("instance_train.json", train_rate),
  48. ("instance_val.json", val_rate)]:
  49. end = start + round(total_num * rate / 100)
  50. save_path = os.path.join(root_dir, "annotations", fn)
  51. if os.path.exists(save_path):
  52. bak_path = save_path + ".bak"
  53. shutil.move(save_path, bak_path)
  54. info(
  55. f"The original annotation file {fn} has been backed up to {bak_path}."
  56. )
  57. assemble_write(all_image_info_list[start:end], all_category_list,
  58. save_path)
  59. start = end
  60. return root_dir
  61. def json2list(json_path, base_image_num):
  62. """ load json as list """
  63. assert os.path.exists(json_path), json_path
  64. with custom_open(json_path, 'r') as f:
  65. data = json.load(f)
  66. image_info_dict = {}
  67. max_image_id = 0
  68. for image_info in data['images']:
  69. # 得到全局唯一的image_id
  70. global_image_id = image_info['id'] + base_image_num
  71. max_image_id = max(global_image_id, max_image_id)
  72. image_info['id'] = global_image_id
  73. image_info_dict[global_image_id] = {"img": image_info, 'anno': []}
  74. image_info_dict = {
  75. image_info['id']: {
  76. "img": image_info,
  77. 'anno': []
  78. }
  79. for image_info in data['images']
  80. }
  81. info(f"Start loading annotation file {json_path}...")
  82. for anno in tqdm(data['annotations']):
  83. global_image_id = anno['image_id'] + base_image_num
  84. anno['image_id'] = global_image_id
  85. image_info_dict[global_image_id]['anno'].append(anno)
  86. image_info_list = [(image_info_dict[image_info]['img'],
  87. image_info_dict[image_info]['anno'])
  88. for image_info in image_info_dict]
  89. return image_info_list, data['categories'], max_image_id
  90. def assemble_write(image_info_list, category_list, save_path):
  91. """ assemble coco format and save to file """
  92. coco_data = {'categories': category_list}
  93. image_list = [i[0] for i in image_info_list]
  94. all_anno_list = []
  95. for i in image_info_list:
  96. all_anno_list.extend(i[1])
  97. anno_list = []
  98. for i, anno in enumerate(all_anno_list):
  99. anno['id'] = i + 1
  100. anno_list.append(anno)
  101. coco_data['images'] = image_list
  102. coco_data['annotations'] = anno_list
  103. write_json_file(coco_data, save_path)
  104. info(f"The splited annotations has been save to {save_path}.")