visualize.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. import os
  15. import cv2
  16. import colorsys
  17. import numpy as np
  18. import matplotlib
  19. matplotlib.use('Agg')
  20. import matplotlib as mpl
  21. import matplotlib.pyplot as plt
  22. import matplotlib.figure as mplfigure
  23. import matplotlib.colors as mplc
  24. from matplotlib.backends.backend_agg import FigureCanvasAgg
  25. import paddlex.utils.logging as logging
  26. from .detection_eval import fixed_linspace, backup_linspace, loadRes
  27. def visualize_detection(image, result, threshold=0.5, save_dir='./'):
  28. """
  29. Visualize bbox and mask results
  30. """
  31. image_name = os.path.split(image)[-1]
  32. image = cv2.imread(image)
  33. image = draw_bbox_mask(image, result, threshold=threshold)
  34. if save_dir is not None:
  35. if not os.path.exists(save_dir):
  36. os.makedirs(save_dir)
  37. out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
  38. cv2.imwrite(out_path, image)
  39. logging.info('The visualized result is saved as {}'.format(out_path))
  40. else:
  41. return image
  42. def visualize_segmentation(image, result, weight=0.6, save_dir='./'):
  43. """
  44. Convert segment result to color image, and save added image.
  45. Args:
  46. image: the path of origin image
  47. result: the predict result of image
  48. weight: the image weight of visual image, and the result weight is (1 - weight)
  49. save_dir: the directory for saving visual image
  50. """
  51. label_map = result['label_map']
  52. color_map = get_color_map_list(256)
  53. color_map = np.array(color_map).astype("uint8")
  54. # Use OpenCV LUT for color mapping
  55. c1 = cv2.LUT(label_map, color_map[:, 0])
  56. c2 = cv2.LUT(label_map, color_map[:, 1])
  57. c3 = cv2.LUT(label_map, color_map[:, 2])
  58. pseudo_img = np.dstack((c1, c2, c3))
  59. im = cv2.imread(image)
  60. vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
  61. if save_dir is not None:
  62. if not os.path.exists(save_dir):
  63. os.makedirs(save_dir)
  64. image_name = os.path.split(image)[-1]
  65. out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
  66. cv2.imwrite(out_path, vis_result)
  67. logging.info('The visualized result is saved as {}'.format(out_path))
  68. else:
  69. return vis_result
  70. def get_color_map_list(num_classes):
  71. """ Returns the color map for visualizing the segmentation mask,
  72. which can support arbitrary number of classes.
  73. Args:
  74. num_classes: Number of classes
  75. Returns:
  76. The color map
  77. """
  78. color_map = num_classes * [0, 0, 0]
  79. for i in range(0, num_classes):
  80. j = 0
  81. lab = i
  82. while lab:
  83. color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
  84. color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
  85. color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
  86. j += 1
  87. lab >>= 3
  88. color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
  89. return color_map
  90. # expand an array of boxes by a given scale.
  91. def expand_boxes(boxes, scale):
  92. """
  93. """
  94. w_half = (boxes[:, 2] - boxes[:, 0]) * .5
  95. h_half = (boxes[:, 3] - boxes[:, 1]) * .5
  96. x_c = (boxes[:, 2] + boxes[:, 0]) * .5
  97. y_c = (boxes[:, 3] + boxes[:, 1]) * .5
  98. w_half *= scale
  99. h_half *= scale
  100. boxes_exp = np.zeros(boxes.shape)
  101. boxes_exp[:, 0] = x_c - w_half
  102. boxes_exp[:, 2] = x_c + w_half
  103. boxes_exp[:, 1] = y_c - h_half
  104. boxes_exp[:, 3] = y_c + h_half
  105. return boxes_exp
  106. def clip_bbox(bbox):
  107. xmin = max(min(bbox[0], 1.), 0.)
  108. ymin = max(min(bbox[1], 1.), 0.)
  109. xmax = max(min(bbox[2], 1.), 0.)
  110. ymax = max(min(bbox[3], 1.), 0.)
  111. return xmin, ymin, xmax, ymax
  112. def draw_bbox_mask(image, results, threshold=0.5):
  113. # refer to https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/visualizer.py
  114. def _change_color_brightness(color, brightness_factor):
  115. assert brightness_factor >= -1.0 and brightness_factor <= 1.0
  116. color = mplc.to_rgb(color)
  117. polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
  118. modified_lightness = polygon_color[1] + (
  119. brightness_factor * polygon_color[1])
  120. modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
  121. modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
  122. modified_color = colorsys.hls_to_rgb(
  123. polygon_color[0], modified_lightness, polygon_color[2])
  124. return modified_color
  125. _SMALL_OBJECT_AREA_THRESH = 1000
  126. # setup figure
  127. width, height = image.shape[1], image.shape[0]
  128. scale = 1
  129. fig = mplfigure.Figure(frameon=False)
  130. dpi = fig.get_dpi()
  131. fig.set_size_inches(
  132. (width * scale + 1e-2) / dpi,
  133. (height * scale + 1e-2) / dpi,
  134. )
  135. canvas = FigureCanvasAgg(fig)
  136. ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
  137. ax.axis("off")
  138. ax.set_xlim(0.0, width)
  139. ax.set_ylim(height)
  140. default_font_size = max(np.sqrt(height * width) // 90, 10 // scale)
  141. linewidth = max(default_font_size / 4, 1)
  142. labels = list()
  143. for dt in np.array(results):
  144. if dt['category'] not in labels:
  145. labels.append(dt['category'])
  146. color_map = get_color_map_list(256)
  147. keep_results = []
  148. areas = []
  149. for dt in np.array(results):
  150. cname, bbox, score = dt['category'], dt['bbox'], dt['score']
  151. if score < threshold:
  152. continue
  153. keep_results.append(dt)
  154. areas.append(bbox[2] * bbox[3])
  155. areas = np.asarray(areas)
  156. sorted_idxs = np.argsort(-areas).tolist()
  157. keep_results = [keep_results[k]
  158. for k in sorted_idxs] if len(keep_results) > 0 else []
  159. for dt in np.array(keep_results):
  160. cname, bbox, score = dt['category'], dt['bbox'], dt['score']
  161. xmin, ymin, w, h = bbox
  162. xmax = xmin + w
  163. ymax = ymin + h
  164. color = tuple(color_map[labels.index(cname) + 2])
  165. color = [c / 255. for c in color]
  166. # draw bbox
  167. ax.add_patch(
  168. mpl.patches.Rectangle(
  169. (xmin, ymin),
  170. w,
  171. h,
  172. fill=False,
  173. edgecolor=color,
  174. linewidth=linewidth * scale,
  175. alpha=0.8,
  176. linestyle="-",
  177. ))
  178. # draw mask
  179. if 'mask' in dt:
  180. mask = dt['mask']
  181. mask = np.ascontiguousarray(mask)
  182. res = cv2.findContours(
  183. mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
  184. hierarchy = res[-1]
  185. alpha = 0.5
  186. if hierarchy is not None:
  187. has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
  188. res = res[-2]
  189. res = [x.flatten() for x in res]
  190. res = [x for x in res if len(x) >= 6]
  191. for segment in res:
  192. segment = segment.reshape(-1, 2)
  193. edge_color = mplc.to_rgb(color) + (1, )
  194. polygon = mpl.patches.Polygon(
  195. segment,
  196. fill=True,
  197. facecolor=mplc.to_rgb(color) + (alpha, ),
  198. edgecolor=edge_color,
  199. linewidth=max(default_font_size // 15 * scale, 1),
  200. )
  201. ax.add_patch(polygon)
  202. # draw label
  203. text_pos = (xmin, ymin)
  204. horiz_align = "left"
  205. instance_area = w * h
  206. if (instance_area < _SMALL_OBJECT_AREA_THRESH * scale
  207. or h < 40 * scale):
  208. if ymin >= height - 5:
  209. text_pos = (xmin, ymin)
  210. else:
  211. text_pos = (xmin, ymax)
  212. height_ratio = h / np.sqrt(height * width)
  213. font_size = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 *
  214. default_font_size)
  215. text = "{} {:.2f}".format(cname, score)
  216. color = np.maximum(list(mplc.to_rgb(color)), 0.2)
  217. color[np.argmax(color)] = max(0.8, np.max(color))
  218. color = _change_color_brightness(color, brightness_factor=0.7)
  219. ax.text(
  220. text_pos[0],
  221. text_pos[1],
  222. text,
  223. size=font_size * scale,
  224. family="sans-serif",
  225. bbox={
  226. "facecolor": "black",
  227. "alpha": 0.8,
  228. "pad": 0.7,
  229. "edgecolor": "none"
  230. },
  231. verticalalignment="top",
  232. horizontalalignment=horiz_align,
  233. color=color,
  234. zorder=10,
  235. rotation=0,
  236. )
  237. s, (width, height) = canvas.print_to_buffer()
  238. buffer = np.frombuffer(s, dtype="uint8")
  239. img_rgba = buffer.reshape(height, width, 4)
  240. rgb, alpha = np.split(img_rgba, [3], axis=2)
  241. try:
  242. import numexpr as ne
  243. visualized_image = ne.evaluate(
  244. "image * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
  245. except ImportError:
  246. alpha = alpha.astype("float32") / 255.0
  247. visualized_image = image * (1 - alpha) + rgb * alpha
  248. visualized_image = visualized_image.astype("uint8")
  249. return visualized_image
  250. def draw_pr_curve(eval_details_file=None,
  251. gt=None,
  252. pred_bbox=None,
  253. pred_mask=None,
  254. iou_thresh=0.5,
  255. save_dir='./'):
  256. if eval_details_file is not None:
  257. import json
  258. with open(eval_details_file, 'r') as f:
  259. eval_details = json.load(f)
  260. pred_bbox = eval_details['bbox']
  261. if 'mask' in eval_details:
  262. pred_mask = eval_details['mask']
  263. gt = eval_details['gt']
  264. if gt is None or pred_bbox is None:
  265. raise Exception(
  266. "gt/pred_bbox/pred_mask is None now, please set right eval_details_file or gt/pred_bbox/pred_mask."
  267. )
  268. if pred_bbox is not None and len(pred_bbox) == 0:
  269. raise Exception("There is no predicted bbox.")
  270. if pred_mask is not None and len(pred_mask) == 0:
  271. raise Exception("There is no predicted mask.")
  272. from pycocotools.coco import COCO
  273. from pycocotools.cocoeval import COCOeval
  274. coco = COCO()
  275. coco.dataset = gt
  276. coco.createIndex()
  277. def _summarize(coco_gt, ap=1, iouThr=None, areaRng='all', maxDets=100):
  278. p = coco_gt.params
  279. aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
  280. mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
  281. if ap == 1:
  282. # dimension of precision: [TxRxKxAxM]
  283. s = coco_gt.eval['precision']
  284. # IoU
  285. if iouThr is not None:
  286. t = np.where(iouThr == p.iouThrs)[0]
  287. s = s[t]
  288. s = s[:, :, :, aind, mind]
  289. else:
  290. # dimension of recall: [TxKxAxM]
  291. s = coco_gt.eval['recall']
  292. if iouThr is not None:
  293. t = np.where(iouThr == p.iouThrs)[0]
  294. s = s[t]
  295. s = s[:, :, aind, mind]
  296. if len(s[s > -1]) == 0:
  297. mean_s = -1
  298. else:
  299. mean_s = np.mean(s[s > -1])
  300. return mean_s
  301. def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'):
  302. from pycocotools.cocoeval import COCOeval
  303. coco_dt = loadRes(coco_gt, coco_dt)
  304. np.linspace = fixed_linspace
  305. coco_eval = COCOeval(coco_gt, coco_dt, style)
  306. coco_eval.params.iouThrs = np.linspace(
  307. iou_thresh, iou_thresh, 1, endpoint=True)
  308. np.linspace = backup_linspace
  309. coco_eval.evaluate()
  310. coco_eval.accumulate()
  311. stats = _summarize(coco_eval, iouThr=iou_thresh)
  312. catIds = coco_gt.getCatIds()
  313. if len(catIds) != coco_eval.eval['precision'].shape[2]:
  314. raise Exception(
  315. "The category number must be same as the third dimension of precisions."
  316. )
  317. x = np.arange(0.0, 1.01, 0.01)
  318. color_map = get_color_map_list(256)[1:256]
  319. plt.subplot(1, 2, 1)
  320. plt.title(style + " precision-recall IoU={}".format(iou_thresh))
  321. plt.xlabel("recall")
  322. plt.ylabel("precision")
  323. plt.xlim(0, 1.01)
  324. plt.ylim(0, 1.01)
  325. plt.grid(linestyle='--', linewidth=1)
  326. plt.plot([0, 1], [0, 1], 'r--', linewidth=1)
  327. my_x_ticks = np.arange(0, 1.01, 0.1)
  328. my_y_ticks = np.arange(0, 1.01, 0.1)
  329. plt.xticks(my_x_ticks, fontsize=5)
  330. plt.yticks(my_y_ticks, fontsize=5)
  331. for idx, catId in enumerate(catIds):
  332. pr_array = coco_eval.eval['precision'][0, :, idx, 0, 2]
  333. precision = pr_array[pr_array > -1]
  334. ap = np.mean(precision) if precision.size else float('nan')
  335. nm = coco_gt.loadCats(catId)[0]['name'] + ' AP={:0.2f}'.format(
  336. float(ap * 100))
  337. color = tuple(color_map[idx])
  338. color = [float(c) / 255 for c in color]
  339. color.append(0.75)
  340. plt.plot(x, pr_array, color=color, label=nm, linewidth=1)
  341. plt.legend(loc="lower left", fontsize=5)
  342. plt.subplot(1, 2, 2)
  343. plt.title(style + " score-recall IoU={}".format(iou_thresh))
  344. plt.xlabel('recall')
  345. plt.ylabel('score')
  346. plt.xlim(0, 1.01)
  347. plt.ylim(0, 1.01)
  348. plt.grid(linestyle='--', linewidth=1)
  349. plt.xticks(my_x_ticks, fontsize=5)
  350. plt.yticks(my_y_ticks, fontsize=5)
  351. for idx, catId in enumerate(catIds):
  352. nm = coco_gt.loadCats(catId)[0]['name']
  353. sr_array = coco_eval.eval['scores'][0, :, idx, 0, 2]
  354. color = tuple(color_map[idx])
  355. color = [float(c) / 255 for c in color]
  356. color.append(0.75)
  357. plt.plot(x, sr_array, color=color, label=nm, linewidth=1)
  358. plt.legend(loc="lower left", fontsize=5)
  359. plt.savefig(
  360. os.path.join(save_dir, "./{}_pr_curve(iou-{}).png".format(
  361. style, iou_thresh)),
  362. dpi=800)
  363. plt.close()
  364. if not os.path.exists(save_dir):
  365. os.makedirs(save_dir)
  366. cal_pr(coco, pred_bbox, iou_thresh, save_dir, style='bbox')
  367. if pred_mask is not None:
  368. cal_pr(coco, pred_mask, iou_thresh, save_dir, style='segm')