detection.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import yaml
  16. import copy
  17. import os.path as osp
  18. import numpy as np
  19. backup_linspace = np.linspace
  20. def fixed_linspace(start,
  21. stop,
  22. num=50,
  23. endpoint=True,
  24. retstep=False,
  25. dtype=None,
  26. axis=0):
  27. '''解决numpy > 1.17.2时pycocotools中linspace的报错问题。
  28. '''
  29. num = int(num)
  30. return backup_linspace(start, stop, num, endpoint, retstep, dtype, axis)
  31. def jaccard_overlap(pred, gt):
  32. '''计算两个框之间的IoU。
  33. '''
  34. def bbox_area(bbox):
  35. width = bbox[2] - bbox[0] + 1
  36. height = bbox[3] - bbox[1] + 1
  37. return width * height
  38. if pred[0] >= gt[2] or pred[2] <= gt[0] or \
  39. pred[1] >= gt[3] or pred[3] <= gt[1]:
  40. return 0.
  41. inter_xmin = max(pred[0], gt[0])
  42. inter_ymin = max(pred[1], gt[1])
  43. inter_xmax = min(pred[2], gt[2])
  44. inter_ymax = min(pred[3], gt[3])
  45. inter_size = bbox_area([inter_xmin, inter_ymin, inter_xmax, inter_ymax])
  46. pred_size = bbox_area(pred)
  47. gt_size = bbox_area(gt)
  48. overlap = float(inter_size) / (pred_size + gt_size - inter_size)
  49. return overlap
  50. def loadRes(coco_obj, anns):
  51. '''导入结果文件并返回pycocotools中的COCO对象。
  52. '''
  53. from pycocotools.coco import COCO
  54. import pycocotools.mask as maskUtils
  55. import time
  56. res = COCO()
  57. res.dataset['images'] = [img for img in coco_obj.dataset['images']]
  58. tic = time.time()
  59. assert type(anns) == list, 'results in not an array of objects'
  60. annsImgIds = [ann['image_id'] for ann in anns]
  61. assert set(annsImgIds) == (set(annsImgIds) & set(coco_obj.getImgIds())), \
  62. 'Results do not correspond to current coco set'
  63. if 'bbox' in anns[0] and not anns[0]['bbox'] == []:
  64. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  65. 'categories'])
  66. for id, ann in enumerate(anns):
  67. bb = ann['bbox']
  68. x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
  69. if not 'segmentation' in ann:
  70. ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
  71. ann['area'] = bb[2] * bb[3]
  72. ann['id'] = id + 1
  73. ann['iscrowd'] = 0
  74. elif 'segmentation' in anns[0]:
  75. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  76. 'categories'])
  77. for id, ann in enumerate(anns):
  78. ann['area'] = maskUtils.area(ann['segmentation'])
  79. if not 'bbox' in ann:
  80. ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
  81. ann['id'] = id + 1
  82. ann['iscrowd'] = 0
  83. res.dataset['annotations'] = anns
  84. res.createIndex()
  85. return res
  86. class DetectionMAP(object):
  87. def __init__(self,
  88. num_classes,
  89. overlap_thresh=0.5,
  90. map_type='11point',
  91. is_bbox_normalized=False,
  92. evaluate_difficult=False):
  93. self.num_classes = num_classes
  94. self.overlap_thresh = overlap_thresh
  95. assert map_type in ['11point', 'integral'], \
  96. "map_type currently only support '11point' "\
  97. "and 'integral'"
  98. self.map_type = map_type
  99. self.is_bbox_normalized = is_bbox_normalized
  100. self.evaluate_difficult = evaluate_difficult
  101. self.reset()
  102. def update(self, bbox, gt_box, gt_label, difficult=None):
  103. '''用预测值和真值更新指标。
  104. '''
  105. if difficult is None:
  106. difficult = np.zeros_like(gt_label)
  107. for gtl, diff in zip(gt_label, difficult):
  108. if self.evaluate_difficult or int(diff) == 0:
  109. self.class_gt_counts[int(np.array(gtl))] += 1
  110. visited = [False] * len(gt_label)
  111. for b in bbox:
  112. label, score, xmin, ymin, xmax, ymax = b.tolist()
  113. pred = [xmin, ymin, xmax, ymax]
  114. max_idx = -1
  115. max_overlap = -1.0
  116. for i, gl in enumerate(gt_label):
  117. if int(gl) == int(label):
  118. overlap = jaccard_overlap(pred, gt_box[i])
  119. if overlap > max_overlap:
  120. max_overlap = overlap
  121. max_idx = i
  122. if max_overlap > self.overlap_thresh:
  123. if self.evaluate_difficult or \
  124. int(np.array(difficult[max_idx])) == 0:
  125. if not visited[max_idx]:
  126. self.class_score_poss[int(label)].append([score, 1.0])
  127. visited[max_idx] = True
  128. else:
  129. self.class_score_poss[int(label)].append([score, 0.0])
  130. else:
  131. self.class_score_poss[int(label)].append([score, 0.0])
  132. def reset(self):
  133. '''初始化指标。
  134. '''
  135. self.class_score_poss = [[] for _ in range(self.num_classes)]
  136. self.class_gt_counts = [0] * self.num_classes
  137. self.mAP = None
  138. self.APs = [None] * self.num_classes
  139. def accumulate(self):
  140. '''汇总指标并由此计算mAP。
  141. '''
  142. mAP = 0.
  143. valid_cnt = 0
  144. for id, (
  145. score_pos, count
  146. ) in enumerate(zip(self.class_score_poss, self.class_gt_counts)):
  147. if count == 0: continue
  148. if len(score_pos) == 0:
  149. valid_cnt += 1
  150. continue
  151. accum_tp_list, accum_fp_list = \
  152. self._get_tp_fp_accum(score_pos)
  153. precision = []
  154. recall = []
  155. for ac_tp, ac_fp in zip(accum_tp_list, accum_fp_list):
  156. precision.append(float(ac_tp) / (ac_tp + ac_fp))
  157. recall.append(float(ac_tp) / count)
  158. if self.map_type == '11point':
  159. max_precisions = [0.] * 11
  160. start_idx = len(precision) - 1
  161. for j in range(10, -1, -1):
  162. for i in range(start_idx, -1, -1):
  163. if recall[i] < float(j) / 10.:
  164. start_idx = i
  165. if j > 0:
  166. max_precisions[j - 1] = max_precisions[j]
  167. break
  168. else:
  169. if max_precisions[j] < precision[i]:
  170. max_precisions[j] = precision[i]
  171. mAP += sum(max_precisions) / 11.
  172. self.APs[id] = sum(max_precisions) / 11.
  173. valid_cnt += 1
  174. elif self.map_type == 'integral':
  175. import math
  176. ap = 0.
  177. prev_recall = 0.
  178. for i in range(len(precision)):
  179. recall_gap = math.fabs(recall[i] - prev_recall)
  180. if recall_gap > 1e-6:
  181. ap += precision[i] * recall_gap
  182. prev_recall = recall[i]
  183. mAP += ap
  184. self.APs[id] = sum(max_precisions) / 11.
  185. valid_cnt += 1
  186. else:
  187. raise Exception("Unspported mAP type {}".format(self.map_type))
  188. self.mAP = mAP / float(valid_cnt) if valid_cnt > 0 else mAP
  189. def get_map(self):
  190. '''获取mAP。
  191. '''
  192. if self.mAP is None:
  193. raise Exception("mAP is not calculated.")
  194. return self.mAP
  195. def _get_tp_fp_accum(self, score_pos_list):
  196. '''计算真阳/假阳。
  197. '''
  198. sorted_list = sorted(score_pos_list, key=lambda s: s[0], reverse=True)
  199. accum_tp = 0
  200. accum_fp = 0
  201. accum_tp_list = []
  202. accum_fp_list = []
  203. for (score, pos) in sorted_list:
  204. accum_tp += int(pos)
  205. accum_tp_list.append(accum_tp)
  206. accum_fp += 1 - int(pos)
  207. accum_fp_list.append(accum_fp)
  208. return accum_tp_list, accum_fp_list
  209. class DetConfusionMatrix(object):
  210. def __init__(self,
  211. num_classes,
  212. overlap_thresh=0.5,
  213. evaluate_difficult=False,
  214. score_threshold=0.3):
  215. self.overlap_thresh = overlap_thresh
  216. self.evaluate_difficult = evaluate_difficult
  217. self.confusion_matrix = np.zeros(shape=(num_classes, num_classes))
  218. self.score_threshold = score_threshold
  219. self.total_tp = [0] * num_classes
  220. self.total_gt = [0] * num_classes
  221. self.total_pred = [0] * num_classes
  222. def update(self, bbox, gt_box, gt_label, difficult=None):
  223. '''更新混淆矩阵。
  224. '''
  225. if difficult is None:
  226. difficult = np.zeros_like(gt_label)
  227. dtind = np.argsort([-d[1] for d in bbox], kind='mergesort')
  228. bbox = [bbox[i] for i in dtind]
  229. det_bbox = []
  230. det_label = []
  231. G = len(gt_box)
  232. D = len(bbox)
  233. gtm = np.full((G, ), -1)
  234. dtm = np.full((D, ), -1)
  235. for j, b in enumerate(bbox):
  236. label, score, xmin, ymin, xmax, ymax = b.tolist()
  237. if float(score) < self.score_threshold:
  238. continue
  239. det_label.append(int(label) - 1)
  240. self.total_pred[int(label) - 1] += 1
  241. det_bbox.append([xmin, ymin, xmax, ymax])
  242. for i, gl in enumerate(gt_label):
  243. self.total_gt[int(gl) - 1] += 1
  244. for j, pred in enumerate(det_bbox):
  245. m = -1
  246. for i, gt in enumerate(gt_box):
  247. overlap = jaccard_overlap(pred, gt)
  248. if overlap >= self.overlap_thresh:
  249. m = i
  250. if m == -1:
  251. continue
  252. gtm[m] = j
  253. dtm[j] = m
  254. for i, gl in enumerate(gt_label):
  255. if gtm[i] == -1:
  256. self.confusion_matrix[int(gl) - 1][self.confusion_matrix.shape[
  257. 1] - 1] += 1
  258. for i, b in enumerate(det_bbox):
  259. if dtm[i] > -1:
  260. gl = int(gt_label[dtm[i]]) - 1
  261. self.confusion_matrix[gl][int(det_label[i])] += 1
  262. if dtm[i] == -1:
  263. self.confusion_matrix[self.confusion_matrix.shape[0] - 1][int(
  264. det_label[i])] += 1
  265. gtm = np.full((G, ), -1)
  266. dtm = np.full((D, ), -1)
  267. for j, pred in enumerate(det_bbox):
  268. m = -1
  269. max_overlap = -1
  270. for i, gt in enumerate(gt_box):
  271. if int(gt_label[i]) - 1 == int(det_label[j]):
  272. overlap = jaccard_overlap(pred, gt)
  273. if overlap > max_overlap:
  274. max_overlap = overlap
  275. m = i
  276. if max_overlap < self.overlap_thresh:
  277. continue
  278. if difficult[m]:
  279. continue
  280. if m == -1 or gtm[m] > -1:
  281. continue
  282. gtm[m] = j
  283. dtm[j] = m
  284. self.total_tp[int(gt_label[m]) - 1] += 1
  285. def get_confusion_matrix(self):
  286. return self.confusion_matrix
  287. class InsSegConfusionMatrix(object):
  288. def __init__(self,
  289. num_classes,
  290. overlap_thresh=0.5,
  291. evaluate_difficult=False,
  292. score_threshold=0.3):
  293. self.overlap_thresh = overlap_thresh
  294. self.evaluate_difficult = evaluate_difficult
  295. self.confusion_matrix = np.zeros(shape=(num_classes, num_classes))
  296. self.score_threshold = score_threshold
  297. self.total_tp = [0] * num_classes
  298. self.total_gt = [0] * num_classes
  299. self.total_pred = [0] * num_classes
  300. def update(self, mask, gt_mask, gt_label, is_crowd=None):
  301. '''更新混淆矩阵。
  302. '''
  303. dtind = np.argsort([-d[1] for d in mask], kind='mergesort')
  304. mask = [mask[i] for i in dtind]
  305. det_mask = []
  306. det_label = []
  307. for j, b in enumerate(mask):
  308. label, score, d_b = b
  309. if float(score) < self.score_threshold:
  310. continue
  311. self.total_pred[int(label) - 1] += 1
  312. det_label.append(label - 1)
  313. det_mask.append(d_b)
  314. for i, gl in enumerate(gt_label):
  315. self.total_gt[int(gl) - 1] += 1
  316. g = [gt for gt in gt_mask]
  317. d = [dt for dt in det_mask]
  318. import pycocotools.mask as maskUtils
  319. ious = maskUtils.iou(d, g, is_crowd)
  320. G = len(gt_mask)
  321. D = len(det_mask)
  322. gtm = np.full((G, ), -1)
  323. dtm = np.full((D, ), -1)
  324. gtIg = np.array(is_crowd)
  325. dtIg = np.zeros((D, ))
  326. for dind, d in enumerate(det_mask):
  327. m = -1
  328. for gind, g in enumerate(gt_mask):
  329. if ious[dind, gind] >= self.overlap_thresh:
  330. m = gind
  331. if m == -1:
  332. continue
  333. dtIg[dind] = gtIg[m]
  334. dtm[dind] = m
  335. gtm[m] = dind
  336. for i, gl in enumerate(gt_label):
  337. if gtm[i] == -1 and gtIg[i] == 0:
  338. self.confusion_matrix[int(gl) - 1][self.confusion_matrix.shape[
  339. 1] - 1] += 1
  340. for i, b in enumerate(det_mask):
  341. if dtm[i] > -1 and dtIg[i] == 0:
  342. gl = int(gt_label[dtm[i]]) - 1
  343. self.confusion_matrix[gl][int(det_label[i])] += 1
  344. if dtm[i] == -1 and dtIg[i] == 0:
  345. self.confusion_matrix[self.confusion_matrix.shape[0] - 1][int(
  346. det_label[i])] += 1
  347. gtm = np.full((G, ), -1)
  348. dtm = np.full((D, ), -1)
  349. for dind, d in enumerate(det_mask):
  350. m = -1
  351. max_overlap = -1
  352. for gind, g in enumerate(gt_mask):
  353. if int(gt_label[gind]) - 1 == int(det_label[dind]):
  354. if ious[dind, gind] > max_overlap:
  355. max_overlap = ious[dind, gind]
  356. m = gind
  357. if max_overlap < self.overlap_thresh:
  358. continue
  359. if m == -1 or gtm[m] > -1:
  360. continue
  361. dtm[dind] = m
  362. gtm[m] = dind
  363. self.total_tp[int(gt_label[m]) - 1] += 1
  364. def get_confusion_matrix(self):
  365. return self.confusion_matrix
  366. class DetEvaluator(object):
  367. def __init__(self, model_path, overlap_thresh=0.5, score_threshold=0.3):
  368. self.model_path = model_path
  369. self.overlap_thresh = overlap_thresh
  370. self.score_threshold = score_threshold
  371. def _prepare_data(self):
  372. with open(osp.join(self.model_path, 'eval_details.json'), 'r') as f:
  373. eval_details = json.load(f)
  374. self.bbox = eval_details['bbox']
  375. self.mask = None
  376. if 'mask' in eval_details:
  377. self.mask = eval_details['mask']
  378. gt_dataset = eval_details['gt']
  379. from pycocotools.coco import COCO
  380. from pycocotools.cocoeval import COCOeval
  381. self.coco = COCO()
  382. self.coco.dataset = gt_dataset
  383. self.coco.createIndex()
  384. img_ids = self.coco.getImgIds()
  385. cat_ids = self.coco.getCatIds()
  386. self.catid2clsid = dict(
  387. {catid: i + 1
  388. for i, catid in enumerate(cat_ids)})
  389. self.cname2cid = dict({
  390. self.coco.loadCats(catid)[0]['name']: clsid
  391. for catid, clsid in self.catid2clsid.items()
  392. })
  393. self.cid2cname = dict(
  394. {cid: cname
  395. for cname, cid in self.cname2cid.items()})
  396. self.cid2cname[0] = 'back_ground'
  397. self.gt = dict()
  398. for img_id in img_ids:
  399. img_anno = self.coco.loadImgs(img_id)[0]
  400. im_fname = img_anno['file_name']
  401. im_w = float(img_anno['width'])
  402. im_h = float(img_anno['height'])
  403. ins_anno_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
  404. instances = self.coco.loadAnns(ins_anno_ids)
  405. bboxes = []
  406. for inst in instances:
  407. x, y, box_w, box_h = inst['bbox']
  408. x1 = max(0, x)
  409. y1 = max(0, y)
  410. x2 = min(im_w - 1, x1 + max(0, box_w - 1))
  411. y2 = min(im_h - 1, y1 + max(0, box_h - 1))
  412. if inst['area'] > 0 and x2 >= x1 and y2 >= y1:
  413. inst['clean_bbox'] = [x1, y1, x2, y2]
  414. bboxes.append(inst)
  415. else:
  416. pass
  417. num_bbox = len(bboxes)
  418. gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
  419. gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
  420. gt_score = np.ones((num_bbox, 1), dtype=np.float32)
  421. is_crowd = np.zeros((num_bbox), dtype=np.int32)
  422. difficult = np.zeros((num_bbox, 1), dtype=np.int32)
  423. gt_poly = [None] * num_bbox
  424. for i, box in enumerate(bboxes):
  425. catid = box['category_id']
  426. gt_class[i][0] = self.catid2clsid[catid]
  427. gt_bbox[i, :] = box['clean_bbox']
  428. is_crowd[i] = box['iscrowd']
  429. if 'segmentation' in box:
  430. gt_poly[i] = self.coco.annToRLE(box)
  431. if 'difficult' in box:
  432. difficult[i][0] = box['difficult']
  433. coco_rec = {
  434. 'is_crowd': is_crowd,
  435. 'gt_class': gt_class,
  436. 'gt_bbox': gt_bbox,
  437. 'gt_score': gt_score,
  438. 'gt_poly': gt_poly,
  439. 'difficult': difficult
  440. }
  441. self.gt[img_id] = coco_rec
  442. self.gtimgids = list(self.gt.keys())
  443. self.detimgids = [ann['image_id'] for ann in self.bbox]
  444. self.det = dict()
  445. if len(self.bbox) > 0:
  446. if 'bbox' in self.bbox[0] and not self.bbox[0]['bbox'] == []:
  447. for id, ann in enumerate(self.bbox):
  448. im_id = ann['image_id']
  449. bb = ann['bbox']
  450. x1, x2, y1, y2 = [
  451. bb[0], bb[0] + bb[2] - 1, bb[1], bb[1] + bb[3] - 1
  452. ]
  453. score = ann['score']
  454. category_id = self.catid2clsid[ann['category_id']]
  455. if int(im_id) not in self.det:
  456. self.det[int(im_id)] = [[
  457. category_id, score, x1, y1, x2, y2
  458. ]]
  459. else:
  460. self.det[int(im_id)].extend(
  461. [[category_id, score, x1, y1, x2, y2]])
  462. if self.mask is not None:
  463. self.maskimgids = [ann['image_id'] for ann in self.mask]
  464. self.segm = dict()
  465. if len(self.mask) > 0:
  466. if 'segmentation' in self.mask[0]:
  467. for id, ann in enumerate(self.mask):
  468. im_id = ann['image_id']
  469. score = ann['score']
  470. segmentation = self.coco.annToRLE(ann)
  471. category_id = self.catid2clsid[ann['category_id']]
  472. if int(im_id) not in self.segm:
  473. self.segm[int(im_id)] = [[
  474. category_id, score, segmentation
  475. ]]
  476. else:
  477. self.segm[int(im_id)].extend(
  478. [[category_id, score, segmentation]])
  479. def cal_confusion_matrix(self):
  480. '''计算混淆矩阵。
  481. '''
  482. self._prepare_data()
  483. confusion_matrix = DetConfusionMatrix(
  484. num_classes=len(self.cid2cname.keys()),
  485. overlap_thresh=self.overlap_thresh,
  486. score_threshold=self.score_threshold)
  487. for im_id in self.gtimgids:
  488. if im_id not in set(self.detimgids):
  489. bbox = []
  490. else:
  491. bbox = np.array(self.det[im_id])
  492. gt_box = self.gt[im_id]['gt_bbox']
  493. gt_label = self.gt[im_id]['gt_class']
  494. difficult = self.gt[im_id]['difficult']
  495. confusion_matrix.update(bbox, gt_box, gt_label, difficult)
  496. self.confusion_matrix = confusion_matrix.get_confusion_matrix()
  497. self.precision_recall = dict()
  498. for id in range(len(self.cid2cname.keys()) - 1):
  499. if confusion_matrix.total_gt[id] == 0:
  500. recall = -1
  501. else:
  502. recall = confusion_matrix.total_tp[
  503. id] / confusion_matrix.total_gt[id]
  504. if confusion_matrix.total_pred[id] == 0:
  505. precision = -1
  506. else:
  507. precision = confusion_matrix.total_tp[
  508. id] / confusion_matrix.total_pred[id]
  509. self.precision_recall[self.cid2cname[id + 1]] = {
  510. "precision": precision,
  511. "recall": recall
  512. }
  513. return self.confusion_matrix
  514. def cal_precision_recall(self):
  515. '''计算precision、recall。
  516. '''
  517. return self.precision_recall
  518. def cal_map(self):
  519. '''计算mAP。
  520. '''
  521. detection_map = DetectionMAP(
  522. num_classes=len(self.cid2cname.keys()),
  523. overlap_thresh=self.overlap_thresh)
  524. for im_id in self.gtimgids:
  525. if im_id not in set(self.detimgids):
  526. bbox = []
  527. else:
  528. bbox = np.array(self.det[im_id])
  529. gt_box = self.gt[im_id]['gt_bbox']
  530. gt_label = self.gt[im_id]['gt_class']
  531. difficult = self.gt[im_id]['difficult']
  532. detection_map.update(bbox, gt_box, gt_label, difficult)
  533. detection_map.accumulate()
  534. self.map = detection_map.get_map()
  535. self.APs = detection_map.APs
  536. return self.map
  537. def cal_ap(self):
  538. '''计算各类AP。
  539. '''
  540. self.aps = dict()
  541. for id, ap in enumerate(self.APs):
  542. if id == 0:
  543. continue
  544. self.aps[self.cid2cname[id]] = ap
  545. return self.aps
  546. def generate_report(self):
  547. '''生成评估报告。
  548. '''
  549. report = dict()
  550. report['Confusion_Matrix'] = copy.deepcopy(self.cal_confusion_matrix())
  551. report['mAP'] = copy.deepcopy(self.cal_map())
  552. report['PRAP'] = copy.deepcopy(self.cal_precision_recall())
  553. report['label_list'] = copy.deepcopy(list(self.cname2cid.keys()))
  554. report['label_list'].append('back_ground')
  555. per_ap = copy.deepcopy(self.cal_ap())
  556. for k, v in per_ap.items():
  557. report['PRAP'][k]["AP"] = v
  558. return report
  559. class InsSegEvaluator(DetEvaluator):
  560. def __init__(self, model_path, overlap_thresh=0.5, score_threshold=0.3):
  561. super(DetEvaluator, self).__init__()
  562. self.model_path = model_path
  563. self.overlap_thresh = overlap_thresh
  564. self.score_threshold = score_threshold
  565. def cal_confusion_matrix_mask(self):
  566. '''计算Mask的混淆矩阵。
  567. '''
  568. confusion_matrix = InsSegConfusionMatrix(
  569. num_classes=len(self.cid2cname.keys()),
  570. overlap_thresh=self.overlap_thresh,
  571. score_threshold=self.score_threshold)
  572. for im_id in self.gtimgids:
  573. if im_id not in set(self.maskimgids):
  574. segm = []
  575. else:
  576. segm = self.segm[im_id]
  577. gt_segm = self.gt[im_id]['gt_poly']
  578. gt_label = self.gt[im_id]['gt_class']
  579. is_crowd = self.gt[im_id]['is_crowd']
  580. confusion_matrix.update(segm, gt_segm, gt_label, is_crowd)
  581. self.confusion_matrix_mask = confusion_matrix.get_confusion_matrix()
  582. self.precision_recall_mask = dict()
  583. for id in range(len(self.cid2cname.keys()) - 1):
  584. if confusion_matrix.total_gt[id] == 0:
  585. recall = -1
  586. else:
  587. recall = confusion_matrix.total_tp[
  588. id] / confusion_matrix.total_gt[id]
  589. if confusion_matrix.total_pred[id] == 0:
  590. precision = -1
  591. else:
  592. precision = confusion_matrix.total_tp[
  593. id] / confusion_matrix.total_pred[id]
  594. self.precision_recall_mask[self.cid2cname[id + 1]] = {
  595. "precision": precision,
  596. "recall": recall
  597. }
  598. return self.confusion_matrix_mask
  599. def cal_precision_recall_mask(self):
  600. '''计算Mask的precision、recall。
  601. '''
  602. return self.precision_recall_mask
  603. def _summarize(self,
  604. coco_gt,
  605. ap=1,
  606. iouThr=None,
  607. areaRng='all',
  608. maxDets=100):
  609. p = coco_gt.params
  610. aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
  611. mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
  612. if ap == 1:
  613. s = coco_gt.eval['precision']
  614. if iouThr is not None:
  615. t = np.where(iouThr == p.iouThrs)[0]
  616. s = s[t]
  617. s = s[:, :, :, aind, mind]
  618. else:
  619. s = coco_gt.eval['recall']
  620. if iouThr is not None:
  621. t = np.where(iouThr == p.iouThrs)[0]
  622. s = s[t]
  623. s = s[:, :, aind, mind]
  624. if len(s[s > -1]) == 0:
  625. mean_s = -1
  626. else:
  627. mean_s = np.mean(s[s > -1])
  628. return mean_s
  629. def cal_map(self):
  630. '''计算BBox的mAP。
  631. '''
  632. if len(self.bbox) > 0:
  633. from pycocotools.cocoeval import COCOeval
  634. coco_dt = loadRes(self.coco, self.bbox)
  635. np.linspace = fixed_linspace
  636. coco_eval = COCOeval(self.coco, coco_dt, 'bbox')
  637. coco_eval.params.iouThrs = np.linspace(
  638. self.overlap_thresh, self.overlap_thresh, 1, endpoint=True)
  639. np.linspace = backup_linspace
  640. coco_eval.evaluate()
  641. coco_eval.accumulate()
  642. self.map = self._summarize(coco_eval, iouThr=self.overlap_thresh)
  643. precision = coco_eval.eval['precision'][0, :, :, 0, 2]
  644. num_classes = len(coco_eval.params.catIds)
  645. self.APs = [None] * num_classes
  646. for i in range(num_classes):
  647. per = precision[:, i]
  648. per = per[per > -1]
  649. self.APs[i] = np.sum(per) / 101 if per.shape[0] > 0 else None
  650. else:
  651. self.map = None
  652. self.APs = [None] * len(self.catid2clsid)
  653. return self.map
  654. def cal_ap(self):
  655. '''计算BBox的各类AP。
  656. '''
  657. self.aps = dict()
  658. for id, ap in enumerate(self.APs):
  659. self.aps[self.cid2cname[id + 1]] = ap
  660. return self.aps
  661. def cal_map_mask(self):
  662. '''计算Mask的mAP。
  663. '''
  664. if len(self.mask) > 0:
  665. from pycocotools.cocoeval import COCOeval
  666. coco_dt = loadRes(self.coco, self.mask)
  667. np.linspace = fixed_linspace
  668. coco_eval = COCOeval(self.coco, coco_dt, 'segm')
  669. coco_eval.params.iouThrs = np.linspace(
  670. self.overlap_thresh, self.overlap_thresh, 1, endpoint=True)
  671. np.linspace = backup_linspace
  672. coco_eval.evaluate()
  673. coco_eval.accumulate()
  674. self.map_mask = self._summarize(
  675. coco_eval, iouThr=self.overlap_thresh)
  676. precision = coco_eval.eval['precision'][0, :, :, 0, 2]
  677. num_classes = len(coco_eval.params.catIds)
  678. self.mask_APs = [None] * num_classes
  679. for i in range(num_classes):
  680. per = precision[:, i]
  681. per = per[per > -1]
  682. self.mask_APs[i] = np.sum(per) / 101 if per.shape[
  683. 0] > 0 else None
  684. else:
  685. self.map_mask = None
  686. self.mask_APs = [None] * len(self.catid2clsid)
  687. return self.map_mask
  688. def cal_ap_mask(self):
  689. '''计算Mask的各类AP。
  690. '''
  691. self.mask_aps = dict()
  692. for id, ap in enumerate(self.mask_APs):
  693. self.mask_aps[self.cid2cname[id + 1]] = ap
  694. return self.mask_aps
  695. def generate_report(self):
  696. '''生成评估报告。
  697. '''
  698. report = dict()
  699. report['BBox_Confusion_Matrix'] = copy.deepcopy(
  700. self.cal_confusion_matrix())
  701. report['BBox_mAP'] = copy.deepcopy(self.cal_map())
  702. report['BBox_PRAP'] = copy.deepcopy(self.cal_precision_recall())
  703. report['label_list'] = copy.deepcopy(list(self.cname2cid.keys()))
  704. report['label_list'].append('back_ground')
  705. per_ap = copy.deepcopy(self.cal_ap())
  706. for k, v in per_ap.items():
  707. report['BBox_PRAP'][k]['AP'] = v
  708. report['Mask_Confusion_Matrix'] = copy.deepcopy(
  709. self.cal_confusion_matrix_mask())
  710. report['Mask_mAP'] = copy.deepcopy(self.cal_map_mask())
  711. report['Mask_PRAP'] = copy.deepcopy(self.cal_precision_recall_mask())
  712. per_ap_mask = copy.deepcopy(self.cal_ap_mask())
  713. for k, v in per_ap_mask.items():
  714. report['Mask_PRAP'][k]['AP'] = v
  715. return report