|
|
@@ -15,7 +15,11 @@
|
|
|
import os
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
+#import matplotlib
|
|
|
+#matplotlib.use('Agg')
|
|
|
+import matplotlib.pyplot as plt
|
|
|
from PIL import Image, ImageDraw
|
|
|
+from .detection_eval import fixed_linspace, backup_linspace, loadRes
|
|
|
|
|
|
|
|
|
def visualize_detection(image, result, threshold=0.5, save_dir=None):
|
|
|
@@ -160,3 +164,127 @@ def draw_bbox_mask(image, results, threshold=0.5, alpha=0.7):
|
|
|
img_array[idx[0], idx[1], :] += alpha * color_mask
|
|
|
image = Image.fromarray(img_array.astype('uint8'))
|
|
|
return image
|
|
|
+
|
|
|
+
|
|
|
+def draw_pr_curve(eval_details_file=None,
|
|
|
+ gt=None,
|
|
|
+ pred_bbox=None,
|
|
|
+ pred_mask=None,
|
|
|
+ iou_thresh=0.5,
|
|
|
+ save_dir='./'):
|
|
|
+ if eval_details_file is not None:
|
|
|
+ import json
|
|
|
+ with open(eval_details_file, 'r') as f:
|
|
|
+ eval_details = json.load(f)
|
|
|
+ pred_bbox = eval_details['bbox']
|
|
|
+ if 'mask' in eval_details:
|
|
|
+ pred_mask = eval_details['mask']
|
|
|
+ gt = eval_details['gt']
|
|
|
+ if gt is None or pred_bbox is None:
|
|
|
+ raise Exception(
|
|
|
+ "gt/pred_bbox/pred_mask is None now, please set right eval_details_file or gt/pred_bbox/pred_mask."
|
|
|
+ )
|
|
|
+ if pred_bbox is not None and len(pred_bbox) == 0:
|
|
|
+ raise Exception("There is no predicted bbox.")
|
|
|
+ if pred_mask is not None and len(pred_mask) == 0:
|
|
|
+ raise Exception("There is no predicted mask.")
|
|
|
+ from pycocotools.coco import COCO
|
|
|
+ from pycocotools.cocoeval import COCOeval
|
|
|
+ coco = COCO()
|
|
|
+ coco.dataset = gt
|
|
|
+ coco.createIndex()
|
|
|
+
|
|
|
+ def _summarize(coco_gt, ap=1, iouThr=None, areaRng='all', maxDets=100):
|
|
|
+ p = coco_gt.params
|
|
|
+ aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
|
|
|
+ mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
|
|
|
+ if ap == 1:
|
|
|
+ # dimension of precision: [TxRxKxAxM]
|
|
|
+ s = coco_gt.eval['precision']
|
|
|
+ # IoU
|
|
|
+ if iouThr is not None:
|
|
|
+ t = np.where(iouThr == p.iouThrs)[0]
|
|
|
+ s = s[t]
|
|
|
+ s = s[:, :, :, aind, mind]
|
|
|
+ else:
|
|
|
+ # dimension of recall: [TxKxAxM]
|
|
|
+ s = coco_gt.eval['recall']
|
|
|
+ if iouThr is not None:
|
|
|
+ t = np.where(iouThr == p.iouThrs)[0]
|
|
|
+ s = s[t]
|
|
|
+ s = s[:, :, aind, mind]
|
|
|
+ if len(s[s > -1]) == 0:
|
|
|
+ mean_s = -1
|
|
|
+ else:
|
|
|
+ mean_s = np.mean(s[s > -1])
|
|
|
+ return mean_s
|
|
|
+
|
|
|
+ def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'):
|
|
|
+ from pycocotools.cocoeval import COCOeval
|
|
|
+ coco_dt = loadRes(coco_gt, coco_dt)
|
|
|
+ np.linspace = fixed_linspace
|
|
|
+ coco_eval = COCOeval(coco_gt, coco_dt, style)
|
|
|
+ coco_eval.params.iouThrs = np.linspace(
|
|
|
+ iou_thresh, iou_thresh, 1, endpoint=True)
|
|
|
+ np.linspace = backup_linspace
|
|
|
+ coco_eval.evaluate()
|
|
|
+ coco_eval.accumulate()
|
|
|
+ stats = _summarize(coco_eval, iouThr=iou_thresh)
|
|
|
+ catIds = coco_gt.getCatIds()
|
|
|
+ if len(catIds) != coco_eval.eval['precision'].shape[2]:
|
|
|
+ raise Exception(
|
|
|
+ "The category number must be same as the third dimension of precisions."
|
|
|
+ )
|
|
|
+ x = np.arange(0.0, 1.01, 0.01)
|
|
|
+ color_map = get_color_map_list(256)[1:256]
|
|
|
+
|
|
|
+ plt.subplot(1, 2, 1)
|
|
|
+ plt.title(style + " precision-recall IoU={}".format(iou_thresh))
|
|
|
+ plt.xlabel("recall")
|
|
|
+ plt.ylabel("precision")
|
|
|
+ plt.xlim(0, 1.01)
|
|
|
+ plt.ylim(0, 1.01)
|
|
|
+ plt.grid(linestyle='--', linewidth=1)
|
|
|
+ plt.plot([0, 1], [0, 1], 'r--', linewidth=1)
|
|
|
+ my_x_ticks = np.arange(0, 1.01, 0.1)
|
|
|
+ my_y_ticks = np.arange(0, 1.01, 0.1)
|
|
|
+ plt.xticks(my_x_ticks, fontsize=5)
|
|
|
+ plt.yticks(my_y_ticks, fontsize=5)
|
|
|
+ for idx, catId in enumerate(catIds):
|
|
|
+ pr_array = coco_eval.eval['precision'][0, :, idx, 0, 2]
|
|
|
+ precision = pr_array[pr_array > -1]
|
|
|
+ ap = np.mean(precision) if precision.size else float('nan')
|
|
|
+ nm = coco_gt.loadCats(catId)[0]['name'] + ' AP={:0.2f}'.format(
|
|
|
+ float(ap * 100))
|
|
|
+ color = tuple(color_map[idx])
|
|
|
+ color = [float(c) / 255 for c in color]
|
|
|
+ color.append(0.75)
|
|
|
+ plt.plot(x, pr_array, color=color, label=nm, linewidth=1)
|
|
|
+ plt.legend(loc="lower left", fontsize=5)
|
|
|
+
|
|
|
+ plt.subplot(1, 2, 2)
|
|
|
+ plt.title(style + " score-recall IoU={}".format(iou_thresh))
|
|
|
+ plt.xlabel('recall')
|
|
|
+ plt.ylabel('score')
|
|
|
+ plt.xlim(0, 1.01)
|
|
|
+ plt.ylim(0, 1.01)
|
|
|
+ plt.grid(linestyle='--', linewidth=1)
|
|
|
+ plt.xticks(my_x_ticks, fontsize=5)
|
|
|
+ plt.yticks(my_y_ticks, fontsize=5)
|
|
|
+ for idx, catId in enumerate(catIds):
|
|
|
+ nm = coco_gt.loadCats(catId)[0]['name']
|
|
|
+ sr_array = coco_eval.eval['scores'][0, :, idx, 0, 2]
|
|
|
+ color = tuple(color_map[idx])
|
|
|
+ color = [float(c) / 255 for c in color]
|
|
|
+ color.append(0.75)
|
|
|
+ plt.plot(x, sr_array, color=color, label=nm, linewidth=1)
|
|
|
+ plt.legend(loc="lower right", fontsize=5)
|
|
|
+ plt.savefig(
|
|
|
+ os.path.join(save_dir, "./{}_pr_curve(iou-{}).png".format(
|
|
|
+ style, iou_thresh)),
|
|
|
+ dpi=800)
|
|
|
+ plt.close()
|
|
|
+
|
|
|
+ cal_pr(coco, pred_bbox, iou_thresh, save_dir, style='bbox')
|
|
|
+ if pred_mask is not None:
|
|
|
+ cal_pr(coco, pred_mask, iou_thresh, save_dir, style='segm')
|