coco_utils.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import sys
  18. import copy
  19. import numpy as np
  20. import itertools
  21. from paddlex.ppdet.metrics.map_utils import draw_pr_curve
  22. from paddlex.ppdet.metrics.json_results import get_det_res, get_det_poly_res, get_seg_res, get_solov2_segm_res
  23. import paddlex.utils.logging as logging
  24. def get_infer_results(outs, catid, bias=0):
  25. """
  26. Get result at the stage of inference.
  27. The output format is dictionary containing bbox or mask result.
  28. For example, bbox result is a list and each element contains
  29. image_id, category_id, bbox and score.
  30. """
  31. if outs is None or len(outs) == 0:
  32. raise ValueError(
  33. 'The number of valid detection result if zero. Please use reasonable model and check input data.'
  34. )
  35. im_id = outs['im_id']
  36. infer_res = {}
  37. if 'bbox' in outs:
  38. if len(outs['bbox']) > 0 and len(outs['bbox'][0]) > 6:
  39. infer_res['bbox'] = get_det_poly_res(
  40. outs['bbox'], outs['bbox_num'], im_id, catid, bias=bias)
  41. else:
  42. infer_res['bbox'] = get_det_res(
  43. outs['bbox'], outs['bbox_num'], im_id, catid, bias=bias)
  44. if 'mask' in outs:
  45. # mask post process
  46. infer_res['mask'] = get_seg_res(outs['mask'], outs['bbox'],
  47. outs['bbox_num'], im_id, catid)
  48. if 'segm' in outs:
  49. infer_res['segm'] = get_solov2_segm_res(outs, im_id, catid)
  50. return infer_res
  51. def cocoapi_eval(anns,
  52. style,
  53. coco_gt=None,
  54. anno_file=None,
  55. max_dets=(100, 300, 1000),
  56. classwise=False):
  57. """
  58. Args:
  59. anns: Evaluation result.
  60. style (str): COCOeval style, can be `bbox` , `segm` and `proposal`.
  61. coco_gt (str): Whether to load COCOAPI through anno_file,
  62. eg: coco_gt = COCO(anno_file)
  63. anno_file (str): COCO annotations file.
  64. max_dets (tuple): COCO evaluation maxDets.
  65. classwise (bool): Whether per-category AP and draw P-R Curve or not.
  66. """
  67. assert coco_gt is not None or anno_file is not None
  68. from pycocotools.coco import COCO
  69. from pycocotools.cocoeval import COCOeval
  70. if coco_gt is None:
  71. coco_gt = COCO(anno_file)
  72. logging.info("Start evaluate...")
  73. coco_dt = loadRes(coco_gt, anns)
  74. if style == 'proposal':
  75. coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
  76. coco_eval.params.useCats = 0
  77. coco_eval.params.maxDets = list(max_dets)
  78. else:
  79. coco_eval = COCOeval(coco_gt, coco_dt, style)
  80. coco_eval.evaluate()
  81. coco_eval.accumulate()
  82. coco_eval.summarize()
  83. if classwise:
  84. # Compute per-category AP and PR curve
  85. try:
  86. from terminaltables import AsciiTable
  87. except Exception as e:
  88. logging.error(
  89. 'terminaltables not found, plaese install terminaltables. '
  90. 'for example: `pip install terminaltables`.')
  91. raise e
  92. precisions = coco_eval.eval['precision']
  93. cat_ids = coco_gt.getCatIds()
  94. # precision: (iou, recall, cls, area range, max dets)
  95. assert len(cat_ids) == precisions.shape[2]
  96. results_per_category = []
  97. for idx, catId in enumerate(cat_ids):
  98. # area range index 0: all area ranges
  99. # max dets index -1: typically 100 per image
  100. nm = coco_gt.loadCats(catId)[0]
  101. precision = precisions[:, :, idx, 0, -1]
  102. precision = precision[precision > -1]
  103. if precision.size:
  104. ap = np.mean(precision)
  105. else:
  106. ap = float('nan')
  107. results_per_category.append(
  108. (str(nm["name"]), '{:0.3f}'.format(float(ap))))
  109. pr_array = precisions[0, :, idx, 0, 2]
  110. recall_array = np.arange(0.0, 1.01, 0.01)
  111. draw_pr_curve(
  112. pr_array,
  113. recall_array,
  114. out_dir=style + '_pr_curve',
  115. file_name='{}_precision_recall_curve.jpg'.format(nm["name"]))
  116. num_columns = min(6, len(results_per_category) * 2)
  117. results_flatten = list(itertools.chain(*results_per_category))
  118. headers = ['category', 'AP'] * (num_columns // 2)
  119. results_2d = itertools.zip_longest(
  120. * [results_flatten[i::num_columns] for i in range(num_columns)])
  121. table_data = [headers]
  122. table_data += [result for result in results_2d]
  123. table = AsciiTable(table_data)
  124. logging.info('Per-category of {} AP: \n{}'.format(style, table.table))
  125. logging.info("per-category PR curve has output to {} folder.".format(
  126. style + '_pr_curve'))
  127. # flush coco evaluation result
  128. sys.stdout.flush()
  129. return coco_eval.stats
  130. def loadRes(coco_obj, anns):
  131. """
  132. Load result file and return a result api object.
  133. :param resFile (str) : file name of result file
  134. :return: res (obj) : result api object
  135. """
  136. # This function has the same functionality as pycocotools.COCO.loadRes,
  137. # except that the input anns is list of results rather than a json file.
  138. # Refer to
  139. # https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/coco.py#L305,
  140. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  141. # or matplotlib.backends is imported for the first time
  142. # pycocotools import matplotlib
  143. import matplotlib
  144. matplotlib.use('Agg')
  145. from pycocotools.coco import COCO
  146. import pycocotools.mask as maskUtils
  147. import time
  148. res = COCO()
  149. res.dataset['images'] = [img for img in coco_obj.dataset['images']]
  150. tic = time.time()
  151. assert type(anns) == list, 'results in not an array of objects'
  152. annsImgIds = [ann['image_id'] for ann in anns]
  153. assert set(annsImgIds) == (set(annsImgIds) & set(coco_obj.getImgIds())), \
  154. 'Results do not correspond to current coco set'
  155. if 'caption' in anns[0]:
  156. imgIds = set([img['id'] for img in res.dataset['images']]) & set(
  157. [ann['image_id'] for ann in anns])
  158. res.dataset['images'] = [
  159. img for img in res.dataset['images'] if img['id'] in imgIds
  160. ]
  161. for id, ann in enumerate(anns):
  162. ann['id'] = id + 1
  163. elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
  164. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  165. 'categories'])
  166. for id, ann in enumerate(anns):
  167. bb = ann['bbox']
  168. x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
  169. if not 'segmentation' in ann:
  170. ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
  171. ann['area'] = bb[2] * bb[3]
  172. ann['id'] = id + 1
  173. ann['iscrowd'] = 0
  174. elif 'segmentation' in anns[0]:
  175. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  176. 'categories'])
  177. for id, ann in enumerate(anns):
  178. # now only support compressed RLE format as segmentation results
  179. ann['area'] = maskUtils.area(ann['segmentation'])
  180. if not 'bbox' in ann:
  181. ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
  182. ann['id'] = id + 1
  183. ann['iscrowd'] = 0
  184. elif 'keypoints' in anns[0]:
  185. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  186. 'categories'])
  187. for id, ann in enumerate(anns):
  188. s = ann['keypoints']
  189. x = s[0::3]
  190. y = s[1::3]
  191. x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y)
  192. ann['area'] = (x1 - x0) * (y1 - y0)
  193. ann['id'] = id + 1
  194. ann['bbox'] = [x0, y0, x1 - x0, y1 - y0]
  195. res.dataset['annotations'] = anns
  196. res.createIndex()
  197. return res