convert_dataset.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import shutil
  16. import json
  17. import random
  18. import xml.etree.ElementTree as ET
  19. from tqdm import tqdm
  20. from .....utils.file_interface import custom_open, write_json_file
  21. from .....utils.errors import ConvertFailedError
  22. from .....utils.logging import info, warning
  23. class Indexer(object):
  24. """ Indexer """
  25. def __init__(self):
  26. """ init indexer """
  27. self._map = {}
  28. self.idx = 0
  29. def get_id(self, key):
  30. """ get id by key """
  31. if key not in self._map:
  32. self.idx += 1
  33. self._map[key] = self.idx
  34. return self._map[key]
  35. def get_list(self, key_name):
  36. """ return list containing key and id """
  37. map_list = []
  38. for key in self._map:
  39. val = self._map[key]
  40. map_list.append({key_name: key, 'id': val})
  41. return map_list
  42. class Extension(object):
  43. """ Extension """
  44. def __init__(self, exts_list):
  45. """ init extension """
  46. self._exts_list = ['.' + ext for ext in exts_list]
  47. def __iter__(self):
  48. """ iterator """
  49. return iter(self._exts_list)
  50. def update(self, ext):
  51. """ update extension """
  52. self._exts_list.remove(ext)
  53. self._exts_list.insert(0, ext)
  54. def check_src_dataset(root_dir, dataset_type):
  55. """ check src dataset format validity """
  56. if dataset_type in ("VOC", "VOCWithUnlabeled"):
  57. anno_suffix = ".xml"
  58. elif dataset_type in ("LabelMe", "LabelMeWithUnlabeled"):
  59. anno_suffix = ".json"
  60. else:
  61. raise ConvertFailedError(
  62. message=f"数据格式转换失败!不支持{dataset_type}格式数据集。当前仅支持 VOC、LabelMe 和 VOCWithUnlabeled、LabelMeWithUnlabeled 格式。"
  63. )
  64. err_msg_prefix = f"数据格式转换失败!请参考上述`{dataset_type}格式数据集示例`检查待转换数据集格式。"
  65. anno_map = {}
  66. for dst_anno, src_anno in [("instance_train.json", "train_anno_list.txt"),
  67. ("instance_val.json", "val_anno_list.txt")]:
  68. src_anno_path = os.path.join(root_dir, src_anno)
  69. if not os.path.exists(src_anno_path):
  70. if dst_anno == "instance_train.json":
  71. raise ConvertFailedError(
  72. message=f"{err_msg_prefix}保证{src_anno_path}文件存在。")
  73. continue
  74. with custom_open(src_anno_path, 'r') as f:
  75. anno_list = f.readlines()
  76. for anno_fn in anno_list:
  77. anno_fn = anno_fn.strip().split(' ')[-1]
  78. anno_path = os.path.join(root_dir, anno_fn)
  79. if not os.path.exists(anno_path):
  80. raise ConvertFailedError(
  81. message=f"{err_msg_prefix}保证\"{src_anno_path}\"中的\"{anno_fn}\"文件存在。"
  82. )
  83. anno_map[dst_anno] = src_anno_path
  84. return anno_map
  85. def convert(dataset_type, input_dir):
  86. """ convert dataset to coco format """
  87. # check format validity
  88. anno_map = check_src_dataset(input_dir, dataset_type)
  89. convert_voc_dataset(input_dir, anno_map) if dataset_type in (
  90. "VOC", "VOCWithUnlabeled") else convert_labelme_dataset(input_dir,
  91. anno_map)
  92. def split_anno_list(root_dir, anno_map):
  93. """Split anno list to 80% train and 20% val """
  94. train_anno_list = []
  95. val_anno_list = []
  96. anno_list_bak = os.path.join(root_dir, "train_anno_list.txt.bak")
  97. shutil.move(anno_map["instance_train.json"], anno_list_bak),
  98. with custom_open(anno_list_bak, 'r') as f:
  99. src_anno = f.readlines()
  100. random.shuffle(src_anno)
  101. train_anno_list = src_anno[:int(len(src_anno) * 0.8)]
  102. val_anno_list = src_anno[int(len(src_anno) * 0.8):]
  103. with custom_open(os.path.join(root_dir, "train_anno_list.txt"), 'w') as f:
  104. f.writelines(train_anno_list)
  105. with custom_open(os.path.join(root_dir, "val_anno_list.txt"), 'w') as f:
  106. f.writelines(val_anno_list)
  107. anno_map["instance_train.json"] = os.path.join(root_dir,
  108. "train_anno_list.txt")
  109. anno_map["instance_val.json"] = os.path.join(root_dir, "val_anno_list.txt")
  110. msg = f"{os.path.join(root_dir,'val_anno_list.txt')}不存在,数据集已默认按照80%训练集,20%验证集划分,\
  111. 且将原始'train_anno_list.txt'重命名为'train_anno_list.txt.bak'."
  112. warning(msg)
  113. return anno_map
  114. def convert_labelme_dataset(root_dir, anno_map):
  115. """ convert dataset labeled by LabelMe to coco format """
  116. label_indexer = Indexer()
  117. img_indexer = Indexer()
  118. annotations_dir = os.path.join(root_dir, "annotations")
  119. if not os.path.exists(annotations_dir):
  120. os.makedirs(annotations_dir)
  121. # FIXME(gaotingquan): support lmssl
  122. unlabeled_path = os.path.join(root_dir, "unlabeled.txt")
  123. if os.path.exists(unlabeled_path):
  124. shutil.move(unlabeled_path,
  125. os.path.join(annotations_dir, "unlabeled.txt"))
  126. # 不存在val_anno_list,对原始数据集进行划分
  127. if 'instance_val.json' not in anno_map:
  128. anno_map = split_anno_list(root_dir, anno_map)
  129. for dst_anno in anno_map:
  130. labelme2coco(label_indexer, img_indexer, root_dir, anno_map[dst_anno],
  131. os.path.join(annotations_dir, dst_anno))
  132. def labelme2coco(label_indexer, img_indexer, root_dir, anno_path, save_path):
  133. """ convert json files generated by LabelMe to coco format and save to files """
  134. with custom_open(anno_path, 'r') as f:
  135. json_list = f.readlines()
  136. anno_num = 0
  137. anno_list = []
  138. image_list = []
  139. info(f"Start loading json annotation files from {anno_path} ...")
  140. for json_path in tqdm(json_list):
  141. json_path = json_path.strip()
  142. if not json_path.endswith(".json"):
  143. info(
  144. f"An illegal json path(\"{json_path}\") found! Has been ignored."
  145. )
  146. continue
  147. with custom_open(os.path.join(root_dir, json_path.strip()), 'r') as f:
  148. labelme_data = json.load(f)
  149. img_id = img_indexer.get_id(labelme_data['imagePath'])
  150. image_list.append({
  151. 'id': img_id,
  152. 'file_name': labelme_data['imagePath'].split('/')[-1],
  153. 'width': labelme_data['imageWidth'],
  154. 'height': labelme_data['imageHeight']
  155. })
  156. for shape in labelme_data['shapes']:
  157. assert shape[
  158. 'shape_type'] == 'rectangle', "Only rectangle are supported."
  159. category_id = label_indexer.get_id(shape['label'])
  160. (x1, y1), (x2, y2) = shape['points']
  161. x1, x2 = sorted([x1, x2])
  162. y1, y2 = sorted([y1, y2])
  163. bbox = list(map(float, [x1, y1, x2 - x1, y2 - y1]))
  164. anno_num += 1
  165. anno_list.append({
  166. 'image_id': img_id,
  167. 'bbox': bbox,
  168. 'category_id': category_id,
  169. 'id': anno_num,
  170. 'iscrowd': 0,
  171. 'area': bbox[2] * bbox[3],
  172. 'ignore': 0
  173. })
  174. category_list = label_indexer.get_list(key_name="name")
  175. data_coco = {
  176. 'images': image_list,
  177. 'categories': category_list,
  178. 'annotations': anno_list
  179. }
  180. write_json_file(data_coco, save_path)
  181. info(f"The converted annotations has been save to {save_path}.")
  182. def convert_voc_dataset(root_dir, anno_map):
  183. """ convert VOC format dataset to coco format """
  184. label_indexer = Indexer()
  185. img_indexer = Indexer()
  186. annotations_dir = os.path.join(root_dir, "annotations")
  187. if not os.path.exists(annotations_dir):
  188. os.makedirs(annotations_dir)
  189. # FIXME(gaotingquan): support lmssl
  190. unlabeled_path = os.path.join(root_dir, "unlabeled.txt")
  191. if os.path.exists(unlabeled_path):
  192. shutil.move(unlabeled_path,
  193. os.path.join(annotations_dir, "unlabeled.txt"))
  194. # 不存在val_anno_list,对原始数据集进行划分
  195. if 'instance_val.json' not in anno_map:
  196. anno_map = split_anno_list(root_dir, anno_map)
  197. for dst_anno in anno_map:
  198. ann_paths = voc_get_label_anno(root_dir, anno_map[dst_anno])
  199. voc_xmls_to_cocojson(
  200. root_dir=root_dir,
  201. annotation_paths=ann_paths,
  202. label_indexer=label_indexer,
  203. img_indexer=img_indexer,
  204. output=annotations_dir,
  205. output_file=dst_anno)
  206. def voc_get_label_anno(root_dir, anno_path):
  207. """
  208. Read VOC format annotation file.
  209. Args:
  210. root_dir (str): The directoty of VOC annotation file.
  211. anno_path (str): The annoation file path.
  212. Returns:
  213. tuple: A tuple of two elements, the first of which is of type dict, representing the mapping between tag names
  214. and their corresponding ids, and the second of type list, representing the list of paths to all annotated files.
  215. """
  216. if not os.path.exists(anno_path):
  217. info(f"The annotation file {anno_path} don't exists, has been ignored!")
  218. return []
  219. with custom_open(anno_path, 'r') as f:
  220. ann_ids = f.readlines()
  221. ann_paths = []
  222. info(f"Start loading xml annotation files from {anno_path} ...")
  223. for aid in ann_ids:
  224. aid = aid.strip().split(' ')[-1]
  225. if not aid.endswith('.xml'):
  226. info(f"An illegal xml path(\"{aid}\") found! Has been ignored.")
  227. continue
  228. ann_path = os.path.join(root_dir, aid)
  229. ann_paths.append(ann_path)
  230. return ann_paths
  231. def voc_get_image_info(annotation_root, img_indexer):
  232. """
  233. Get the iamge info from VOC annotation file.
  234. Args:
  235. annotation_root: The annotation root.
  236. img_indexer: indexer to get image id by filename.
  237. Returns:
  238. dict: The image info.
  239. Raises:
  240. AssertionError: When filename cannot be found in 'annotation_root'.
  241. """
  242. filename = annotation_root.findtext('filename')
  243. assert filename is not None, filename
  244. img_name = os.path.basename(filename)
  245. im_id = img_indexer.get_id(filename)
  246. size = annotation_root.find('size')
  247. width = float(size.findtext('width'))
  248. height = float(size.findtext('height'))
  249. image_info = {
  250. 'file_name': filename,
  251. 'height': height,
  252. 'width': width,
  253. 'id': im_id
  254. }
  255. return image_info
  256. def voc_get_coco_annotation(obj, label_indexer):
  257. """
  258. Convert VOC format annotation to COCO format.
  259. Args:
  260. obj: a obj in VOC.
  261. label_indexer: indexer to get category id by label name.
  262. Returns:
  263. dict: A dict with the COCO format annotation info.
  264. Raises:
  265. AssertionError: When the width or height of the annotation box is illegal.
  266. """
  267. label = obj.findtext('name')
  268. category_id = label_indexer.get_id(label)
  269. bndbox = obj.find('bndbox')
  270. xmin = float(bndbox.findtext('xmin'))
  271. ymin = float(bndbox.findtext('ymin'))
  272. xmax = float(bndbox.findtext('xmax'))
  273. ymax = float(bndbox.findtext('ymax'))
  274. if xmin > xmax or ymin > ymax:
  275. temp = xmin
  276. xmin = min(xmin, xmax)
  277. xmax = max(temp, xmax)
  278. temp = ymin
  279. ymin = min(ymin, ymax)
  280. ymax = max(temp, ymax)
  281. o_width = xmax - xmin
  282. o_height = ymax - ymin
  283. anno = {
  284. 'area': o_width * o_height,
  285. 'iscrowd': 0,
  286. 'bbox': [xmin, ymin, o_width, o_height],
  287. 'category_id': category_id,
  288. 'ignore': 0,
  289. }
  290. return anno
  291. def voc_xmls_to_cocojson(root_dir, annotation_paths, label_indexer, img_indexer,
  292. output, output_file):
  293. """
  294. Convert VOC format data to COCO format.
  295. Args:
  296. annotation_paths (list): A list of paths to the XML files.
  297. label_indexer: indexer to get category id by label name.
  298. img_indexer: indexer to get image id by filename.
  299. output (str): The directory to save output JSON file.
  300. output_file (str): Output JSON file name.
  301. Returns:
  302. None
  303. """
  304. extension_list = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']
  305. suffixs = Extension(extension_list)
  306. def match(root_dir, prefilename, prexlm_name):
  307. """ matching extension """
  308. for ext in suffixs:
  309. if os.path.exists(
  310. os.path.join(root_dir, 'images', prefilename + ext)):
  311. suffixs.update(ext)
  312. return prefilename + ext
  313. elif os.path.exists(
  314. os.path.join(root_dir, 'images', prexlm_name + ext)):
  315. suffixs.update(ext)
  316. return prexlm_name + ext
  317. return None
  318. output_json_dict = {
  319. "images": [],
  320. "type": "instances",
  321. "annotations": [],
  322. "categories": []
  323. }
  324. bnd_id = 1 # bounding box start id
  325. info('Start converting !')
  326. for a_path in tqdm(annotation_paths):
  327. # Read annotation xml
  328. ann_tree = ET.parse(a_path)
  329. ann_root = ann_tree.getroot()
  330. file_name = ann_root.find("filename")
  331. prefile_name = file_name.text.split('.')[0]
  332. prexlm_name = os.path.basename(a_path).split('.')[0]
  333. # 根据file_name 和 xlm_name 分别匹配查找图片
  334. f_name = match(root_dir, prefile_name, prexlm_name)
  335. if f_name is not None:
  336. file_name.text = f_name
  337. else:
  338. prefile_name_set = set({prefile_name, prexlm_name})
  339. prefile_name_set = ','.join(prefile_name_set)
  340. suffix_set = ','.join(extension_list)
  341. images_path = os.path.join(root_dir, 'images')
  342. info(
  343. f'{images_path}/{{{prefile_name_set}}}.{{{suffix_set}}} both not exists,will be skipped.'
  344. )
  345. continue
  346. img_info = voc_get_image_info(ann_root, img_indexer)
  347. output_json_dict['images'].append(img_info)
  348. for obj in ann_root.findall('object'):
  349. if obj.find('bndbox') is None: #Skip the ojbect wihtout bndbox
  350. continue
  351. ann = voc_get_coco_annotation(obj=obj, label_indexer=label_indexer)
  352. ann.update({'image_id': img_info['id'], 'id': bnd_id})
  353. output_json_dict['annotations'].append(ann)
  354. bnd_id = bnd_id + 1
  355. output_json_dict['categories'] = label_indexer.get_list(key_name="name")
  356. output_file = os.path.join(output, output_file)
  357. write_json_file(output_json_dict, output_file)
  358. info(f"The converted annotations has been save to {output_file}.")