convert_dataset.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import shutil
  16. import json
  17. import random
  18. import xml.etree.ElementTree as ET
  19. import cv2
  20. import numpy as np
  21. from PIL import Image, ImageDraw
  22. from tqdm import tqdm
  23. from .....utils.file_interface import custom_open, write_json_file
  24. from .....utils.errors import ConvertFailedError
  25. from .....utils.logging import info, warning
  26. class Indexer(object):
  27. """ Indexer """
  28. def __init__(self):
  29. """ init indexer """
  30. self._map = {}
  31. self.idx = 0
  32. def get_id(self, key):
  33. """ get id by key """
  34. if key not in self._map:
  35. self.idx += 1
  36. self._map[key] = self.idx
  37. return self._map[key]
  38. def get_list(self, key_name):
  39. """ return list containing key and id """
  40. map_list = []
  41. for key in self._map:
  42. val = self._map[key]
  43. map_list.append({key_name: key, 'id': val})
  44. return map_list
  45. class Extension(object):
  46. """ Extension """
  47. def __init__(self, exts_list):
  48. """ init extension """
  49. self._exts_list = ['.' + ext for ext in exts_list]
  50. def __iter__(self):
  51. """ iterator """
  52. return iter(self._exts_list)
  53. def update(self, ext):
  54. """ update extension """
  55. self._exts_list.remove(ext)
  56. self._exts_list.insert(0, ext)
  57. def check_src_dataset(root_dir, dataset_type):
  58. """ check src dataset format validity """
  59. if dataset_type == "LabelMe":
  60. anno_suffix = ".json"
  61. else:
  62. raise ConvertFailedError(
  63. message=f"数据格式转换失败!不支持{dataset_type}格式数据集。当前仅支持 LabelMe 格式。")
  64. err_msg_prefix = f"数据格式转换失败!请参考上述`{dataset_type}格式数据集示例`检查待转换数据集格式。"
  65. anno_map = {}
  66. for dst_anno, src_anno in [("instance_train.json", "train_anno_list.txt"),
  67. ("instance_val.json", "val_anno_list.txt")]:
  68. src_anno_path = os.path.join(root_dir, src_anno)
  69. if not os.path.exists(src_anno_path):
  70. if dst_anno == "instance_train.json":
  71. raise ConvertFailedError(
  72. message=f"{err_msg_prefix}保证{src_anno_path}文件存在。")
  73. continue
  74. with custom_open(src_anno_path, 'r') as f:
  75. anno_list = f.readlines()
  76. for anno_fn in anno_list:
  77. anno_fn = anno_fn.strip().split(' ')[-1]
  78. anno_path = os.path.join(root_dir, anno_fn)
  79. if not os.path.exists(anno_path):
  80. raise ConvertFailedError(
  81. message=f"{err_msg_prefix}保证\"{src_anno_path}\"中的\"{anno_fn}\"文件存在。"
  82. )
  83. anno_map[dst_anno] = src_anno_path
  84. return anno_map
  85. def convert(dataset_type, input_dir):
  86. """ convert dataset to coco format """
  87. # check format validity
  88. anno_map = check_src_dataset(input_dir, dataset_type)
  89. if dataset_type == "LabelMe":
  90. convert_labelme_dataset(input_dir, anno_map)
  91. else:
  92. raise ValueError
  93. def split_anno_list(root_dir, anno_map):
  94. """Split anno list to 80% train and 20% val """
  95. train_anno_list = []
  96. val_anno_list = []
  97. anno_list_bak = os.path.join(root_dir, "train_anno_list.txt.bak")
  98. shutil.move(anno_map["instance_train.json"], anno_list_bak),
  99. with custom_open(anno_list_bak, 'r') as f:
  100. src_anno = f.readlines()
  101. random.shuffle(src_anno)
  102. train_anno_list = src_anno[:int(len(src_anno) * 0.8)]
  103. val_anno_list = src_anno[int(len(src_anno) * 0.8):]
  104. with custom_open(os.path.join(root_dir, "train_anno_list.txt"), 'w') as f:
  105. f.writelines(train_anno_list)
  106. with custom_open(os.path.join(root_dir, "val_anno_list.txt"), 'w') as f:
  107. f.writelines(val_anno_list)
  108. anno_map["instance_train.json"] = os.path.join(root_dir,
  109. "train_anno_list.txt")
  110. anno_map["instance_val.json"] = os.path.join(root_dir, "val_anno_list.txt")
  111. msg = f"{os.path.join(root_dir,'val_anno_list.txt')}不存在,数据集已默认按照80%训练集,20%验证集划分,\
  112. 且将原始'train_anno_list.txt'重命名为'train_anno_list.txt.bak'."
  113. warning(msg)
  114. return anno_map
  115. def convert_labelme_dataset(root_dir, anno_map):
  116. """ convert dataset labeled by LabelMe to coco format """
  117. label_indexer = Indexer()
  118. img_indexer = Indexer()
  119. annotations_dir = os.path.join(root_dir, "annotations")
  120. if not os.path.exists(annotations_dir):
  121. os.makedirs(annotations_dir)
  122. # 不存在val_anno_list,对原始数据集进行划分
  123. if 'instance_val.json' not in anno_map:
  124. anno_map = split_anno_list(root_dir, anno_map)
  125. for dst_anno in anno_map:
  126. labelme2coco(label_indexer, img_indexer, root_dir, anno_map[dst_anno],
  127. os.path.join(annotations_dir, dst_anno))
  128. def labelme2coco(label_indexer, img_indexer, root_dir, anno_path, save_path):
  129. """ convert json files generated by LabelMe to coco format and save to files """
  130. import pycocotools.mask as mask_util
  131. with custom_open(anno_path, 'r') as f:
  132. json_list = f.readlines()
  133. anno_num = 0
  134. anno_list = []
  135. image_list = []
  136. info(f"Start loading json annotation files from {anno_path} ...")
  137. for json_path in tqdm(json_list):
  138. json_path = json_path.strip()
  139. assert json_path.endswith(".json"), json_path
  140. with custom_open(os.path.join(root_dir, json_path.strip()), 'r') as f:
  141. labelme_data = json.load(f)
  142. img_id = img_indexer.get_id(labelme_data['imagePath'])
  143. height = labelme_data['imageHeight']
  144. width = labelme_data['imageWidth']
  145. image_list.append({
  146. 'id': img_id,
  147. 'file_name': labelme_data['imagePath'].split('/')[-1],
  148. 'width': width,
  149. 'height': height,
  150. })
  151. for shape in labelme_data['shapes']:
  152. assert shape[
  153. 'shape_type'] == 'polygon', "Only polygon are supported."
  154. category_id = label_indexer.get_id(shape['label'])
  155. points = shape["points"]
  156. segmentation = [np.asarray(points).flatten().tolist()]
  157. mask = points_to_mask([height, width], points)
  158. mask = np.asfortranarray(mask.astype(np.uint8))
  159. mask = mask_util.encode(mask)
  160. area = float(mask_util.area(mask))
  161. bbox = mask_util.toBbox(mask).flatten().tolist()
  162. anno_num += 1
  163. anno_list.append({
  164. 'image_id': img_id,
  165. 'bbox': bbox,
  166. 'segmentation': segmentation,
  167. 'category_id': category_id,
  168. 'id': anno_num,
  169. 'iscrowd': 0,
  170. 'area': area,
  171. 'ignore': 0
  172. })
  173. category_list = label_indexer.get_list(key_name="name")
  174. data_coco = {
  175. 'images': image_list,
  176. 'categories': category_list,
  177. 'annotations': anno_list
  178. }
  179. write_json_file(data_coco, save_path)
  180. info(f"The converted annotations has been save to {save_path}.")
  181. def points_to_mask(img_shape, points):
  182. """convert polygon points to binary mask"""
  183. mask = np.zeros(img_shape[:2], dtype=np.uint8)
  184. mask = Image.fromarray(mask)
  185. draw = ImageDraw.Draw(mask)
  186. xy = [tuple(point) for point in points]
  187. assert len(xy) > 2, "Polygon must have points more than 2"
  188. draw.polygon(xy=xy, outline=1, fill=1)
  189. mask = np.asarray(mask, dtype=bool)
  190. return mask