voc.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import copy
  16. import os.path as osp
  17. import random
  18. import re
  19. import numpy as np
  20. from collections import OrderedDict
  21. import xml.etree.ElementTree as ET
  22. from paddle.io import Dataset
  23. from paddlex.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
  24. from paddlex.cv.transforms import Decode, MixupImage
  25. class VOCDetection(Dataset):
  26. """读取PascalVOC格式的检测数据集,并对样本进行相应的处理。
  27. Args:
  28. data_dir (str): 数据集所在的目录路径。
  29. file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
  30. label_list (str): 描述数据集包含的类别信息文件路径。
  31. transforms (paddlex.det.transforms): 数据集中每个样本的预处理/增强算子。
  32. num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
  33. 系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
  34. 一半。
  35. shuffle (bool): 是否需要对数据集中样本打乱顺序。默认为False。
  36. """
  37. def __init__(self,
  38. data_dir,
  39. file_list,
  40. label_list,
  41. transforms=None,
  42. num_workers='auto',
  43. shuffle=False):
  44. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  45. # or matplotlib.backends is imported for the first time
  46. # pycocotools import matplotlib
  47. import matplotlib
  48. matplotlib.use('Agg')
  49. from pycocotools.coco import COCO
  50. super(VOCDetection, self).__init__()
  51. self.data_fields = None
  52. self.transforms = copy.deepcopy(transforms)
  53. self.num_max_boxes = 50
  54. self.use_mix = False
  55. if self.transforms is not None:
  56. for op in self.transforms.transforms:
  57. if isinstance(op, MixupImage):
  58. self.mixup_op = copy.deepcopy(op)
  59. self.use_mix = True
  60. self.num_max_boxes *= 2
  61. break
  62. self.batch_transforms = None
  63. self.num_workers = get_num_workers(num_workers)
  64. self.shuffle = shuffle
  65. self.file_list = list()
  66. self.labels = list()
  67. annotations = dict()
  68. annotations['images'] = list()
  69. annotations['categories'] = list()
  70. annotations['annotations'] = list()
  71. cname2cid = OrderedDict()
  72. label_id = 0
  73. with open(label_list, 'r', encoding=get_encoding(label_list)) as f:
  74. for line in f.readlines():
  75. cname2cid[line.strip()] = label_id
  76. label_id += 1
  77. self.labels.append(line.strip())
  78. logging.info("Starting to read file list from dataset...")
  79. for k, v in cname2cid.items():
  80. annotations['categories'].append({
  81. 'supercategory': 'component',
  82. 'id': v + 1,
  83. 'name': k
  84. })
  85. ct = 0
  86. ann_ct = 0
  87. with open(file_list, 'r', encoding=get_encoding(file_list)) as f:
  88. while True:
  89. line = f.readline()
  90. if not line:
  91. break
  92. if len(line.strip().split()) > 2:
  93. raise Exception("A space is defined as the separator, "
  94. "but it exists in image or label name {}."
  95. .format(line))
  96. img_file, xml_file = [
  97. osp.join(data_dir, x) for x in line.strip().split()[:2]
  98. ]
  99. img_file = path_normalization(img_file)
  100. xml_file = path_normalization(xml_file)
  101. if not is_pic(img_file):
  102. continue
  103. if not osp.isfile(xml_file):
  104. continue
  105. if not osp.exists(img_file):
  106. logging.warning('The image file {} does not exist!'.format(
  107. img_file))
  108. continue
  109. if not osp.exists(xml_file):
  110. logging.warning('The annotation file {} does not exist!'.
  111. format(xml_file))
  112. continue
  113. tree = ET.parse(xml_file)
  114. if tree.find('id') is None:
  115. im_id = np.array([ct])
  116. else:
  117. ct = int(tree.find('id').text)
  118. im_id = np.array([int(tree.find('id').text)])
  119. pattern = re.compile('<object>', re.IGNORECASE)
  120. obj_match = pattern.findall(
  121. str(ET.tostringlist(tree.getroot())))
  122. if len(obj_match) == 0:
  123. continue
  124. obj_tag = obj_match[0][1:-1]
  125. objs = tree.findall(obj_tag)
  126. pattern = re.compile('<size>', re.IGNORECASE)
  127. size_tag = pattern.findall(
  128. str(ET.tostringlist(tree.getroot())))
  129. if len(size_tag) > 0:
  130. size_tag = size_tag[0][1:-1]
  131. size_element = tree.find(size_tag)
  132. pattern = re.compile('<width>', re.IGNORECASE)
  133. width_tag = pattern.findall(
  134. str(ET.tostringlist(size_element)))[0][1:-1]
  135. im_w = float(size_element.find(width_tag).text)
  136. pattern = re.compile('<height>', re.IGNORECASE)
  137. height_tag = pattern.findall(
  138. str(ET.tostringlist(size_element)))[0][1:-1]
  139. im_h = float(size_element.find(height_tag).text)
  140. else:
  141. im_w = 0
  142. im_h = 0
  143. gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
  144. gt_class = np.zeros((len(objs), 1), dtype=np.int32)
  145. gt_score = np.ones((len(objs), 1), dtype=np.float32)
  146. is_crowd = np.zeros((len(objs), 1), dtype=np.int32)
  147. difficult = np.zeros((len(objs), 1), dtype=np.int32)
  148. skipped_indices = list()
  149. for i, obj in enumerate(objs):
  150. pattern = re.compile('<name>', re.IGNORECASE)
  151. name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][
  152. 1:-1]
  153. cname = obj.find(name_tag).text.strip()
  154. gt_class[i][0] = cname2cid[cname]
  155. pattern = re.compile('<difficult>', re.IGNORECASE)
  156. diff_tag = pattern.findall(str(ET.tostringlist(obj)))
  157. if len(diff_tag) == 0:
  158. _difficult = 0
  159. else:
  160. diff_tag = diff_tag[0][1:-1]
  161. try:
  162. _difficult = int(obj.find(diff_tag).text)
  163. except Exception:
  164. _difficult = 0
  165. pattern = re.compile('<bndbox>', re.IGNORECASE)
  166. box_tag = pattern.findall(str(ET.tostringlist(obj)))
  167. if len(box_tag) == 0:
  168. logging.warning(
  169. "There's no field '<bndbox>' in one of object, "
  170. "so this object will be ignored. xml file: {}".
  171. format(xml_file))
  172. continue
  173. box_tag = box_tag[0][1:-1]
  174. box_element = obj.find(box_tag)
  175. pattern = re.compile('<xmin>', re.IGNORECASE)
  176. xmin_tag = pattern.findall(
  177. str(ET.tostringlist(box_element)))[0][1:-1]
  178. x1 = float(box_element.find(xmin_tag).text)
  179. pattern = re.compile('<ymin>', re.IGNORECASE)
  180. ymin_tag = pattern.findall(
  181. str(ET.tostringlist(box_element)))[0][1:-1]
  182. y1 = float(box_element.find(ymin_tag).text)
  183. pattern = re.compile('<xmax>', re.IGNORECASE)
  184. xmax_tag = pattern.findall(
  185. str(ET.tostringlist(box_element)))[0][1:-1]
  186. x2 = float(box_element.find(xmax_tag).text)
  187. pattern = re.compile('<ymax>', re.IGNORECASE)
  188. ymax_tag = pattern.findall(
  189. str(ET.tostringlist(box_element)))[0][1:-1]
  190. y2 = float(box_element.find(ymax_tag).text)
  191. x1 = max(0, x1)
  192. y1 = max(0, y1)
  193. if im_w > 0.5 and im_h > 0.5:
  194. x2 = min(im_w - 1, x2)
  195. y2 = min(im_h - 1, y2)
  196. if not (x2 >= x1 and y2 >= y1):
  197. skipped_indices.append(i)
  198. logging.warning(
  199. "Bounding box for object {} does not satisfy x1 <= x2 and y1 <= y2, "
  200. "so this object is skipped".format(i))
  201. continue
  202. gt_bbox[i] = [x1, y1, x2, y2]
  203. is_crowd[i][0] = 0
  204. difficult[i][0] = _difficult
  205. annotations['annotations'].append({
  206. 'iscrowd': 0,
  207. 'image_id': int(im_id[0]),
  208. 'bbox': [x1, y1, x2 - x1 + 1, y2 - y1 + 1],
  209. 'area': float((x2 - x1 + 1) * (y2 - y1 + 1)),
  210. 'category_id': cname2cid[cname] + 1,
  211. 'id': ann_ct,
  212. 'difficult': _difficult
  213. })
  214. ann_ct += 1
  215. if skipped_indices:
  216. gt_bbox = np.delete(gt_bbox, skipped_indices, axis=0)
  217. gt_class = np.delete(gt_class, skipped_indices, axis=0)
  218. gt_score = np.delete(gt_score, skipped_indices, axis=0)
  219. is_crowd = np.delete(is_crowd, skipped_indices, axis=0)
  220. difficult = np.delete(difficult, skipped_indices, axis=0)
  221. im_info = {
  222. 'im_id': im_id,
  223. 'image_shape': np.array([im_h, im_w]).astype('int32'),
  224. }
  225. label_info = {
  226. 'is_crowd': is_crowd,
  227. 'gt_class': gt_class,
  228. 'gt_bbox': gt_bbox,
  229. 'gt_score': gt_score,
  230. 'difficult': difficult
  231. }
  232. if gt_bbox.size != 0:
  233. self.file_list.append({
  234. 'image': img_file,
  235. **
  236. im_info,
  237. **
  238. label_info
  239. })
  240. ct += 1
  241. annotations['images'].append({
  242. 'height': im_h,
  243. 'width': im_w,
  244. 'id': int(im_id[0]),
  245. 'file_name': osp.split(img_file)[1]
  246. })
  247. if self.use_mix:
  248. self.num_max_boxes = max(self.num_max_boxes, 2 * len(objs))
  249. else:
  250. self.num_max_boxes = max(self.num_max_boxes, len(objs))
  251. if not len(self.file_list) > 0:
  252. raise Exception('not found any voc record in %s' % (file_list))
  253. logging.info("{} samples in file {}".format(
  254. len(self.file_list), file_list))
  255. self.num_samples = len(self.file_list)
  256. self.coco_gt = COCO()
  257. self.coco_gt.dataset = annotations
  258. self.coco_gt.createIndex()
  259. self._epoch = 0
  260. def __getitem__(self, idx):
  261. sample = copy.deepcopy(self.file_list[idx])
  262. if self.data_fields is not None:
  263. sample = {k: sample[k] for k in self.data_fields}
  264. if self.use_mix and (self.mixup_op.mixup_epoch == -1 or
  265. self._epoch < self.mixup_op.mixup_epoch):
  266. if self.num_samples > 1:
  267. mix_idx = random.randint(1, self.num_samples - 1)
  268. mix_pos = (mix_idx + idx) % self.num_samples
  269. else:
  270. mix_pos = 0
  271. sample_mix = copy.deepcopy(self.file_list[mix_pos])
  272. if self.data_fields is not None:
  273. sample_mix = {k: sample_mix[k] for k in self.data_fields}
  274. sample = self.mixup_op(sample=[
  275. Decode(to_rgb=False)(sample), Decode(to_rgb=False)(sample_mix)
  276. ])
  277. sample = self.transforms(sample)
  278. return sample
  279. def __len__(self):
  280. return self.num_samples
  281. def set_epoch(self, epoch_id):
  282. self._epoch = epoch_id