convert_dataset.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import os
  16. import random
  17. import shutil
  18. import numpy as np
  19. from PIL import Image, ImageDraw
  20. from .....utils.deps import function_requires_deps
  21. from .....utils.errors import ConvertFailedError
  22. from .....utils.file_interface import custom_open, write_json_file
  23. from .....utils.logging import info, warning
  24. class Indexer(object):
  25. """Indexer"""
  26. def __init__(self):
  27. """init indexer"""
  28. self._map = {}
  29. self.idx = 0
  30. def get_id(self, key):
  31. """get id by key"""
  32. if key not in self._map:
  33. self.idx += 1
  34. self._map[key] = self.idx
  35. return self._map[key]
  36. def get_list(self, key_name):
  37. """return list containing key and id"""
  38. map_list = []
  39. for key in self._map:
  40. val = self._map[key]
  41. map_list.append({key_name: key, "id": val})
  42. return map_list
  43. class Extension(object):
  44. """Extension"""
  45. def __init__(self, exts_list):
  46. """init extension"""
  47. self._exts_list = ["." + ext for ext in exts_list]
  48. def __iter__(self):
  49. """iterator"""
  50. return iter(self._exts_list)
  51. def update(self, ext):
  52. """update extension"""
  53. self._exts_list.remove(ext)
  54. self._exts_list.insert(0, ext)
  55. def check_src_dataset(root_dir, dataset_type):
  56. """check src dataset format validity"""
  57. if dataset_type == "LabelMe":
  58. pass
  59. else:
  60. raise ConvertFailedError(
  61. message=f"数据格式转换失败!不支持{dataset_type}格式数据集。当前仅支持 LabelMe 格式。"
  62. )
  63. err_msg_prefix = f"数据格式转换失败!请参考上述`{dataset_type}格式数据集示例`检查待转换数据集格式。"
  64. anno_map = {}
  65. for dst_anno, src_anno in [
  66. ("instance_train.json", "train_anno_list.txt"),
  67. ("instance_val.json", "val_anno_list.txt"),
  68. ]:
  69. src_anno_path = os.path.join(root_dir, src_anno)
  70. if not os.path.exists(src_anno_path):
  71. if dst_anno == "instance_train.json":
  72. raise ConvertFailedError(
  73. message=f"{err_msg_prefix}保证{src_anno_path}文件存在。"
  74. )
  75. continue
  76. with custom_open(src_anno_path, "r") as f:
  77. anno_list = f.readlines()
  78. for anno_fn in anno_list:
  79. anno_fn = anno_fn.strip().split(" ")[-1]
  80. anno_path = os.path.join(root_dir, anno_fn)
  81. if not os.path.exists(anno_path):
  82. raise ConvertFailedError(
  83. message=f'{err_msg_prefix}保证"{src_anno_path}"中的"{anno_fn}"文件存在。'
  84. )
  85. anno_map[dst_anno] = src_anno_path
  86. return anno_map
  87. def convert(dataset_type, input_dir):
  88. """convert dataset to coco format"""
  89. # check format validity
  90. anno_map = check_src_dataset(input_dir, dataset_type)
  91. if dataset_type == "LabelMe":
  92. convert_labelme_dataset(input_dir, anno_map)
  93. else:
  94. raise ValueError
  95. def split_anno_list(root_dir, anno_map):
  96. """Split anno list to 80% train and 20% val"""
  97. train_anno_list = []
  98. val_anno_list = []
  99. anno_list_bak = os.path.join(root_dir, "train_anno_list.txt.bak")
  100. shutil.move(anno_map["instance_train.json"], anno_list_bak),
  101. with custom_open(anno_list_bak, "r") as f:
  102. src_anno = f.readlines()
  103. random.shuffle(src_anno)
  104. train_anno_list = src_anno[: int(len(src_anno) * 0.8)]
  105. val_anno_list = src_anno[int(len(src_anno) * 0.8) :]
  106. with custom_open(os.path.join(root_dir, "train_anno_list.txt"), "w") as f:
  107. f.writelines(train_anno_list)
  108. with custom_open(os.path.join(root_dir, "val_anno_list.txt"), "w") as f:
  109. f.writelines(val_anno_list)
  110. anno_map["instance_train.json"] = os.path.join(root_dir, "train_anno_list.txt")
  111. anno_map["instance_val.json"] = os.path.join(root_dir, "val_anno_list.txt")
  112. msg = f"{os.path.join(root_dir,'val_anno_list.txt')}不存在,数据集已默认按照80%训练集,20%验证集划分,\
  113. 且将原始'train_anno_list.txt'重命名为'train_anno_list.txt.bak'."
  114. warning(msg)
  115. return anno_map
  116. def convert_labelme_dataset(root_dir, anno_map):
  117. """convert dataset labeled by LabelMe to coco format"""
  118. label_indexer = Indexer()
  119. img_indexer = Indexer()
  120. annotations_dir = os.path.join(root_dir, "annotations")
  121. if not os.path.exists(annotations_dir):
  122. os.makedirs(annotations_dir)
  123. # 不存在val_anno_list,对原始数据集进行划分
  124. if "instance_val.json" not in anno_map:
  125. anno_map = split_anno_list(root_dir, anno_map)
  126. for dst_anno in anno_map:
  127. labelme2coco(
  128. label_indexer,
  129. img_indexer,
  130. root_dir,
  131. anno_map[dst_anno],
  132. os.path.join(annotations_dir, dst_anno),
  133. )
  134. @function_requires_deps("tqdm", "pycocotools")
  135. def labelme2coco(label_indexer, img_indexer, root_dir, anno_path, save_path):
  136. """convert json files generated by LabelMe to coco format and save to files"""
  137. import pycocotools.mask as mask_util
  138. from tqdm import tqdm
  139. with custom_open(anno_path, "r") as f:
  140. json_list = f.readlines()
  141. anno_num = 0
  142. anno_list = []
  143. image_list = []
  144. info(f"Start loading json annotation files from {anno_path} ...")
  145. for json_path in tqdm(json_list):
  146. json_path = json_path.strip()
  147. assert json_path.endswith(".json"), json_path
  148. with custom_open(os.path.join(root_dir, json_path.strip()), "r") as f:
  149. labelme_data = json.load(f)
  150. img_id = img_indexer.get_id(labelme_data["imagePath"])
  151. height = labelme_data["imageHeight"]
  152. width = labelme_data["imageWidth"]
  153. image_list.append(
  154. {
  155. "id": img_id,
  156. "file_name": labelme_data["imagePath"].split("/")[-1],
  157. "width": width,
  158. "height": height,
  159. }
  160. )
  161. for shape in labelme_data["shapes"]:
  162. assert shape["shape_type"] == "polygon", "Only polygon are supported."
  163. category_id = label_indexer.get_id(shape["label"])
  164. points = shape["points"]
  165. segmentation = [np.asarray(points).flatten().tolist()]
  166. mask = points_to_mask([height, width], points)
  167. mask = np.asfortranarray(mask.astype(np.uint8))
  168. mask = mask_util.encode(mask)
  169. area = float(mask_util.area(mask))
  170. bbox = mask_util.toBbox(mask).flatten().tolist()
  171. anno_num += 1
  172. anno_list.append(
  173. {
  174. "image_id": img_id,
  175. "bbox": bbox,
  176. "segmentation": segmentation,
  177. "category_id": category_id,
  178. "id": anno_num,
  179. "iscrowd": 0,
  180. "area": area,
  181. "ignore": 0,
  182. }
  183. )
  184. category_list = label_indexer.get_list(key_name="name")
  185. data_coco = {
  186. "images": image_list,
  187. "categories": category_list,
  188. "annotations": anno_list,
  189. }
  190. write_json_file(data_coco, save_path)
  191. info(f"The converted annotations has been save to {save_path}.")
  192. def points_to_mask(img_shape, points):
  193. """convert polygon points to binary mask"""
  194. mask = np.zeros(img_shape[:2], dtype=np.uint8)
  195. mask = Image.fromarray(mask)
  196. draw = ImageDraw.Draw(mask)
  197. xy = [tuple(point) for point in points]
  198. assert len(xy) > 2, "Polygon must have points more than 2"
  199. draw.polygon(xy=xy, outline=1, fill=1)
  200. mask = np.asarray(mask, dtype=bool)
  201. return mask