coco_utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import sys
  18. import copy
  19. import os
  20. import os.path as osp
  21. import numpy as np
  22. import itertools
  23. from paddlex.ppdet.metrics.map_utils import draw_pr_curve
  24. from paddlex.ppdet.metrics.json_results import get_det_res, get_det_poly_res, get_seg_res, get_solov2_segm_res
  25. import paddlex.utils.logging as logging
  26. def get_infer_results(outs, catid, bias=0):
  27. """
  28. Get result at the stage of inference.
  29. The output format is dictionary containing bbox or mask result.
  30. For example, bbox result is a list and each element contains
  31. image_id, category_id, bbox and score.
  32. """
  33. if outs is None or len(outs) == 0:
  34. raise ValueError(
  35. 'The number of valid detection result if zero. Please use reasonable model and check input data.'
  36. )
  37. im_id = outs['im_id']
  38. infer_res = {}
  39. if 'bbox' in outs:
  40. if len(outs['bbox']) > 0 and len(outs['bbox'][0]) > 6:
  41. infer_res['bbox'] = get_det_poly_res(
  42. outs['bbox'], outs['bbox_num'], im_id, catid, bias=bias)
  43. else:
  44. infer_res['bbox'] = get_det_res(
  45. outs['bbox'], outs['bbox_num'], im_id, catid, bias=bias)
  46. if 'mask' in outs:
  47. # mask post process
  48. infer_res['mask'] = get_seg_res(outs['mask'], outs['bbox'],
  49. outs['bbox_num'], im_id, catid)
  50. if 'segm' in outs:
  51. infer_res['segm'] = get_solov2_segm_res(outs, im_id, catid)
  52. return infer_res
  53. def cocoapi_eval(anns,
  54. style,
  55. coco_gt=None,
  56. anno_file=None,
  57. max_dets=(100, 300, 1000),
  58. classwise=False):
  59. """
  60. Args:
  61. anns: Evaluation result.
  62. style (str): COCOeval style, can be `bbox` , `segm` and `proposal`.
  63. coco_gt (str): Whether to load COCOAPI through anno_file,
  64. eg: coco_gt = COCO(anno_file)
  65. anno_file (str): COCO annotations file.
  66. max_dets (tuple): COCO evaluation maxDets.
  67. classwise (bool): Whether per-category AP and draw P-R Curve or not.
  68. """
  69. assert coco_gt is not None or anno_file is not None
  70. from pycocotools.coco import COCO
  71. from pycocotools.cocoeval import COCOeval
  72. if coco_gt is None:
  73. coco_gt = COCO(anno_file)
  74. logging.info("Start evaluate...")
  75. coco_dt = loadRes(coco_gt, anns)
  76. if style == 'proposal':
  77. coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
  78. coco_eval.params.useCats = 0
  79. coco_eval.params.maxDets = list(max_dets)
  80. else:
  81. coco_eval = COCOeval(coco_gt, coco_dt, style)
  82. coco_eval.evaluate()
  83. coco_eval.accumulate()
  84. coco_eval.summarize()
  85. if classwise:
  86. # Compute per-category AP and PR curve
  87. try:
  88. from terminaltables import AsciiTable
  89. except Exception as e:
  90. logging.error(
  91. 'terminaltables not found, plaese install terminaltables. '
  92. 'for example: `pip install terminaltables`.')
  93. raise e
  94. precisions = coco_eval.eval['precision']
  95. cat_ids = coco_gt.getCatIds()
  96. # precision: (iou, recall, cls, area range, max dets)
  97. assert len(cat_ids) == precisions.shape[2]
  98. results_per_category = []
  99. for idx, catId in enumerate(cat_ids):
  100. # area range index 0: all area ranges
  101. # max dets index -1: typically 100 per image
  102. nm = coco_gt.loadCats(catId)[0]
  103. precision = precisions[:, :, idx, 0, -1]
  104. precision = precision[precision > -1]
  105. if precision.size:
  106. ap = np.mean(precision)
  107. else:
  108. ap = float('nan')
  109. results_per_category.append(
  110. (str(nm["name"]), '{:0.3f}'.format(float(ap))))
  111. pr_array = precisions[0, :, idx, 0, 2]
  112. recall_array = np.arange(0.0, 1.01, 0.01)
  113. draw_pr_curve(
  114. pr_array,
  115. recall_array,
  116. out_dir=style + '_pr_curve',
  117. file_name='{}_precision_recall_curve.jpg'.format(nm["name"]))
  118. num_columns = min(6, len(results_per_category) * 2)
  119. results_flatten = list(itertools.chain(*results_per_category))
  120. headers = ['category', 'AP'] * (num_columns // 2)
  121. results_2d = itertools.zip_longest(
  122. *[results_flatten[i::num_columns] for i in range(num_columns)])
  123. table_data = [headers]
  124. table_data += [result for result in results_2d]
  125. table = AsciiTable(table_data)
  126. logging.info('Per-category of {} AP: \n{}'.format(style, table.table))
  127. logging.info("per-category PR curve has output to {} folder.".format(
  128. style + '_pr_curve'))
  129. # flush coco evaluation result
  130. sys.stdout.flush()
  131. return coco_eval.stats
  132. def loadRes(coco_obj, anns):
  133. """
  134. Load result file and return a result api object.
  135. :param resFile (str) : file name of result file
  136. :return: res (obj) : result api object
  137. """
  138. # This function has the same functionality as pycocotools.COCO.loadRes,
  139. # except that the input anns is list of results rather than a json file.
  140. # Refer to
  141. # https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/coco.py#L305,
  142. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  143. # or matplotlib.backends is imported for the first time
  144. # pycocotools import matplotlib
  145. import matplotlib
  146. matplotlib.use('Agg')
  147. from pycocotools.coco import COCO
  148. import pycocotools.mask as maskUtils
  149. import time
  150. res = COCO()
  151. res.dataset['images'] = [img for img in coco_obj.dataset['images']]
  152. tic = time.time()
  153. assert type(anns) == list, 'results in not an array of objects'
  154. annsImgIds = [ann['image_id'] for ann in anns]
  155. assert set(annsImgIds) == (set(annsImgIds) & set(coco_obj.getImgIds())), \
  156. 'Results do not correspond to current coco set'
  157. if 'caption' in anns[0]:
  158. imgIds = set([img['id'] for img in res.dataset['images']]) & set(
  159. [ann['image_id'] for ann in anns])
  160. res.dataset['images'] = [
  161. img for img in res.dataset['images'] if img['id'] in imgIds
  162. ]
  163. for id, ann in enumerate(anns):
  164. ann['id'] = id + 1
  165. elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
  166. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  167. 'categories'])
  168. for id, ann in enumerate(anns):
  169. bb = ann['bbox']
  170. x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
  171. if not 'segmentation' in ann:
  172. ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
  173. ann['area'] = bb[2] * bb[3]
  174. ann['id'] = id + 1
  175. ann['iscrowd'] = 0
  176. elif 'segmentation' in anns[0]:
  177. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  178. 'categories'])
  179. for id, ann in enumerate(anns):
  180. # now only support compressed RLE format as segmentation results
  181. ann['area'] = maskUtils.area(ann['segmentation'])
  182. if not 'bbox' in ann:
  183. ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
  184. ann['id'] = id + 1
  185. ann['iscrowd'] = 0
  186. elif 'keypoints' in anns[0]:
  187. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  188. 'categories'])
  189. for id, ann in enumerate(anns):
  190. s = ann['keypoints']
  191. x = s[0::3]
  192. y = s[1::3]
  193. x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y)
  194. ann['area'] = (x1 - x0) * (y1 - y0)
  195. ann['id'] = id + 1
  196. ann['bbox'] = [x0, y0, x1 - x0, y1 - y0]
  197. res.dataset['annotations'] = anns
  198. res.createIndex()
  199. return res
  200. def makeplot(rs, ps, outDir, class_name, iou_type):
  201. import matplotlib.pyplot as plt
  202. cs = np.vstack([
  203. np.ones((2, 3)),
  204. np.array([0.31, 0.51, 0.74]),
  205. np.array([0.75, 0.31, 0.30]),
  206. np.array([0.36, 0.90, 0.38]),
  207. np.array([0.50, 0.39, 0.64]),
  208. np.array([1, 0.6, 0]),
  209. ])
  210. areaNames = ['allarea', 'small', 'medium', 'large']
  211. types = ['C75', 'C50', 'Loc', 'Sim', 'Oth', 'BG', 'FN']
  212. for i in range(len(areaNames)):
  213. area_ps = ps[..., i, 0]
  214. figure_title = iou_type + '-' + class_name + '-' + areaNames[i]
  215. aps = [ps_.mean() for ps_ in area_ps]
  216. ps_curve = [
  217. ps_.mean(axis=1) if ps_.ndim > 1 else ps_ for ps_ in area_ps
  218. ]
  219. ps_curve.insert(0, np.zeros(ps_curve[0].shape))
  220. fig = plt.figure()
  221. ax = plt.subplot(111)
  222. for k in range(len(types)):
  223. ax.plot(rs, ps_curve[k + 1], color=[0, 0, 0], linewidth=0.5)
  224. ax.fill_between(
  225. rs,
  226. ps_curve[k],
  227. ps_curve[k + 1],
  228. color=cs[k],
  229. label=str(f'[{aps[k]:.3f}]' + types[k]), )
  230. plt.xlabel('recall')
  231. plt.ylabel('precision')
  232. plt.xlim(0, 1.0)
  233. plt.ylim(0, 1.0)
  234. plt.title(figure_title)
  235. plt.legend()
  236. # plt.show()
  237. fig.savefig(osp.join(outDir, f'{figure_title}.png'))
  238. plt.close(fig)
  239. def analyze_individual_category(k, cocoDt, cocoGt, catId, iou_type,
  240. areas=None):
  241. """针对某个特定类别,分析忽略亚类混淆和类别混淆时的准确率。
  242. Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/coco_error_analysis.py
  243. Args:
  244. k (int): 待分析类别的序号。
  245. cocoDt (pycocotols.coco.COCO): 按COCO类存放的预测结果。
  246. cocoGt (pycocotols.coco.COCO): 按COCO类存放的真值。
  247. catId (int): 待分析类别在数据集中的类别id。
  248. iou_type (str): iou计算方式,若为检测框,则设置为'bbox',若为像素级分割结果,则设置为'segm'。
  249. Returns:
  250. int:
  251. dict: 有关键字'ps_supercategory'和'ps_allcategory'。关键字'ps_supercategory'的键值是忽略亚类间
  252. 混淆时的准确率,关键字'ps_allcategory'的键值是忽略类别间混淆时的准确率。
  253. """
  254. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  255. # or matplotlib.backends is imported for the first time
  256. # pycocotools import matplotlib
  257. import matplotlib
  258. matplotlib.use('Agg')
  259. from pycocotools.coco import COCO
  260. from pycocotools.cocoeval import COCOeval
  261. nm = cocoGt.loadCats(catId)[0]
  262. print(f'--------------analyzing {k + 1}-{nm["name"]}---------------')
  263. ps_ = {}
  264. dt = copy.deepcopy(cocoDt)
  265. nm = cocoGt.loadCats(catId)[0]
  266. imgIds = cocoGt.getImgIds()
  267. dt_anns = dt.dataset['annotations']
  268. select_dt_anns = []
  269. for ann in dt_anns:
  270. if ann['category_id'] == catId:
  271. select_dt_anns.append(ann)
  272. dt.dataset['annotations'] = select_dt_anns
  273. dt.createIndex()
  274. # compute precision but ignore superclass confusion
  275. gt = copy.deepcopy(cocoGt)
  276. child_catIds = gt.getCatIds(supNms=[nm['supercategory']])
  277. for idx, ann in enumerate(gt.dataset['annotations']):
  278. if ann['category_id'] in child_catIds and ann['category_id'] != catId:
  279. gt.dataset['annotations'][idx]['ignore'] = 1
  280. gt.dataset['annotations'][idx]['iscrowd'] = 1
  281. gt.dataset['annotations'][idx]['category_id'] = catId
  282. cocoEval = COCOeval(gt, copy.deepcopy(dt), iou_type)
  283. cocoEval.params.imgIds = imgIds
  284. cocoEval.params.maxDets = [100]
  285. cocoEval.params.iouThrs = [0.1]
  286. cocoEval.params.useCats = 1
  287. if areas:
  288. cocoEval.params.areaRng = [[0**2, areas[2]], [0**2, areas[0]],
  289. [areas[0], areas[1]], [areas[1], areas[2]]]
  290. cocoEval.evaluate()
  291. cocoEval.accumulate()
  292. ps_supercategory = cocoEval.eval['precision'][0, :, k, :, :]
  293. ps_['ps_supercategory'] = ps_supercategory
  294. # compute precision but ignore any class confusion
  295. gt = copy.deepcopy(cocoGt)
  296. for idx, ann in enumerate(gt.dataset['annotations']):
  297. if ann['category_id'] != catId:
  298. gt.dataset['annotations'][idx]['ignore'] = 1
  299. gt.dataset['annotations'][idx]['iscrowd'] = 1
  300. gt.dataset['annotations'][idx]['category_id'] = catId
  301. cocoEval = COCOeval(gt, copy.deepcopy(dt), iou_type)
  302. cocoEval.params.imgIds = imgIds
  303. cocoEval.params.maxDets = [100]
  304. cocoEval.params.iouThrs = [0.1]
  305. cocoEval.params.useCats = 1
  306. if areas:
  307. cocoEval.params.areaRng = [[0**2, areas[2]], [0**2, areas[0]],
  308. [areas[0], areas[1]], [areas[1], areas[2]]]
  309. cocoEval.evaluate()
  310. cocoEval.accumulate()
  311. ps_allcategory = cocoEval.eval['precision'][0, :, k, :, :]
  312. ps_['ps_allcategory'] = ps_allcategory
  313. return k, ps_
  314. def coco_error_analysis(eval_details_file=None,
  315. gt=None,
  316. pred_bbox=None,
  317. pred_mask=None,
  318. save_dir='./output'):
  319. """逐个分析模型预测错误的原因,并将分析结果以图表的形式展示。
  320. 分析结果说明参考COCODataset官网给出分析工具说明https://cocodataset.org/#detection-eval。
  321. Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/analysis_tools/coco_error_analysis.py
  322. Args:
  323. eval_details_file (str): 模型评估结果的保存路径,包含真值信息和预测结果。
  324. gt (list): 数据集的真值信息。默认值为None。
  325. pred_bbox (list): 模型在数据集上的预测框。默认值为None。
  326. pred_mask (list): 模型在数据集上的预测mask。默认值为None。
  327. save_dir (str): 可视化结果保存路径。默认值为'./output'。
  328. Note:
  329. eval_details_file的优先级更高,只要eval_details_file不为None,
  330. 就会从eval_details_file提取真值信息和预测结果做分析。
  331. 当eval_details_file为None时,则用gt、pred_mask、pred_mask做分析。
  332. """
  333. import multiprocessing as mp
  334. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  335. # or matplotlib.backends is imported for the first time
  336. # pycocotools import matplotlib
  337. import matplotlib
  338. matplotlib.use('Agg')
  339. from pycocotools.coco import COCO
  340. from pycocotools.cocoeval import COCOeval
  341. if eval_details_file is not None:
  342. import json
  343. with open(eval_details_file, 'r') as f:
  344. eval_details = json.load(f)
  345. pred_bbox = eval_details['bbox']
  346. if 'mask' in eval_details:
  347. pred_mask = eval_details['mask']
  348. gt = eval_details['gt']
  349. if gt is None or pred_bbox is None:
  350. raise Exception(
  351. "gt/pred_bbox/pred_mask is None now, please set right eval_details_file or gt/pred_bbox/pred_mask."
  352. )
  353. if pred_bbox is not None and len(pred_bbox) == 0:
  354. raise Exception("There is no predicted bbox.")
  355. if pred_mask is not None and len(pred_mask) == 0:
  356. raise Exception("There is no predicted mask.")
  357. def _analyze_results(cocoGt, cocoDt, res_type, out_dir):
  358. directory = osp.dirname(osp.join(out_dir, ''))
  359. if not osp.exists(directory):
  360. logging.info('-------------create {}-----------------'.format(
  361. out_dir))
  362. os.makedirs(directory)
  363. imgIds = cocoGt.getImgIds()
  364. res_out_dir = osp.join(out_dir, res_type, '')
  365. res_directory = os.path.dirname(res_out_dir)
  366. if not os.path.exists(res_directory):
  367. logging.info('-------------create {}-----------------'.format(
  368. res_out_dir))
  369. os.makedirs(res_directory)
  370. iou_type = res_type
  371. cocoEval = COCOeval(
  372. copy.deepcopy(cocoGt), copy.deepcopy(cocoDt), iou_type)
  373. cocoEval.params.imgIds = imgIds
  374. cocoEval.params.iouThrs = [.75, .5, .1]
  375. cocoEval.params.maxDets = [100]
  376. cocoEval.evaluate()
  377. cocoEval.accumulate()
  378. ps = cocoEval.eval['precision']
  379. ps = np.vstack([ps, np.zeros((4, *ps.shape[1:]))])
  380. catIds = cocoGt.getCatIds()
  381. recThrs = cocoEval.params.recThrs
  382. thread_num = mp.cpu_count() if mp.cpu_count() < 8 else 8
  383. thread_pool = mp.pool.ThreadPool(thread_num)
  384. args = [(k, cocoDt, cocoGt, catId, iou_type)
  385. for k, catId in enumerate(catIds)]
  386. analyze_results = thread_pool.starmap(analyze_individual_category,
  387. args)
  388. for k, catId in enumerate(catIds):
  389. nm = cocoGt.loadCats(catId)[0]
  390. logging.info('--------------saving {}-{}---------------'.format(
  391. k + 1, nm['name']))
  392. analyze_result = analyze_results[k]
  393. assert k == analyze_result[0], ""
  394. ps_supercategory = analyze_result[1]['ps_supercategory']
  395. ps_allcategory = analyze_result[1]['ps_allcategory']
  396. # compute precision but ignore superclass confusion
  397. ps[3, :, k, :, :] = ps_supercategory
  398. # compute precision but ignore any class confusion
  399. ps[4, :, k, :, :] = ps_allcategory
  400. # fill in background and false negative errors and plot
  401. ps[ps == -1] = 0
  402. ps[5, :, k, :, :] = ps[4, :, k, :, :] > 0
  403. ps[6, :, k, :, :] = 1.0
  404. makeplot(recThrs, ps[:, :, k], res_out_dir, nm['name'], iou_type)
  405. makeplot(recThrs, ps, res_out_dir, 'allclass', iou_type)
  406. coco_gt = COCO()
  407. coco_gt.dataset = gt
  408. coco_gt.createIndex()
  409. if pred_bbox is not None:
  410. coco_dt = loadRes(coco_gt, pred_bbox)
  411. _analyze_results(coco_gt, coco_dt, res_type='bbox', out_dir=save_dir)
  412. if pred_mask is not None:
  413. coco_dt = loadRes(coco_gt, pred_mask)
  414. _analyze_results(coco_gt, coco_dt, res_type='segm', out_dir=save_dir)
  415. logging.info("The analysis figures are saved in {}".format(save_dir))