|
@@ -14,7 +14,6 @@
|
|
|
|
|
|
|
|
import os
|
|
import os
|
|
|
import cv2
|
|
import cv2
|
|
|
-import colorsys
|
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import time
|
|
import time
|
|
|
import pycocotools.mask as mask_util
|
|
import pycocotools.mask as mask_util
|
|
@@ -37,7 +36,6 @@ def visualize_detection(image,
|
|
|
else:
|
|
else:
|
|
|
image_name = os.path.split(image)[-1]
|
|
image_name = os.path.split(image)[-1]
|
|
|
image = cv2.imread(image)
|
|
image = cv2.imread(image)
|
|
|
-
|
|
|
|
|
image = draw_bbox_mask(image, result, threshold=threshold, color_map=color)
|
|
image = draw_bbox_mask(image, result, threshold=threshold, color_map=color)
|
|
|
if save_dir is not None:
|
|
if save_dir is not None:
|
|
|
if not os.path.exists(save_dir):
|
|
if not os.path.exists(save_dir):
|
|
@@ -164,42 +162,10 @@ def clip_bbox(bbox):
|
|
|
|
|
|
|
|
|
|
|
|
|
def draw_bbox_mask(image, results, threshold=0.5, color_map=None):
|
|
def draw_bbox_mask(image, results, threshold=0.5, color_map=None):
|
|
|
- import matplotlib
|
|
|
|
|
- matplotlib.use('Agg')
|
|
|
|
|
- import matplotlib as mpl
|
|
|
|
|
- import matplotlib.figure as mplfigure
|
|
|
|
|
- import matplotlib.colors as mplc
|
|
|
|
|
- from matplotlib.backends.backend_agg import FigureCanvasAgg
|
|
|
|
|
-
|
|
|
|
|
- # refer to https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/visualizer.py
|
|
|
|
|
- def _change_color_brightness(color, brightness_factor):
|
|
|
|
|
- assert brightness_factor >= -1.0 and brightness_factor <= 1.0
|
|
|
|
|
- color = mplc.to_rgb(color)
|
|
|
|
|
- polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
|
|
|
|
|
- modified_lightness = polygon_color[1] + (brightness_factor *
|
|
|
|
|
- polygon_color[1])
|
|
|
|
|
- modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
|
|
|
|
|
- modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
|
|
|
|
|
- modified_color = colorsys.hls_to_rgb(
|
|
|
|
|
- polygon_color[0], modified_lightness, polygon_color[2])
|
|
|
|
|
- return modified_color
|
|
|
|
|
-
|
|
|
|
|
_SMALL_OBJECT_AREA_THRESH = 1000
|
|
_SMALL_OBJECT_AREA_THRESH = 1000
|
|
|
- # setup figure
|
|
|
|
|
- width, height = image.shape[1], image.shape[0]
|
|
|
|
|
- scale = 1
|
|
|
|
|
- fig = mplfigure.Figure(frameon=False)
|
|
|
|
|
- dpi = fig.get_dpi()
|
|
|
|
|
- fig.set_size_inches(
|
|
|
|
|
- (width * scale + 1e-2) / dpi,
|
|
|
|
|
- (height * scale + 1e-2) / dpi, )
|
|
|
|
|
- canvas = FigureCanvasAgg(fig)
|
|
|
|
|
- ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
|
|
|
|
|
- ax.axis("off")
|
|
|
|
|
- ax.set_xlim(0.0, width)
|
|
|
|
|
- ax.set_ylim(height)
|
|
|
|
|
- default_font_size = max(np.sqrt(height * width) // 90, 10 // scale)
|
|
|
|
|
- linewidth = max(default_font_size / 4, 1)
|
|
|
|
|
|
|
+ height, width = image.shape[:2]
|
|
|
|
|
+ default_font_scale = max(np.sqrt(height * width) // 900, .5)
|
|
|
|
|
+ linewidth = max(default_font_scale / 40, 2)
|
|
|
|
|
|
|
|
labels = list()
|
|
labels = list()
|
|
|
for dt in results:
|
|
for dt in results:
|
|
@@ -233,100 +199,72 @@ def draw_bbox_mask(image, results, threshold=0.5, color_map=None):
|
|
|
|
|
|
|
|
for dt in keep_results:
|
|
for dt in keep_results:
|
|
|
cname, bbox, score = dt['category'], dt['bbox'], dt['score']
|
|
cname, bbox, score = dt['category'], dt['bbox'], dt['score']
|
|
|
|
|
+ bbox = list(map(int, bbox))
|
|
|
xmin, ymin, w, h = bbox
|
|
xmin, ymin, w, h = bbox
|
|
|
xmax = xmin + w
|
|
xmax = xmin + w
|
|
|
ymax = ymin + h
|
|
ymax = ymin + h
|
|
|
|
|
|
|
|
color = tuple(color_map[labels.index(cname)])
|
|
color = tuple(color_map[labels.index(cname)])
|
|
|
- color = [c / 255. for c in color]
|
|
|
|
|
# draw bbox
|
|
# draw bbox
|
|
|
- ax.add_patch(
|
|
|
|
|
- mpl.patches.Rectangle(
|
|
|
|
|
- (xmin, ymin),
|
|
|
|
|
- w,
|
|
|
|
|
- h,
|
|
|
|
|
- fill=False,
|
|
|
|
|
- edgecolor=color,
|
|
|
|
|
- linewidth=linewidth * scale,
|
|
|
|
|
- alpha=0.8,
|
|
|
|
|
- linestyle="-", ))
|
|
|
|
|
|
|
+ image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color,
|
|
|
|
|
+ linewidth)
|
|
|
|
|
|
|
|
# draw mask
|
|
# draw mask
|
|
|
if 'mask' in dt:
|
|
if 'mask' in dt:
|
|
|
- mask = mask_util.decode(dt['mask'])
|
|
|
|
|
- mask = np.ascontiguousarray(mask)
|
|
|
|
|
- res = cv2.findContours(
|
|
|
|
|
- mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
|
|
|
|
|
- hierarchy = res[-1]
|
|
|
|
|
- alpha = 0.5
|
|
|
|
|
- if hierarchy is not None:
|
|
|
|
|
- has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
|
|
|
|
|
- res = res[-2]
|
|
|
|
|
- res = [x.flatten() for x in res]
|
|
|
|
|
- res = [x for x in res if len(x) >= 6]
|
|
|
|
|
- for segment in res:
|
|
|
|
|
- segment = segment.reshape(-1, 2)
|
|
|
|
|
- edge_color = mplc.to_rgb(color) + (1, )
|
|
|
|
|
- polygon = mpl.patches.Polygon(
|
|
|
|
|
- segment,
|
|
|
|
|
- fill=True,
|
|
|
|
|
- facecolor=mplc.to_rgb(color) + (alpha, ),
|
|
|
|
|
- edgecolor=edge_color,
|
|
|
|
|
- linewidth=max(default_font_size // 15 * scale, 1), )
|
|
|
|
|
- ax.add_patch(polygon)
|
|
|
|
|
|
|
+ mask = mask_util.decode(dt['mask']) * 255
|
|
|
|
|
+ image = image.astype('float32')
|
|
|
|
|
+ alpha = .7
|
|
|
|
|
+ w_ratio = .4
|
|
|
|
|
+ color_mask = np.asarray(color, dtype=np.int)
|
|
|
|
|
+ for c in range(3):
|
|
|
|
|
+ color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
|
|
|
|
|
+ idx = np.nonzero(mask)
|
|
|
|
|
+ image[idx[0], idx[1], :] *= 1.0 - alpha
|
|
|
|
|
+ image[idx[0], idx[1], :] += alpha * color_mask
|
|
|
|
|
+ image = image.astype("uint8")
|
|
|
|
|
+ contours = cv2.findContours(
|
|
|
|
|
+ mask.astype("uint8"), cv2.RETR_CCOMP,
|
|
|
|
|
+ cv2.CHAIN_APPROX_NONE)[-2]
|
|
|
|
|
+ image = cv2.drawContours(
|
|
|
|
|
+ image,
|
|
|
|
|
+ contours,
|
|
|
|
|
+ contourIdx=-1,
|
|
|
|
|
+ color=color,
|
|
|
|
|
+ thickness=1,
|
|
|
|
|
+ lineType=cv2.LINE_AA)
|
|
|
|
|
|
|
|
# draw label
|
|
# draw label
|
|
|
text_pos = (xmin, ymin)
|
|
text_pos = (xmin, ymin)
|
|
|
- horiz_align = "left"
|
|
|
|
|
instance_area = w * h
|
|
instance_area = w * h
|
|
|
- if (instance_area < _SMALL_OBJECT_AREA_THRESH * scale or
|
|
|
|
|
- h < 40 * scale):
|
|
|
|
|
|
|
+ if (instance_area < _SMALL_OBJECT_AREA_THRESH or h < 40):
|
|
|
if ymin >= height - 5:
|
|
if ymin >= height - 5:
|
|
|
text_pos = (xmin, ymin)
|
|
text_pos = (xmin, ymin)
|
|
|
else:
|
|
else:
|
|
|
text_pos = (xmin, ymax)
|
|
text_pos = (xmin, ymax)
|
|
|
height_ratio = h / np.sqrt(height * width)
|
|
height_ratio = h / np.sqrt(height * width)
|
|
|
- font_size = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2,
|
|
|
|
|
- 2) * 0.5 * default_font_size)
|
|
|
|
|
|
|
+ font_scale = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2,
|
|
|
|
|
+ 2) * 0.5 * default_font_scale)
|
|
|
text = "{} {:.2f}".format(cname, score)
|
|
text = "{} {:.2f}".format(cname, score)
|
|
|
- color = np.maximum(list(mplc.to_rgb(color)), 0.2)
|
|
|
|
|
- color[np.argmax(color)] = max(0.8, np.max(color))
|
|
|
|
|
- color = _change_color_brightness(color, brightness_factor=0.7)
|
|
|
|
|
- ax.text(
|
|
|
|
|
- text_pos[0],
|
|
|
|
|
- text_pos[1],
|
|
|
|
|
|
|
+ (tw, th), baseline = cv2.getTextSize(
|
|
|
text,
|
|
text,
|
|
|
- size=font_size * scale,
|
|
|
|
|
- family="sans-serif",
|
|
|
|
|
- bbox={
|
|
|
|
|
- "facecolor": "black",
|
|
|
|
|
- "alpha": 0.8,
|
|
|
|
|
- "pad": 0.7,
|
|
|
|
|
- "edgecolor": "none"
|
|
|
|
|
- },
|
|
|
|
|
- verticalalignment="top",
|
|
|
|
|
- horizontalalignment=horiz_align,
|
|
|
|
|
|
|
+ fontFace=cv2.FONT_HERSHEY_DUPLEX,
|
|
|
|
|
+ fontScale=font_scale,
|
|
|
|
|
+ thickness=1)
|
|
|
|
|
+ image = cv2.rectangle(
|
|
|
|
|
+ image,
|
|
|
|
|
+ text_pos, (text_pos[0] + tw, text_pos[1] + th + baseline),
|
|
|
color=color,
|
|
color=color,
|
|
|
- zorder=10,
|
|
|
|
|
- rotation=0, )
|
|
|
|
|
-
|
|
|
|
|
- s, (width, height) = canvas.print_to_buffer()
|
|
|
|
|
- buffer = np.frombuffer(s, dtype="uint8")
|
|
|
|
|
-
|
|
|
|
|
- img_rgba = buffer.reshape(height, width, 4)
|
|
|
|
|
- rgb, alpha = np.split(img_rgba, [3], axis=2)
|
|
|
|
|
-
|
|
|
|
|
- try:
|
|
|
|
|
- import numexpr as ne
|
|
|
|
|
- visualized_image = ne.evaluate(
|
|
|
|
|
- "image * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
|
|
|
|
|
- except ImportError:
|
|
|
|
|
- alpha = alpha.astype("float32") / 255.0
|
|
|
|
|
- visualized_image = image * (1 - alpha) + rgb * alpha
|
|
|
|
|
-
|
|
|
|
|
- visualized_image = visualized_image.astype("uint8")
|
|
|
|
|
-
|
|
|
|
|
- return visualized_image
|
|
|
|
|
|
|
+ thickness=-1)
|
|
|
|
|
+ image = cv2.putText(
|
|
|
|
|
+ image,
|
|
|
|
|
+ text, (text_pos[0], text_pos[1] + th),
|
|
|
|
|
+ fontFace=cv2.FONT_HERSHEY_DUPLEX,
|
|
|
|
|
+ fontScale=font_scale,
|
|
|
|
|
+ color=(255, 255, 255),
|
|
|
|
|
+ thickness=1,
|
|
|
|
|
+ lineType=cv2.LINE_AA)
|
|
|
|
|
+
|
|
|
|
|
+ return image
|
|
|
|
|
|
|
|
|
|
|
|
|
def draw_pr_curve(eval_details_file=None,
|
|
def draw_pr_curve(eval_details_file=None,
|