瀏覽代碼

reimplement detection visualize

FlyingQianMM 5 年之前
父節點
當前提交
52d9fcbdea
共有 1 個文件被更改,包括 125 次插入27 次删除
  1. 125 27
      paddlex/cv/models/utils/visualize.py

+ 125 - 27
paddlex/cv/models/utils/visualize.py

@@ -15,7 +15,10 @@
 import os
 import cv2
 import numpy as np
-from PIL import Image, ImageDraw
+import matplotlib as mpl
+import matplotlib.figure as mplfigure
+import matplotlib.colors as mplc
+from matplotlib.backends.backend_agg import FigureCanvasAgg
 
 
 def visualize_detection(image, result, threshold=0.5, save_dir=None):
@@ -24,13 +27,13 @@ def visualize_detection(image, result, threshold=0.5, save_dir=None):
     """
 
     image_name = os.path.split(image)[-1]
-    image = Image.open(image).convert('RGB')
+    image = cv2.imread(image)
     image = draw_bbox_mask(image, result, threshold=threshold)
     if save_dir is not None:
         if not os.path.exists(save_dir):
             os.makedirs(save_dir)
         out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
-        image.save(out_path, quality=95)
+        cv2.imwrite(out_path, image)
     else:
         return image
 
@@ -117,46 +120,141 @@ def clip_bbox(bbox):
     return xmin, ymin, xmax, ymax
 
 
-def draw_bbox_mask(image, results, threshold=0.5, alpha=0.7):
+def draw_bbox_mask(image, results, threshold=0.5):
+    # refer to  https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/visualizer.py
+    _SMALL_OBJECT_AREA_THRESH = 1000
+    # setup figure
+    width, height = image.shape[1], image.shape[0]
+    scale = 1
+    fig = mplfigure.Figure(frameon=False)
+    dpi = fig.get_dpi()
+    fig.set_size_inches(
+        (width * scale + 1e-2) / dpi,
+        (height * scale + 1e-2) / dpi,
+    )
+    canvas = FigureCanvasAgg(fig)
+    ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
+    ax.axis("off")
+    ax.set_xlim(0.0, width)
+    ax.set_ylim(height)
+    default_font_size = max(np.sqrt(height * width) // 90, 10 // scale)
+    linewidth = max(default_font_size / 4, 1)
+
     labels = list()
     for dt in np.array(results):
         if dt['category'] not in labels:
             labels.append(dt['category'])
-    color_map = get_color_map_list(len(labels))
+    color_map = get_color_map_list(256)
 
+    keep_results = []
+    areas = []
     for dt in np.array(results):
         cname, bbox, score = dt['category'], dt['bbox'], dt['score']
         if score < threshold:
             continue
+        keep_results.append(dt)
+        areas.append(bbox[2] * bbox[3])
+    areas = np.asarray(areas)
+    sorted_idxs = np.argsort(-areas).tolist()
+    keep_results = [keep_results[k]
+                    for k in sorted_idxs] if len(keep_results) > 0 else []
 
+    for dt in np.array(keep_results):
+        cname, bbox, score = dt['category'], dt['bbox'], dt['score']
         xmin, ymin, w, h = bbox
         xmax = xmin + w
         ymax = ymin + h
 
-        color = tuple(color_map[labels.index(cname)])
-
+        color = tuple(color_map[labels.index(cname) + 2])
+        color = [c / 255. for c in color]
         # draw bbox
-        draw = ImageDraw.Draw(image)
-        draw.line([(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
-                   (xmin, ymin)],
-                  width=2,
-                  fill=color)
-
-        # draw label
-        text = "{} {:.2f}".format(cname, score)
-        tw, th = draw.textsize(text)
-        draw.rectangle([(xmin + 1, ymin - th), (xmin + tw + 1, ymin)],
-                       fill=color)
-        draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
+        ax.add_patch(
+            mpl.patches.Rectangle(
+                (xmin, ymin),
+                w,
+                h,
+                fill=False,
+                edgecolor=color,
+                linewidth=linewidth * scale,
+                alpha=0.5,
+                linestyle="-",
+            ))
 
         # draw mask
         if 'mask' in dt:
             mask = dt['mask']
-            color_mask = np.array(color_map[labels.index(
-                dt['category'])]).astype('float32')
-            img_array = np.array(image).astype('float32')
-            idx = np.nonzero(mask)
-            img_array[idx[0], idx[1], :] *= 1.0 - alpha
-            img_array[idx[0], idx[1], :] += alpha * color_mask
-            image = Image.fromarray(img_array.astype('uint8'))
-    return image
+            mask = np.ascontiguousarray(mask)
+            res = cv2.findContours(
+                mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
+            hierarchy = res[-1]
+            alpha = 0.75
+            if hierarchy is not None:
+                has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
+                res = res[-2]
+                res = [x.flatten() for x in res]
+                res = [x for x in res if len(x) >= 6]
+                for segment in res:
+                    segment = segment.reshape(-1, 2)
+                    edge_color = mplc.to_rgb(color) + (1, )
+                    polygon = mpl.patches.Polygon(
+                        segment,
+                        fill=True,
+                        facecolor=mplc.to_rgb(color) + (alpha, ),
+                        edgecolor=edge_color,
+                        linewidth=max(default_font_size // 15 * scale, 1),
+                    )
+                    ax.add_patch(polygon)
+
+        # draw label
+        text_pos = (xmin, ymin)
+        horiz_align = "left"
+        instance_area = w * h
+        if (instance_area < _SMALL_OBJECT_AREA_THRESH * scale
+                or h < 40 * scale):
+            if ymin >= height - 5:
+                text_pos = (xmin, ymin)
+            else:
+                text_pos = (xmin, ymax)
+        height_ratio = h / np.sqrt(height * width)
+        font_size = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 *
+                     default_font_size)
+        text = "{} {:.2f}".format(cname, score)
+        color = np.maximum(list(mplc.to_rgb(color)), 0.2)
+        color[np.argmax(color)] = max(0.8, np.max(color))
+
+        ax.text(
+            text_pos[0],
+            text_pos[1],
+            text,
+            size=font_size * scale,
+            family="sans-serif",
+            bbox={
+                "facecolor": "black",
+                "alpha": 0.8,
+                "pad": 0.7,
+                "edgecolor": "none"
+            },
+            verticalalignment="top",
+            horizontalalignment=horiz_align,
+            color=color,
+            zorder=10,
+            rotation=0,
+        )
+
+    s, (width, height) = canvas.print_to_buffer()
+    buffer = np.frombuffer(s, dtype="uint8")
+
+    img_rgba = buffer.reshape(height, width, 4)
+    rgb, alpha = np.split(img_rgba, [3], axis=2)
+
+    try:
+        import numexpr as ne
+        visualized_image = ne.evaluate(
+            "image * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
+    except ImportError:
+        alpha = alpha.astype("float32") / 255.0
+        visualized_image = image * (1 - alpha) + rgb * alpha
+
+    visualized_image = visualized_image.astype("uint8")
+
+    return visualized_image