detection_eval.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067
  1. # coding: utf8
  2. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import absolute_import
  16. import numpy as np
  17. import json
  18. import os
  19. import sys
  20. import cv2
  21. import copy
  22. import paddlex.utils.logging as logging
  23. # fix linspace problem for pycocotools while numpy > 1.17.2
  24. backup_linspace = np.linspace
  25. def fixed_linspace(start,
  26. stop,
  27. num=50,
  28. endpoint=True,
  29. retstep=False,
  30. dtype=None,
  31. axis=0):
  32. num = int(num)
  33. return backup_linspace(start, stop, num, endpoint, retstep, dtype, axis)
  34. def eval_results(results,
  35. metric,
  36. coco_gt,
  37. with_background=True,
  38. resolution=None,
  39. is_bbox_normalized=False,
  40. map_type='11point'):
  41. """Evaluation for evaluation program results"""
  42. box_ap_stats = []
  43. coco_gt_data = copy.deepcopy(coco_gt)
  44. eval_details = {'gt': copy.deepcopy(coco_gt.dataset)}
  45. if metric == 'COCO':
  46. np.linspace = fixed_linspace
  47. if 'proposal' in results[0]:
  48. proposal_eval(results, coco_gt_data)
  49. if 'bbox' in results[0]:
  50. box_ap_stats, xywh_results = coco_bbox_eval(
  51. results,
  52. coco_gt_data,
  53. with_background,
  54. is_bbox_normalized=is_bbox_normalized)
  55. if 'mask' in results[0]:
  56. mask_ap_stats, segm_results = mask_eval(results, coco_gt_data,
  57. resolution)
  58. ap_stats = [box_ap_stats, mask_ap_stats]
  59. eval_details['bbox'] = xywh_results
  60. eval_details['mask'] = segm_results
  61. return ap_stats, eval_details
  62. np.linspace = backup_linspace
  63. else:
  64. if 'accum_map' in results[-1]:
  65. res = np.mean(results[-1]['accum_map'][0])
  66. logging.debug('mAP: {:.2f}'.format(res * 100.))
  67. box_ap_stats.append(res * 100.)
  68. elif 'bbox' in results[0]:
  69. box_ap, xywh_results = voc_bbox_eval(
  70. results,
  71. coco_gt_data,
  72. with_background,
  73. is_bbox_normalized=is_bbox_normalized,
  74. map_type=map_type)
  75. box_ap_stats.append(box_ap)
  76. eval_details['bbox'] = xywh_results
  77. return box_ap_stats, eval_details
  78. def proposal_eval(results, coco_gt, outputfile, max_dets=(100, 300, 1000)):
  79. assert 'proposal' in results[0]
  80. assert outfile.endswith('.json')
  81. xywh_results = proposal2out(results)
  82. assert len(
  83. xywh_results) > 0, "The number of valid proposal detected is zero.\n \
  84. Please use reasonable model and check input data."
  85. with open(outfile, 'w') as f:
  86. json.dump(xywh_results, f)
  87. cocoapi_eval(xywh_results, 'proposal', coco_gt=coco_gt, max_dets=max_dets)
  88. # flush coco evaluation result
  89. sys.stdout.flush()
  90. def coco_bbox_eval(results,
  91. coco_gt,
  92. with_background=True,
  93. is_bbox_normalized=False):
  94. assert 'bbox' in results[0]
  95. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  96. # or matplotlib.backends is imported for the first time
  97. # pycocotools import matplotlib
  98. import matplotlib
  99. matplotlib.use('Agg')
  100. from pycocotools.coco import COCO
  101. cat_ids = coco_gt.getCatIds()
  102. # when with_background = True, mapping category to classid, like:
  103. # background:0, first_class:1, second_class:2, ...
  104. clsid2catid = dict(
  105. {i + int(with_background): catid
  106. for i, catid in enumerate(cat_ids)})
  107. xywh_results = bbox2out(
  108. results, clsid2catid, is_bbox_normalized=is_bbox_normalized)
  109. results = copy.deepcopy(xywh_results)
  110. if len(xywh_results) == 0:
  111. logging.warning(
  112. "The number of valid bbox detected is zero.\n Please use reasonable model and check input data.\n stop eval!"
  113. )
  114. return [0.0], results
  115. map_stats = cocoapi_eval(xywh_results, 'bbox', coco_gt=coco_gt)
  116. # flush coco evaluation result
  117. sys.stdout.flush()
  118. return map_stats, results
  119. def loadRes(coco_obj, anns):
  120. """
  121. Load result file and return a result api object.
  122. :param resFile (str) : file name of result file
  123. :return: res (obj) : result api object
  124. """
  125. # This function has the same functionality as pycocotools.COCO.loadRes,
  126. # except that the input anns is list of results rather than a json file.
  127. # Refer to
  128. # https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/coco.py#L305,
  129. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  130. # or matplotlib.backends is imported for the first time
  131. # pycocotools import matplotlib
  132. import matplotlib
  133. matplotlib.use('Agg')
  134. from pycocotools.coco import COCO
  135. import pycocotools.mask as maskUtils
  136. import time
  137. res = COCO()
  138. res.dataset['images'] = [img for img in coco_obj.dataset['images']]
  139. tic = time.time()
  140. assert type(anns) == list, 'results in not an array of objects'
  141. annsImgIds = [ann['image_id'] for ann in anns]
  142. assert set(annsImgIds) == (set(annsImgIds) & set(coco_obj.getImgIds())), \
  143. 'Results do not correspond to current coco set'
  144. if 'caption' in anns[0]:
  145. imgIds = set([img['id'] for img in res.dataset['images']]) & set(
  146. [ann['image_id'] for ann in anns])
  147. res.dataset['images'] = [
  148. img for img in res.dataset['images'] if img['id'] in imgIds
  149. ]
  150. for id, ann in enumerate(anns):
  151. ann['id'] = id + 1
  152. elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
  153. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  154. 'categories'])
  155. for id, ann in enumerate(anns):
  156. bb = ann['bbox']
  157. x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
  158. if not 'segmentation' in ann:
  159. ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
  160. ann['area'] = bb[2] * bb[3]
  161. ann['id'] = id + 1
  162. ann['iscrowd'] = 0
  163. elif 'segmentation' in anns[0]:
  164. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  165. 'categories'])
  166. for id, ann in enumerate(anns):
  167. # now only support compressed RLE format as segmentation results
  168. ann['area'] = maskUtils.area(ann['segmentation'])
  169. if not 'bbox' in ann:
  170. ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
  171. ann['id'] = id + 1
  172. ann['iscrowd'] = 0
  173. elif 'keypoints' in anns[0]:
  174. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  175. 'categories'])
  176. for id, ann in enumerate(anns):
  177. s = ann['keypoints']
  178. x = s[0::3]
  179. y = s[1::3]
  180. x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y)
  181. ann['area'] = (x1 - x0) * (y1 - y0)
  182. ann['id'] = id + 1
  183. ann['bbox'] = [x0, y0, x1 - x0, y1 - y0]
  184. res.dataset['annotations'] = anns
  185. res.createIndex()
  186. return res
  187. def mask_eval(results, coco_gt, resolution, thresh_binarize=0.5):
  188. assert 'mask' in results[0]
  189. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  190. # or matplotlib.backends is imported for the first time
  191. # pycocotools import matplotlib
  192. import matplotlib
  193. matplotlib.use('Agg')
  194. from pycocotools.coco import COCO
  195. clsid2catid = {i + 1: v for i, v in enumerate(coco_gt.getCatIds())}
  196. segm_results = mask2out(results, clsid2catid, resolution, thresh_binarize)
  197. results = copy.deepcopy(segm_results)
  198. if len(segm_results) == 0:
  199. logging.warning(
  200. "The number of valid mask detected is zero.\n Please use reasonable model and check input data."
  201. )
  202. return None, results
  203. map_stats = cocoapi_eval(segm_results, 'segm', coco_gt=coco_gt)
  204. return map_stats, results
  205. def cocoapi_eval(anns,
  206. style,
  207. coco_gt=None,
  208. anno_file=None,
  209. max_dets=(100, 300, 1000)):
  210. """
  211. Args:
  212. anns: Evaluation result.
  213. style: COCOeval style, can be `bbox` , `segm` and `proposal`.
  214. coco_gt: Whether to load COCOAPI through anno_file,
  215. eg: coco_gt = COCO(anno_file)
  216. anno_file: COCO annotations file.
  217. max_dets: COCO evaluation maxDets.
  218. """
  219. assert coco_gt != None or anno_file != None
  220. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  221. # or matplotlib.backends is imported for the first time
  222. # pycocotools import matplotlib
  223. import matplotlib
  224. matplotlib.use('Agg')
  225. from pycocotools.coco import COCO
  226. from pycocotools.cocoeval import COCOeval
  227. if coco_gt == None:
  228. coco_gt = COCO(anno_file)
  229. logging.debug("Start evaluate...")
  230. coco_dt = loadRes(coco_gt, anns)
  231. if style == 'proposal':
  232. coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
  233. coco_eval.params.useCats = 0
  234. coco_eval.params.maxDets = list(max_dets)
  235. else:
  236. coco_eval = COCOeval(coco_gt, coco_dt, style)
  237. coco_eval.evaluate()
  238. coco_eval.accumulate()
  239. coco_eval.summarize()
  240. return coco_eval.stats
  241. def proposal2out(results, is_bbox_normalized=False):
  242. xywh_res = []
  243. for t in results:
  244. bboxes = t['proposal'][0]
  245. lengths = t['proposal'][1][0]
  246. im_ids = np.array(t['im_id'][0]).flatten()
  247. assert len(lengths) == im_ids.size
  248. if bboxes.shape == (1, 1) or bboxes is None:
  249. continue
  250. k = 0
  251. for i in range(len(lengths)):
  252. num = lengths[i]
  253. im_id = int(im_ids[i])
  254. for j in range(num):
  255. dt = bboxes[k]
  256. xmin, ymin, xmax, ymax = dt.tolist()
  257. if is_bbox_normalized:
  258. xmin, ymin, xmax, ymax = \
  259. clip_bbox([xmin, ymin, xmax, ymax])
  260. w = xmax - xmin
  261. h = ymax - ymin
  262. else:
  263. w = xmax - xmin + 1
  264. h = ymax - ymin + 1
  265. bbox = [xmin, ymin, w, h]
  266. coco_res = {
  267. 'image_id': im_id,
  268. 'category_id': 1,
  269. 'bbox': bbox,
  270. 'score': 1.0
  271. }
  272. xywh_res.append(coco_res)
  273. k += 1
  274. return xywh_res
  275. def bbox2out(results, clsid2catid, is_bbox_normalized=False):
  276. """
  277. Args:
  278. results: request a dict, should include: `bbox`, `im_id`,
  279. if is_bbox_normalized=True, also need `im_shape`.
  280. clsid2catid: class id to category id map of COCO2017 dataset.
  281. is_bbox_normalized: whether or not bbox is normalized.
  282. """
  283. xywh_res = []
  284. for t in results:
  285. bboxes = t['bbox'][0]
  286. lengths = t['bbox'][1][0]
  287. im_ids = np.array(t['im_id'][0]).flatten()
  288. if bboxes.shape == (1, 1) or bboxes is None:
  289. continue
  290. k = 0
  291. for i in range(len(lengths)):
  292. num = lengths[i]
  293. im_id = int(im_ids[i])
  294. for j in range(num):
  295. dt = bboxes[k]
  296. clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
  297. catid = (clsid2catid[int(clsid)])
  298. if is_bbox_normalized:
  299. xmin, ymin, xmax, ymax = \
  300. clip_bbox([xmin, ymin, xmax, ymax])
  301. w = xmax - xmin
  302. h = ymax - ymin
  303. im_shape = t['im_shape'][0][i].tolist()
  304. im_height, im_width = int(im_shape[0]), int(im_shape[1])
  305. xmin *= im_width
  306. ymin *= im_height
  307. w *= im_width
  308. h *= im_height
  309. else:
  310. w = xmax - xmin + 1
  311. h = ymax - ymin + 1
  312. bbox = [xmin, ymin, w, h]
  313. coco_res = {
  314. 'image_id': im_id,
  315. 'category_id': catid,
  316. 'bbox': bbox,
  317. 'score': score
  318. }
  319. xywh_res.append(coco_res)
  320. k += 1
  321. return xywh_res
  322. def mask2out(results, clsid2catid, resolution, thresh_binarize=0.5):
  323. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  324. # or matplotlib.backends is imported for the first time
  325. # pycocotools import matplotlib
  326. import matplotlib
  327. matplotlib.use('Agg')
  328. import pycocotools.mask as mask_util
  329. scale = (resolution + 2.0) / resolution
  330. segm_res = []
  331. # for each batch
  332. for t in results:
  333. bboxes = t['bbox'][0]
  334. lengths = t['bbox'][1][0]
  335. im_ids = np.array(t['im_id'][0])
  336. if bboxes.shape == (1, 1) or bboxes is None:
  337. continue
  338. if len(bboxes.tolist()) == 0:
  339. continue
  340. masks = t['mask'][0]
  341. s = 0
  342. # for each sample
  343. for i in range(len(lengths)):
  344. num = lengths[i]
  345. im_id = int(im_ids[i][0])
  346. im_shape = t['im_shape'][0][i]
  347. bbox = bboxes[s:s + num][:, 2:]
  348. clsid_scores = bboxes[s:s + num][:, 0:2]
  349. mask = masks[s:s + num]
  350. s += num
  351. im_h = int(im_shape[0])
  352. im_w = int(im_shape[1])
  353. expand_bbox = expand_boxes(bbox, scale)
  354. expand_bbox = expand_bbox.astype(np.int32)
  355. padded_mask = np.zeros(
  356. (resolution + 2, resolution + 2), dtype=np.float32)
  357. for j in range(num):
  358. xmin, ymin, xmax, ymax = expand_bbox[j].tolist()
  359. clsid, score = clsid_scores[j].tolist()
  360. clsid = int(clsid)
  361. padded_mask[1:-1, 1:-1] = mask[j, clsid, :, :]
  362. catid = clsid2catid[clsid]
  363. w = xmax - xmin + 1
  364. h = ymax - ymin + 1
  365. w = np.maximum(w, 1)
  366. h = np.maximum(h, 1)
  367. resized_mask = cv2.resize(padded_mask, (w, h))
  368. resized_mask = np.array(
  369. resized_mask > thresh_binarize, dtype=np.uint8)
  370. im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
  371. x0 = min(max(xmin, 0), im_w)
  372. x1 = min(max(xmax + 1, 0), im_w)
  373. y0 = min(max(ymin, 0), im_h)
  374. y1 = min(max(ymax + 1, 0), im_h)
  375. im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), (
  376. x0 - xmin):(x1 - xmin)]
  377. segm = mask_util.encode(
  378. np.array(
  379. im_mask[:, :, np.newaxis], order='F'))[0]
  380. catid = clsid2catid[clsid]
  381. segm['counts'] = segm['counts'].decode('utf8')
  382. coco_res = {
  383. 'image_id': im_id,
  384. 'category_id': catid,
  385. 'segmentation': segm,
  386. 'score': score
  387. }
  388. segm_res.append(coco_res)
  389. return segm_res
  390. def expand_boxes(boxes, scale):
  391. """
  392. Expand an array of boxes by a given scale.
  393. """
  394. w_half = (boxes[:, 2] - boxes[:, 0]) * .5
  395. h_half = (boxes[:, 3] - boxes[:, 1]) * .5
  396. x_c = (boxes[:, 2] + boxes[:, 0]) * .5
  397. y_c = (boxes[:, 3] + boxes[:, 1]) * .5
  398. w_half *= scale
  399. h_half *= scale
  400. boxes_exp = np.zeros(boxes.shape)
  401. boxes_exp[:, 0] = x_c - w_half
  402. boxes_exp[:, 2] = x_c + w_half
  403. boxes_exp[:, 1] = y_c - h_half
  404. boxes_exp[:, 3] = y_c + h_half
  405. return boxes_exp
  406. def voc_bbox_eval(results,
  407. coco_gt,
  408. with_background=False,
  409. overlap_thresh=0.5,
  410. map_type='11point',
  411. is_bbox_normalized=False,
  412. evaluate_difficult=False):
  413. """
  414. Bounding box evaluation for VOC dataset
  415. Args:
  416. results (list): prediction bounding box results.
  417. class_num (int): evaluation class number.
  418. overlap_thresh (float): the postive threshold of
  419. bbox overlap
  420. map_type (string): method for mAP calcualtion,
  421. can only be '11point' or 'integral'
  422. is_bbox_normalized (bool): whether bbox is normalized
  423. to range [0, 1].
  424. evaluate_difficult (bool): whether to evaluate
  425. difficult gt bbox.
  426. """
  427. assert 'bbox' in results[0]
  428. logging.debug("Start evaluate...")
  429. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  430. # or matplotlib.backends is imported for the first time
  431. # pycocotools import matplotlib
  432. import matplotlib
  433. matplotlib.use('Agg')
  434. from pycocotools.coco import COCO
  435. cat_ids = coco_gt.getCatIds()
  436. # when with_background = True, mapping category to classid, like:
  437. # background:0, first_class:1, second_class:2, ...
  438. clsid2catid = dict(
  439. {i + int(with_background): catid
  440. for i, catid in enumerate(cat_ids)})
  441. class_num = len(clsid2catid) + int(with_background)
  442. detection_map = DetectionMAP(
  443. class_num=class_num,
  444. overlap_thresh=overlap_thresh,
  445. map_type=map_type,
  446. is_bbox_normalized=is_bbox_normalized,
  447. evaluate_difficult=evaluate_difficult)
  448. xywh_res = []
  449. det_nums = 0
  450. gt_nums = 0
  451. for t in results:
  452. bboxes = t['bbox'][0]
  453. bbox_lengths = t['bbox'][1][0]
  454. im_ids = np.array(t['im_id'][0]).flatten()
  455. if bboxes.shape == (1, 1) or bboxes is None:
  456. continue
  457. gt_boxes = t['gt_box'][0]
  458. gt_labels = t['gt_label'][0]
  459. difficults = t['is_difficult'][0] if not evaluate_difficult \
  460. else None
  461. if len(t['gt_box'][1]) == 0:
  462. # gt_box, gt_label, difficult read as zero padded Tensor
  463. bbox_idx = 0
  464. for i in range(len(gt_boxes)):
  465. gt_box = gt_boxes[i]
  466. gt_label = gt_labels[i]
  467. difficult = None if difficults is None \
  468. else difficults[i]
  469. bbox_num = bbox_lengths[i]
  470. bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
  471. gt_box, gt_label, difficult = prune_zero_padding(
  472. gt_box, gt_label, difficult)
  473. detection_map.update(bbox, gt_box, gt_label, difficult)
  474. bbox_idx += bbox_num
  475. det_nums += bbox_num
  476. gt_nums += gt_box.shape[0]
  477. im_id = int(im_ids[i])
  478. for b in bbox:
  479. clsid, score, xmin, ymin, xmax, ymax = b.tolist()
  480. w = xmax - xmin + 1
  481. h = ymax - ymin + 1
  482. bbox = [xmin, ymin, w, h]
  483. coco_res = {
  484. 'image_id': im_id,
  485. 'category_id': clsid2catid[clsid],
  486. 'bbox': bbox,
  487. 'score': score
  488. }
  489. xywh_res.append(coco_res)
  490. else:
  491. # gt_box, gt_label, difficult read as LoDTensor
  492. gt_box_lengths = t['gt_box'][1][0]
  493. bbox_idx = 0
  494. gt_box_idx = 0
  495. for i in range(len(bbox_lengths)):
  496. bbox_num = bbox_lengths[i]
  497. gt_box_num = gt_box_lengths[i]
  498. bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
  499. gt_box = gt_boxes[gt_box_idx:gt_box_idx + gt_box_num]
  500. gt_label = gt_labels[gt_box_idx:gt_box_idx + gt_box_num]
  501. difficult = None if difficults is None else \
  502. difficults[gt_box_idx: gt_box_idx + gt_box_num]
  503. detection_map.update(bbox, gt_box, gt_label, difficult)
  504. bbox_idx += bbox_num
  505. gt_box_idx += gt_box_num
  506. im_id = int(im_ids[i])
  507. for b in bbox:
  508. clsid, score, xmin, ymin, xmax, ymax = b.tolist()
  509. w = xmax - xmin + 1
  510. h = ymax - ymin + 1
  511. bbox = [xmin, ymin, w, h]
  512. coco_res = {
  513. 'image_id': im_id,
  514. 'category_id': clsid2catid[clsid],
  515. 'bbox': bbox,
  516. 'score': score
  517. }
  518. xywh_res.append(coco_res)
  519. logging.debug("Accumulating evaluatation results...")
  520. detection_map.accumulate()
  521. map_stat = 100. * detection_map.get_map()
  522. logging.debug("mAP({:.2f}, {}) = {:.2f}".format(overlap_thresh, map_type,
  523. map_stat))
  524. return map_stat, xywh_res
  525. def prune_zero_padding(gt_box, gt_label, difficult=None):
  526. valid_cnt = 0
  527. for i in range(len(gt_box)):
  528. if gt_box[i, 0] == 0 and gt_box[i, 1] == 0 and \
  529. gt_box[i, 2] == 0 and gt_box[i, 3] == 0:
  530. break
  531. valid_cnt += 1
  532. return (gt_box[:valid_cnt], gt_label[:valid_cnt], difficult[:valid_cnt]
  533. if difficult is not None else None)
  534. def bbox_area(bbox, is_bbox_normalized):
  535. """
  536. Calculate area of a bounding box
  537. """
  538. norm = 1. - float(is_bbox_normalized)
  539. width = bbox[2] - bbox[0] + norm
  540. height = bbox[3] - bbox[1] + norm
  541. return width * height
  542. def jaccard_overlap(pred, gt, is_bbox_normalized=False):
  543. """
  544. Calculate jaccard overlap ratio between two bounding box
  545. """
  546. if pred[0] >= gt[2] or pred[2] <= gt[0] or \
  547. pred[1] >= gt[3] or pred[3] <= gt[1]:
  548. return 0.
  549. inter_xmin = max(pred[0], gt[0])
  550. inter_ymin = max(pred[1], gt[1])
  551. inter_xmax = min(pred[2], gt[2])
  552. inter_ymax = min(pred[3], gt[3])
  553. inter_size = bbox_area([inter_xmin, inter_ymin, inter_xmax, inter_ymax],
  554. is_bbox_normalized)
  555. pred_size = bbox_area(pred, is_bbox_normalized)
  556. gt_size = bbox_area(gt, is_bbox_normalized)
  557. overlap = float(inter_size) / (pred_size + gt_size - inter_size)
  558. return overlap
  559. class DetectionMAP(object):
  560. """
  561. Calculate detection mean average precision.
  562. Currently support two types: 11point and integral
  563. Args:
  564. class_num (int): the class number.
  565. overlap_thresh (float): The threshold of overlap
  566. ratio between prediction bounding box and
  567. ground truth bounding box for deciding
  568. true/false positive. Default 0.5.
  569. map_type (str): calculation method of mean average
  570. precision, currently support '11point' and
  571. 'integral'. Default '11point'.
  572. is_bbox_normalized (bool): whther bounding boxes
  573. is normalized to range[0, 1]. Default False.
  574. evaluate_difficult (bool): whether to evaluate
  575. difficult bounding boxes. Default False.
  576. """
  577. def __init__(self,
  578. class_num,
  579. overlap_thresh=0.5,
  580. map_type='11point',
  581. is_bbox_normalized=False,
  582. evaluate_difficult=False):
  583. self.class_num = class_num
  584. self.overlap_thresh = overlap_thresh
  585. assert map_type in ['11point', 'integral'], \
  586. "map_type currently only support '11point' "\
  587. "and 'integral'"
  588. self.map_type = map_type
  589. self.is_bbox_normalized = is_bbox_normalized
  590. self.evaluate_difficult = evaluate_difficult
  591. self.reset()
  592. def update(self, bbox, gt_box, gt_label, difficult=None):
  593. """
  594. Update metric statics from given prediction and ground
  595. truth infomations.
  596. """
  597. if difficult is None:
  598. difficult = np.zeros_like(gt_label)
  599. # record class gt count
  600. for gtl, diff in zip(gt_label, difficult):
  601. if self.evaluate_difficult or int(diff) == 0:
  602. self.class_gt_counts[int(np.array(gtl))] += 1
  603. # record class score positive
  604. visited = [False] * len(gt_label)
  605. for b in bbox:
  606. label, score, xmin, ymin, xmax, ymax = b.tolist()
  607. pred = [xmin, ymin, xmax, ymax]
  608. max_idx = -1
  609. max_overlap = -1.0
  610. for i, gl in enumerate(gt_label):
  611. if int(gl) == int(label):
  612. overlap = jaccard_overlap(pred, gt_box[i],
  613. self.is_bbox_normalized)
  614. if overlap > max_overlap:
  615. max_overlap = overlap
  616. max_idx = i
  617. if max_overlap > self.overlap_thresh:
  618. if self.evaluate_difficult or \
  619. int(np.array(difficult[max_idx])) == 0:
  620. if not visited[max_idx]:
  621. self.class_score_poss[int(label)].append([score, 1.0])
  622. visited[max_idx] = True
  623. else:
  624. self.class_score_poss[int(label)].append([score, 0.0])
  625. else:
  626. self.class_score_poss[int(label)].append([score, 0.0])
  627. def reset(self):
  628. """
  629. Reset metric statics
  630. """
  631. self.class_score_poss = [[] for _ in range(self.class_num)]
  632. self.class_gt_counts = [0] * self.class_num
  633. self.mAP = None
  634. self.APs = [None] * self.class_num
  635. def accumulate(self):
  636. """
  637. Accumulate metric results and calculate mAP
  638. """
  639. mAP = 0.
  640. valid_cnt = 0
  641. for id, (
  642. score_pos, count
  643. ) in enumerate(zip(self.class_score_poss, self.class_gt_counts)):
  644. if count == 0: continue
  645. if len(score_pos) == 0:
  646. valid_cnt += 1
  647. continue
  648. accum_tp_list, accum_fp_list = \
  649. self._get_tp_fp_accum(score_pos)
  650. precision = []
  651. recall = []
  652. for ac_tp, ac_fp in zip(accum_tp_list, accum_fp_list):
  653. precision.append(float(ac_tp) / (ac_tp + ac_fp))
  654. recall.append(float(ac_tp) / count)
  655. if self.map_type == '11point':
  656. max_precisions = [0.] * 11
  657. start_idx = len(precision) - 1
  658. for j in range(10, -1, -1):
  659. for i in range(start_idx, -1, -1):
  660. if recall[i] < float(j) / 10.:
  661. start_idx = i
  662. if j > 0:
  663. max_precisions[j - 1] = max_precisions[j]
  664. break
  665. else:
  666. if max_precisions[j] < precision[i]:
  667. max_precisions[j] = precision[i]
  668. mAP += sum(max_precisions) / 11.
  669. self.APs[id] = sum(max_precisions) / 11.
  670. valid_cnt += 1
  671. elif self.map_type == 'integral':
  672. import math
  673. ap = 0.
  674. prev_recall = 0.
  675. for i in range(len(precision)):
  676. recall_gap = math.fabs(recall[i] - prev_recall)
  677. if recall_gap > 1e-6:
  678. ap += precision[i] * recall_gap
  679. prev_recall = recall[i]
  680. mAP += ap
  681. self.APs[id] = sum(max_precisions) / 11.
  682. valid_cnt += 1
  683. else:
  684. raise Exception("Unspported mAP type {}".format(self.map_type))
  685. self.mAP = mAP / float(valid_cnt) if valid_cnt > 0 else mAP
  686. def get_map(self):
  687. """
  688. Get mAP result
  689. """
  690. if self.mAP is None:
  691. raise Exception("mAP is not calculated.")
  692. return self.mAP
  693. def _get_tp_fp_accum(self, score_pos_list):
  694. """
  695. Calculate accumulating true/false positive results from
  696. [score, pos] records
  697. """
  698. sorted_list = sorted(score_pos_list, key=lambda s: s[0], reverse=True)
  699. accum_tp = 0
  700. accum_fp = 0
  701. accum_tp_list = []
  702. accum_fp_list = []
  703. for (score, pos) in sorted_list:
  704. accum_tp += int(pos)
  705. accum_tp_list.append(accum_tp)
  706. accum_fp += 1 - int(pos)
  707. accum_fp_list.append(accum_fp)
  708. return accum_tp_list, accum_fp_list
  709. def makeplot(rs, ps, outDir, class_name, iou_type):
  710. """针对某个特定类别,绘制不同评估要求下的准确率和召回率。
  711. 绘制结果说明参考COCODataset官网给出分析工具说明https://cocodataset.org/#detection-eval。
  712. Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/analysis_tools/coco_error_analysis.py#L13
  713. Args:
  714. rs (np.array): 在不同置信度阈值下计算得到的召回率。
  715. ps (np.array): 在不同置信度阈值下计算得到的准确率。ps与rs相同位置下的数值为同一个置信度阈值
  716. 计算得到的准确率与召回率。
  717. outDir (str): 图表保存的路径。
  718. class_name (str): 类别名。
  719. iou_type (str): iou计算方式,若为检测框,则设置为'bbox',若为像素级分割结果,则设置为'segm'。
  720. """
  721. import matplotlib.pyplot as plt
  722. cs = np.vstack([
  723. np.ones((2, 3)), np.array([.31, .51, .74]), np.array([.75, .31, .30]),
  724. np.array([.36, .90, .38]), np.array([.50, .39, .64]),
  725. np.array([1, .6, 0])
  726. ])
  727. areaNames = ['allarea', 'small', 'medium', 'large']
  728. types = ['C75', 'C50', 'Loc', 'Sim', 'Oth', 'BG', 'FN']
  729. for i in range(len(areaNames)):
  730. area_ps = ps[..., i, 0]
  731. figure_tile = iou_type + '-' + class_name + '-' + areaNames[i]
  732. aps = [ps_.mean() for ps_ in area_ps]
  733. ps_curve = [
  734. ps_.mean(axis=1) if ps_.ndim > 1 else ps_ for ps_ in area_ps
  735. ]
  736. ps_curve.insert(0, np.zeros(ps_curve[0].shape))
  737. fig = plt.figure()
  738. ax = plt.subplot(111)
  739. for k in range(len(types)):
  740. ax.plot(rs, ps_curve[k + 1], color=[0, 0, 0], linewidth=0.5)
  741. ax.fill_between(
  742. rs,
  743. ps_curve[k],
  744. ps_curve[k + 1],
  745. color=cs[k],
  746. label=str('[{:.3f}'.format(aps[k]) + ']' + types[k]))
  747. plt.xlabel('recall')
  748. plt.ylabel('precision')
  749. plt.xlim(0, 1.)
  750. plt.ylim(0, 1.)
  751. plt.title(figure_tile)
  752. plt.legend()
  753. fig.savefig(outDir + '/{}.png'.format(figure_tile))
  754. plt.close(fig)
  755. def analyze_individual_category(k, cocoDt, cocoGt, catId, iou_type):
  756. """针对某个特定类别,分析忽略亚类混淆和类别混淆时的准确率。
  757. Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/coco_error_analysis.py
  758. Args:
  759. k (int): 待分析类别的序号。
  760. cocoDt (pycocotols.coco.COCO): 按COCO类存放的预测结果。
  761. cocoGt (pycocotols.coco.COCO): 按COCO类存放的真值。
  762. catId (int): 待分析类别在数据集中的类别id。
  763. iou_type (str): iou计算方式,若为检测框,则设置为'bbox',若为像素级分割结果,则设置为'segm'。
  764. Returns:
  765. int:
  766. dict: 有关键字'ps_supercategory'和'ps_allcategory'。关键字'ps_supercategory'的键值是忽略亚类间
  767. 混淆时的准确率,关键字'ps_allcategory'的键值是忽略类别间混淆时的准确率。
  768. """
  769. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  770. # or matplotlib.backends is imported for the first time
  771. # pycocotools import matplotlib
  772. import matplotlib
  773. matplotlib.use('Agg')
  774. from pycocotools.coco import COCO
  775. from pycocotools.cocoeval import COCOeval
  776. nm = cocoGt.loadCats(catId)[0]
  777. logging.info('--------------analyzing {}-{}---------------'.format(
  778. k + 1, nm['name']))
  779. ps_ = {}
  780. dt = copy.deepcopy(cocoDt)
  781. nm = cocoGt.loadCats(catId)[0]
  782. imgIds = cocoGt.getImgIds()
  783. dt_anns = dt.dataset['annotations']
  784. select_dt_anns = []
  785. for ann in dt_anns:
  786. if ann['category_id'] == catId:
  787. select_dt_anns.append(ann)
  788. dt.dataset['annotations'] = select_dt_anns
  789. dt.createIndex()
  790. # compute precision but ignore superclass confusion
  791. gt = copy.deepcopy(cocoGt)
  792. child_catIds = gt.getCatIds(supNms=[nm['supercategory']])
  793. for idx, ann in enumerate(gt.dataset['annotations']):
  794. if (ann['category_id'] in child_catIds and
  795. ann['category_id'] != catId):
  796. gt.dataset['annotations'][idx]['ignore'] = 1
  797. gt.dataset['annotations'][idx]['iscrowd'] = 1
  798. gt.dataset['annotations'][idx]['category_id'] = catId
  799. cocoEval = COCOeval(gt, copy.deepcopy(dt), iou_type)
  800. cocoEval.params.imgIds = imgIds
  801. cocoEval.params.maxDets = [100]
  802. cocoEval.params.iouThrs = [.1]
  803. cocoEval.params.useCats = 1
  804. cocoEval.evaluate()
  805. cocoEval.accumulate()
  806. ps_supercategory = cocoEval.eval['precision'][0, :, k, :, :]
  807. ps_['ps_supercategory'] = ps_supercategory
  808. # compute precision but ignore any class confusion
  809. gt = copy.deepcopy(cocoGt)
  810. for idx, ann in enumerate(gt.dataset['annotations']):
  811. if ann['category_id'] != catId:
  812. gt.dataset['annotations'][idx]['ignore'] = 1
  813. gt.dataset['annotations'][idx]['iscrowd'] = 1
  814. gt.dataset['annotations'][idx]['category_id'] = catId
  815. cocoEval = COCOeval(gt, copy.deepcopy(dt), iou_type)
  816. cocoEval.params.imgIds = imgIds
  817. cocoEval.params.maxDets = [100]
  818. cocoEval.params.iouThrs = [.1]
  819. cocoEval.params.useCats = 1
  820. cocoEval.evaluate()
  821. cocoEval.accumulate()
  822. ps_allcategory = cocoEval.eval['precision'][0, :, k, :, :]
  823. ps_['ps_allcategory'] = ps_allcategory
  824. return k, ps_
  825. def coco_error_analysis(eval_details_file=None,
  826. gt=None,
  827. pred_bbox=None,
  828. pred_mask=None,
  829. save_dir='./output'):
  830. """逐个分析模型预测错误的原因,并将分析结果以图表的形式展示。
  831. 分析结果说明参考COCODataset官网给出分析工具说明https://cocodataset.org/#detection-eval。
  832. Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/coco_error_analysis.py
  833. Args:
  834. eval_details_file (str): 模型评估结果的保存路径,包含真值信息和预测结果。
  835. gt (list): 数据集的真值信息。默认值为None。
  836. pred_bbox (list): 模型在数据集上的预测框。默认值为None。
  837. pred_mask (list): 模型在数据集上的预测mask。默认值为None。
  838. save_dir (str): 可视化结果保存路径。默认值为'./output'。
  839. Note:
  840. eval_details_file的优先级更高,只要eval_details_file不为None,
  841. 就会从eval_details_file提取真值信息和预测结果做分析。
  842. 当eval_details_file为None时,则用gt、pred_mask、pred_mask做分析。
  843. """
  844. import multiprocessing as mp
  845. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  846. # or matplotlib.backends is imported for the first time
  847. # pycocotools import matplotlib
  848. import matplotlib
  849. matplotlib.use('Agg')
  850. from pycocotools.coco import COCO
  851. from pycocotools.cocoeval import COCOeval
  852. if eval_details_file is not None:
  853. import json
  854. with open(eval_details_file, 'r') as f:
  855. eval_details = json.load(f)
  856. pred_bbox = eval_details['bbox']
  857. if 'mask' in eval_details:
  858. pred_mask = eval_details['mask']
  859. gt = eval_details['gt']
  860. if gt is None or pred_bbox is None:
  861. raise Exception(
  862. "gt/pred_bbox/pred_mask is None now, please set right eval_details_file or gt/pred_bbox/pred_mask."
  863. )
  864. if pred_bbox is not None and len(pred_bbox) == 0:
  865. raise Exception("There is no predicted bbox.")
  866. if pred_mask is not None and len(pred_mask) == 0:
  867. raise Exception("There is no predicted mask.")
  868. def _analyze_results(cocoGt, cocoDt, res_type, out_dir):
  869. """
  870. Refer to
  871. https://github.com/open-mmlab/mmdetection/blob/master/tools/analysis_tools/coco_error_analysis.py#L235
  872. """
  873. directory = os.path.dirname(out_dir + '/')
  874. if not os.path.exists(directory):
  875. logging.info('-------------create {}-----------------'.format(
  876. out_dir))
  877. os.makedirs(directory)
  878. imgIds = cocoGt.getImgIds()
  879. res_out_dir = out_dir + '/' + res_type + '/'
  880. res_directory = os.path.dirname(res_out_dir)
  881. if not os.path.exists(res_directory):
  882. logging.info('-------------create {}-----------------'.format(
  883. res_out_dir))
  884. os.makedirs(res_directory)
  885. iou_type = res_type
  886. cocoEval = COCOeval(
  887. copy.deepcopy(cocoGt), copy.deepcopy(cocoDt), iou_type)
  888. cocoEval.params.imgIds = imgIds
  889. cocoEval.params.iouThrs = [.75, .5, .1]
  890. cocoEval.params.maxDets = [100]
  891. cocoEval.evaluate()
  892. cocoEval.accumulate()
  893. ps = cocoEval.eval['precision']
  894. ps = np.vstack([ps, np.zeros((4, *ps.shape[1:]))])
  895. catIds = cocoGt.getCatIds()
  896. recThrs = cocoEval.params.recThrs
  897. thread_num = mp.cpu_count() if mp.cpu_count() < 8 else 8
  898. thread_pool = mp.pool.ThreadPool(thread_num)
  899. args = [(k, cocoDt, cocoGt, catId, iou_type)
  900. for k, catId in enumerate(catIds)]
  901. analyze_results = thread_pool.starmap(analyze_individual_category,
  902. args)
  903. for k, catId in enumerate(catIds):
  904. nm = cocoGt.loadCats(catId)[0]
  905. logging.info('--------------saving {}-{}---------------'.format(
  906. k + 1, nm['name']))
  907. analyze_result = analyze_results[k]
  908. assert k == analyze_result[0], ""
  909. ps_supercategory = analyze_result[1]['ps_supercategory']
  910. ps_allcategory = analyze_result[1]['ps_allcategory']
  911. # compute precision but ignore superclass confusion
  912. ps[3, :, k, :, :] = ps_supercategory
  913. # compute precision but ignore any class confusion
  914. ps[4, :, k, :, :] = ps_allcategory
  915. # fill in background and false negative errors and plot
  916. T, _, _, A, _ = ps.shape
  917. for t in range(T):
  918. for a in range(A):
  919. if np.sum(ps[t, :, k, a, :] ==
  920. -1) != len(ps[t, :, k, :, :]):
  921. ps[t, :, k, a, :][ps[t, :, k, a, :] == -1] = 0
  922. ps[5, :, k, :, :] = (ps[4, :, k, :, :] > 0)
  923. ps[6, :, k, :, :] = 1.0
  924. makeplot(recThrs, ps[:, :, k], res_out_dir, nm['name'], iou_type)
  925. makeplot(recThrs, ps, res_out_dir, 'allclass', iou_type)
  926. np.linspace = fixed_linspace
  927. coco_gt = COCO()
  928. coco_gt.dataset = gt
  929. coco_gt.createIndex()
  930. # matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
  931. # or matplotlib.backends is imported for the first time
  932. # pycocotools import matplotlib
  933. import matplotlib
  934. matplotlib.use('Agg')
  935. from pycocotools.cocoeval import COCOeval
  936. if pred_bbox is not None:
  937. coco_dt = loadRes(coco_gt, pred_bbox)
  938. _analyze_results(coco_gt, coco_dt, res_type='bbox', out_dir=save_dir)
  939. if pred_mask is not None:
  940. coco_dt = loadRes(coco_gt, pred_mask)
  941. _analyze_results(coco_gt, coco_dt, res_type='segm', out_dir=save_dir)
  942. np.linspace = backup_linspace
  943. logging.info("The analysis figures are saved in {}".format(save_dir))