Browse Source

Merge pull request #26 from FlyingQianMM/develop_qh

reimplement detection visualize
Jason 5 years ago
parent
commit
afa3547129

BIN
docs/apis/images/insect_bbox_pr_curve(iou-0.5).png


BIN
docs/apis/images/xiaoduxiong_bbox_pr_curve(iou-0.5).png


BIN
docs/apis/images/xiaoduxiong_segm_pr_curve(iou-0.5).png


+ 18 - 32
docs/apis/visualize.md

@@ -40,8 +40,16 @@ paddlex.det.draw_pr_curve(eval_details_file=None, gt=None, pred_bbox=None, pred_
 **注意:**`eval_details_file`的优先级更高,只要`eval_details_file`不为None,就会从`eval_details_file`提取真值信息和预测结果做分析。当`eval_details_file`为None时,则用`gt`、`pred_mask`、`pred_mask`做分析。
 
 ### 使用示例
-> 示例一:
-点击下载如下示例中的[模型](https://bj.bcebos.com/paddlex/models/xiaoduxiong_epoch_12.tar.gz)和[数据集](https://bj.bcebos.com/paddlex/datasets/xiaoduxiong_ins_det.tar.gz)
+点击下载如下示例中的[模型](https://bj.bcebos.com/paddlex/models/insect_epoch_270.zip)和[数据集](https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz)
+
+> 方式一:分析训练过程中保存的模型文件夹中的评估结果文件`eval_details.json`,例如[模型](https://bj.bcebos.com/paddlex/models/insect_epoch_270.zip)中的`eval_details.json`。
+```
+import paddlex as pdx
+eval_details_file = 'insect_epoch_270/eval_details.json'
+pdx.det.draw_pr_curve(eval_details_file, save_dir='./insect')
+```
+> 方式二:分析模型评估函数返回的评估结果。
+
 ```
 import os
 # 选择使用0号卡
@@ -50,40 +58,18 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 from paddlex.det import transforms
 import paddlex as pdx
 
-eval_transforms = transforms.Compose([
-    transforms.Normalize(),
-    transforms.ResizeByShort(short_size=800, max_size=1333),
-    transforms.Padding(coarsest_stride=32)
-])
-
-eval_dataset = pdx.datasets.CocoDetection(
-    data_dir='xiaoduxiong_ins_det/JPEGImages',
-    ann_file='xiaoduxiong_ins_det/val.json',
-    transforms=eval_transforms)
-
-model = pdx.load_model('xiaoduxiong_epoch_12')
-metrics, evaluate_details = model.evaluate(eval_dataset, batch_size=1, return_details=True)
+model = pdx.load_model('insect_epoch_270')
+eval_dataset = pdx.datasets.VOCDetection(
+    data_dir='insect_det',
+    file_list='insect_det/val_list.txt',
+    label_list='insect_det/labels.txt',
+    transforms=model.eval_transforms)
+metrics, evaluate_details = model.evaluate(eval_dataset, batch_size=8, return_details=True)
 gt = evaluate_details['gt']
 bbox = evaluate_details['bbox']
-mask = evaluate_details['mask']
-
-# 分别可视化bbox和mask的准召曲线
-pdx.det.draw_pr_curve(gt=gt, pred_bbox=bbox, pred_mask=mask, save_dir='./xiaoduxiong')
+pdx.det.draw_pr_curve(gt=gt, pred_bbox=bbox, save_dir='./insect')
 ```
-预测框的各个类别的准确率和召回率的对应关系、召回率和置信度阈值的对应关系可视化如下:
-![](./images/xiaoduxiong_bbox_pr_curve(iou-0.5).png)
-
-预测mask的各个类别的准确率和召回率的对应关系、召回率和置信度阈值的对应关系可视化如下:
-![](./images/xiaoduxiong_segm_pr_curve(iou-0.5).png)
-
-> 示例二:
-使用[yolov3_darknet53.py示例代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/detection/yolov3_darknet53.py)训练完成后,加载模型评估结果文件进行分析:
 
-```
-import paddlex as pdx
-eval_details_file = 'output/yolov3_darknet53/best_model/eval_details.json'
-pdx.det.draw_pr_curve(eval_details_file, save_dir='./insect')
-```
 预测框的各个类别的准确率和召回率的对应关系、召回率和置信度阈值的对应关系可视化如下:
 ![](./images/insect_bbox_pr_curve(iou-0.5).png)
 

BIN
docs/images/visualized_maskrcnn.jpeg


+ 2 - 1
paddlex/cv/datasets/voc.py

@@ -17,6 +17,7 @@ import copy
 import os.path as osp
 import random
 import numpy as np
+from collections import OrderedDict
 import xml.etree.ElementTree as ET
 import paddlex.utils.logging as logging
 from .dataset import Dataset
@@ -66,7 +67,7 @@ class VOCDetection(Dataset):
         annotations['categories'] = []
         annotations['annotations'] = []
 
-        cname2cid = {}
+        cname2cid = OrderedDict()
         label_id = 1
         with open(label_list, 'r', encoding=get_encoding(label_list)) as fr:
             for line in fr.readlines():

+ 145 - 30
paddlex/cv/models/utils/visualize.py

@@ -14,9 +14,8 @@
 
 import os
 import cv2
+import colorsys
 import numpy as np
-from PIL import Image, ImageDraw
-
 import paddlex.utils.logging as logging
 from .detection_eval import fixed_linspace, backup_linspace, loadRes
 
@@ -27,13 +26,13 @@ def visualize_detection(image, result, threshold=0.5, save_dir='./'):
     """
 
     image_name = os.path.split(image)[-1]
-    image = Image.open(image).convert('RGB')
+    image = cv2.imread(image)
     image = draw_bbox_mask(image, result, threshold=threshold)
     if save_dir is not None:
         if not os.path.exists(save_dir):
             os.makedirs(save_dir)
         out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
-        image.save(out_path, quality=95)
+        cv2.imwrite(out_path, image)
         logging.info('The visualized result is saved as {}'.format(out_path))
     else:
         return image
@@ -122,49 +121,163 @@ def clip_bbox(bbox):
     return xmin, ymin, xmax, ymax
 
 
-def draw_bbox_mask(image, results, threshold=0.5, alpha=0.7):
+def draw_bbox_mask(image, results, threshold=0.5):
+    import matplotlib
+    matplotlib.use('Agg')
+    import matplotlib as mpl
+    import matplotlib.figure as mplfigure
+    import matplotlib.colors as mplc
+    from matplotlib.backends.backend_agg import FigureCanvasAgg
+
+    # refer to  https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/visualizer.py
+    def _change_color_brightness(color, brightness_factor):
+        assert brightness_factor >= -1.0 and brightness_factor <= 1.0
+        color = mplc.to_rgb(color)
+        polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
+        modified_lightness = polygon_color[1] + (
+            brightness_factor * polygon_color[1])
+        modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
+        modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
+        modified_color = colorsys.hls_to_rgb(
+            polygon_color[0], modified_lightness, polygon_color[2])
+        return modified_color
+
+    _SMALL_OBJECT_AREA_THRESH = 1000
+    # setup figure
+    width, height = image.shape[1], image.shape[0]
+    scale = 1
+    fig = mplfigure.Figure(frameon=False)
+    dpi = fig.get_dpi()
+    fig.set_size_inches(
+        (width * scale + 1e-2) / dpi,
+        (height * scale + 1e-2) / dpi,
+    )
+    canvas = FigureCanvasAgg(fig)
+    ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
+    ax.axis("off")
+    ax.set_xlim(0.0, width)
+    ax.set_ylim(height)
+    default_font_size = max(np.sqrt(height * width) // 90, 10 // scale)
+    linewidth = max(default_font_size / 4, 1)
+
     labels = list()
     for dt in np.array(results):
         if dt['category'] not in labels:
             labels.append(dt['category'])
-    color_map = get_color_map_list(len(labels))
+    color_map = get_color_map_list(256)
 
+    keep_results = []
+    areas = []
     for dt in np.array(results):
         cname, bbox, score = dt['category'], dt['bbox'], dt['score']
         if score < threshold:
             continue
-
+        keep_results.append(dt)
+        areas.append(bbox[2] * bbox[3])
+    areas = np.asarray(areas)
+    sorted_idxs = np.argsort(-areas).tolist()
+    keep_results = [keep_results[k]
+                    for k in sorted_idxs] if len(keep_results) > 0 else []
+
+    for dt in np.array(keep_results):
+        cname, bbox, score = dt['category'], dt['bbox'], dt['score']
         xmin, ymin, w, h = bbox
         xmax = xmin + w
         ymax = ymin + h
 
-        color = tuple(color_map[labels.index(cname)])
-
+        color = tuple(color_map[labels.index(cname) + 2])
+        color = [c / 255. for c in color]
         # draw bbox
-        draw = ImageDraw.Draw(image)
-        draw.line([(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
-                   (xmin, ymin)],
-                  width=2,
-                  fill=color)
-
-        # draw label
-        text = "{} {:.2f}".format(cname, score)
-        tw, th = draw.textsize(text)
-        draw.rectangle([(xmin + 1, ymin - th), (xmin + tw + 1, ymin)],
-                       fill=color)
-        draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
+        ax.add_patch(
+            mpl.patches.Rectangle(
+                (xmin, ymin),
+                w,
+                h,
+                fill=False,
+                edgecolor=color,
+                linewidth=linewidth * scale,
+                alpha=0.8,
+                linestyle="-",
+            ))
 
         # draw mask
         if 'mask' in dt:
             mask = dt['mask']
-            color_mask = np.array(color_map[labels.index(
-                dt['category'])]).astype('float32')
-            img_array = np.array(image).astype('float32')
-            idx = np.nonzero(mask)
-            img_array[idx[0], idx[1], :] *= 1.0 - alpha
-            img_array[idx[0], idx[1], :] += alpha * color_mask
-            image = Image.fromarray(img_array.astype('uint8'))
-    return image
+            mask = np.ascontiguousarray(mask)
+            res = cv2.findContours(
+                mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
+            hierarchy = res[-1]
+            alpha = 0.5
+            if hierarchy is not None:
+                has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
+                res = res[-2]
+                res = [x.flatten() for x in res]
+                res = [x for x in res if len(x) >= 6]
+                for segment in res:
+                    segment = segment.reshape(-1, 2)
+                    edge_color = mplc.to_rgb(color) + (1, )
+                    polygon = mpl.patches.Polygon(
+                        segment,
+                        fill=True,
+                        facecolor=mplc.to_rgb(color) + (alpha, ),
+                        edgecolor=edge_color,
+                        linewidth=max(default_font_size // 15 * scale, 1),
+                    )
+                    ax.add_patch(polygon)
+
+        # draw label
+        text_pos = (xmin, ymin)
+        horiz_align = "left"
+        instance_area = w * h
+        if (instance_area < _SMALL_OBJECT_AREA_THRESH * scale
+                or h < 40 * scale):
+            if ymin >= height - 5:
+                text_pos = (xmin, ymin)
+            else:
+                text_pos = (xmin, ymax)
+        height_ratio = h / np.sqrt(height * width)
+        font_size = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 *
+                     default_font_size)
+        text = "{} {:.2f}".format(cname, score)
+        color = np.maximum(list(mplc.to_rgb(color)), 0.2)
+        color[np.argmax(color)] = max(0.8, np.max(color))
+        color = _change_color_brightness(color, brightness_factor=0.7)
+        ax.text(
+            text_pos[0],
+            text_pos[1],
+            text,
+            size=font_size * scale,
+            family="sans-serif",
+            bbox={
+                "facecolor": "black",
+                "alpha": 0.8,
+                "pad": 0.7,
+                "edgecolor": "none"
+            },
+            verticalalignment="top",
+            horizontalalignment=horiz_align,
+            color=color,
+            zorder=10,
+            rotation=0,
+        )
+
+    s, (width, height) = canvas.print_to_buffer()
+    buffer = np.frombuffer(s, dtype="uint8")
+
+    img_rgba = buffer.reshape(height, width, 4)
+    rgb, alpha = np.split(img_rgba, [3], axis=2)
+
+    try:
+        import numexpr as ne
+        visualized_image = ne.evaluate(
+            "image * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
+    except ImportError:
+        alpha = alpha.astype("float32") / 255.0
+        visualized_image = image * (1 - alpha) + rgb * alpha
+
+    visualized_image = visualized_image.astype("uint8")
+
+    return visualized_image
 
 
 def draw_pr_curve(eval_details_file=None,
@@ -189,6 +302,9 @@ def draw_pr_curve(eval_details_file=None,
         raise Exception("There is no predicted bbox.")
     if pred_mask is not None and len(pred_mask) == 0:
         raise Exception("There is no predicted mask.")
+    import matplotlib
+    matplotlib.use('Agg')
+    import matplotlib.pyplot as plt
     from pycocotools.coco import COCO
     from pycocotools.cocoeval import COCOeval
     coco = COCO()
@@ -221,7 +337,6 @@ def draw_pr_curve(eval_details_file=None,
         return mean_s
 
     def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'):
-        import matplotlib.pyplot as plt
         from pycocotools.cocoeval import COCOeval
         coco_dt = loadRes(coco_gt, coco_dt)
         np.linspace = fixed_linspace