detection.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import yaml
  16. import copy
  17. import os.path as osp
  18. import numpy as np
  19. from paddlex_restful.restful.dataset.utils import get_encoding
  20. backup_linspace = np.linspace
  21. def fixed_linspace(start,
  22. stop,
  23. num=50,
  24. endpoint=True,
  25. retstep=False,
  26. dtype=None,
  27. axis=0):
  28. '''解决numpy > 1.17.2时pycocotools中linspace的报错问题。
  29. '''
  30. num = int(num)
  31. return backup_linspace(start, stop, num, endpoint, retstep, dtype, axis)
  32. def jaccard_overlap(pred, gt):
  33. '''计算两个框之间的IoU。
  34. '''
  35. def bbox_area(bbox):
  36. width = bbox[2] - bbox[0] + 1
  37. height = bbox[3] - bbox[1] + 1
  38. return width * height
  39. if pred[0] >= gt[2] or pred[2] <= gt[0] or \
  40. pred[1] >= gt[3] or pred[3] <= gt[1]:
  41. return 0.
  42. inter_xmin = max(pred[0], gt[0])
  43. inter_ymin = max(pred[1], gt[1])
  44. inter_xmax = min(pred[2], gt[2])
  45. inter_ymax = min(pred[3], gt[3])
  46. inter_size = bbox_area([inter_xmin, inter_ymin, inter_xmax, inter_ymax])
  47. pred_size = bbox_area(pred)
  48. gt_size = bbox_area(gt)
  49. overlap = float(inter_size) / (pred_size + gt_size - inter_size)
  50. return overlap
  51. def loadRes(coco_obj, anns):
  52. '''导入结果文件并返回pycocotools中的COCO对象。
  53. '''
  54. from pycocotools.coco import COCO
  55. import pycocotools.mask as maskUtils
  56. import time
  57. res = COCO()
  58. res.dataset['images'] = [img for img in coco_obj.dataset['images']]
  59. tic = time.time()
  60. assert type(anns) == list, 'results in not an array of objects'
  61. annsImgIds = [ann['image_id'] for ann in anns]
  62. assert set(annsImgIds) == (set(annsImgIds) & set(coco_obj.getImgIds())), \
  63. 'Results do not correspond to current coco set'
  64. if 'bbox' in anns[0] and not anns[0]['bbox'] == []:
  65. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  66. 'categories'])
  67. for id, ann in enumerate(anns):
  68. bb = ann['bbox']
  69. x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
  70. if not 'segmentation' in ann:
  71. ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
  72. ann['area'] = bb[2] * bb[3]
  73. ann['id'] = id + 1
  74. ann['iscrowd'] = 0
  75. elif 'segmentation' in anns[0]:
  76. res.dataset['categories'] = copy.deepcopy(coco_obj.dataset[
  77. 'categories'])
  78. for id, ann in enumerate(anns):
  79. ann['area'] = maskUtils.area(ann['segmentation'])
  80. if not 'bbox' in ann:
  81. ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
  82. ann['id'] = id + 1
  83. ann['iscrowd'] = 0
  84. res.dataset['annotations'] = anns
  85. res.createIndex()
  86. return res
  87. class DetectionMAP(object):
  88. def __init__(self,
  89. num_classes,
  90. overlap_thresh=0.5,
  91. map_type='11point',
  92. is_bbox_normalized=False,
  93. evaluate_difficult=False):
  94. self.num_classes = num_classes
  95. self.overlap_thresh = overlap_thresh
  96. assert map_type in ['11point', 'integral'], \
  97. "map_type currently only support '11point' "\
  98. "and 'integral'"
  99. self.map_type = map_type
  100. self.is_bbox_normalized = is_bbox_normalized
  101. self.evaluate_difficult = evaluate_difficult
  102. self.reset()
  103. def update(self, bbox, gt_box, gt_label, difficult=None):
  104. '''用预测值和真值更新指标。
  105. '''
  106. if difficult is None:
  107. difficult = np.zeros_like(gt_label)
  108. for gtl, diff in zip(gt_label, difficult):
  109. if self.evaluate_difficult or int(diff) == 0:
  110. self.class_gt_counts[int(np.array(gtl))] += 1
  111. visited = [False] * len(gt_label)
  112. for b in bbox:
  113. label, score, xmin, ymin, xmax, ymax = b.tolist()
  114. pred = [xmin, ymin, xmax, ymax]
  115. max_idx = -1
  116. max_overlap = -1.0
  117. for i, gl in enumerate(gt_label):
  118. if int(gl) == int(label):
  119. overlap = jaccard_overlap(pred, gt_box[i])
  120. if overlap > max_overlap:
  121. max_overlap = overlap
  122. max_idx = i
  123. if max_overlap > self.overlap_thresh:
  124. if self.evaluate_difficult or \
  125. int(np.array(difficult[max_idx])) == 0:
  126. if not visited[max_idx]:
  127. self.class_score_poss[int(label)].append([score, 1.0])
  128. visited[max_idx] = True
  129. else:
  130. self.class_score_poss[int(label)].append([score, 0.0])
  131. else:
  132. self.class_score_poss[int(label)].append([score, 0.0])
  133. def reset(self):
  134. '''初始化指标。
  135. '''
  136. self.class_score_poss = [[] for _ in range(self.num_classes)]
  137. self.class_gt_counts = [0] * self.num_classes
  138. self.mAP = None
  139. self.APs = [None] * self.num_classes
  140. def accumulate(self):
  141. '''汇总指标并由此计算mAP。
  142. '''
  143. mAP = 0.
  144. valid_cnt = 0
  145. for id, (
  146. score_pos, count
  147. ) in enumerate(zip(self.class_score_poss, self.class_gt_counts)):
  148. if count == 0: continue
  149. if len(score_pos) == 0:
  150. valid_cnt += 1
  151. continue
  152. accum_tp_list, accum_fp_list = \
  153. self._get_tp_fp_accum(score_pos)
  154. precision = []
  155. recall = []
  156. for ac_tp, ac_fp in zip(accum_tp_list, accum_fp_list):
  157. precision.append(float(ac_tp) / (ac_tp + ac_fp))
  158. recall.append(float(ac_tp) / count)
  159. if self.map_type == '11point':
  160. max_precisions = [0.] * 11
  161. start_idx = len(precision) - 1
  162. for j in range(10, -1, -1):
  163. for i in range(start_idx, -1, -1):
  164. if recall[i] < float(j) / 10.:
  165. start_idx = i
  166. if j > 0:
  167. max_precisions[j - 1] = max_precisions[j]
  168. break
  169. else:
  170. if max_precisions[j] < precision[i]:
  171. max_precisions[j] = precision[i]
  172. mAP += sum(max_precisions) / 11.
  173. self.APs[id] = sum(max_precisions) / 11.
  174. valid_cnt += 1
  175. elif self.map_type == 'integral':
  176. import math
  177. ap = 0.
  178. prev_recall = 0.
  179. for i in range(len(precision)):
  180. recall_gap = math.fabs(recall[i] - prev_recall)
  181. if recall_gap > 1e-6:
  182. ap += precision[i] * recall_gap
  183. prev_recall = recall[i]
  184. mAP += ap
  185. self.APs[id] = sum(max_precisions) / 11.
  186. valid_cnt += 1
  187. else:
  188. raise Exception("Unspported mAP type {}".format(self.map_type))
  189. self.mAP = mAP / float(valid_cnt) if valid_cnt > 0 else mAP
  190. def get_map(self):
  191. '''获取mAP。
  192. '''
  193. if self.mAP is None:
  194. raise Exception("mAP is not calculated.")
  195. return self.mAP
  196. def _get_tp_fp_accum(self, score_pos_list):
  197. '''计算真阳/假阳。
  198. '''
  199. sorted_list = sorted(score_pos_list, key=lambda s: s[0], reverse=True)
  200. accum_tp = 0
  201. accum_fp = 0
  202. accum_tp_list = []
  203. accum_fp_list = []
  204. for (score, pos) in sorted_list:
  205. accum_tp += int(pos)
  206. accum_tp_list.append(accum_tp)
  207. accum_fp += 1 - int(pos)
  208. accum_fp_list.append(accum_fp)
  209. return accum_tp_list, accum_fp_list
  210. class DetConfusionMatrix(object):
  211. def __init__(self,
  212. num_classes,
  213. overlap_thresh=0.5,
  214. evaluate_difficult=False,
  215. score_threshold=0.3):
  216. self.overlap_thresh = overlap_thresh
  217. self.evaluate_difficult = evaluate_difficult
  218. self.confusion_matrix = np.zeros(shape=(num_classes, num_classes))
  219. self.score_threshold = score_threshold
  220. self.total_tp = [0] * num_classes
  221. self.total_gt = [0] * num_classes
  222. self.total_pred = [0] * num_classes
  223. def update(self, bbox, gt_box, gt_label, difficult=None):
  224. '''更新混淆矩阵。
  225. '''
  226. if difficult is None:
  227. difficult = np.zeros_like(gt_label)
  228. dtind = np.argsort([-d[1] for d in bbox], kind='mergesort')
  229. bbox = [bbox[i] for i in dtind]
  230. det_bbox = []
  231. det_label = []
  232. G = len(gt_box)
  233. D = len(bbox)
  234. gtm = np.full((G, ), -1)
  235. dtm = np.full((D, ), -1)
  236. for j, b in enumerate(bbox):
  237. label, score, xmin, ymin, xmax, ymax = b.tolist()
  238. if float(score) < self.score_threshold:
  239. continue
  240. det_label.append(int(label) - 1)
  241. self.total_pred[int(label) - 1] += 1
  242. det_bbox.append([xmin, ymin, xmax, ymax])
  243. for i, gl in enumerate(gt_label):
  244. self.total_gt[int(gl) - 1] += 1
  245. for j, pred in enumerate(det_bbox):
  246. m = -1
  247. for i, gt in enumerate(gt_box):
  248. overlap = jaccard_overlap(pred, gt)
  249. if overlap >= self.overlap_thresh:
  250. m = i
  251. if m == -1:
  252. continue
  253. gtm[m] = j
  254. dtm[j] = m
  255. for i, gl in enumerate(gt_label):
  256. if gtm[i] == -1:
  257. self.confusion_matrix[int(gl) - 1][self.confusion_matrix.shape[
  258. 1] - 1] += 1
  259. for i, b in enumerate(det_bbox):
  260. if dtm[i] > -1:
  261. gl = int(gt_label[dtm[i]]) - 1
  262. self.confusion_matrix[gl][int(det_label[i])] += 1
  263. if dtm[i] == -1:
  264. self.confusion_matrix[self.confusion_matrix.shape[0] - 1][int(
  265. det_label[i])] += 1
  266. gtm = np.full((G, ), -1)
  267. dtm = np.full((D, ), -1)
  268. for j, pred in enumerate(det_bbox):
  269. m = -1
  270. max_overlap = -1
  271. for i, gt in enumerate(gt_box):
  272. if int(gt_label[i]) - 1 == int(det_label[j]):
  273. overlap = jaccard_overlap(pred, gt)
  274. if overlap > max_overlap:
  275. max_overlap = overlap
  276. m = i
  277. if max_overlap < self.overlap_thresh:
  278. continue
  279. if difficult[m]:
  280. continue
  281. if m == -1 or gtm[m] > -1:
  282. continue
  283. gtm[m] = j
  284. dtm[j] = m
  285. self.total_tp[int(gt_label[m]) - 1] += 1
  286. def get_confusion_matrix(self):
  287. return self.confusion_matrix
  288. class InsSegConfusionMatrix(object):
  289. def __init__(self,
  290. num_classes,
  291. overlap_thresh=0.5,
  292. evaluate_difficult=False,
  293. score_threshold=0.3):
  294. self.overlap_thresh = overlap_thresh
  295. self.evaluate_difficult = evaluate_difficult
  296. self.confusion_matrix = np.zeros(shape=(num_classes, num_classes))
  297. self.score_threshold = score_threshold
  298. self.total_tp = [0] * num_classes
  299. self.total_gt = [0] * num_classes
  300. self.total_pred = [0] * num_classes
  301. def update(self, mask, gt_mask, gt_label, is_crowd=None):
  302. '''更新混淆矩阵。
  303. '''
  304. dtind = np.argsort([-d[1] for d in mask], kind='mergesort')
  305. mask = [mask[i] for i in dtind]
  306. det_mask = []
  307. det_label = []
  308. for j, b in enumerate(mask):
  309. label, score, d_b = b
  310. if float(score) < self.score_threshold:
  311. continue
  312. self.total_pred[int(label) - 1] += 1
  313. det_label.append(label - 1)
  314. det_mask.append(d_b)
  315. for i, gl in enumerate(gt_label):
  316. self.total_gt[int(gl) - 1] += 1
  317. g = [gt for gt in gt_mask]
  318. d = [dt for dt in det_mask]
  319. import pycocotools.mask as maskUtils
  320. ious = maskUtils.iou(d, g, is_crowd)
  321. G = len(gt_mask)
  322. D = len(det_mask)
  323. gtm = np.full((G, ), -1)
  324. dtm = np.full((D, ), -1)
  325. gtIg = np.array(is_crowd)
  326. dtIg = np.zeros((D, ))
  327. for dind, d in enumerate(det_mask):
  328. m = -1
  329. for gind, g in enumerate(gt_mask):
  330. if ious[dind, gind] >= self.overlap_thresh:
  331. m = gind
  332. if m == -1:
  333. continue
  334. dtIg[dind] = gtIg[m]
  335. dtm[dind] = m
  336. gtm[m] = dind
  337. for i, gl in enumerate(gt_label):
  338. if gtm[i] == -1 and gtIg[i] == 0:
  339. self.confusion_matrix[int(gl) - 1][self.confusion_matrix.shape[
  340. 1] - 1] += 1
  341. for i, b in enumerate(det_mask):
  342. if dtm[i] > -1 and dtIg[i] == 0:
  343. gl = int(gt_label[dtm[i]]) - 1
  344. self.confusion_matrix[gl][int(det_label[i])] += 1
  345. if dtm[i] == -1 and dtIg[i] == 0:
  346. self.confusion_matrix[self.confusion_matrix.shape[0] - 1][int(
  347. det_label[i])] += 1
  348. gtm = np.full((G, ), -1)
  349. dtm = np.full((D, ), -1)
  350. for dind, d in enumerate(det_mask):
  351. m = -1
  352. max_overlap = -1
  353. for gind, g in enumerate(gt_mask):
  354. if int(gt_label[gind]) - 1 == int(det_label[dind]):
  355. if ious[dind, gind] > max_overlap:
  356. max_overlap = ious[dind, gind]
  357. m = gind
  358. if max_overlap < self.overlap_thresh:
  359. continue
  360. if m == -1 or gtm[m] > -1:
  361. continue
  362. dtm[dind] = m
  363. gtm[m] = dind
  364. self.total_tp[int(gt_label[m]) - 1] += 1
  365. def get_confusion_matrix(self):
  366. return self.confusion_matrix
  367. class DetEvaluator(object):
  368. def __init__(self, model_path, overlap_thresh=0.5, score_threshold=0.3):
  369. self.model_path = model_path
  370. self.overlap_thresh = overlap_thresh if overlap_thresh is not None else .5
  371. self.score_threshold = score_threshold if score_threshold is not None else .3
  372. def _prepare_data(self):
  373. eval_details_file = osp.join(self.model_path, 'eval_details.json')
  374. with open(
  375. eval_details_file, 'r',
  376. encoding=get_encoding(eval_details_file)) as f:
  377. eval_details = json.load(f)
  378. self.bbox = eval_details['bbox']
  379. self.mask = None
  380. if 'mask' in eval_details:
  381. self.mask = eval_details['mask']
  382. gt_dataset = eval_details['gt']
  383. from pycocotools.coco import COCO
  384. from pycocotools.cocoeval import COCOeval
  385. self.coco = COCO()
  386. self.coco.dataset = gt_dataset
  387. self.coco.createIndex()
  388. img_ids = self.coco.getImgIds()
  389. cat_ids = self.coco.getCatIds()
  390. self.catid2clsid = dict(
  391. {catid: i + 1
  392. for i, catid in enumerate(cat_ids)})
  393. self.cname2cid = dict({
  394. self.coco.loadCats(catid)[0]['name']: clsid
  395. for catid, clsid in self.catid2clsid.items()
  396. })
  397. self.cid2cname = dict(
  398. {cid: cname
  399. for cname, cid in self.cname2cid.items()})
  400. self.cid2cname[0] = 'back_ground'
  401. self.gt = dict()
  402. for img_id in img_ids:
  403. img_anno = self.coco.loadImgs(img_id)[0]
  404. im_fname = img_anno['file_name']
  405. im_w = float(img_anno['width'])
  406. im_h = float(img_anno['height'])
  407. ins_anno_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
  408. instances = self.coco.loadAnns(ins_anno_ids)
  409. bboxes = []
  410. for inst in instances:
  411. x, y, box_w, box_h = inst['bbox']
  412. x1 = max(0, x)
  413. y1 = max(0, y)
  414. x2 = min(im_w - 1, x1 + max(0, box_w - 1))
  415. y2 = min(im_h - 1, y1 + max(0, box_h - 1))
  416. if inst['area'] > 0 and x2 >= x1 and y2 >= y1:
  417. inst['clean_bbox'] = [x1, y1, x2, y2]
  418. bboxes.append(inst)
  419. else:
  420. pass
  421. num_bbox = len(bboxes)
  422. gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
  423. gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
  424. gt_score = np.ones((num_bbox, 1), dtype=np.float32)
  425. is_crowd = np.zeros((num_bbox), dtype=np.int32)
  426. difficult = np.zeros((num_bbox, 1), dtype=np.int32)
  427. gt_poly = [None] * num_bbox
  428. for i, box in enumerate(bboxes):
  429. catid = box['category_id']
  430. gt_class[i][0] = self.catid2clsid[catid]
  431. gt_bbox[i, :] = box['clean_bbox']
  432. is_crowd[i] = box['iscrowd']
  433. if 'segmentation' in box:
  434. gt_poly[i] = self.coco.annToRLE(box)
  435. if 'difficult' in box:
  436. difficult[i][0] = box['difficult']
  437. coco_rec = {
  438. 'is_crowd': is_crowd,
  439. 'gt_class': gt_class,
  440. 'gt_bbox': gt_bbox,
  441. 'gt_score': gt_score,
  442. 'gt_poly': gt_poly,
  443. 'difficult': difficult
  444. }
  445. self.gt[img_id] = coco_rec
  446. self.gtimgids = list(self.gt.keys())
  447. self.detimgids = [ann['image_id'] for ann in self.bbox]
  448. self.det = dict()
  449. if len(self.bbox) > 0:
  450. if 'bbox' in self.bbox[0] and not self.bbox[0]['bbox'] == []:
  451. for id, ann in enumerate(self.bbox):
  452. im_id = ann['image_id']
  453. bb = ann['bbox']
  454. x1, x2, y1, y2 = [
  455. bb[0], bb[0] + bb[2] - 1, bb[1], bb[1] + bb[3] - 1
  456. ]
  457. score = ann['score']
  458. category_id = self.catid2clsid[ann['category_id']]
  459. if int(im_id) not in self.det:
  460. self.det[int(im_id)] = [[
  461. category_id, score, x1, y1, x2, y2
  462. ]]
  463. else:
  464. self.det[int(im_id)].extend(
  465. [[category_id, score, x1, y1, x2, y2]])
  466. if self.mask is not None:
  467. self.maskimgids = [ann['image_id'] for ann in self.mask]
  468. self.segm = dict()
  469. if len(self.mask) > 0:
  470. if 'segmentation' in self.mask[0]:
  471. for id, ann in enumerate(self.mask):
  472. im_id = ann['image_id']
  473. score = ann['score']
  474. segmentation = self.coco.annToRLE(ann)
  475. category_id = self.catid2clsid[ann['category_id']]
  476. if int(im_id) not in self.segm:
  477. self.segm[int(im_id)] = [[
  478. category_id, score, segmentation
  479. ]]
  480. else:
  481. self.segm[int(im_id)].extend(
  482. [[category_id, score, segmentation]])
  483. def cal_confusion_matrix(self):
  484. '''计算混淆矩阵。
  485. '''
  486. self._prepare_data()
  487. confusion_matrix = DetConfusionMatrix(
  488. num_classes=len(self.cid2cname.keys()),
  489. overlap_thresh=self.overlap_thresh,
  490. score_threshold=self.score_threshold)
  491. for im_id in self.gtimgids:
  492. if im_id not in set(self.detimgids):
  493. bbox = []
  494. else:
  495. bbox = np.array(self.det[im_id])
  496. gt_box = self.gt[im_id]['gt_bbox']
  497. gt_label = self.gt[im_id]['gt_class']
  498. difficult = self.gt[im_id]['difficult']
  499. confusion_matrix.update(bbox, gt_box, gt_label, difficult)
  500. self.confusion_matrix = confusion_matrix.get_confusion_matrix()
  501. self.precision_recall = dict()
  502. for id in range(len(self.cid2cname.keys()) - 1):
  503. if confusion_matrix.total_gt[id] == 0:
  504. recall = -1
  505. else:
  506. recall = confusion_matrix.total_tp[
  507. id] / confusion_matrix.total_gt[id]
  508. if confusion_matrix.total_pred[id] == 0:
  509. precision = -1
  510. else:
  511. precision = confusion_matrix.total_tp[
  512. id] / confusion_matrix.total_pred[id]
  513. self.precision_recall[self.cid2cname[id + 1]] = {
  514. "precision": precision,
  515. "recall": recall
  516. }
  517. return self.confusion_matrix
  518. def cal_precision_recall(self):
  519. '''计算precision、recall。
  520. '''
  521. return self.precision_recall
  522. def cal_map(self):
  523. '''计算mAP。
  524. '''
  525. detection_map = DetectionMAP(
  526. num_classes=len(self.cid2cname.keys()),
  527. overlap_thresh=self.overlap_thresh)
  528. for im_id in self.gtimgids:
  529. if im_id not in set(self.detimgids):
  530. bbox = []
  531. else:
  532. bbox = np.array(self.det[im_id])
  533. gt_box = self.gt[im_id]['gt_bbox']
  534. gt_label = self.gt[im_id]['gt_class']
  535. difficult = self.gt[im_id]['difficult']
  536. detection_map.update(bbox, gt_box, gt_label, difficult)
  537. detection_map.accumulate()
  538. self.map = detection_map.get_map()
  539. self.APs = detection_map.APs
  540. return self.map
  541. def cal_ap(self):
  542. '''计算各类AP。
  543. '''
  544. self.aps = dict()
  545. for id, ap in enumerate(self.APs):
  546. if id == 0:
  547. continue
  548. self.aps[self.cid2cname[id]] = ap
  549. return self.aps
  550. def generate_report(self):
  551. '''生成评估报告。
  552. '''
  553. report = dict()
  554. report['Confusion_Matrix'] = copy.deepcopy(self.cal_confusion_matrix()
  555. .tolist())
  556. report['mAP'] = copy.deepcopy(self.cal_map())
  557. report['PRAP'] = copy.deepcopy(self.cal_precision_recall())
  558. report['label_list'] = copy.deepcopy(list(self.cname2cid.keys()))
  559. report['label_list'].append('back_ground')
  560. per_ap = copy.deepcopy(self.cal_ap())
  561. for k, v in per_ap.items():
  562. report['PRAP'][k]["AP"] = v
  563. return report
  564. class InsSegEvaluator(DetEvaluator):
  565. def __init__(self, model_path, overlap_thresh=0.5, score_threshold=0.3):
  566. super(DetEvaluator, self).__init__()
  567. self.model_path = model_path
  568. self.overlap_thresh = overlap_thresh if overlap_thresh is not None else .5
  569. self.score_threshold = score_threshold if score_threshold is not None else .3
  570. def cal_confusion_matrix_mask(self):
  571. '''计算Mask的混淆矩阵。
  572. '''
  573. confusion_matrix = InsSegConfusionMatrix(
  574. num_classes=len(self.cid2cname.keys()),
  575. overlap_thresh=self.overlap_thresh,
  576. score_threshold=self.score_threshold)
  577. for im_id in self.gtimgids:
  578. if im_id not in set(self.maskimgids):
  579. segm = []
  580. else:
  581. segm = self.segm[im_id]
  582. gt_segm = self.gt[im_id]['gt_poly']
  583. gt_label = self.gt[im_id]['gt_class']
  584. is_crowd = self.gt[im_id]['is_crowd']
  585. confusion_matrix.update(segm, gt_segm, gt_label, is_crowd)
  586. self.confusion_matrix_mask = confusion_matrix.get_confusion_matrix()
  587. self.precision_recall_mask = dict()
  588. for id in range(len(self.cid2cname.keys()) - 1):
  589. if confusion_matrix.total_gt[id] == 0:
  590. recall = -1
  591. else:
  592. recall = confusion_matrix.total_tp[
  593. id] / confusion_matrix.total_gt[id]
  594. if confusion_matrix.total_pred[id] == 0:
  595. precision = -1
  596. else:
  597. precision = confusion_matrix.total_tp[
  598. id] / confusion_matrix.total_pred[id]
  599. self.precision_recall_mask[self.cid2cname[id + 1]] = {
  600. "precision": precision,
  601. "recall": recall
  602. }
  603. return self.confusion_matrix_mask
  604. def cal_precision_recall_mask(self):
  605. '''计算Mask的precision、recall。
  606. '''
  607. return self.precision_recall_mask
  608. def _summarize(self,
  609. coco_gt,
  610. ap=1,
  611. iouThr=None,
  612. areaRng='all',
  613. maxDets=100):
  614. p = coco_gt.params
  615. aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
  616. mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
  617. if ap == 1:
  618. s = coco_gt.eval['precision']
  619. if iouThr is not None:
  620. t = np.where(iouThr == p.iouThrs)[0]
  621. s = s[t]
  622. s = s[:, :, :, aind, mind]
  623. else:
  624. s = coco_gt.eval['recall']
  625. if iouThr is not None:
  626. t = np.where(iouThr == p.iouThrs)[0]
  627. s = s[t]
  628. s = s[:, :, aind, mind]
  629. if len(s[s > -1]) == 0:
  630. mean_s = -1
  631. else:
  632. mean_s = np.mean(s[s > -1])
  633. return mean_s
  634. def cal_map(self):
  635. '''计算BBox的mAP。
  636. '''
  637. if len(self.bbox) > 0:
  638. from pycocotools.cocoeval import COCOeval
  639. coco_dt = loadRes(self.coco, self.bbox)
  640. np.linspace = fixed_linspace
  641. coco_eval = COCOeval(self.coco, coco_dt, 'bbox')
  642. coco_eval.params.iouThrs = np.linspace(
  643. self.overlap_thresh, self.overlap_thresh, 1, endpoint=True)
  644. np.linspace = backup_linspace
  645. coco_eval.evaluate()
  646. coco_eval.accumulate()
  647. self.map = self._summarize(coco_eval, iouThr=self.overlap_thresh)
  648. precision = coco_eval.eval['precision'][0, :, :, 0, 2]
  649. num_classes = len(coco_eval.params.catIds)
  650. self.APs = [None] * num_classes
  651. for i in range(num_classes):
  652. per = precision[:, i]
  653. per = per[per > -1]
  654. self.APs[i] = np.sum(per) / 101 if per.shape[0] > 0 else None
  655. else:
  656. self.map = None
  657. self.APs = [None] * len(self.catid2clsid)
  658. return self.map
  659. def cal_ap(self):
  660. '''计算BBox的各类AP。
  661. '''
  662. self.aps = dict()
  663. for id, ap in enumerate(self.APs):
  664. self.aps[self.cid2cname[id + 1]] = ap
  665. return self.aps
  666. def cal_map_mask(self):
  667. '''计算Mask的mAP。
  668. '''
  669. if len(self.mask) > 0:
  670. from pycocotools.cocoeval import COCOeval
  671. coco_dt = loadRes(self.coco, self.mask)
  672. np.linspace = fixed_linspace
  673. coco_eval = COCOeval(self.coco, coco_dt, 'segm')
  674. coco_eval.params.iouThrs = np.linspace(
  675. self.overlap_thresh, self.overlap_thresh, 1, endpoint=True)
  676. np.linspace = backup_linspace
  677. coco_eval.evaluate()
  678. coco_eval.accumulate()
  679. self.map_mask = self._summarize(
  680. coco_eval, iouThr=self.overlap_thresh)
  681. precision = coco_eval.eval['precision'][0, :, :, 0, 2]
  682. num_classes = len(coco_eval.params.catIds)
  683. self.mask_APs = [None] * num_classes
  684. for i in range(num_classes):
  685. per = precision[:, i]
  686. per = per[per > -1]
  687. self.mask_APs[i] = np.sum(per) / 101 if per.shape[
  688. 0] > 0 else None
  689. else:
  690. self.map_mask = None
  691. self.mask_APs = [None] * len(self.catid2clsid)
  692. return self.map_mask
  693. def cal_ap_mask(self):
  694. '''计算Mask的各类AP。
  695. '''
  696. self.mask_aps = dict()
  697. for id, ap in enumerate(self.mask_APs):
  698. self.mask_aps[self.cid2cname[id + 1]] = ap
  699. return self.mask_aps
  700. def generate_report(self):
  701. '''生成评估报告。
  702. '''
  703. report = dict()
  704. report['BBox_Confusion_Matrix'] = copy.deepcopy(
  705. self.cal_confusion_matrix().tolist())
  706. report['BBox_mAP'] = copy.deepcopy(self.cal_map())
  707. report['BBox_PRAP'] = copy.deepcopy(self.cal_precision_recall())
  708. report['label_list'] = copy.deepcopy(list(self.cname2cid.keys()))
  709. report['label_list'].append('back_ground')
  710. per_ap = copy.deepcopy(self.cal_ap())
  711. for k, v in per_ap.items():
  712. report['BBox_PRAP'][k]['AP'] = v
  713. report['Mask_Confusion_Matrix'] = copy.deepcopy(
  714. self.cal_confusion_matrix_mask().tolist())
  715. report['Mask_mAP'] = copy.deepcopy(self.cal_map_mask())
  716. report['Mask_PRAP'] = copy.deepcopy(self.cal_precision_recall_mask())
  717. per_ap_mask = copy.deepcopy(self.cal_ap_mask())
  718. for k, v in per_ap_mask.items():
  719. report['Mask_PRAP'][k]['AP'] = v
  720. return report