Эх сурвалжийг харах

visualize precision-recall curve

FlyingQianMM 5 жил өмнө
parent
commit
44a55f792b

+ 128 - 0
paddlex/cv/models/utils/visualize.py

@@ -15,7 +15,11 @@
 import os
 import cv2
 import numpy as np
+#import matplotlib
+#matplotlib.use('Agg')
+import matplotlib.pyplot as plt
 from PIL import Image, ImageDraw
+from .detection_eval import fixed_linspace, backup_linspace, loadRes
 
 
 def visualize_detection(image, result, threshold=0.5, save_dir=None):
@@ -160,3 +164,127 @@ def draw_bbox_mask(image, results, threshold=0.5, alpha=0.7):
             img_array[idx[0], idx[1], :] += alpha * color_mask
             image = Image.fromarray(img_array.astype('uint8'))
     return image
+
+
+def draw_pr_curve(eval_details_file=None,
+                  gt=None,
+                  pred_bbox=None,
+                  pred_mask=None,
+                  iou_thresh=0.5,
+                  save_dir='./'):
+    if eval_details_file is not None:
+        import json
+        with open(eval_details_file, 'r') as f:
+            eval_details = json.load(f)
+            pred_bbox = eval_details['bbox']
+            if 'mask' in eval_details:
+                pred_mask = eval_details['mask']
+            gt = eval_details['gt']
+    if gt is None or pred_bbox is None:
+        raise Exception(
+            "gt/pred_bbox/pred_mask is None now, please set right eval_details_file or gt/pred_bbox/pred_mask."
+        )
+    if pred_bbox is not None and len(pred_bbox) == 0:
+        raise Exception("There is no predicted bbox.")
+    if pred_mask is not None and len(pred_mask) == 0:
+        raise Exception("There is no predicted mask.")
+    from pycocotools.coco import COCO
+    from pycocotools.cocoeval import COCOeval
+    coco = COCO()
+    coco.dataset = gt
+    coco.createIndex()
+
+    def _summarize(coco_gt, ap=1, iouThr=None, areaRng='all', maxDets=100):
+        p = coco_gt.params
+        aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
+        mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
+        if ap == 1:
+            # dimension of precision: [TxRxKxAxM]
+            s = coco_gt.eval['precision']
+            # IoU
+            if iouThr is not None:
+                t = np.where(iouThr == p.iouThrs)[0]
+                s = s[t]
+            s = s[:, :, :, aind, mind]
+        else:
+            # dimension of recall: [TxKxAxM]
+            s = coco_gt.eval['recall']
+            if iouThr is not None:
+                t = np.where(iouThr == p.iouThrs)[0]
+                s = s[t]
+            s = s[:, :, aind, mind]
+        if len(s[s > -1]) == 0:
+            mean_s = -1
+        else:
+            mean_s = np.mean(s[s > -1])
+        return mean_s
+
+    def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'):
+        from pycocotools.cocoeval import COCOeval
+        coco_dt = loadRes(coco_gt, coco_dt)
+        np.linspace = fixed_linspace
+        coco_eval = COCOeval(coco_gt, coco_dt, style)
+        coco_eval.params.iouThrs = np.linspace(
+            iou_thresh, iou_thresh, 1, endpoint=True)
+        np.linspace = backup_linspace
+        coco_eval.evaluate()
+        coco_eval.accumulate()
+        stats = _summarize(coco_eval, iouThr=iou_thresh)
+        catIds = coco_gt.getCatIds()
+        if len(catIds) != coco_eval.eval['precision'].shape[2]:
+            raise Exception(
+                "The category number must be same as the third dimension of precisions."
+            )
+        x = np.arange(0.0, 1.01, 0.01)
+        color_map = get_color_map_list(256)[1:256]
+
+        plt.subplot(1, 2, 1)
+        plt.title(style + " precision-recall IoU={}".format(iou_thresh))
+        plt.xlabel("recall")
+        plt.ylabel("precision")
+        plt.xlim(0, 1.01)
+        plt.ylim(0, 1.01)
+        plt.grid(linestyle='--', linewidth=1)
+        plt.plot([0, 1], [0, 1], 'r--', linewidth=1)
+        my_x_ticks = np.arange(0, 1.01, 0.1)
+        my_y_ticks = np.arange(0, 1.01, 0.1)
+        plt.xticks(my_x_ticks, fontsize=5)
+        plt.yticks(my_y_ticks, fontsize=5)
+        for idx, catId in enumerate(catIds):
+            pr_array = coco_eval.eval['precision'][0, :, idx, 0, 2]
+            precision = pr_array[pr_array > -1]
+            ap = np.mean(precision) if precision.size else float('nan')
+            nm = coco_gt.loadCats(catId)[0]['name'] + ' AP={:0.2f}'.format(
+                float(ap * 100))
+            color = tuple(color_map[idx])
+            color = [float(c) / 255 for c in color]
+            color.append(0.75)
+            plt.plot(x, pr_array, color=color, label=nm, linewidth=1)
+        plt.legend(loc="lower left", fontsize=5)
+
+        plt.subplot(1, 2, 2)
+        plt.title(style + " score-recall IoU={}".format(iou_thresh))
+        plt.xlabel('recall')
+        plt.ylabel('score')
+        plt.xlim(0, 1.01)
+        plt.ylim(0, 1.01)
+        plt.grid(linestyle='--', linewidth=1)
+        plt.xticks(my_x_ticks, fontsize=5)
+        plt.yticks(my_y_ticks, fontsize=5)
+        for idx, catId in enumerate(catIds):
+            nm = coco_gt.loadCats(catId)[0]['name']
+            sr_array = coco_eval.eval['scores'][0, :, idx, 0, 2]
+            color = tuple(color_map[idx])
+            color = [float(c) / 255 for c in color]
+            color.append(0.75)
+            plt.plot(x, sr_array, color=color, label=nm, linewidth=1)
+        plt.legend(loc="lower right", fontsize=5)
+        plt.savefig(
+            os.path.join(save_dir, "./{}_pr_curve(iou-{}).png".format(
+                style, iou_thresh)),
+            dpi=800)
+        plt.close()
+
+    cal_pr(coco, pred_bbox, iou_thresh, save_dir, style='bbox')
+    if pred_mask is not None:
+        cal_pr(coco, pred_mask, iou_thresh, save_dir, style='segm')