detection_eval.py 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012
  1. # coding: utf8
  2. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import absolute_import
  16. import numpy as np
  17. import json
  18. import os
  19. import sys
  20. import cv2
  21. import copy
  22. import paddlex.utils.logging as logging
  23. # fix linspace problem for pycocotools while numpy > 1.17.2
  24. backup_linspace = np.linspace
  25. def fixed_linspace(start,
  26. stop,
  27. num=50,
  28. endpoint=True,
  29. retstep=False,
  30. dtype=None,
  31. axis=0):
  32. num = int(num)
  33. return backup_linspace(start, stop, num, endpoint, retstep, dtype, axis)
  34. def eval_results(results,
  35. metric,
  36. coco_gt,
  37. with_background=True,
  38. resolution=None,
  39. is_bbox_normalized=False,
  40. map_type='11point'):
  41. """Evaluation for evaluation program results"""
  42. box_ap_stats = []
  43. coco_gt_data = copy.deepcopy(coco_gt)
  44. eval_details = {'gt': copy.deepcopy(coco_gt.dataset)}
  45. if metric == 'COCO':
  46. np.linspace = fixed_linspace
  47. if 'proposal' in results[0]:
  48. proposal_eval(results, coco_gt_data)
  49. if 'bbox' in results[0]:
  50. box_ap_stats, xywh_results = coco_bbox_eval(
  51. results,
  52. coco_gt_data,
  53. with_background,
  54. is_bbox_normalized=is_bbox_normalized)
  55. if 'mask' in results[0]:
  56. mask_ap_stats, segm_results = mask_eval(results, coco_gt_data,
  57. resolution)
  58. ap_stats = [box_ap_stats, mask_ap_stats]
  59. eval_details['bbox'] = xywh_results
  60. eval_details['mask'] = segm_results
  61. return ap_stats, eval_details
  62. np.linspace = backup_linspace
  63. else:
  64. if 'accum_map' in results[-1]:
  65. res = np.mean(results[-1]['accum_map'][0])
  66. logging.debug('mAP: {:.2f}'.format(res * 100.))
  67. box_ap_stats.append(res * 100.)
  68. elif 'bbox' in results[0]:
  69. box_ap, xywh_results = voc_bbox_eval(
  70. results,
  71. coco_gt_data,
  72. with_background,
  73. is_bbox_normalized=is_bbox_normalized,
  74. map_type=map_type)
  75. box_ap_stats.append(box_ap)
  76. eval_details['bbox'] = xywh_results
  77. return box_ap_stats, eval_details
  78. def proposal_eval(results, coco_gt, outputfile, max_dets=(100, 300, 1000)):
  79. assert 'proposal' in results[0]
  80. assert outfile.endswith('.json')
  81. xywh_results = proposal2out(results)
  82. assert len(
  83. xywh_results) > 0, "The number of valid proposal detected is zero.\n \
  84. Please use reasonable model and check input data."
  85. with open(outfile, 'w') as f:
  86. json.dump(xywh_results, f)
  87. cocoapi_eval(xywh_results, 'proposal', coco_gt=coco_gt, max_dets=max_dets)
  88. # flush coco evaluation result
  89. sys.stdout.flush()
  90. def coco_bbox_eval(results,
  91. coco_gt,
  92. with_background=True,
  93. is_bbox_normalized=False):
  94. assert 'bbox' in results[0]
  95. from pycocotools.coco import COCO
  96. cat_ids = coco_gt.getCatIds()
  97. # when with_background = True, mapping category to classid, like:
  98. # background:0, first_class:1, second_class:2, ...
  99. clsid2catid = dict(
  100. {i + int(with_background): catid
  101. for i, catid in enumerate(cat_ids)})
  102. xywh_results = bbox2out(
  103. results, clsid2catid, is_bbox_normalized=is_bbox_normalized)
  104. results = copy.deepcopy(xywh_results)
  105. if len(xywh_results) == 0:
  106. logging.warning(
  107. "The number of valid bbox detected is zero.\n Please use reasonable model and check input data.\n stop eval!"
  108. )
  109. return [0.0], results
  110. map_stats = cocoapi_eval(xywh_results, 'bbox', coco_gt=coco_gt)
  111. # flush coco evaluation result
  112. sys.stdout.flush()
  113. return map_stats, results
  114. def loadRes(coco_obj, anns):
  115. """
  116. Load result file and return a result api object.
  117. :param resFile (str) : file name of result file
  118. :return: res (obj) : result api object
  119. """
  120. from pycocotools.coco import COCO
  121. import pycocotools.mask as maskUtils
  122. import time
  123. res = COCO()
  124. res.dataset['images'] = [img for img in coco_obj.dataset['images']]
  125. tic = time.time()
  126. assert type(anns) == list, 'results in not an array of objects'
  127. annsImgIds = [ann['image_id'] for ann in anns]
  128. assert set(annsImgIds) == (set(annsImgIds) & set(coco_obj.getImgIds())), \
  129. 'Results do not correspond to current coco set'
  130. if 'caption' in anns[0]:
  131. imgIds = set([img['id'] for img in res.dataset['images']]) & set(
  132. [ann['image_id'] for ann in anns])
  133. res.dataset['images'] = [
  134. img for img in res.dataset['images'] if img['id'] in imgIds
  135. ]
  136. for id, ann in enumerate(anns):
  137. ann['id'] = id + 1
  138. elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
  139. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  140. 'categories'])
  141. for id, ann in enumerate(anns):
  142. bb = ann['bbox']
  143. x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
  144. if not 'segmentation' in ann:
  145. ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
  146. ann['area'] = bb[2] * bb[3]
  147. ann['id'] = id + 1
  148. ann['iscrowd'] = 0
  149. elif 'segmentation' in anns[0]:
  150. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  151. 'categories'])
  152. for id, ann in enumerate(anns):
  153. # now only support compressed RLE format as segmentation results
  154. ann['area'] = maskUtils.area(ann['segmentation'])
  155. if not 'bbox' in ann:
  156. ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
  157. ann['id'] = id + 1
  158. ann['iscrowd'] = 0
  159. elif 'keypoints' in anns[0]:
  160. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  161. 'categories'])
  162. for id, ann in enumerate(anns):
  163. s = ann['keypoints']
  164. x = s[0::3]
  165. y = s[1::3]
  166. x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y)
  167. ann['area'] = (x1 - x0) * (y1 - y0)
  168. ann['id'] = id + 1
  169. ann['bbox'] = [x0, y0, x1 - x0, y1 - y0]
  170. res.dataset['annotations'] = anns
  171. res.createIndex()
  172. return res
  173. def mask_eval(results, coco_gt, resolution, thresh_binarize=0.5):
  174. assert 'mask' in results[0]
  175. from pycocotools.coco import COCO
  176. clsid2catid = {i + 1: v for i, v in enumerate(coco_gt.getCatIds())}
  177. segm_results = mask2out(results, clsid2catid, resolution, thresh_binarize)
  178. results = copy.deepcopy(segm_results)
  179. if len(segm_results) == 0:
  180. logging.warning(
  181. "The number of valid mask detected is zero.\n Please use reasonable model and check input data."
  182. )
  183. return None, results
  184. map_stats = cocoapi_eval(segm_results, 'segm', coco_gt=coco_gt)
  185. return map_stats, results
  186. def cocoapi_eval(anns,
  187. style,
  188. coco_gt=None,
  189. anno_file=None,
  190. max_dets=(100, 300, 1000)):
  191. """
  192. Args:
  193. anns: Evaluation result.
  194. style: COCOeval style, can be `bbox` , `segm` and `proposal`.
  195. coco_gt: Whether to load COCOAPI through anno_file,
  196. eg: coco_gt = COCO(anno_file)
  197. anno_file: COCO annotations file.
  198. max_dets: COCO evaluation maxDets.
  199. """
  200. assert coco_gt != None or anno_file != None
  201. from pycocotools.coco import COCO
  202. from pycocotools.cocoeval import COCOeval
  203. if coco_gt == None:
  204. coco_gt = COCO(anno_file)
  205. logging.debug("Start evaluate...")
  206. coco_dt = loadRes(coco_gt, anns)
  207. if style == 'proposal':
  208. coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
  209. coco_eval.params.useCats = 0
  210. coco_eval.params.maxDets = list(max_dets)
  211. else:
  212. coco_eval = COCOeval(coco_gt, coco_dt, style)
  213. coco_eval.evaluate()
  214. coco_eval.accumulate()
  215. coco_eval.summarize()
  216. return coco_eval.stats
  217. def proposal2out(results, is_bbox_normalized=False):
  218. xywh_res = []
  219. for t in results:
  220. bboxes = t['proposal'][0]
  221. lengths = t['proposal'][1][0]
  222. im_ids = np.array(t['im_id'][0]).flatten()
  223. assert len(lengths) == im_ids.size
  224. if bboxes.shape == (1, 1) or bboxes is None:
  225. continue
  226. k = 0
  227. for i in range(len(lengths)):
  228. num = lengths[i]
  229. im_id = int(im_ids[i])
  230. for j in range(num):
  231. dt = bboxes[k]
  232. xmin, ymin, xmax, ymax = dt.tolist()
  233. if is_bbox_normalized:
  234. xmin, ymin, xmax, ymax = \
  235. clip_bbox([xmin, ymin, xmax, ymax])
  236. w = xmax - xmin
  237. h = ymax - ymin
  238. else:
  239. w = xmax - xmin + 1
  240. h = ymax - ymin + 1
  241. bbox = [xmin, ymin, w, h]
  242. coco_res = {
  243. 'image_id': im_id,
  244. 'category_id': 1,
  245. 'bbox': bbox,
  246. 'score': 1.0
  247. }
  248. xywh_res.append(coco_res)
  249. k += 1
  250. return xywh_res
  251. def bbox2out(results, clsid2catid, is_bbox_normalized=False):
  252. """
  253. Args:
  254. results: request a dict, should include: `bbox`, `im_id`,
  255. if is_bbox_normalized=True, also need `im_shape`.
  256. clsid2catid: class id to category id map of COCO2017 dataset.
  257. is_bbox_normalized: whether or not bbox is normalized.
  258. """
  259. xywh_res = []
  260. for t in results:
  261. bboxes = t['bbox'][0]
  262. lengths = t['bbox'][1][0]
  263. im_ids = np.array(t['im_id'][0]).flatten()
  264. if bboxes.shape == (1, 1) or bboxes is None:
  265. continue
  266. k = 0
  267. for i in range(len(lengths)):
  268. num = lengths[i]
  269. im_id = int(im_ids[i])
  270. for j in range(num):
  271. dt = bboxes[k]
  272. clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
  273. catid = (clsid2catid[int(clsid)])
  274. if is_bbox_normalized:
  275. xmin, ymin, xmax, ymax = \
  276. clip_bbox([xmin, ymin, xmax, ymax])
  277. w = xmax - xmin
  278. h = ymax - ymin
  279. im_shape = t['im_shape'][0][i].tolist()
  280. im_height, im_width = int(im_shape[0]), int(im_shape[1])
  281. xmin *= im_width
  282. ymin *= im_height
  283. w *= im_width
  284. h *= im_height
  285. else:
  286. w = xmax - xmin + 1
  287. h = ymax - ymin + 1
  288. bbox = [xmin, ymin, w, h]
  289. coco_res = {
  290. 'image_id': im_id,
  291. 'category_id': catid,
  292. 'bbox': bbox,
  293. 'score': score
  294. }
  295. xywh_res.append(coco_res)
  296. k += 1
  297. return xywh_res
  298. def mask2out(results, clsid2catid, resolution, thresh_binarize=0.5):
  299. import pycocotools.mask as mask_util
  300. scale = (resolution + 2.0) / resolution
  301. segm_res = []
  302. # for each batch
  303. for t in results:
  304. bboxes = t['bbox'][0]
  305. lengths = t['bbox'][1][0]
  306. im_ids = np.array(t['im_id'][0])
  307. if bboxes.shape == (1, 1) or bboxes is None:
  308. continue
  309. if len(bboxes.tolist()) == 0:
  310. continue
  311. masks = t['mask'][0]
  312. s = 0
  313. # for each sample
  314. for i in range(len(lengths)):
  315. num = lengths[i]
  316. im_id = int(im_ids[i][0])
  317. im_shape = t['im_shape'][0][i]
  318. bbox = bboxes[s:s + num][:, 2:]
  319. clsid_scores = bboxes[s:s + num][:, 0:2]
  320. mask = masks[s:s + num]
  321. s += num
  322. im_h = int(im_shape[0])
  323. im_w = int(im_shape[1])
  324. expand_bbox = expand_boxes(bbox, scale)
  325. expand_bbox = expand_bbox.astype(np.int32)
  326. padded_mask = np.zeros(
  327. (resolution + 2, resolution + 2), dtype=np.float32)
  328. for j in range(num):
  329. xmin, ymin, xmax, ymax = expand_bbox[j].tolist()
  330. clsid, score = clsid_scores[j].tolist()
  331. clsid = int(clsid)
  332. padded_mask[1:-1, 1:-1] = mask[j, clsid, :, :]
  333. catid = clsid2catid[clsid]
  334. w = xmax - xmin + 1
  335. h = ymax - ymin + 1
  336. w = np.maximum(w, 1)
  337. h = np.maximum(h, 1)
  338. resized_mask = cv2.resize(padded_mask, (w, h))
  339. resized_mask = np.array(
  340. resized_mask > thresh_binarize, dtype=np.uint8)
  341. im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
  342. x0 = min(max(xmin, 0), im_w)
  343. x1 = min(max(xmax + 1, 0), im_w)
  344. y0 = min(max(ymin, 0), im_h)
  345. y1 = min(max(ymax + 1, 0), im_h)
  346. im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), (
  347. x0 - xmin):(x1 - xmin)]
  348. segm = mask_util.encode(
  349. np.array(
  350. im_mask[:, :, np.newaxis], order='F'))[0]
  351. catid = clsid2catid[clsid]
  352. segm['counts'] = segm['counts'].decode('utf8')
  353. coco_res = {
  354. 'image_id': im_id,
  355. 'category_id': catid,
  356. 'segmentation': segm,
  357. 'score': score
  358. }
  359. segm_res.append(coco_res)
  360. return segm_res
  361. def expand_boxes(boxes, scale):
  362. """
  363. Expand an array of boxes by a given scale.
  364. """
  365. w_half = (boxes[:, 2] - boxes[:, 0]) * .5
  366. h_half = (boxes[:, 3] - boxes[:, 1]) * .5
  367. x_c = (boxes[:, 2] + boxes[:, 0]) * .5
  368. y_c = (boxes[:, 3] + boxes[:, 1]) * .5
  369. w_half *= scale
  370. h_half *= scale
  371. boxes_exp = np.zeros(boxes.shape)
  372. boxes_exp[:, 0] = x_c - w_half
  373. boxes_exp[:, 2] = x_c + w_half
  374. boxes_exp[:, 1] = y_c - h_half
  375. boxes_exp[:, 3] = y_c + h_half
  376. return boxes_exp
  377. def voc_bbox_eval(results,
  378. coco_gt,
  379. with_background=False,
  380. overlap_thresh=0.5,
  381. map_type='11point',
  382. is_bbox_normalized=False,
  383. evaluate_difficult=False):
  384. """
  385. Bounding box evaluation for VOC dataset
  386. Args:
  387. results (list): prediction bounding box results.
  388. class_num (int): evaluation class number.
  389. overlap_thresh (float): the postive threshold of
  390. bbox overlap
  391. map_type (string): method for mAP calcualtion,
  392. can only be '11point' or 'integral'
  393. is_bbox_normalized (bool): whether bbox is normalized
  394. to range [0, 1].
  395. evaluate_difficult (bool): whether to evaluate
  396. difficult gt bbox.
  397. """
  398. assert 'bbox' in results[0]
  399. logging.debug("Start evaluate...")
  400. from pycocotools.coco import COCO
  401. cat_ids = coco_gt.getCatIds()
  402. # when with_background = True, mapping category to classid, like:
  403. # background:0, first_class:1, second_class:2, ...
  404. clsid2catid = dict(
  405. {i + int(with_background): catid
  406. for i, catid in enumerate(cat_ids)})
  407. class_num = len(clsid2catid) + int(with_background)
  408. detection_map = DetectionMAP(
  409. class_num=class_num,
  410. overlap_thresh=overlap_thresh,
  411. map_type=map_type,
  412. is_bbox_normalized=is_bbox_normalized,
  413. evaluate_difficult=evaluate_difficult)
  414. xywh_res = []
  415. det_nums = 0
  416. gt_nums = 0
  417. for t in results:
  418. bboxes = t['bbox'][0]
  419. bbox_lengths = t['bbox'][1][0]
  420. im_ids = np.array(t['im_id'][0]).flatten()
  421. if bboxes.shape == (1, 1) or bboxes is None:
  422. continue
  423. gt_boxes = t['gt_box'][0]
  424. gt_labels = t['gt_label'][0]
  425. difficults = t['is_difficult'][0] if not evaluate_difficult \
  426. else None
  427. if len(t['gt_box'][1]) == 0:
  428. # gt_box, gt_label, difficult read as zero padded Tensor
  429. bbox_idx = 0
  430. for i in range(len(gt_boxes)):
  431. gt_box = gt_boxes[i]
  432. gt_label = gt_labels[i]
  433. difficult = None if difficults is None \
  434. else difficults[i]
  435. bbox_num = bbox_lengths[i]
  436. bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
  437. gt_box, gt_label, difficult = prune_zero_padding(
  438. gt_box, gt_label, difficult)
  439. detection_map.update(bbox, gt_box, gt_label, difficult)
  440. bbox_idx += bbox_num
  441. det_nums += bbox_num
  442. gt_nums += gt_box.shape[0]
  443. im_id = int(im_ids[i])
  444. for b in bbox:
  445. clsid, score, xmin, ymin, xmax, ymax = b.tolist()
  446. w = xmax - xmin + 1
  447. h = ymax - ymin + 1
  448. bbox = [xmin, ymin, w, h]
  449. coco_res = {
  450. 'image_id': im_id,
  451. 'category_id': clsid2catid[clsid],
  452. 'bbox': bbox,
  453. 'score': score
  454. }
  455. xywh_res.append(coco_res)
  456. else:
  457. # gt_box, gt_label, difficult read as LoDTensor
  458. gt_box_lengths = t['gt_box'][1][0]
  459. bbox_idx = 0
  460. gt_box_idx = 0
  461. for i in range(len(bbox_lengths)):
  462. bbox_num = bbox_lengths[i]
  463. gt_box_num = gt_box_lengths[i]
  464. bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
  465. gt_box = gt_boxes[gt_box_idx:gt_box_idx + gt_box_num]
  466. gt_label = gt_labels[gt_box_idx:gt_box_idx + gt_box_num]
  467. difficult = None if difficults is None else \
  468. difficults[gt_box_idx: gt_box_idx + gt_box_num]
  469. detection_map.update(bbox, gt_box, gt_label, difficult)
  470. bbox_idx += bbox_num
  471. gt_box_idx += gt_box_num
  472. im_id = int(im_ids[i])
  473. for b in bbox:
  474. clsid, score, xmin, ymin, xmax, ymax = b.tolist()
  475. w = xmax - xmin + 1
  476. h = ymax - ymin + 1
  477. bbox = [xmin, ymin, w, h]
  478. coco_res = {
  479. 'image_id': im_id,
  480. 'category_id': clsid2catid[clsid],
  481. 'bbox': bbox,
  482. 'score': score
  483. }
  484. xywh_res.append(coco_res)
  485. logging.debug("Accumulating evaluatation results...")
  486. detection_map.accumulate()
  487. map_stat = 100. * detection_map.get_map()
  488. logging.debug("mAP({:.2f}, {}) = {:.2f}".format(overlap_thresh, map_type,
  489. map_stat))
  490. return map_stat, xywh_res
  491. def prune_zero_padding(gt_box, gt_label, difficult=None):
  492. valid_cnt = 0
  493. for i in range(len(gt_box)):
  494. if gt_box[i, 0] == 0 and gt_box[i, 1] == 0 and \
  495. gt_box[i, 2] == 0 and gt_box[i, 3] == 0:
  496. break
  497. valid_cnt += 1
  498. return (gt_box[:valid_cnt], gt_label[:valid_cnt], difficult[:valid_cnt]
  499. if difficult is not None else None)
  500. def bbox_area(bbox, is_bbox_normalized):
  501. """
  502. Calculate area of a bounding box
  503. """
  504. norm = 1. - float(is_bbox_normalized)
  505. width = bbox[2] - bbox[0] + norm
  506. height = bbox[3] - bbox[1] + norm
  507. return width * height
  508. def jaccard_overlap(pred, gt, is_bbox_normalized=False):
  509. """
  510. Calculate jaccard overlap ratio between two bounding box
  511. """
  512. if pred[0] >= gt[2] or pred[2] <= gt[0] or \
  513. pred[1] >= gt[3] or pred[3] <= gt[1]:
  514. return 0.
  515. inter_xmin = max(pred[0], gt[0])
  516. inter_ymin = max(pred[1], gt[1])
  517. inter_xmax = min(pred[2], gt[2])
  518. inter_ymax = min(pred[3], gt[3])
  519. inter_size = bbox_area([inter_xmin, inter_ymin, inter_xmax, inter_ymax],
  520. is_bbox_normalized)
  521. pred_size = bbox_area(pred, is_bbox_normalized)
  522. gt_size = bbox_area(gt, is_bbox_normalized)
  523. overlap = float(inter_size) / (pred_size + gt_size - inter_size)
  524. return overlap
  525. class DetectionMAP(object):
  526. """
  527. Calculate detection mean average precision.
  528. Currently support two types: 11point and integral
  529. Args:
  530. class_num (int): the class number.
  531. overlap_thresh (float): The threshold of overlap
  532. ratio between prediction bounding box and
  533. ground truth bounding box for deciding
  534. true/false positive. Default 0.5.
  535. map_type (str): calculation method of mean average
  536. precision, currently support '11point' and
  537. 'integral'. Default '11point'.
  538. is_bbox_normalized (bool): whther bounding boxes
  539. is normalized to range[0, 1]. Default False.
  540. evaluate_difficult (bool): whether to evaluate
  541. difficult bounding boxes. Default False.
  542. """
  543. def __init__(self,
  544. class_num,
  545. overlap_thresh=0.5,
  546. map_type='11point',
  547. is_bbox_normalized=False,
  548. evaluate_difficult=False):
  549. self.class_num = class_num
  550. self.overlap_thresh = overlap_thresh
  551. assert map_type in ['11point', 'integral'], \
  552. "map_type currently only support '11point' "\
  553. "and 'integral'"
  554. self.map_type = map_type
  555. self.is_bbox_normalized = is_bbox_normalized
  556. self.evaluate_difficult = evaluate_difficult
  557. self.reset()
  558. def update(self, bbox, gt_box, gt_label, difficult=None):
  559. """
  560. Update metric statics from given prediction and ground
  561. truth infomations.
  562. """
  563. if difficult is None:
  564. difficult = np.zeros_like(gt_label)
  565. # record class gt count
  566. for gtl, diff in zip(gt_label, difficult):
  567. if self.evaluate_difficult or int(diff) == 0:
  568. self.class_gt_counts[int(np.array(gtl))] += 1
  569. # record class score positive
  570. visited = [False] * len(gt_label)
  571. for b in bbox:
  572. label, score, xmin, ymin, xmax, ymax = b.tolist()
  573. pred = [xmin, ymin, xmax, ymax]
  574. max_idx = -1
  575. max_overlap = -1.0
  576. for i, gl in enumerate(gt_label):
  577. if int(gl) == int(label):
  578. overlap = jaccard_overlap(pred, gt_box[i],
  579. self.is_bbox_normalized)
  580. if overlap > max_overlap:
  581. max_overlap = overlap
  582. max_idx = i
  583. if max_overlap > self.overlap_thresh:
  584. if self.evaluate_difficult or \
  585. int(np.array(difficult[max_idx])) == 0:
  586. if not visited[max_idx]:
  587. self.class_score_poss[int(label)].append([score, 1.0])
  588. visited[max_idx] = True
  589. else:
  590. self.class_score_poss[int(label)].append([score, 0.0])
  591. else:
  592. self.class_score_poss[int(label)].append([score, 0.0])
  593. def reset(self):
  594. """
  595. Reset metric statics
  596. """
  597. self.class_score_poss = [[] for _ in range(self.class_num)]
  598. self.class_gt_counts = [0] * self.class_num
  599. self.mAP = None
  600. self.APs = [None] * self.class_num
  601. def accumulate(self):
  602. """
  603. Accumulate metric results and calculate mAP
  604. """
  605. mAP = 0.
  606. valid_cnt = 0
  607. for id, (
  608. score_pos, count
  609. ) in enumerate(zip(self.class_score_poss, self.class_gt_counts)):
  610. if count == 0: continue
  611. if len(score_pos) == 0:
  612. valid_cnt += 1
  613. continue
  614. accum_tp_list, accum_fp_list = \
  615. self._get_tp_fp_accum(score_pos)
  616. precision = []
  617. recall = []
  618. for ac_tp, ac_fp in zip(accum_tp_list, accum_fp_list):
  619. precision.append(float(ac_tp) / (ac_tp + ac_fp))
  620. recall.append(float(ac_tp) / count)
  621. if self.map_type == '11point':
  622. max_precisions = [0.] * 11
  623. start_idx = len(precision) - 1
  624. for j in range(10, -1, -1):
  625. for i in range(start_idx, -1, -1):
  626. if recall[i] < float(j) / 10.:
  627. start_idx = i
  628. if j > 0:
  629. max_precisions[j - 1] = max_precisions[j]
  630. break
  631. else:
  632. if max_precisions[j] < precision[i]:
  633. max_precisions[j] = precision[i]
  634. mAP += sum(max_precisions) / 11.
  635. self.APs[id] = sum(max_precisions) / 11.
  636. valid_cnt += 1
  637. elif self.map_type == 'integral':
  638. import math
  639. ap = 0.
  640. prev_recall = 0.
  641. for i in range(len(precision)):
  642. recall_gap = math.fabs(recall[i] - prev_recall)
  643. if recall_gap > 1e-6:
  644. ap += precision[i] * recall_gap
  645. prev_recall = recall[i]
  646. mAP += ap
  647. self.APs[id] = sum(max_precisions) / 11.
  648. valid_cnt += 1
  649. else:
  650. raise Exception("Unspported mAP type {}".format(self.map_type))
  651. self.mAP = mAP / float(valid_cnt) if valid_cnt > 0 else mAP
  652. def get_map(self):
  653. """
  654. Get mAP result
  655. """
  656. if self.mAP is None:
  657. raise Exception("mAP is not calculated.")
  658. return self.mAP
  659. def _get_tp_fp_accum(self, score_pos_list):
  660. """
  661. Calculate accumulating true/false positive results from
  662. [score, pos] records
  663. """
  664. sorted_list = sorted(score_pos_list, key=lambda s: s[0], reverse=True)
  665. accum_tp = 0
  666. accum_fp = 0
  667. accum_tp_list = []
  668. accum_fp_list = []
  669. for (score, pos) in sorted_list:
  670. accum_tp += int(pos)
  671. accum_tp_list.append(accum_tp)
  672. accum_fp += 1 - int(pos)
  673. accum_fp_list.append(accum_fp)
  674. return accum_tp_list, accum_fp_list
  675. def makeplot(rs, ps, outDir, class_name, iou_type):
  676. """针对某个特定类别,绘制不同评估要求下的准确率和召回率。
  677. 绘制结果说明参考COCODataset官网给出分析工具说明https://cocodataset.org/#detection-eval。
  678. Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/coco_error_analysis.py
  679. Args:
  680. rs (np.array): 在不同置信度阈值下计算得到的召回率。
  681. ps (np.array): 在不同置信度阈值下计算得到的准确率。ps与rs相同位置下的数值为同一个置信度阈值
  682. 计算得到的准确率与召回率。
  683. outDir (str): 图表保存的路径。
  684. class_name (str): 类别名。
  685. iou_type (str): iou计算方式,若为检测框,则设置为'bbox',若为像素级分割结果,则设置为'segm'。
  686. """
  687. import matplotlib.pyplot as plt
  688. cs = np.vstack([
  689. np.ones((2, 3)), np.array([.31, .51, .74]), np.array([.75, .31, .30]),
  690. np.array([.36, .90, .38]), np.array([.50, .39, .64]),
  691. np.array([1, .6, 0])
  692. ])
  693. areaNames = ['allarea', 'small', 'medium', 'large']
  694. types = ['C75', 'C50', 'Loc', 'Sim', 'Oth', 'BG', 'FN']
  695. for i in range(len(areaNames)):
  696. area_ps = ps[..., i, 0]
  697. figure_tile = iou_type + '-' + class_name + '-' + areaNames[i]
  698. aps = [ps_.mean() for ps_ in area_ps]
  699. ps_curve = [
  700. ps_.mean(axis=1) if ps_.ndim > 1 else ps_ for ps_ in area_ps
  701. ]
  702. ps_curve.insert(0, np.zeros(ps_curve[0].shape))
  703. fig = plt.figure()
  704. ax = plt.subplot(111)
  705. for k in range(len(types)):
  706. ax.plot(rs, ps_curve[k + 1], color=[0, 0, 0], linewidth=0.5)
  707. ax.fill_between(
  708. rs,
  709. ps_curve[k],
  710. ps_curve[k + 1],
  711. color=cs[k],
  712. label=str('[{:.3f}'.format(aps[k]) + ']' + types[k]))
  713. plt.xlabel('recall')
  714. plt.ylabel('precision')
  715. plt.xlim(0, 1.)
  716. plt.ylim(0, 1.)
  717. plt.title(figure_tile)
  718. plt.legend()
  719. fig.savefig(outDir + '/{}.png'.format(figure_tile))
  720. plt.close(fig)
  721. def analyze_individual_category(k, cocoDt, cocoGt, catId, iou_type):
  722. """针对某个特定类别,分析忽略亚类混淆和类别混淆时的准确率。
  723. Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/coco_error_analysis.py
  724. Args:
  725. k (int): 待分析类别的序号。
  726. cocoDt (pycocotols.coco.COCO): 按COCO类存放的预测结果。
  727. cocoGt (pycocotols.coco.COCO): 按COCO类存放的真值。
  728. catId (int): 待分析类别在数据集中的类别id。
  729. iou_type (str): iou计算方式,若为检测框,则设置为'bbox',若为像素级分割结果,则设置为'segm'。
  730. Returns:
  731. int:
  732. dict: 有关键字'ps_supercategory'和'ps_allcategory'。关键字'ps_supercategory'的键值是忽略亚类间
  733. 混淆时的准确率,关键字'ps_allcategory'的键值是忽略类别间混淆时的准确率。
  734. """
  735. from pycocotools.coco import COCO
  736. from pycocotools.cocoeval import COCOeval
  737. nm = cocoGt.loadCats(catId)[0]
  738. logging.info('--------------analyzing {}-{}---------------'.format(
  739. k + 1, nm['name']))
  740. ps_ = {}
  741. dt = copy.deepcopy(cocoDt)
  742. nm = cocoGt.loadCats(catId)[0]
  743. imgIds = cocoGt.getImgIds()
  744. dt_anns = dt.dataset['annotations']
  745. select_dt_anns = []
  746. for ann in dt_anns:
  747. if ann['category_id'] == catId:
  748. select_dt_anns.append(ann)
  749. dt.dataset['annotations'] = select_dt_anns
  750. dt.createIndex()
  751. # compute precision but ignore superclass confusion
  752. gt = copy.deepcopy(cocoGt)
  753. child_catIds = gt.getCatIds(supNms=[nm['supercategory']])
  754. for idx, ann in enumerate(gt.dataset['annotations']):
  755. if (ann['category_id'] in child_catIds and
  756. ann['category_id'] != catId):
  757. gt.dataset['annotations'][idx]['ignore'] = 1
  758. gt.dataset['annotations'][idx]['iscrowd'] = 1
  759. gt.dataset['annotations'][idx]['category_id'] = catId
  760. cocoEval = COCOeval(gt, copy.deepcopy(dt), iou_type)
  761. cocoEval.params.imgIds = imgIds
  762. cocoEval.params.maxDets = [100]
  763. cocoEval.params.iouThrs = [.1]
  764. cocoEval.params.useCats = 1
  765. cocoEval.evaluate()
  766. cocoEval.accumulate()
  767. ps_supercategory = cocoEval.eval['precision'][0, :, k, :, :]
  768. ps_['ps_supercategory'] = ps_supercategory
  769. # compute precision but ignore any class confusion
  770. gt = copy.deepcopy(cocoGt)
  771. for idx, ann in enumerate(gt.dataset['annotations']):
  772. if ann['category_id'] != catId:
  773. gt.dataset['annotations'][idx]['ignore'] = 1
  774. gt.dataset['annotations'][idx]['iscrowd'] = 1
  775. gt.dataset['annotations'][idx]['category_id'] = catId
  776. cocoEval = COCOeval(gt, copy.deepcopy(dt), iou_type)
  777. cocoEval.params.imgIds = imgIds
  778. cocoEval.params.maxDets = [100]
  779. cocoEval.params.iouThrs = [.1]
  780. cocoEval.params.useCats = 1
  781. cocoEval.evaluate()
  782. cocoEval.accumulate()
  783. ps_allcategory = cocoEval.eval['precision'][0, :, k, :, :]
  784. ps_['ps_allcategory'] = ps_allcategory
  785. return k, ps_
  786. def coco_error_analysis(eval_details_file=None,
  787. gt=None,
  788. pred_bbox=None,
  789. pred_mask=None,
  790. save_dir='./output'):
  791. """逐个分析模型预测错误的原因,并将分析结果以图表的形式展示。
  792. 分析结果说明参考COCODataset官网给出分析工具说明https://cocodataset.org/#detection-eval。
  793. Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/coco_error_analysis.py
  794. Args:
  795. eval_details_file (str): 模型评估结果的保存路径,包含真值信息和预测结果。
  796. gt (list): 数据集的真值信息。默认值为None。
  797. pred_bbox (list): 模型在数据集上的预测框。默认值为None。
  798. pred_mask (list): 模型在数据集上的预测mask。默认值为None。
  799. save_dir (str): 可视化结果保存路径。默认值为'./output'。
  800. Note:
  801. eval_details_file的优先级更高,只要eval_details_file不为None,
  802. 就会从eval_details_file提取真值信息和预测结果做分析。
  803. 当eval_details_file为None时,则用gt、pred_mask、pred_mask做分析。
  804. """
  805. import multiprocessing as mp
  806. from pycocotools.coco import COCO
  807. from pycocotools.cocoeval import COCOeval
  808. if eval_details_file is not None:
  809. import json
  810. with open(eval_details_file, 'r') as f:
  811. eval_details = json.load(f)
  812. pred_bbox = eval_details['bbox']
  813. if 'mask' in eval_details:
  814. pred_mask = eval_details['mask']
  815. gt = eval_details['gt']
  816. if gt is None or pred_bbox is None:
  817. raise Exception(
  818. "gt/pred_bbox/pred_mask is None now, please set right eval_details_file or gt/pred_bbox/pred_mask."
  819. )
  820. if pred_bbox is not None and len(pred_bbox) == 0:
  821. raise Exception("There is no predicted bbox.")
  822. if pred_mask is not None and len(pred_mask) == 0:
  823. raise Exception("There is no predicted mask.")
  824. def _analyze_results(cocoGt, cocoDt, res_type, out_dir):
  825. directory = os.path.dirname(out_dir + '/')
  826. if not os.path.exists(directory):
  827. logging.info('-------------create {}-----------------'.format(
  828. out_dir))
  829. os.makedirs(directory)
  830. imgIds = cocoGt.getImgIds()
  831. res_out_dir = out_dir + '/' + res_type + '/'
  832. res_directory = os.path.dirname(res_out_dir)
  833. if not os.path.exists(res_directory):
  834. logging.info('-------------create {}-----------------'.format(
  835. res_out_dir))
  836. os.makedirs(res_directory)
  837. iou_type = res_type
  838. cocoEval = COCOeval(
  839. copy.deepcopy(cocoGt), copy.deepcopy(cocoDt), iou_type)
  840. cocoEval.params.imgIds = imgIds
  841. cocoEval.params.iouThrs = [.75, .5, .1]
  842. cocoEval.params.maxDets = [100]
  843. cocoEval.evaluate()
  844. cocoEval.accumulate()
  845. ps = cocoEval.eval['precision']
  846. ps = np.vstack([ps, np.zeros((4, *ps.shape[1:]))])
  847. catIds = cocoGt.getCatIds()
  848. recThrs = cocoEval.params.recThrs
  849. thread_num = mp.cpu_count() if mp.cpu_count() < 8 else 8
  850. thread_pool = mp.pool.ThreadPool(thread_num)
  851. args = [(k, cocoDt, cocoGt, catId, iou_type)
  852. for k, catId in enumerate(catIds)]
  853. analyze_results = thread_pool.starmap(analyze_individual_category, args)
  854. for k, catId in enumerate(catIds):
  855. nm = cocoGt.loadCats(catId)[0]
  856. logging.info('--------------saving {}-{}---------------'.format(
  857. k + 1, nm['name']))
  858. analyze_result = analyze_results[k]
  859. assert k == analyze_result[0], ""
  860. ps_supercategory = analyze_result[1]['ps_supercategory']
  861. ps_allcategory = analyze_result[1]['ps_allcategory']
  862. # compute precision but ignore superclass confusion
  863. ps[3, :, k, :, :] = ps_supercategory
  864. # compute precision but ignore any class confusion
  865. ps[4, :, k, :, :] = ps_allcategory
  866. # fill in background and false negative errors and plot
  867. T, _, _, A, _ = ps.shape
  868. for t in range(T):
  869. for a in range(A):
  870. if np.sum(ps[t, :, k, a, :] ==
  871. -1) != len(ps[t, :, k, :, :]):
  872. ps[t, :, k, a, :][ps[t, :, k, a, :] == -1] = 0
  873. ps[5, :, k, :, :] = (ps[4, :, k, :, :] > 0)
  874. ps[6, :, k, :, :] = 1.0
  875. makeplot(recThrs, ps[:, :, k], res_out_dir, nm['name'], iou_type)
  876. makeplot(recThrs, ps, res_out_dir, 'allclass', iou_type)
  877. np.linspace = fixed_linspace
  878. coco_gt = COCO()
  879. coco_gt.dataset = gt
  880. coco_gt.createIndex()
  881. from pycocotools.cocoeval import COCOeval
  882. if pred_bbox is not None:
  883. coco_dt = loadRes(coco_gt, pred_bbox)
  884. _analyze_results(coco_gt, coco_dt, res_type='bbox', out_dir=save_dir)
  885. if pred_mask is not None:
  886. coco_dt = loadRes(coco_gt, pred_mask)
  887. _analyze_results(coco_gt, coco_dt, res_type='segm', out_dir=save_dir)
  888. np.linspace = backup_linspace
  889. logging.info("The analysis figures are saved in {}".format(save_dir))