|
|
@@ -15,7 +15,10 @@
|
|
|
import os
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
-from PIL import Image, ImageDraw
|
|
|
+import matplotlib as mpl
|
|
|
+import matplotlib.figure as mplfigure
|
|
|
+import matplotlib.colors as mplc
|
|
|
+from matplotlib.backends.backend_agg import FigureCanvasAgg
|
|
|
|
|
|
|
|
|
def visualize_detection(image, result, threshold=0.5, save_dir=None):
|
|
|
@@ -24,13 +27,13 @@ def visualize_detection(image, result, threshold=0.5, save_dir=None):
|
|
|
"""
|
|
|
|
|
|
image_name = os.path.split(image)[-1]
|
|
|
- image = Image.open(image).convert('RGB')
|
|
|
+ image = cv2.imread(image)
|
|
|
image = draw_bbox_mask(image, result, threshold=threshold)
|
|
|
if save_dir is not None:
|
|
|
if not os.path.exists(save_dir):
|
|
|
os.makedirs(save_dir)
|
|
|
out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
|
|
|
- image.save(out_path, quality=95)
|
|
|
+ cv2.imwrite(out_path, image)
|
|
|
else:
|
|
|
return image
|
|
|
|
|
|
@@ -117,46 +120,141 @@ def clip_bbox(bbox):
|
|
|
return xmin, ymin, xmax, ymax
|
|
|
|
|
|
|
|
|
-def draw_bbox_mask(image, results, threshold=0.5, alpha=0.7):
|
|
|
+def draw_bbox_mask(image, results, threshold=0.5):
|
|
|
+ # refer to https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/visualizer.py
|
|
|
+ _SMALL_OBJECT_AREA_THRESH = 1000
|
|
|
+ # setup figure
|
|
|
+ width, height = image.shape[1], image.shape[0]
|
|
|
+ scale = 1
|
|
|
+ fig = mplfigure.Figure(frameon=False)
|
|
|
+ dpi = fig.get_dpi()
|
|
|
+ fig.set_size_inches(
|
|
|
+ (width * scale + 1e-2) / dpi,
|
|
|
+ (height * scale + 1e-2) / dpi,
|
|
|
+ )
|
|
|
+ canvas = FigureCanvasAgg(fig)
|
|
|
+ ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
|
|
|
+ ax.axis("off")
|
|
|
+ ax.set_xlim(0.0, width)
|
|
|
+ ax.set_ylim(height)
|
|
|
+ default_font_size = max(np.sqrt(height * width) // 90, 10 // scale)
|
|
|
+ linewidth = max(default_font_size / 4, 1)
|
|
|
+
|
|
|
labels = list()
|
|
|
for dt in np.array(results):
|
|
|
if dt['category'] not in labels:
|
|
|
labels.append(dt['category'])
|
|
|
- color_map = get_color_map_list(len(labels))
|
|
|
+ color_map = get_color_map_list(256)
|
|
|
|
|
|
+ keep_results = []
|
|
|
+ areas = []
|
|
|
for dt in np.array(results):
|
|
|
cname, bbox, score = dt['category'], dt['bbox'], dt['score']
|
|
|
if score < threshold:
|
|
|
continue
|
|
|
+ keep_results.append(dt)
|
|
|
+ areas.append(bbox[2] * bbox[3])
|
|
|
+ areas = np.asarray(areas)
|
|
|
+ sorted_idxs = np.argsort(-areas).tolist()
|
|
|
+ keep_results = [keep_results[k]
|
|
|
+ for k in sorted_idxs] if len(keep_results) > 0 else []
|
|
|
|
|
|
+ for dt in np.array(keep_results):
|
|
|
+ cname, bbox, score = dt['category'], dt['bbox'], dt['score']
|
|
|
xmin, ymin, w, h = bbox
|
|
|
xmax = xmin + w
|
|
|
ymax = ymin + h
|
|
|
|
|
|
- color = tuple(color_map[labels.index(cname)])
|
|
|
-
|
|
|
+ color = tuple(color_map[labels.index(cname) + 2])
|
|
|
+ color = [c / 255. for c in color]
|
|
|
# draw bbox
|
|
|
- draw = ImageDraw.Draw(image)
|
|
|
- draw.line([(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
|
|
|
- (xmin, ymin)],
|
|
|
- width=2,
|
|
|
- fill=color)
|
|
|
-
|
|
|
- # draw label
|
|
|
- text = "{} {:.2f}".format(cname, score)
|
|
|
- tw, th = draw.textsize(text)
|
|
|
- draw.rectangle([(xmin + 1, ymin - th), (xmin + tw + 1, ymin)],
|
|
|
- fill=color)
|
|
|
- draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
|
|
|
+ ax.add_patch(
|
|
|
+ mpl.patches.Rectangle(
|
|
|
+ (xmin, ymin),
|
|
|
+ w,
|
|
|
+ h,
|
|
|
+ fill=False,
|
|
|
+ edgecolor=color,
|
|
|
+ linewidth=linewidth * scale,
|
|
|
+ alpha=0.5,
|
|
|
+ linestyle="-",
|
|
|
+ ))
|
|
|
|
|
|
# draw mask
|
|
|
if 'mask' in dt:
|
|
|
mask = dt['mask']
|
|
|
- color_mask = np.array(color_map[labels.index(
|
|
|
- dt['category'])]).astype('float32')
|
|
|
- img_array = np.array(image).astype('float32')
|
|
|
- idx = np.nonzero(mask)
|
|
|
- img_array[idx[0], idx[1], :] *= 1.0 - alpha
|
|
|
- img_array[idx[0], idx[1], :] += alpha * color_mask
|
|
|
- image = Image.fromarray(img_array.astype('uint8'))
|
|
|
- return image
|
|
|
+ mask = np.ascontiguousarray(mask)
|
|
|
+ res = cv2.findContours(
|
|
|
+ mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
|
|
|
+ hierarchy = res[-1]
|
|
|
+ alpha = 0.75
|
|
|
+ if hierarchy is not None:
|
|
|
+ has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
|
|
|
+ res = res[-2]
|
|
|
+ res = [x.flatten() for x in res]
|
|
|
+ res = [x for x in res if len(x) >= 6]
|
|
|
+ for segment in res:
|
|
|
+ segment = segment.reshape(-1, 2)
|
|
|
+ edge_color = mplc.to_rgb(color) + (1, )
|
|
|
+ polygon = mpl.patches.Polygon(
|
|
|
+ segment,
|
|
|
+ fill=True,
|
|
|
+ facecolor=mplc.to_rgb(color) + (alpha, ),
|
|
|
+ edgecolor=edge_color,
|
|
|
+ linewidth=max(default_font_size // 15 * scale, 1),
|
|
|
+ )
|
|
|
+ ax.add_patch(polygon)
|
|
|
+
|
|
|
+ # draw label
|
|
|
+ text_pos = (xmin, ymin)
|
|
|
+ horiz_align = "left"
|
|
|
+ instance_area = w * h
|
|
|
+ if (instance_area < _SMALL_OBJECT_AREA_THRESH * scale
|
|
|
+ or h < 40 * scale):
|
|
|
+ if ymin >= height - 5:
|
|
|
+ text_pos = (xmin, ymin)
|
|
|
+ else:
|
|
|
+ text_pos = (xmin, ymax)
|
|
|
+ height_ratio = h / np.sqrt(height * width)
|
|
|
+ font_size = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 *
|
|
|
+ default_font_size)
|
|
|
+ text = "{} {:.2f}".format(cname, score)
|
|
|
+ color = np.maximum(list(mplc.to_rgb(color)), 0.2)
|
|
|
+ color[np.argmax(color)] = max(0.8, np.max(color))
|
|
|
+
|
|
|
+ ax.text(
|
|
|
+ text_pos[0],
|
|
|
+ text_pos[1],
|
|
|
+ text,
|
|
|
+ size=font_size * scale,
|
|
|
+ family="sans-serif",
|
|
|
+ bbox={
|
|
|
+ "facecolor": "black",
|
|
|
+ "alpha": 0.8,
|
|
|
+ "pad": 0.7,
|
|
|
+ "edgecolor": "none"
|
|
|
+ },
|
|
|
+ verticalalignment="top",
|
|
|
+ horizontalalignment=horiz_align,
|
|
|
+ color=color,
|
|
|
+ zorder=10,
|
|
|
+ rotation=0,
|
|
|
+ )
|
|
|
+
|
|
|
+ s, (width, height) = canvas.print_to_buffer()
|
|
|
+ buffer = np.frombuffer(s, dtype="uint8")
|
|
|
+
|
|
|
+ img_rgba = buffer.reshape(height, width, 4)
|
|
|
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
|
|
|
+
|
|
|
+ try:
|
|
|
+ import numexpr as ne
|
|
|
+ visualized_image = ne.evaluate(
|
|
|
+ "image * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
|
|
|
+ except ImportError:
|
|
|
+ alpha = alpha.astype("float32") / 255.0
|
|
|
+ visualized_image = image * (1 - alpha) + rgb * alpha
|
|
|
+
|
|
|
+ visualized_image = visualized_image.astype("uint8")
|
|
|
+
|
|
|
+ return visualized_image
|