convert_dataset.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from .....utils.deps import function_requires_deps, is_dep_available
  16. from .....utils.errors import ConvertFailedError
  17. from .....utils.logging import info, warning
  18. if is_dep_available("pycocotools"):
  19. from pycocotools.coco import COCO
  20. if is_dep_available("tqdm"):
  21. from tqdm import tqdm
  22. def check_src_dataset(root_dir, dataset_type):
  23. """check src dataset format validity"""
  24. if dataset_type in ("COCO"):
  25. pass
  26. else:
  27. raise ConvertFailedError(
  28. message=f"数据格式转换失败!不支持{dataset_type}格式数据集。当前仅支持 COCO 格式。"
  29. )
  30. err_msg_prefix = f"数据格式转换失败!请参考上述`{dataset_type}格式数据集示例`检查待转换数据集格式。"
  31. for anno in ["annotations/instance_train.json", "annotations/instance_val.json"]:
  32. src_anno_path = os.path.join(root_dir, anno)
  33. if not os.path.exists(src_anno_path):
  34. raise ConvertFailedError(
  35. message=f"{err_msg_prefix}保证{src_anno_path}文件存在。"
  36. )
  37. return None
  38. def convert(dataset_type, input_dir):
  39. """convert dataset to multilabel format"""
  40. # check format validity
  41. check_src_dataset(input_dir, dataset_type)
  42. if dataset_type in ("COCO"):
  43. convert_coco_dataset(input_dir)
  44. else:
  45. raise ConvertFailedError(
  46. message=f"数据格式转换失败!不支持{dataset_type}格式数据集。当前仅支持 COCO 格式。"
  47. )
  48. def convert_coco_dataset(root_dir):
  49. for anno in ["annotations/instance_train.json", "annotations/instance_val.json"]:
  50. src_img_dir = root_dir
  51. src_anno_path = os.path.join(root_dir, anno)
  52. coco2multilabels(src_img_dir, src_anno_path, root_dir)
  53. @function_requires_deps("pycocotools", "tqdm")
  54. def coco2multilabels(src_img_dir, src_anno_path, root_dir):
  55. image_dir = os.path.join(root_dir, "images")
  56. label_type = (
  57. os.path.basename(src_anno_path).replace("instance_", "").replace(".json", "")
  58. )
  59. anno_save_path = os.path.join(root_dir, "{}.txt".format(label_type))
  60. coco = COCO(src_anno_path)
  61. cat_id_map = {
  62. old_cat_id: new_cat_id for new_cat_id, old_cat_id in enumerate(coco.getCatIds())
  63. }
  64. num_classes = len(list(cat_id_map.keys()))
  65. with open(anno_save_path, "w") as fp:
  66. for img_id in tqdm(sorted(coco.getImgIds())):
  67. img_info = coco.loadImgs([img_id])[0]
  68. img_filename = img_info["file_name"]
  69. img_w = img_info["width"]
  70. img_h = img_info["height"]
  71. img_filepath = os.path.join(image_dir, img_filename)
  72. if not os.path.exists(img_filepath):
  73. warning(
  74. "Illegal image file: {}, "
  75. "and it will be ignored".format(img_filepath)
  76. )
  77. continue
  78. if img_w < 0 or img_h < 0:
  79. warning(
  80. "Illegal width: {} or height: {} in annotation, "
  81. "and im_id: {} will be ignored".format(img_w, img_h, img_id)
  82. )
  83. continue
  84. ins_anno_ids = coco.getAnnIds(imgIds=[img_id])
  85. instances = coco.loadAnns(ins_anno_ids)
  86. label = [0] * num_classes
  87. for instance in instances:
  88. label[cat_id_map[instance["category_id"]]] = 1
  89. img_filename = os.path.join("images", img_filename)
  90. fp.writelines("{}\t{}\n".format(img_filename, ",".join(map(str, label))))
  91. fp.close()
  92. if label_type == "train":
  93. label_txt_save_path = os.path.join(root_dir, "label.txt")
  94. with open(label_txt_save_path, "w") as fp:
  95. for cat in coco.cats.values():
  96. id = cat["id"]
  97. name = cat["name"]
  98. fp.writelines("{} {}\n".format(id, name))
  99. fp.close()
  100. info("Save label names to {}.".format(label_txt_save_path))