detection_eval.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import numpy as np
  16. import json
  17. import os
  18. import sys
  19. import cv2
  20. import copy
  21. import paddlex.utils.logging as logging
  22. # fix linspace problem for pycocotools while numpy > 1.17.2
  23. backup_linspace = np.linspace
  24. def fixed_linspace(start,
  25. stop,
  26. num=50,
  27. endpoint=True,
  28. retstep=False,
  29. dtype=None,
  30. axis=0):
  31. num = int(num)
  32. return backup_linspace(start, stop, num, endpoint, retstep, dtype, axis)
  33. def eval_results(results,
  34. metric,
  35. coco_gt,
  36. with_background=True,
  37. resolution=None,
  38. is_bbox_normalized=False,
  39. map_type='11point'):
  40. """Evaluation for evaluation program results"""
  41. box_ap_stats = []
  42. coco_gt_data = copy.deepcopy(coco_gt)
  43. eval_details = {'gt': copy.deepcopy(coco_gt.dataset)}
  44. if metric == 'COCO':
  45. np.linspace = fixed_linspace
  46. if 'proposal' in results[0]:
  47. proposal_eval(results, coco_gt_data)
  48. if 'bbox' in results[0]:
  49. box_ap_stats, xywh_results = coco_bbox_eval(
  50. results,
  51. coco_gt_data,
  52. with_background,
  53. is_bbox_normalized=is_bbox_normalized)
  54. if 'mask' in results[0]:
  55. mask_ap_stats, segm_results = mask_eval(results, coco_gt_data,
  56. resolution)
  57. ap_stats = [box_ap_stats, mask_ap_stats]
  58. eval_details['bbox'] = xywh_results
  59. eval_details['mask'] = segm_results
  60. return ap_stats, eval_details
  61. np.linspace = backup_linspace
  62. else:
  63. if 'accum_map' in results[-1]:
  64. res = np.mean(results[-1]['accum_map'][0])
  65. logging.debug('mAP: {:.2f}'.format(res * 100.))
  66. box_ap_stats.append(res * 100.)
  67. elif 'bbox' in results[0]:
  68. box_ap, xywh_results = voc_bbox_eval(
  69. results,
  70. coco_gt_data,
  71. with_background,
  72. is_bbox_normalized=is_bbox_normalized,
  73. map_type=map_type)
  74. box_ap_stats.append(box_ap)
  75. eval_details['bbox'] = xywh_results
  76. return box_ap_stats, eval_details
  77. def proposal_eval(results, coco_gt, outputfile, max_dets=(100, 300, 1000)):
  78. assert 'proposal' in results[0]
  79. assert outfile.endswith('.json')
  80. xywh_results = proposal2out(results)
  81. assert len(
  82. xywh_results) > 0, "The number of valid proposal detected is zero.\n \
  83. Please use reasonable model and check input data."
  84. with open(outfile, 'w') as f:
  85. json.dump(xywh_results, f)
  86. cocoapi_eval(xywh_results, 'proposal', coco_gt=coco_gt, max_dets=max_dets)
  87. # flush coco evaluation result
  88. sys.stdout.flush()
  89. def coco_bbox_eval(results,
  90. coco_gt,
  91. with_background=True,
  92. is_bbox_normalized=False):
  93. assert 'bbox' in results[0]
  94. from pycocotools.coco import COCO
  95. cat_ids = coco_gt.getCatIds()
  96. # when with_background = True, mapping category to classid, like:
  97. # background:0, first_class:1, second_class:2, ...
  98. clsid2catid = dict(
  99. {i + int(with_background): catid
  100. for i, catid in enumerate(cat_ids)})
  101. xywh_results = bbox2out(
  102. results, clsid2catid, is_bbox_normalized=is_bbox_normalized)
  103. results = copy.deepcopy(xywh_results)
  104. if len(xywh_results) == 0:
  105. logging.warning(
  106. "The number of valid bbox detected is zero.\n Please use reasonable model and check input data.\n stop eval!"
  107. )
  108. return [0.0], results
  109. map_stats = cocoapi_eval(xywh_results, 'bbox', coco_gt=coco_gt)
  110. # flush coco evaluation result
  111. sys.stdout.flush()
  112. return map_stats, results
  113. def loadRes(coco_obj, anns):
  114. """
  115. Load result file and return a result api object.
  116. :param resFile (str) : file name of result file
  117. :return: res (obj) : result api object
  118. """
  119. from pycocotools.coco import COCO
  120. import pycocotools.mask as maskUtils
  121. import time
  122. res = COCO()
  123. res.dataset['images'] = [img for img in coco_obj.dataset['images']]
  124. tic = time.time()
  125. assert type(anns) == list, 'results in not an array of objects'
  126. annsImgIds = [ann['image_id'] for ann in anns]
  127. assert set(annsImgIds) == (set(annsImgIds) & set(coco_obj.getImgIds())), \
  128. 'Results do not correspond to current coco set'
  129. if 'caption' in anns[0]:
  130. imgIds = set([img['id'] for img in res.dataset['images']]) & set(
  131. [ann['image_id'] for ann in anns])
  132. res.dataset['images'] = [
  133. img for img in res.dataset['images'] if img['id'] in imgIds
  134. ]
  135. for id, ann in enumerate(anns):
  136. ann['id'] = id + 1
  137. elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
  138. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  139. 'categories'])
  140. for id, ann in enumerate(anns):
  141. bb = ann['bbox']
  142. x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
  143. if not 'segmentation' in ann:
  144. ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
  145. ann['area'] = bb[2] * bb[3]
  146. ann['id'] = id + 1
  147. ann['iscrowd'] = 0
  148. elif 'segmentation' in anns[0]:
  149. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  150. 'categories'])
  151. for id, ann in enumerate(anns):
  152. # now only support compressed RLE format as segmentation results
  153. ann['area'] = maskUtils.area(ann['segmentation'])
  154. if not 'bbox' in ann:
  155. ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
  156. ann['id'] = id + 1
  157. ann['iscrowd'] = 0
  158. elif 'keypoints' in anns[0]:
  159. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  160. 'categories'])
  161. for id, ann in enumerate(anns):
  162. s = ann['keypoints']
  163. x = s[0::3]
  164. y = s[1::3]
  165. x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y)
  166. ann['area'] = (x1 - x0) * (y1 - y0)
  167. ann['id'] = id + 1
  168. ann['bbox'] = [x0, y0, x1 - x0, y1 - y0]
  169. res.dataset['annotations'] = anns
  170. res.createIndex()
  171. return res
  172. def mask_eval(results, coco_gt, resolution, thresh_binarize=0.5):
  173. assert 'mask' in results[0]
  174. from pycocotools.coco import COCO
  175. clsid2catid = {i + 1: v for i, v in enumerate(coco_gt.getCatIds())}
  176. segm_results = mask2out(results, clsid2catid, resolution, thresh_binarize)
  177. results = copy.deepcopy(segm_results)
  178. if len(segm_results) == 0:
  179. logging.warning(
  180. "The number of valid mask detected is zero.\n Please use reasonable model and check input data."
  181. )
  182. return None, results
  183. map_stats = cocoapi_eval(segm_results, 'segm', coco_gt=coco_gt)
  184. return map_stats, results
  185. def cocoapi_eval(anns,
  186. style,
  187. coco_gt=None,
  188. anno_file=None,
  189. max_dets=(100, 300, 1000)):
  190. """
  191. Args:
  192. anns: Evaluation result.
  193. style: COCOeval style, can be `bbox` , `segm` and `proposal`.
  194. coco_gt: Whether to load COCOAPI through anno_file,
  195. eg: coco_gt = COCO(anno_file)
  196. anno_file: COCO annotations file.
  197. max_dets: COCO evaluation maxDets.
  198. """
  199. assert coco_gt != None or anno_file != None
  200. from pycocotools.coco import COCO
  201. from pycocotools.cocoeval import COCOeval
  202. if coco_gt == None:
  203. coco_gt = COCO(anno_file)
  204. logging.debug("Start evaluate...")
  205. coco_dt = loadRes(coco_gt, anns)
  206. if style == 'proposal':
  207. coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
  208. coco_eval.params.useCats = 0
  209. coco_eval.params.maxDets = list(max_dets)
  210. else:
  211. coco_eval = COCOeval(coco_gt, coco_dt, style)
  212. coco_eval.evaluate()
  213. coco_eval.accumulate()
  214. coco_eval.summarize()
  215. return coco_eval.stats
  216. def proposal2out(results, is_bbox_normalized=False):
  217. xywh_res = []
  218. for t in results:
  219. bboxes = t['proposal'][0]
  220. lengths = t['proposal'][1][0]
  221. im_ids = np.array(t['im_id'][0]).flatten()
  222. assert len(lengths) == im_ids.size
  223. if bboxes.shape == (1, 1) or bboxes is None:
  224. continue
  225. k = 0
  226. for i in range(len(lengths)):
  227. num = lengths[i]
  228. im_id = int(im_ids[i])
  229. for j in range(num):
  230. dt = bboxes[k]
  231. xmin, ymin, xmax, ymax = dt.tolist()
  232. if is_bbox_normalized:
  233. xmin, ymin, xmax, ymax = \
  234. clip_bbox([xmin, ymin, xmax, ymax])
  235. w = xmax - xmin
  236. h = ymax - ymin
  237. else:
  238. w = xmax - xmin + 1
  239. h = ymax - ymin + 1
  240. bbox = [xmin, ymin, w, h]
  241. coco_res = {
  242. 'image_id': im_id,
  243. 'category_id': 1,
  244. 'bbox': bbox,
  245. 'score': 1.0
  246. }
  247. xywh_res.append(coco_res)
  248. k += 1
  249. return xywh_res
  250. def bbox2out(results, clsid2catid, is_bbox_normalized=False):
  251. """
  252. Args:
  253. results: request a dict, should include: `bbox`, `im_id`,
  254. if is_bbox_normalized=True, also need `im_shape`.
  255. clsid2catid: class id to category id map of COCO2017 dataset.
  256. is_bbox_normalized: whether or not bbox is normalized.
  257. """
  258. xywh_res = []
  259. for t in results:
  260. bboxes = t['bbox'][0]
  261. lengths = t['bbox'][1][0]
  262. im_ids = np.array(t['im_id'][0]).flatten()
  263. if bboxes.shape == (1, 1) or bboxes is None:
  264. continue
  265. k = 0
  266. for i in range(len(lengths)):
  267. num = lengths[i]
  268. im_id = int(im_ids[i])
  269. for j in range(num):
  270. dt = bboxes[k]
  271. clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
  272. catid = (clsid2catid[int(clsid)])
  273. if is_bbox_normalized:
  274. xmin, ymin, xmax, ymax = \
  275. clip_bbox([xmin, ymin, xmax, ymax])
  276. w = xmax - xmin
  277. h = ymax - ymin
  278. im_shape = t['im_shape'][0][i].tolist()
  279. im_height, im_width = int(im_shape[0]), int(im_shape[1])
  280. xmin *= im_width
  281. ymin *= im_height
  282. w *= im_width
  283. h *= im_height
  284. else:
  285. w = xmax - xmin + 1
  286. h = ymax - ymin + 1
  287. bbox = [xmin, ymin, w, h]
  288. coco_res = {
  289. 'image_id': im_id,
  290. 'category_id': catid,
  291. 'bbox': bbox,
  292. 'score': score
  293. }
  294. xywh_res.append(coco_res)
  295. k += 1
  296. return xywh_res
  297. def mask2out(results, clsid2catid, resolution, thresh_binarize=0.5):
  298. import pycocotools.mask as mask_util
  299. scale = (resolution + 2.0) / resolution
  300. segm_res = []
  301. # for each batch
  302. for t in results:
  303. bboxes = t['bbox'][0]
  304. lengths = t['bbox'][1][0]
  305. im_ids = np.array(t['im_id'][0])
  306. if bboxes.shape == (1, 1) or bboxes is None:
  307. continue
  308. if len(bboxes.tolist()) == 0:
  309. continue
  310. masks = t['mask'][0]
  311. s = 0
  312. # for each sample
  313. for i in range(len(lengths)):
  314. num = lengths[i]
  315. im_id = int(im_ids[i][0])
  316. im_shape = t['im_shape'][0][i]
  317. bbox = bboxes[s:s + num][:, 2:]
  318. clsid_scores = bboxes[s:s + num][:, 0:2]
  319. mask = masks[s:s + num]
  320. s += num
  321. im_h = int(im_shape[0])
  322. im_w = int(im_shape[1])
  323. expand_bbox = expand_boxes(bbox, scale)
  324. expand_bbox = expand_bbox.astype(np.int32)
  325. padded_mask = np.zeros(
  326. (resolution + 2, resolution + 2), dtype=np.float32)
  327. for j in range(num):
  328. xmin, ymin, xmax, ymax = expand_bbox[j].tolist()
  329. clsid, score = clsid_scores[j].tolist()
  330. clsid = int(clsid)
  331. padded_mask[1:-1, 1:-1] = mask[j, clsid, :, :]
  332. catid = clsid2catid[clsid]
  333. w = xmax - xmin + 1
  334. h = ymax - ymin + 1
  335. w = np.maximum(w, 1)
  336. h = np.maximum(h, 1)
  337. resized_mask = cv2.resize(padded_mask, (w, h))
  338. resized_mask = np.array(
  339. resized_mask > thresh_binarize, dtype=np.uint8)
  340. im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
  341. x0 = min(max(xmin, 0), im_w)
  342. x1 = min(max(xmax + 1, 0), im_w)
  343. y0 = min(max(ymin, 0), im_h)
  344. y1 = min(max(ymax + 1, 0), im_h)
  345. im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), (
  346. x0 - xmin):(x1 - xmin)]
  347. segm = mask_util.encode(
  348. np.array(
  349. im_mask[:, :, np.newaxis], order='F'))[0]
  350. catid = clsid2catid[clsid]
  351. segm['counts'] = segm['counts'].decode('utf8')
  352. coco_res = {
  353. 'image_id': im_id,
  354. 'category_id': catid,
  355. 'segmentation': segm,
  356. 'score': score
  357. }
  358. segm_res.append(coco_res)
  359. return segm_res
  360. def expand_boxes(boxes, scale):
  361. """
  362. Expand an array of boxes by a given scale.
  363. """
  364. w_half = (boxes[:, 2] - boxes[:, 0]) * .5
  365. h_half = (boxes[:, 3] - boxes[:, 1]) * .5
  366. x_c = (boxes[:, 2] + boxes[:, 0]) * .5
  367. y_c = (boxes[:, 3] + boxes[:, 1]) * .5
  368. w_half *= scale
  369. h_half *= scale
  370. boxes_exp = np.zeros(boxes.shape)
  371. boxes_exp[:, 0] = x_c - w_half
  372. boxes_exp[:, 2] = x_c + w_half
  373. boxes_exp[:, 1] = y_c - h_half
  374. boxes_exp[:, 3] = y_c + h_half
  375. return boxes_exp
  376. def voc_bbox_eval(results,
  377. coco_gt,
  378. with_background=False,
  379. overlap_thresh=0.5,
  380. map_type='11point',
  381. is_bbox_normalized=False,
  382. evaluate_difficult=False):
  383. """
  384. Bounding box evaluation for VOC dataset
  385. Args:
  386. results (list): prediction bounding box results.
  387. class_num (int): evaluation class number.
  388. overlap_thresh (float): the postive threshold of
  389. bbox overlap
  390. map_type (string): method for mAP calcualtion,
  391. can only be '11point' or 'integral'
  392. is_bbox_normalized (bool): whether bbox is normalized
  393. to range [0, 1].
  394. evaluate_difficult (bool): whether to evaluate
  395. difficult gt bbox.
  396. """
  397. assert 'bbox' in results[0]
  398. logging.debug("Start evaluate...")
  399. from pycocotools.coco import COCO
  400. cat_ids = coco_gt.getCatIds()
  401. # when with_background = True, mapping category to classid, like:
  402. # background:0, first_class:1, second_class:2, ...
  403. clsid2catid = dict(
  404. {i + int(with_background): catid
  405. for i, catid in enumerate(cat_ids)})
  406. class_num = len(clsid2catid) + int(with_background)
  407. detection_map = DetectionMAP(
  408. class_num=class_num,
  409. overlap_thresh=overlap_thresh,
  410. map_type=map_type,
  411. is_bbox_normalized=is_bbox_normalized,
  412. evaluate_difficult=evaluate_difficult)
  413. xywh_res = []
  414. det_nums = 0
  415. gt_nums = 0
  416. for t in results:
  417. bboxes = t['bbox'][0]
  418. bbox_lengths = t['bbox'][1][0]
  419. im_ids = np.array(t['im_id'][0]).flatten()
  420. if bboxes.shape == (1, 1) or bboxes is None:
  421. continue
  422. gt_boxes = t['gt_box'][0]
  423. gt_labels = t['gt_label'][0]
  424. difficults = t['is_difficult'][0] if not evaluate_difficult \
  425. else None
  426. if len(t['gt_box'][1]) == 0:
  427. # gt_box, gt_label, difficult read as zero padded Tensor
  428. bbox_idx = 0
  429. for i in range(len(gt_boxes)):
  430. gt_box = gt_boxes[i]
  431. gt_label = gt_labels[i]
  432. difficult = None if difficults is None \
  433. else difficults[i]
  434. bbox_num = bbox_lengths[i]
  435. bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
  436. gt_box, gt_label, difficult = prune_zero_padding(
  437. gt_box, gt_label, difficult)
  438. detection_map.update(bbox, gt_box, gt_label, difficult)
  439. bbox_idx += bbox_num
  440. det_nums += bbox_num
  441. gt_nums += gt_box.shape[0]
  442. im_id = int(im_ids[i])
  443. for b in bbox:
  444. clsid, score, xmin, ymin, xmax, ymax = b.tolist()
  445. w = xmax - xmin + 1
  446. h = ymax - ymin + 1
  447. bbox = [xmin, ymin, w, h]
  448. coco_res = {
  449. 'image_id': im_id,
  450. 'category_id': clsid2catid[clsid],
  451. 'bbox': bbox,
  452. 'score': score
  453. }
  454. xywh_res.append(coco_res)
  455. else:
  456. # gt_box, gt_label, difficult read as LoDTensor
  457. gt_box_lengths = t['gt_box'][1][0]
  458. bbox_idx = 0
  459. gt_box_idx = 0
  460. for i in range(len(bbox_lengths)):
  461. bbox_num = bbox_lengths[i]
  462. gt_box_num = gt_box_lengths[i]
  463. bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
  464. gt_box = gt_boxes[gt_box_idx:gt_box_idx + gt_box_num]
  465. gt_label = gt_labels[gt_box_idx:gt_box_idx + gt_box_num]
  466. difficult = None if difficults is None else \
  467. difficults[gt_box_idx: gt_box_idx + gt_box_num]
  468. detection_map.update(bbox, gt_box, gt_label, difficult)
  469. bbox_idx += bbox_num
  470. gt_box_idx += gt_box_num
  471. im_id = int(im_ids[i])
  472. for b in bbox:
  473. clsid, score, xmin, ymin, xmax, ymax = b.tolist()
  474. w = xmax - xmin + 1
  475. h = ymax - ymin + 1
  476. bbox = [xmin, ymin, w, h]
  477. coco_res = {
  478. 'image_id': im_id,
  479. 'category_id': clsid2catid[clsid],
  480. 'bbox': bbox,
  481. 'score': score
  482. }
  483. xywh_res.append(coco_res)
  484. logging.debug("Accumulating evaluatation results...")
  485. detection_map.accumulate()
  486. map_stat = 100. * detection_map.get_map()
  487. logging.debug("mAP({:.2f}, {}) = {:.2f}".format(overlap_thresh, map_type,
  488. map_stat))
  489. return map_stat, xywh_res
  490. def prune_zero_padding(gt_box, gt_label, difficult=None):
  491. valid_cnt = 0
  492. for i in range(len(gt_box)):
  493. if gt_box[i, 0] == 0 and gt_box[i, 1] == 0 and \
  494. gt_box[i, 2] == 0 and gt_box[i, 3] == 0:
  495. break
  496. valid_cnt += 1
  497. return (gt_box[:valid_cnt], gt_label[:valid_cnt], difficult[:valid_cnt]
  498. if difficult is not None else None)
  499. def bbox_area(bbox, is_bbox_normalized):
  500. """
  501. Calculate area of a bounding box
  502. """
  503. norm = 1. - float(is_bbox_normalized)
  504. width = bbox[2] - bbox[0] + norm
  505. height = bbox[3] - bbox[1] + norm
  506. return width * height
  507. def jaccard_overlap(pred, gt, is_bbox_normalized=False):
  508. """
  509. Calculate jaccard overlap ratio between two bounding box
  510. """
  511. if pred[0] >= gt[2] or pred[2] <= gt[0] or \
  512. pred[1] >= gt[3] or pred[3] <= gt[1]:
  513. return 0.
  514. inter_xmin = max(pred[0], gt[0])
  515. inter_ymin = max(pred[1], gt[1])
  516. inter_xmax = min(pred[2], gt[2])
  517. inter_ymax = min(pred[3], gt[3])
  518. inter_size = bbox_area([inter_xmin, inter_ymin, inter_xmax, inter_ymax],
  519. is_bbox_normalized)
  520. pred_size = bbox_area(pred, is_bbox_normalized)
  521. gt_size = bbox_area(gt, is_bbox_normalized)
  522. overlap = float(inter_size) / (pred_size + gt_size - inter_size)
  523. return overlap
  524. class DetectionMAP(object):
  525. """
  526. Calculate detection mean average precision.
  527. Currently support two types: 11point and integral
  528. Args:
  529. class_num (int): the class number.
  530. overlap_thresh (float): The threshold of overlap
  531. ratio between prediction bounding box and
  532. ground truth bounding box for deciding
  533. true/false positive. Default 0.5.
  534. map_type (str): calculation method of mean average
  535. precision, currently support '11point' and
  536. 'integral'. Default '11point'.
  537. is_bbox_normalized (bool): whther bounding boxes
  538. is normalized to range[0, 1]. Default False.
  539. evaluate_difficult (bool): whether to evaluate
  540. difficult bounding boxes. Default False.
  541. """
  542. def __init__(self,
  543. class_num,
  544. overlap_thresh=0.5,
  545. map_type='11point',
  546. is_bbox_normalized=False,
  547. evaluate_difficult=False):
  548. self.class_num = class_num
  549. self.overlap_thresh = overlap_thresh
  550. assert map_type in ['11point', 'integral'], \
  551. "map_type currently only support '11point' "\
  552. "and 'integral'"
  553. self.map_type = map_type
  554. self.is_bbox_normalized = is_bbox_normalized
  555. self.evaluate_difficult = evaluate_difficult
  556. self.reset()
  557. def update(self, bbox, gt_box, gt_label, difficult=None):
  558. """
  559. Update metric statics from given prediction and ground
  560. truth infomations.
  561. """
  562. if difficult is None:
  563. difficult = np.zeros_like(gt_label)
  564. # record class gt count
  565. for gtl, diff in zip(gt_label, difficult):
  566. if self.evaluate_difficult or int(diff) == 0:
  567. self.class_gt_counts[int(np.array(gtl))] += 1
  568. # record class score positive
  569. visited = [False] * len(gt_label)
  570. for b in bbox:
  571. label, score, xmin, ymin, xmax, ymax = b.tolist()
  572. pred = [xmin, ymin, xmax, ymax]
  573. max_idx = -1
  574. max_overlap = -1.0
  575. for i, gl in enumerate(gt_label):
  576. if int(gl) == int(label):
  577. overlap = jaccard_overlap(pred, gt_box[i],
  578. self.is_bbox_normalized)
  579. if overlap > max_overlap:
  580. max_overlap = overlap
  581. max_idx = i
  582. if max_overlap > self.overlap_thresh:
  583. if self.evaluate_difficult or \
  584. int(np.array(difficult[max_idx])) == 0:
  585. if not visited[max_idx]:
  586. self.class_score_poss[int(label)].append([score, 1.0])
  587. visited[max_idx] = True
  588. else:
  589. self.class_score_poss[int(label)].append([score, 0.0])
  590. else:
  591. self.class_score_poss[int(label)].append([score, 0.0])
  592. def reset(self):
  593. """
  594. Reset metric statics
  595. """
  596. self.class_score_poss = [[] for _ in range(self.class_num)]
  597. self.class_gt_counts = [0] * self.class_num
  598. self.mAP = None
  599. self.APs = [None] * self.class_num
  600. def accumulate(self):
  601. """
  602. Accumulate metric results and calculate mAP
  603. """
  604. mAP = 0.
  605. valid_cnt = 0
  606. for id, (
  607. score_pos, count
  608. ) in enumerate(zip(self.class_score_poss, self.class_gt_counts)):
  609. if count == 0: continue
  610. if len(score_pos) == 0:
  611. valid_cnt += 1
  612. continue
  613. accum_tp_list, accum_fp_list = \
  614. self._get_tp_fp_accum(score_pos)
  615. precision = []
  616. recall = []
  617. for ac_tp, ac_fp in zip(accum_tp_list, accum_fp_list):
  618. precision.append(float(ac_tp) / (ac_tp + ac_fp))
  619. recall.append(float(ac_tp) / count)
  620. if self.map_type == '11point':
  621. max_precisions = [0.] * 11
  622. start_idx = len(precision) - 1
  623. for j in range(10, -1, -1):
  624. for i in range(start_idx, -1, -1):
  625. if recall[i] < float(j) / 10.:
  626. start_idx = i
  627. if j > 0:
  628. max_precisions[j - 1] = max_precisions[j]
  629. break
  630. else:
  631. if max_precisions[j] < precision[i]:
  632. max_precisions[j] = precision[i]
  633. mAP += sum(max_precisions) / 11.
  634. self.APs[id] = sum(max_precisions) / 11.
  635. valid_cnt += 1
  636. elif self.map_type == 'integral':
  637. import math
  638. ap = 0.
  639. prev_recall = 0.
  640. for i in range(len(precision)):
  641. recall_gap = math.fabs(recall[i] - prev_recall)
  642. if recall_gap > 1e-6:
  643. ap += precision[i] * recall_gap
  644. prev_recall = recall[i]
  645. mAP += ap
  646. self.APs[id] = sum(max_precisions) / 11.
  647. valid_cnt += 1
  648. else:
  649. raise Exception("Unspported mAP type {}".format(self.map_type))
  650. self.mAP = mAP / float(valid_cnt) if valid_cnt > 0 else mAP
  651. def get_map(self):
  652. """
  653. Get mAP result
  654. """
  655. if self.mAP is None:
  656. raise Exception("mAP is not calculated.")
  657. return self.mAP
  658. def _get_tp_fp_accum(self, score_pos_list):
  659. """
  660. Calculate accumulating true/false positive results from
  661. [score, pos] records
  662. """
  663. sorted_list = sorted(score_pos_list, key=lambda s: s[0], reverse=True)
  664. accum_tp = 0
  665. accum_fp = 0
  666. accum_tp_list = []
  667. accum_fp_list = []
  668. for (score, pos) in sorted_list:
  669. accum_tp += int(pos)
  670. accum_tp_list.append(accum_tp)
  671. accum_fp += 1 - int(pos)
  672. accum_fp_list.append(accum_fp)
  673. return accum_tp_list, accum_fp_list
  674. def makeplot(rs, ps, outDir, class_name, iou_type):
  675. import matplotlib.pyplot as plt
  676. cs = np.vstack([
  677. np.ones((2, 3)), np.array([.31, .51, .74]), np.array([.75, .31, .30]),
  678. np.array([.36, .90, .38]), np.array([.50, .39, .64]),
  679. np.array([1, .6, 0])
  680. ])
  681. areaNames = ['allarea', 'small', 'medium', 'large']
  682. types = ['C75', 'C50', 'Loc', 'Sim', 'Oth', 'BG', 'FN']
  683. for i in range(len(areaNames)):
  684. area_ps = ps[..., i, 0]
  685. figure_tile = iou_type + '-' + class_name + '-' + areaNames[i]
  686. aps = [ps_.mean() for ps_ in area_ps]
  687. ps_curve = [
  688. ps_.mean(axis=1) if ps_.ndim > 1 else ps_ for ps_ in area_ps
  689. ]
  690. ps_curve.insert(0, np.zeros(ps_curve[0].shape))
  691. fig = plt.figure()
  692. ax = plt.subplot(111)
  693. for k in range(len(types)):
  694. ax.plot(rs, ps_curve[k + 1], color=[0, 0, 0], linewidth=0.5)
  695. ax.fill_between(
  696. rs,
  697. ps_curve[k],
  698. ps_curve[k + 1],
  699. color=cs[k],
  700. label=str('[{:.3f}'.format(aps[k]) + ']' + types[k]))
  701. plt.xlabel('recall')
  702. plt.ylabel('precision')
  703. plt.xlim(0, 1.)
  704. plt.ylim(0, 1.)
  705. plt.title(figure_tile)
  706. plt.legend()
  707. fig.savefig(outDir + '/{}.png'.format(figure_tile))
  708. plt.close(fig)
  709. def analyze_individual_category(k, cocoDt, cocoGt, catId, iou_type):
  710. from pycocotools.coco import COCO
  711. from pycocotools.cocoeval import COCOeval
  712. nm = cocoGt.loadCats(catId)[0]
  713. logging.info('--------------analyzing {}-{}---------------'.format(
  714. k + 1, nm['name']))
  715. ps_ = {}
  716. dt = copy.deepcopy(cocoDt)
  717. nm = cocoGt.loadCats(catId)[0]
  718. imgIds = cocoGt.getImgIds()
  719. dt_anns = dt.dataset['annotations']
  720. select_dt_anns = []
  721. for ann in dt_anns:
  722. if ann['category_id'] == catId:
  723. select_dt_anns.append(ann)
  724. dt.dataset['annotations'] = select_dt_anns
  725. dt.createIndex()
  726. # compute precision but ignore superclass confusion
  727. gt = copy.deepcopy(cocoGt)
  728. child_catIds = gt.getCatIds(supNms=[nm['supercategory']])
  729. for idx, ann in enumerate(gt.dataset['annotations']):
  730. if (ann['category_id'] in child_catIds and
  731. ann['category_id'] != catId):
  732. gt.dataset['annotations'][idx]['ignore'] = 1
  733. gt.dataset['annotations'][idx]['iscrowd'] = 1
  734. gt.dataset['annotations'][idx]['category_id'] = catId
  735. cocoEval = COCOeval(gt, copy.deepcopy(dt), iou_type)
  736. cocoEval.params.imgIds = imgIds
  737. cocoEval.params.maxDets = [100]
  738. cocoEval.params.iouThrs = [.1]
  739. cocoEval.params.useCats = 1
  740. cocoEval.evaluate()
  741. cocoEval.accumulate()
  742. ps_supercategory = cocoEval.eval['precision'][0, :, k, :, :]
  743. ps_['ps_supercategory'] = ps_supercategory
  744. # compute precision but ignore any class confusion
  745. gt = copy.deepcopy(cocoGt)
  746. for idx, ann in enumerate(gt.dataset['annotations']):
  747. if ann['category_id'] != catId:
  748. gt.dataset['annotations'][idx]['ignore'] = 1
  749. gt.dataset['annotations'][idx]['iscrowd'] = 1
  750. gt.dataset['annotations'][idx]['category_id'] = catId
  751. cocoEval = COCOeval(gt, copy.deepcopy(dt), iou_type)
  752. cocoEval.params.imgIds = imgIds
  753. cocoEval.params.maxDets = [100]
  754. cocoEval.params.iouThrs = [.1]
  755. cocoEval.params.useCats = 1
  756. cocoEval.evaluate()
  757. cocoEval.accumulate()
  758. ps_allcategory = cocoEval.eval['precision'][0, :, k, :, :]
  759. ps_['ps_allcategory'] = ps_allcategory
  760. return k, ps_
  761. def coco_error_analysis(eval_details_file=None,
  762. gt=None,
  763. pred_bbox=None,
  764. pred_mask=None,
  765. save_dir='./output'):
  766. """
  767. Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/coco_error_analysis.py
  768. """
  769. from multiprocessing import Pool
  770. from pycocotools.coco import COCO
  771. from pycocotools.cocoeval import COCOeval
  772. if eval_details_file is not None:
  773. import json
  774. with open(eval_details_file, 'r') as f:
  775. eval_details = json.load(f)
  776. pred_bbox = eval_details['bbox']
  777. if 'mask' in eval_details:
  778. pred_mask = eval_details['mask']
  779. gt = eval_details['gt']
  780. if gt is None or pred_bbox is None:
  781. raise Exception(
  782. "gt/pred_bbox/pred_mask is None now, please set right eval_details_file or gt/pred_bbox/pred_mask."
  783. )
  784. if pred_bbox is not None and len(pred_bbox) == 0:
  785. raise Exception("There is no predicted bbox.")
  786. if pred_mask is not None and len(pred_mask) == 0:
  787. raise Exception("There is no predicted mask.")
  788. def _analyze_results(cocoGt, cocoDt, res_type, out_dir):
  789. directory = os.path.dirname(out_dir + '/')
  790. if not os.path.exists(directory):
  791. logging.info('-------------create {}-----------------'.format(
  792. out_dir))
  793. os.makedirs(directory)
  794. imgIds = cocoGt.getImgIds()
  795. res_out_dir = out_dir + '/' + res_type + '/'
  796. res_directory = os.path.dirname(res_out_dir)
  797. if not os.path.exists(res_directory):
  798. logging.info('-------------create {}-----------------'.format(
  799. res_out_dir))
  800. os.makedirs(res_directory)
  801. iou_type = res_type
  802. cocoEval = COCOeval(
  803. copy.deepcopy(cocoGt), copy.deepcopy(cocoDt), iou_type)
  804. cocoEval.params.imgIds = imgIds
  805. cocoEval.params.iouThrs = [.75, .5, .1]
  806. cocoEval.params.maxDets = [100]
  807. cocoEval.evaluate()
  808. cocoEval.accumulate()
  809. ps = cocoEval.eval['precision']
  810. ps = np.vstack([ps, np.zeros((4, *ps.shape[1:]))])
  811. catIds = cocoGt.getCatIds()
  812. recThrs = cocoEval.params.recThrs
  813. with Pool(processes=48) as pool:
  814. args = [(k, cocoDt, cocoGt, catId, iou_type)
  815. for k, catId in enumerate(catIds)]
  816. analyze_results = pool.starmap(analyze_individual_category, args)
  817. for k, catId in enumerate(catIds):
  818. nm = cocoGt.loadCats(catId)[0]
  819. logging.info('--------------saving {}-{}---------------'.format(
  820. k + 1, nm['name']))
  821. analyze_result = analyze_results[k]
  822. assert k == analyze_result[0], ""
  823. ps_supercategory = analyze_result[1]['ps_supercategory']
  824. ps_allcategory = analyze_result[1]['ps_allcategory']
  825. # compute precision but ignore superclass confusion
  826. ps[3, :, k, :, :] = ps_supercategory
  827. # compute precision but ignore any class confusion
  828. ps[4, :, k, :, :] = ps_allcategory
  829. # fill in background and false negative errors and plot
  830. T, _, _, A, _ = ps.shape
  831. for t in range(T):
  832. for a in range(A):
  833. if np.sum(ps[t, :, k, a, :] ==
  834. -1) != len(ps[t, :, k, :, :]):
  835. ps[t, :, k, a, :][ps[t, :, k, a, :] == -1] = 0
  836. ps[5, :, k, :, :] = (ps[4, :, k, :, :] > 0)
  837. ps[6, :, k, :, :] = 1.0
  838. makeplot(recThrs, ps[:, :, k], res_out_dir, nm['name'], iou_type)
  839. makeplot(recThrs, ps, res_out_dir, 'allclass', iou_type)
  840. coco_gt = COCO()
  841. coco_gt.dataset = gt
  842. coco_gt.createIndex()
  843. from pycocotools.cocoeval import COCOeval
  844. if pred_bbox is not None:
  845. coco_dt = loadRes(coco_gt, pred_bbox)
  846. _analyze_results(coco_gt, coco_dt, res_type='bbox', out_dir=save_dir)
  847. if pred_mask is not None:
  848. coco_dt = loadRes(coco_gt, pred_mask)
  849. _analyze_results(coco_gt, coco_dt, res_type='segm', out_dir=save_dir)
  850. logging.info("The analysis figures are saved in {}".format(save_dir))