coco_utils.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import sys
  18. import numpy as np
  19. from .map_utils import draw_pr_curve
  20. from .json_results import (
  21. get_det_res,
  22. get_det_poly_res,
  23. get_seg_res,
  24. get_solov2_segm_res,
  25. )
  26. from . import fd_logging as logging
  27. import copy
  28. def loadRes(coco_obj, anns):
  29. """
  30. Load result file and return a result api object.
  31. :param resFile (str) : file name of result file
  32. :return: res (obj) : result api object
  33. """
  34. # This function has the same functionality as pycocotools.COCO.loadRes,
  35. # except that the input anns is list of results rather than a json file.
  36. # Refer to
  37. # https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/coco.py#L305,
  38. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  39. # or matplotlib.backends is imported for the first time
  40. # pycocotools import matplotlib
  41. import matplotlib
  42. matplotlib.use("Agg")
  43. from pycocotools.coco import COCO
  44. import pycocotools.mask as maskUtils
  45. import time
  46. res = COCO()
  47. res.dataset["images"] = [img for img in coco_obj.dataset["images"]]
  48. tic = time.time()
  49. assert isinstance(anns) == list, "results in not an array of objects"
  50. annsImgIds = [ann["image_id"] for ann in anns]
  51. assert set(annsImgIds) == (
  52. set(annsImgIds) & set(coco_obj.getImgIds())
  53. ), "Results do not correspond to current coco set"
  54. if "caption" in anns[0]:
  55. imgIds = set([img["id"] for img in res.dataset["images"]]) & set(
  56. [ann["image_id"] for ann in anns]
  57. )
  58. res.dataset["images"] = [
  59. img for img in res.dataset["images"] if img["id"] in imgIds
  60. ]
  61. for id, ann in enumerate(anns):
  62. ann["id"] = id + 1
  63. elif "bbox" in anns[0] and not anns[0]["bbox"] == []:
  64. res.dataset["categories"] = copy.deepcopy(coco_obj.dataset["categories"])
  65. for id, ann in enumerate(anns):
  66. bb = ann["bbox"]
  67. x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
  68. if not "segmentation" in ann:
  69. ann["segmentation"] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
  70. ann["area"] = bb[2] * bb[3]
  71. ann["id"] = id + 1
  72. ann["iscrowd"] = 0
  73. elif "segmentation" in anns[0]:
  74. res.dataset["categories"] = copy.deepcopy(coco_obj.dataset["categories"])
  75. for id, ann in enumerate(anns):
  76. # now only support compressed RLE format as segmentation results
  77. ann["area"] = maskUtils.area(ann["segmentation"])
  78. if not "bbox" in ann:
  79. ann["bbox"] = maskUtils.toBbox(ann["segmentation"])
  80. ann["id"] = id + 1
  81. ann["iscrowd"] = 0
  82. elif "keypoints" in anns[0]:
  83. res.dataset["categories"] = copy.deepcopy(coco_obj.dataset["categories"])
  84. for id, ann in enumerate(anns):
  85. s = ann["keypoints"]
  86. x = s[0::3]
  87. y = s[1::3]
  88. x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y)
  89. ann["area"] = (x1 - x0) * (y1 - y0)
  90. ann["id"] = id + 1
  91. ann["bbox"] = [x0, y0, x1 - x0, y1 - y0]
  92. res.dataset["annotations"] = anns
  93. res.createIndex()
  94. return res
  95. def get_infer_results(outs, catid, bias=0):
  96. """
  97. Get result at the stage of inference.
  98. The output format is dictionary containing bbox or mask result.
  99. For example, bbox result is a list and each element contains
  100. image_id, category_id, bbox and score.
  101. """
  102. if outs is None or len(outs) == 0:
  103. raise ValueError(
  104. "The number of valid detection result if zero. Please use reasonable model and check input data."
  105. )
  106. im_id = outs["im_id"]
  107. infer_res = {}
  108. if "bbox" in outs:
  109. if len(outs["bbox"]) > 0 and len(outs["bbox"][0]) > 6:
  110. infer_res["bbox"] = get_det_poly_res(
  111. outs["bbox"], outs["bbox_num"], im_id, catid, bias=bias
  112. )
  113. else:
  114. infer_res["bbox"] = get_det_res(
  115. outs["bbox"], outs["bbox_num"], im_id, catid, bias=bias
  116. )
  117. if "mask" in outs:
  118. # mask post process
  119. infer_res["mask"] = get_seg_res(
  120. outs["mask"], outs["bbox"], outs["bbox_num"], im_id, catid
  121. )
  122. if "segm" in outs:
  123. infer_res["segm"] = get_solov2_segm_res(outs, im_id, catid)
  124. return infer_res
  125. def cocoapi_eval(
  126. anns,
  127. style,
  128. coco_gt=None,
  129. anno_file=None,
  130. max_dets=(100, 300, 1000),
  131. classwise=False,
  132. ):
  133. """
  134. Args:
  135. anns: Evaluation result.
  136. style (str): COCOeval style, can be `bbox` , `segm` and `proposal`.
  137. coco_gt (str): Whether to load COCOAPI through anno_file,
  138. eg: coco_gt = COCO(anno_file)
  139. anno_file (str): COCO annotations file.
  140. max_dets (tuple): COCO evaluation maxDets.
  141. classwise (bool): Whether per-category AP and draw P-R Curve or not.
  142. """
  143. assert coco_gt is not None or anno_file is not None
  144. from pycocotools.coco import COCO
  145. from pycocotools.cocoeval import COCOeval
  146. if coco_gt is None:
  147. coco_gt = COCO(anno_file)
  148. logging.info("Start evaluate...")
  149. coco_dt = loadRes(coco_gt, anns)
  150. if style == "proposal":
  151. coco_eval = COCOeval(coco_gt, coco_dt, "bbox")
  152. coco_eval.params.useCats = 0
  153. coco_eval.params.maxDets = list(max_dets)
  154. else:
  155. coco_eval = COCOeval(coco_gt, coco_dt, style)
  156. coco_eval.evaluate()
  157. coco_eval.accumulate()
  158. coco_eval.summarize()
  159. if classwise:
  160. # Compute per-category AP and PR curve
  161. try:
  162. from terminaltables import AsciiTable
  163. except Exception as e:
  164. logging.error(
  165. "terminaltables not found, plaese install terminaltables. "
  166. "for example: `pip install terminaltables`."
  167. )
  168. raise e
  169. precisions = coco_eval.eval["precision"]
  170. cat_ids = coco_gt.getCatIds()
  171. # precision: (iou, recall, cls, area range, max dets)
  172. assert len(cat_ids) == precisions.shape[2]
  173. results_per_category = []
  174. for idx, catId in enumerate(cat_ids):
  175. # area range index 0: all area ranges
  176. # max dets index -1: typically 100 per image
  177. nm = coco_gt.loadCats(catId)[0]
  178. precision = precisions[:, :, idx, 0, -1]
  179. precision = precision[precision > -1]
  180. if precision.size:
  181. ap = np.mean(precision)
  182. else:
  183. ap = float("nan")
  184. results_per_category.append((str(nm["name"]), "{:0.3f}".format(float(ap))))
  185. pr_array = precisions[0, :, idx, 0, 2]
  186. recall_array = np.arange(0.0, 1.01, 0.01)
  187. draw_pr_curve(
  188. pr_array,
  189. recall_array,
  190. out_dir=style + "_pr_curve",
  191. file_name="{}_precision_recall_curve.jpg".format(nm["name"]),
  192. )
  193. num_columns = min(6, len(results_per_category) * 2)
  194. import itertools
  195. results_flatten = list(itertools.chain(*results_per_category))
  196. headers = ["category", "AP"] * (num_columns // 2)
  197. results_2d = itertools.zip_longest(
  198. *[results_flatten[i::num_columns] for i in range(num_columns)]
  199. )
  200. table_data = [headers]
  201. table_data += [result for result in results_2d]
  202. table = AsciiTable(table_data)
  203. logging.info("Per-category of {} AP: \n{}".format(style, table.table))
  204. logging.info(
  205. "per-category PR curve has output to {} folder.".format(style + "_pr_curve")
  206. )
  207. # flush coco evaluation result
  208. sys.stdout.flush()
  209. return coco_eval.stats