فهرست منبع

refine det_visualize

will-jl944 4 سال پیش
والد
کامیت
34a628c187
2فایلهای تغییر یافته به همراه48 افزوده شده و 95 حذف شده
  1. 1 1
      dygraph/paddlex/cv/models/detector.py
  2. 47 94
      dygraph/paddlex/cv/models/utils/visualize.py

+ 1 - 1
dygraph/paddlex/cv/models/detector.py

@@ -531,7 +531,7 @@ class BaseDetector(BaseModel):
                     category = self.labels[int(num_id)]
                     w = xmax - xmin
                     h = ymax - ymin
-                    bbox = [xmin, ymin, w, h]
+                    bbox = list(map(int, [xmin, ymin, w, h]))
                     dt_res = {
                         'category_id': int(num_id),
                         'category': category,

+ 47 - 94
dygraph/paddlex/cv/models/utils/visualize.py

@@ -37,7 +37,6 @@ def visualize_detection(image,
     else:
         image_name = os.path.split(image)[-1]
         image = cv2.imread(image)
-
     image = draw_bbox_mask(image, result, threshold=threshold, color_map=color)
     if save_dir is not None:
         if not os.path.exists(save_dir):
@@ -164,12 +163,7 @@ def clip_bbox(bbox):
 
 
 def draw_bbox_mask(image, results, threshold=0.5, color_map=None):
-    import matplotlib
-    matplotlib.use('Agg')
-    import matplotlib as mpl
-    import matplotlib.figure as mplfigure
     import matplotlib.colors as mplc
-    from matplotlib.backends.backend_agg import FigureCanvasAgg
 
     # refer to https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/visualizer.py
     def _change_color_brightness(color, brightness_factor):
@@ -185,21 +179,9 @@ def draw_bbox_mask(image, results, threshold=0.5, color_map=None):
         return modified_color
 
     _SMALL_OBJECT_AREA_THRESH = 1000
-    # setup figure
-    width, height = image.shape[1], image.shape[0]
-    scale = 1
-    fig = mplfigure.Figure(frameon=False)
-    dpi = fig.get_dpi()
-    fig.set_size_inches(
-        (width * scale + 1e-2) / dpi,
-        (height * scale + 1e-2) / dpi, )
-    canvas = FigureCanvasAgg(fig)
-    ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
-    ax.axis("off")
-    ax.set_xlim(0.0, width)
-    ax.set_ylim(height)
-    default_font_size = max(np.sqrt(height * width) // 90, 10 // scale)
-    linewidth = max(default_font_size / 4, 1)
+    height, width = image.shape[:2]
+    default_font_scale = max(np.sqrt(height * width) // 900, .5)
+    linewidth = max(default_font_scale / 40, 2)
 
     labels = list()
     for dt in results:
@@ -238,95 +220,66 @@ def draw_bbox_mask(image, results, threshold=0.5, color_map=None):
         ymax = ymin + h
 
         color = tuple(color_map[labels.index(cname)])
-        color = [c / 255. for c in color]
         # draw bbox
-        ax.add_patch(
-            mpl.patches.Rectangle(
-                (xmin, ymin),
-                w,
-                h,
-                fill=False,
-                edgecolor=color,
-                linewidth=linewidth * scale,
-                alpha=0.8,
-                linestyle="-", ))
+        image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color,
+                              linewidth)
 
         # draw mask
         if 'mask' in dt:
-            mask = mask_util.decode(dt['mask'])
-            mask = np.ascontiguousarray(mask)
-            res = cv2.findContours(
-                mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
-            hierarchy = res[-1]
-            alpha = 0.5
-            if hierarchy is not None:
-                has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
-                res = res[-2]
-                res = [x.flatten() for x in res]
-                res = [x for x in res if len(x) >= 6]
-                for segment in res:
-                    segment = segment.reshape(-1, 2)
-                    edge_color = mplc.to_rgb(color) + (1, )
-                    polygon = mpl.patches.Polygon(
-                        segment,
-                        fill=True,
-                        facecolor=mplc.to_rgb(color) + (alpha, ),
-                        edgecolor=edge_color,
-                        linewidth=max(default_font_size // 15 * scale, 1), )
-                    ax.add_patch(polygon)
+            mask = mask_util.decode(dt['mask']) * 255
+            image = image.astype('float32')
+            alpha = .7
+            w_ratio = .4
+            color_mask = np.asarray(color, dtype=np.int)
+            for c in range(3):
+                color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
+            idx = np.nonzero(mask)
+            image[idx[0], idx[1], :] *= 1.0 - alpha
+            image[idx[0], idx[1], :] += alpha * color_mask
+            image = image.astype("uint8")
+            contours = cv2.findContours(
+                mask.astype("uint8"), cv2.RETR_CCOMP,
+                cv2.CHAIN_APPROX_NONE)[-2]
+            image = cv2.drawContours(
+                image,
+                contours,
+                contourIdx=-1,
+                color=color,
+                thickness=1,
+                lineType=cv2.LINE_AA)
 
         # draw label
         text_pos = (xmin, ymin)
-        horiz_align = "left"
         instance_area = w * h
-        if (instance_area < _SMALL_OBJECT_AREA_THRESH * scale or
-                h < 40 * scale):
+        if (instance_area < _SMALL_OBJECT_AREA_THRESH or h < 40):
             if ymin >= height - 5:
                 text_pos = (xmin, ymin)
             else:
                 text_pos = (xmin, ymax)
         height_ratio = h / np.sqrt(height * width)
-        font_size = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2,
-                             2) * 0.5 * default_font_size)
+        font_scale = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2,
+                              2) * 0.5 * default_font_scale)
         text = "{} {:.2f}".format(cname, score)
-        color = np.maximum(list(mplc.to_rgb(color)), 0.2)
-        color[np.argmax(color)] = max(0.8, np.max(color))
-        color = _change_color_brightness(color, brightness_factor=0.7)
-        ax.text(
-            text_pos[0],
-            text_pos[1],
+        (tw, th), baseline = cv2.getTextSize(
             text,
-            size=font_size * scale,
-            family="sans-serif",
-            bbox={
-                "facecolor": "black",
-                "alpha": 0.8,
-                "pad": 0.7,
-                "edgecolor": "none"
-            },
-            verticalalignment="top",
-            horizontalalignment=horiz_align,
+            fontFace=cv2.FONT_HERSHEY_DUPLEX,
+            fontScale=font_scale,
+            thickness=1)
+        image = cv2.rectangle(
+            image,
+            text_pos, (text_pos[0] + tw, text_pos[1] + th + baseline),
             color=color,
-            zorder=10,
-            rotation=0, )
-
-    s, (width, height) = canvas.print_to_buffer()
-    buffer = np.frombuffer(s, dtype="uint8")
-
-    img_rgba = buffer.reshape(height, width, 4)
-    rgb, alpha = np.split(img_rgba, [3], axis=2)
-
-    try:
-        import numexpr as ne
-        visualized_image = ne.evaluate(
-            "image * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
-    except ImportError:
-        alpha = alpha.astype("float32") / 255.0
-        visualized_image = image * (1 - alpha) + rgb * alpha
-
-    visualized_image = visualized_image.astype("uint8")
-
-    return visualized_image
+            thickness=-1)
+        image = cv2.putText(
+            image,
+            text, (text_pos[0], text_pos[1] + th),
+            fontFace=cv2.FONT_HERSHEY_SIMPLEX,
+            fontScale=default_font_scale,
+            color=(255, 255, 255),
+            thickness=1,
+            lineType=cv2.LINE_AA)
+
+    return image
 
 
 def draw_pr_curve(eval_details_file=None,