Эх сурвалжийг харах

Merge pull request #877 from will-jl944/develop_jf

refine det_visualize
FlyingQianMM 4 жил өмнө
parent
commit
13bb67e636

+ 48 - 110
dygraph/paddlex/cv/models/utils/visualize.py

@@ -14,7 +14,6 @@
 
 import os
 import cv2
-import colorsys
 import numpy as np
 import time
 import pycocotools.mask as mask_util
@@ -37,7 +36,6 @@ def visualize_detection(image,
     else:
         image_name = os.path.split(image)[-1]
         image = cv2.imread(image)
-
     image = draw_bbox_mask(image, result, threshold=threshold, color_map=color)
     if save_dir is not None:
         if not os.path.exists(save_dir):
@@ -164,42 +162,10 @@ def clip_bbox(bbox):
 
 
 def draw_bbox_mask(image, results, threshold=0.5, color_map=None):
-    import matplotlib
-    matplotlib.use('Agg')
-    import matplotlib as mpl
-    import matplotlib.figure as mplfigure
-    import matplotlib.colors as mplc
-    from matplotlib.backends.backend_agg import FigureCanvasAgg
-
-    # refer to https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/visualizer.py
-    def _change_color_brightness(color, brightness_factor):
-        assert brightness_factor >= -1.0 and brightness_factor <= 1.0
-        color = mplc.to_rgb(color)
-        polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
-        modified_lightness = polygon_color[1] + (brightness_factor *
-                                                 polygon_color[1])
-        modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
-        modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
-        modified_color = colorsys.hls_to_rgb(
-            polygon_color[0], modified_lightness, polygon_color[2])
-        return modified_color
-
     _SMALL_OBJECT_AREA_THRESH = 1000
-    # setup figure
-    width, height = image.shape[1], image.shape[0]
-    scale = 1
-    fig = mplfigure.Figure(frameon=False)
-    dpi = fig.get_dpi()
-    fig.set_size_inches(
-        (width * scale + 1e-2) / dpi,
-        (height * scale + 1e-2) / dpi, )
-    canvas = FigureCanvasAgg(fig)
-    ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
-    ax.axis("off")
-    ax.set_xlim(0.0, width)
-    ax.set_ylim(height)
-    default_font_size = max(np.sqrt(height * width) // 90, 10 // scale)
-    linewidth = max(default_font_size / 4, 1)
+    height, width = image.shape[:2]
+    default_font_scale = max(np.sqrt(height * width) // 900, .5)
+    linewidth = max(default_font_scale / 40, 2)
 
     labels = list()
     for dt in results:
@@ -233,100 +199,72 @@ def draw_bbox_mask(image, results, threshold=0.5, color_map=None):
 
     for dt in keep_results:
         cname, bbox, score = dt['category'], dt['bbox'], dt['score']
+        bbox = list(map(int, bbox))
         xmin, ymin, w, h = bbox
         xmax = xmin + w
         ymax = ymin + h
 
         color = tuple(color_map[labels.index(cname)])
-        color = [c / 255. for c in color]
         # draw bbox
-        ax.add_patch(
-            mpl.patches.Rectangle(
-                (xmin, ymin),
-                w,
-                h,
-                fill=False,
-                edgecolor=color,
-                linewidth=linewidth * scale,
-                alpha=0.8,
-                linestyle="-", ))
+        image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color,
+                              linewidth)
 
         # draw mask
         if 'mask' in dt:
-            mask = mask_util.decode(dt['mask'])
-            mask = np.ascontiguousarray(mask)
-            res = cv2.findContours(
-                mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
-            hierarchy = res[-1]
-            alpha = 0.5
-            if hierarchy is not None:
-                has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
-                res = res[-2]
-                res = [x.flatten() for x in res]
-                res = [x for x in res if len(x) >= 6]
-                for segment in res:
-                    segment = segment.reshape(-1, 2)
-                    edge_color = mplc.to_rgb(color) + (1, )
-                    polygon = mpl.patches.Polygon(
-                        segment,
-                        fill=True,
-                        facecolor=mplc.to_rgb(color) + (alpha, ),
-                        edgecolor=edge_color,
-                        linewidth=max(default_font_size // 15 * scale, 1), )
-                    ax.add_patch(polygon)
+            mask = mask_util.decode(dt['mask']) * 255
+            image = image.astype('float32')
+            alpha = .7
+            w_ratio = .4
+            color_mask = np.asarray(color, dtype=np.int)
+            for c in range(3):
+                color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
+            idx = np.nonzero(mask)
+            image[idx[0], idx[1], :] *= 1.0 - alpha
+            image[idx[0], idx[1], :] += alpha * color_mask
+            image = image.astype("uint8")
+            contours = cv2.findContours(
+                mask.astype("uint8"), cv2.RETR_CCOMP,
+                cv2.CHAIN_APPROX_NONE)[-2]
+            image = cv2.drawContours(
+                image,
+                contours,
+                contourIdx=-1,
+                color=color,
+                thickness=1,
+                lineType=cv2.LINE_AA)
 
         # draw label
         text_pos = (xmin, ymin)
-        horiz_align = "left"
         instance_area = w * h
-        if (instance_area < _SMALL_OBJECT_AREA_THRESH * scale or
-                h < 40 * scale):
+        if (instance_area < _SMALL_OBJECT_AREA_THRESH or h < 40):
             if ymin >= height - 5:
                 text_pos = (xmin, ymin)
             else:
                 text_pos = (xmin, ymax)
         height_ratio = h / np.sqrt(height * width)
-        font_size = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2,
-                             2) * 0.5 * default_font_size)
+        font_scale = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2,
+                              2) * 0.5 * default_font_scale)
         text = "{} {:.2f}".format(cname, score)
-        color = np.maximum(list(mplc.to_rgb(color)), 0.2)
-        color[np.argmax(color)] = max(0.8, np.max(color))
-        color = _change_color_brightness(color, brightness_factor=0.7)
-        ax.text(
-            text_pos[0],
-            text_pos[1],
+        (tw, th), baseline = cv2.getTextSize(
             text,
-            size=font_size * scale,
-            family="sans-serif",
-            bbox={
-                "facecolor": "black",
-                "alpha": 0.8,
-                "pad": 0.7,
-                "edgecolor": "none"
-            },
-            verticalalignment="top",
-            horizontalalignment=horiz_align,
+            fontFace=cv2.FONT_HERSHEY_DUPLEX,
+            fontScale=font_scale,
+            thickness=1)
+        image = cv2.rectangle(
+            image,
+            text_pos, (text_pos[0] + tw, text_pos[1] + th + baseline),
             color=color,
-            zorder=10,
-            rotation=0, )
-
-    s, (width, height) = canvas.print_to_buffer()
-    buffer = np.frombuffer(s, dtype="uint8")
-
-    img_rgba = buffer.reshape(height, width, 4)
-    rgb, alpha = np.split(img_rgba, [3], axis=2)
-
-    try:
-        import numexpr as ne
-        visualized_image = ne.evaluate(
-            "image * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
-    except ImportError:
-        alpha = alpha.astype("float32") / 255.0
-        visualized_image = image * (1 - alpha) + rgb * alpha
-
-    visualized_image = visualized_image.astype("uint8")
-
-    return visualized_image
+            thickness=-1)
+        image = cv2.putText(
+            image,
+            text, (text_pos[0], text_pos[1] + th),
+            fontFace=cv2.FONT_HERSHEY_DUPLEX,
+            fontScale=font_scale,
+            color=(255, 255, 255),
+            thickness=1,
+            lineType=cv2.LINE_AA)
+
+    return image
 
 
 def draw_pr_curve(eval_details_file=None,