visualize.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. import os
  15. import cv2
  16. import numpy as np
  17. import matplotlib as mpl
  18. import matplotlib.figure as mplfigure
  19. import matplotlib.colors as mplc
  20. from matplotlib.backends.backend_agg import FigureCanvasAgg
  21. def visualize_detection(image, result, threshold=0.5, save_dir=None):
  22. """
  23. Visualize bbox and mask results
  24. """
  25. image_name = os.path.split(image)[-1]
  26. image = cv2.imread(image)
  27. image = draw_bbox_mask(image, result, threshold=threshold)
  28. if save_dir is not None:
  29. if not os.path.exists(save_dir):
  30. os.makedirs(save_dir)
  31. out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
  32. cv2.imwrite(out_path, image)
  33. else:
  34. return image
  35. def visualize_segmentation(image, result, weight=0.6, save_dir=None):
  36. """
  37. Convert segment result to color image, and save added image.
  38. Args:
  39. image: the path of origin image
  40. result: the predict result of image
  41. weight: the image weight of visual image, and the result weight is (1 - weight)
  42. save_dir: the directory for saving visual image
  43. """
  44. label_map = result['label_map']
  45. color_map = get_color_map_list(256)
  46. color_map = np.array(color_map).astype("uint8")
  47. # Use OpenCV LUT for color mapping
  48. c1 = cv2.LUT(label_map, color_map[:, 0])
  49. c2 = cv2.LUT(label_map, color_map[:, 1])
  50. c3 = cv2.LUT(label_map, color_map[:, 2])
  51. pseudo_img = np.dstack((c1, c2, c3))
  52. im = cv2.imread(image)
  53. vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
  54. if save_dir is not None:
  55. if not os.path.exists(save_dir):
  56. os.makedirs(save_dir)
  57. image_name = os.path.split(image)[-1]
  58. out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
  59. cv2.imwrite(out_path, vis_result)
  60. else:
  61. return vis_result
  62. def get_color_map_list(num_classes):
  63. """ Returns the color map for visualizing the segmentation mask,
  64. which can support arbitrary number of classes.
  65. Args:
  66. num_classes: Number of classes
  67. Returns:
  68. The color map
  69. """
  70. color_map = num_classes * [0, 0, 0]
  71. for i in range(0, num_classes):
  72. j = 0
  73. lab = i
  74. while lab:
  75. color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
  76. color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
  77. color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
  78. j += 1
  79. lab >>= 3
  80. color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
  81. return color_map
  82. # expand an array of boxes by a given scale.
  83. def expand_boxes(boxes, scale):
  84. """
  85. """
  86. w_half = (boxes[:, 2] - boxes[:, 0]) * .5
  87. h_half = (boxes[:, 3] - boxes[:, 1]) * .5
  88. x_c = (boxes[:, 2] + boxes[:, 0]) * .5
  89. y_c = (boxes[:, 3] + boxes[:, 1]) * .5
  90. w_half *= scale
  91. h_half *= scale
  92. boxes_exp = np.zeros(boxes.shape)
  93. boxes_exp[:, 0] = x_c - w_half
  94. boxes_exp[:, 2] = x_c + w_half
  95. boxes_exp[:, 1] = y_c - h_half
  96. boxes_exp[:, 3] = y_c + h_half
  97. return boxes_exp
  98. def clip_bbox(bbox):
  99. xmin = max(min(bbox[0], 1.), 0.)
  100. ymin = max(min(bbox[1], 1.), 0.)
  101. xmax = max(min(bbox[2], 1.), 0.)
  102. ymax = max(min(bbox[3], 1.), 0.)
  103. return xmin, ymin, xmax, ymax
  104. def draw_bbox_mask(image, results, threshold=0.5):
  105. # refer to https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/visualizer.py
  106. _SMALL_OBJECT_AREA_THRESH = 1000
  107. # setup figure
  108. width, height = image.shape[1], image.shape[0]
  109. scale = 1
  110. fig = mplfigure.Figure(frameon=False)
  111. dpi = fig.get_dpi()
  112. fig.set_size_inches(
  113. (width * scale + 1e-2) / dpi,
  114. (height * scale + 1e-2) / dpi,
  115. )
  116. canvas = FigureCanvasAgg(fig)
  117. ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
  118. ax.axis("off")
  119. ax.set_xlim(0.0, width)
  120. ax.set_ylim(height)
  121. default_font_size = max(np.sqrt(height * width) // 90, 10 // scale)
  122. linewidth = max(default_font_size / 4, 1)
  123. labels = list()
  124. for dt in np.array(results):
  125. if dt['category'] not in labels:
  126. labels.append(dt['category'])
  127. color_map = get_color_map_list(256)
  128. keep_results = []
  129. areas = []
  130. for dt in np.array(results):
  131. cname, bbox, score = dt['category'], dt['bbox'], dt['score']
  132. if score < threshold:
  133. continue
  134. keep_results.append(dt)
  135. areas.append(bbox[2] * bbox[3])
  136. areas = np.asarray(areas)
  137. sorted_idxs = np.argsort(-areas).tolist()
  138. keep_results = [keep_results[k]
  139. for k in sorted_idxs] if len(keep_results) > 0 else []
  140. for dt in np.array(keep_results):
  141. cname, bbox, score = dt['category'], dt['bbox'], dt['score']
  142. xmin, ymin, w, h = bbox
  143. xmax = xmin + w
  144. ymax = ymin + h
  145. color = tuple(color_map[labels.index(cname) + 2])
  146. color = [c / 255. for c in color]
  147. # draw bbox
  148. ax.add_patch(
  149. mpl.patches.Rectangle(
  150. (xmin, ymin),
  151. w,
  152. h,
  153. fill=False,
  154. edgecolor=color,
  155. linewidth=linewidth * scale,
  156. alpha=0.5,
  157. linestyle="-",
  158. ))
  159. # draw mask
  160. if 'mask' in dt:
  161. mask = dt['mask']
  162. mask = np.ascontiguousarray(mask)
  163. res = cv2.findContours(
  164. mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
  165. hierarchy = res[-1]
  166. alpha = 0.75
  167. if hierarchy is not None:
  168. has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
  169. res = res[-2]
  170. res = [x.flatten() for x in res]
  171. res = [x for x in res if len(x) >= 6]
  172. for segment in res:
  173. segment = segment.reshape(-1, 2)
  174. edge_color = mplc.to_rgb(color) + (1, )
  175. polygon = mpl.patches.Polygon(
  176. segment,
  177. fill=True,
  178. facecolor=mplc.to_rgb(color) + (alpha, ),
  179. edgecolor=edge_color,
  180. linewidth=max(default_font_size // 15 * scale, 1),
  181. )
  182. ax.add_patch(polygon)
  183. # draw label
  184. text_pos = (xmin, ymin)
  185. horiz_align = "left"
  186. instance_area = w * h
  187. if (instance_area < _SMALL_OBJECT_AREA_THRESH * scale
  188. or h < 40 * scale):
  189. if ymin >= height - 5:
  190. text_pos = (xmin, ymin)
  191. else:
  192. text_pos = (xmin, ymax)
  193. height_ratio = h / np.sqrt(height * width)
  194. font_size = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 *
  195. default_font_size)
  196. text = "{} {:.2f}".format(cname, score)
  197. color = np.maximum(list(mplc.to_rgb(color)), 0.2)
  198. color[np.argmax(color)] = max(0.8, np.max(color))
  199. ax.text(
  200. text_pos[0],
  201. text_pos[1],
  202. text,
  203. size=font_size * scale,
  204. family="sans-serif",
  205. bbox={
  206. "facecolor": "black",
  207. "alpha": 0.8,
  208. "pad": 0.7,
  209. "edgecolor": "none"
  210. },
  211. verticalalignment="top",
  212. horizontalalignment=horiz_align,
  213. color=color,
  214. zorder=10,
  215. rotation=0,
  216. )
  217. s, (width, height) = canvas.print_to_buffer()
  218. buffer = np.frombuffer(s, dtype="uint8")
  219. img_rgba = buffer.reshape(height, width, 4)
  220. rgb, alpha = np.split(img_rgba, [3], axis=2)
  221. try:
  222. import numexpr as ne
  223. visualized_image = ne.evaluate(
  224. "image * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
  225. except ImportError:
  226. alpha = alpha.astype("float32") / 255.0
  227. visualized_image = image * (1 - alpha) + rgb * alpha
  228. visualized_image = visualized_image.astype("uint8")
  229. return visualized_image