visualize.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. import os
  15. import cv2
  16. import colorsys
  17. import numpy as np
  18. import matplotlib as mpl
  19. import matplotlib.pyplot as plt
  20. import matplotlib.figure as mplfigure
  21. import matplotlib.colors as mplc
  22. from matplotlib.backends.backend_agg import FigureCanvasAgg
  23. import paddlex.utils.logging as logging
  24. from .detection_eval import fixed_linspace, backup_linspace, loadRes
  25. def visualize_detection(image, result, threshold=0.5, save_dir='./'):
  26. """
  27. Visualize bbox and mask results
  28. """
  29. image_name = os.path.split(image)[-1]
  30. image = cv2.imread(image)
  31. image = draw_bbox_mask(image, result, threshold=threshold)
  32. if save_dir is not None:
  33. if not os.path.exists(save_dir):
  34. os.makedirs(save_dir)
  35. out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
  36. cv2.imwrite(out_path, image)
  37. logging.info('The visualized result is saved as {}'.format(out_path))
  38. else:
  39. return image
  40. def visualize_segmentation(image, result, weight=0.6, save_dir='./'):
  41. """
  42. Convert segment result to color image, and save added image.
  43. Args:
  44. image: the path of origin image
  45. result: the predict result of image
  46. weight: the image weight of visual image, and the result weight is (1 - weight)
  47. save_dir: the directory for saving visual image
  48. """
  49. label_map = result['label_map']
  50. color_map = get_color_map_list(256)
  51. color_map = np.array(color_map).astype("uint8")
  52. # Use OpenCV LUT for color mapping
  53. c1 = cv2.LUT(label_map, color_map[:, 0])
  54. c2 = cv2.LUT(label_map, color_map[:, 1])
  55. c3 = cv2.LUT(label_map, color_map[:, 2])
  56. pseudo_img = np.dstack((c1, c2, c3))
  57. im = cv2.imread(image)
  58. vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
  59. if save_dir is not None:
  60. if not os.path.exists(save_dir):
  61. os.makedirs(save_dir)
  62. image_name = os.path.split(image)[-1]
  63. out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
  64. cv2.imwrite(out_path, vis_result)
  65. logging.info('The visualized result is saved as {}'.format(out_path))
  66. else:
  67. return vis_result
  68. def get_color_map_list(num_classes):
  69. """ Returns the color map for visualizing the segmentation mask,
  70. which can support arbitrary number of classes.
  71. Args:
  72. num_classes: Number of classes
  73. Returns:
  74. The color map
  75. """
  76. color_map = num_classes * [0, 0, 0]
  77. for i in range(0, num_classes):
  78. j = 0
  79. lab = i
  80. while lab:
  81. color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
  82. color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
  83. color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
  84. j += 1
  85. lab >>= 3
  86. color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
  87. return color_map
  88. # expand an array of boxes by a given scale.
  89. def expand_boxes(boxes, scale):
  90. """
  91. """
  92. w_half = (boxes[:, 2] - boxes[:, 0]) * .5
  93. h_half = (boxes[:, 3] - boxes[:, 1]) * .5
  94. x_c = (boxes[:, 2] + boxes[:, 0]) * .5
  95. y_c = (boxes[:, 3] + boxes[:, 1]) * .5
  96. w_half *= scale
  97. h_half *= scale
  98. boxes_exp = np.zeros(boxes.shape)
  99. boxes_exp[:, 0] = x_c - w_half
  100. boxes_exp[:, 2] = x_c + w_half
  101. boxes_exp[:, 1] = y_c - h_half
  102. boxes_exp[:, 3] = y_c + h_half
  103. return boxes_exp
  104. def clip_bbox(bbox):
  105. xmin = max(min(bbox[0], 1.), 0.)
  106. ymin = max(min(bbox[1], 1.), 0.)
  107. xmax = max(min(bbox[2], 1.), 0.)
  108. ymax = max(min(bbox[3], 1.), 0.)
  109. return xmin, ymin, xmax, ymax
  110. def draw_bbox_mask(image, results, threshold=0.5):
  111. # refer to https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/visualizer.py
  112. def _change_color_brightness(color, brightness_factor):
  113. assert brightness_factor >= -1.0 and brightness_factor <= 1.0
  114. color = mplc.to_rgb(color)
  115. polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
  116. modified_lightness = polygon_color[1] + (
  117. brightness_factor * polygon_color[1])
  118. modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
  119. modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
  120. modified_color = colorsys.hls_to_rgb(
  121. polygon_color[0], modified_lightness, polygon_color[2])
  122. return modified_color
  123. _SMALL_OBJECT_AREA_THRESH = 1000
  124. # setup figure
  125. width, height = image.shape[1], image.shape[0]
  126. scale = 1
  127. fig = mplfigure.Figure(frameon=False)
  128. dpi = fig.get_dpi()
  129. fig.set_size_inches(
  130. (width * scale + 1e-2) / dpi,
  131. (height * scale + 1e-2) / dpi,
  132. )
  133. canvas = FigureCanvasAgg(fig)
  134. ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
  135. ax.axis("off")
  136. ax.set_xlim(0.0, width)
  137. ax.set_ylim(height)
  138. default_font_size = max(np.sqrt(height * width) // 90, 10 // scale)
  139. linewidth = max(default_font_size / 4, 1)
  140. labels = list()
  141. for dt in np.array(results):
  142. if dt['category'] not in labels:
  143. labels.append(dt['category'])
  144. color_map = get_color_map_list(256)
  145. keep_results = []
  146. areas = []
  147. for dt in np.array(results):
  148. cname, bbox, score = dt['category'], dt['bbox'], dt['score']
  149. if score < threshold:
  150. continue
  151. keep_results.append(dt)
  152. areas.append(bbox[2] * bbox[3])
  153. areas = np.asarray(areas)
  154. sorted_idxs = np.argsort(-areas).tolist()
  155. keep_results = [keep_results[k]
  156. for k in sorted_idxs] if len(keep_results) > 0 else []
  157. for dt in np.array(keep_results):
  158. cname, bbox, score = dt['category'], dt['bbox'], dt['score']
  159. xmin, ymin, w, h = bbox
  160. xmax = xmin + w
  161. ymax = ymin + h
  162. color = tuple(color_map[labels.index(cname) + 2])
  163. color = [c / 255. for c in color]
  164. # draw bbox
  165. ax.add_patch(
  166. mpl.patches.Rectangle(
  167. (xmin, ymin),
  168. w,
  169. h,
  170. fill=False,
  171. edgecolor=color,
  172. linewidth=linewidth * scale,
  173. alpha=0.8,
  174. linestyle="-",
  175. ))
  176. # draw mask
  177. if 'mask' in dt:
  178. mask = dt['mask']
  179. mask = np.ascontiguousarray(mask)
  180. res = cv2.findContours(
  181. mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
  182. hierarchy = res[-1]
  183. alpha = 0.5
  184. if hierarchy is not None:
  185. has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
  186. res = res[-2]
  187. res = [x.flatten() for x in res]
  188. res = [x for x in res if len(x) >= 6]
  189. for segment in res:
  190. segment = segment.reshape(-1, 2)
  191. edge_color = mplc.to_rgb(color) + (1, )
  192. polygon = mpl.patches.Polygon(
  193. segment,
  194. fill=True,
  195. facecolor=mplc.to_rgb(color) + (alpha, ),
  196. edgecolor=edge_color,
  197. linewidth=max(default_font_size // 15 * scale, 1),
  198. )
  199. ax.add_patch(polygon)
  200. # draw label
  201. text_pos = (xmin, ymin)
  202. horiz_align = "left"
  203. instance_area = w * h
  204. if (instance_area < _SMALL_OBJECT_AREA_THRESH * scale
  205. or h < 40 * scale):
  206. if ymin >= height - 5:
  207. text_pos = (xmin, ymin)
  208. else:
  209. text_pos = (xmin, ymax)
  210. height_ratio = h / np.sqrt(height * width)
  211. font_size = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 *
  212. default_font_size)
  213. text = "{} {:.2f}".format(cname, score)
  214. color = np.maximum(list(mplc.to_rgb(color)), 0.2)
  215. color[np.argmax(color)] = max(0.8, np.max(color))
  216. color = _change_color_brightness(color, brightness_factor=0.7)
  217. ax.text(
  218. text_pos[0],
  219. text_pos[1],
  220. text,
  221. size=font_size * scale,
  222. family="sans-serif",
  223. bbox={
  224. "facecolor": "black",
  225. "alpha": 0.8,
  226. "pad": 0.7,
  227. "edgecolor": "none"
  228. },
  229. verticalalignment="top",
  230. horizontalalignment=horiz_align,
  231. color=color,
  232. zorder=10,
  233. rotation=0,
  234. )
  235. s, (width, height) = canvas.print_to_buffer()
  236. buffer = np.frombuffer(s, dtype="uint8")
  237. img_rgba = buffer.reshape(height, width, 4)
  238. rgb, alpha = np.split(img_rgba, [3], axis=2)
  239. try:
  240. import numexpr as ne
  241. visualized_image = ne.evaluate(
  242. "image * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
  243. except ImportError:
  244. alpha = alpha.astype("float32") / 255.0
  245. visualized_image = image * (1 - alpha) + rgb * alpha
  246. visualized_image = visualized_image.astype("uint8")
  247. return visualized_image
  248. def draw_pr_curve(eval_details_file=None,
  249. gt=None,
  250. pred_bbox=None,
  251. pred_mask=None,
  252. iou_thresh=0.5,
  253. save_dir='./'):
  254. if eval_details_file is not None:
  255. import json
  256. with open(eval_details_file, 'r') as f:
  257. eval_details = json.load(f)
  258. pred_bbox = eval_details['bbox']
  259. if 'mask' in eval_details:
  260. pred_mask = eval_details['mask']
  261. gt = eval_details['gt']
  262. if gt is None or pred_bbox is None:
  263. raise Exception(
  264. "gt/pred_bbox/pred_mask is None now, please set right eval_details_file or gt/pred_bbox/pred_mask."
  265. )
  266. if pred_bbox is not None and len(pred_bbox) == 0:
  267. raise Exception("There is no predicted bbox.")
  268. if pred_mask is not None and len(pred_mask) == 0:
  269. raise Exception("There is no predicted mask.")
  270. from pycocotools.coco import COCO
  271. from pycocotools.cocoeval import COCOeval
  272. coco = COCO()
  273. coco.dataset = gt
  274. coco.createIndex()
  275. def _summarize(coco_gt, ap=1, iouThr=None, areaRng='all', maxDets=100):
  276. p = coco_gt.params
  277. aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
  278. mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
  279. if ap == 1:
  280. # dimension of precision: [TxRxKxAxM]
  281. s = coco_gt.eval['precision']
  282. # IoU
  283. if iouThr is not None:
  284. t = np.where(iouThr == p.iouThrs)[0]
  285. s = s[t]
  286. s = s[:, :, :, aind, mind]
  287. else:
  288. # dimension of recall: [TxKxAxM]
  289. s = coco_gt.eval['recall']
  290. if iouThr is not None:
  291. t = np.where(iouThr == p.iouThrs)[0]
  292. s = s[t]
  293. s = s[:, :, aind, mind]
  294. if len(s[s > -1]) == 0:
  295. mean_s = -1
  296. else:
  297. mean_s = np.mean(s[s > -1])
  298. return mean_s
  299. def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'):
  300. import matplotlib.pyplot as plt
  301. from pycocotools.cocoeval import COCOeval
  302. coco_dt = loadRes(coco_gt, coco_dt)
  303. np.linspace = fixed_linspace
  304. coco_eval = COCOeval(coco_gt, coco_dt, style)
  305. coco_eval.params.iouThrs = np.linspace(
  306. iou_thresh, iou_thresh, 1, endpoint=True)
  307. np.linspace = backup_linspace
  308. coco_eval.evaluate()
  309. coco_eval.accumulate()
  310. stats = _summarize(coco_eval, iouThr=iou_thresh)
  311. catIds = coco_gt.getCatIds()
  312. if len(catIds) != coco_eval.eval['precision'].shape[2]:
  313. raise Exception(
  314. "The category number must be same as the third dimension of precisions."
  315. )
  316. x = np.arange(0.0, 1.01, 0.01)
  317. color_map = get_color_map_list(256)[1:256]
  318. plt.subplot(1, 2, 1)
  319. plt.title(style + " precision-recall IoU={}".format(iou_thresh))
  320. plt.xlabel("recall")
  321. plt.ylabel("precision")
  322. plt.xlim(0, 1.01)
  323. plt.ylim(0, 1.01)
  324. plt.grid(linestyle='--', linewidth=1)
  325. plt.plot([0, 1], [0, 1], 'r--', linewidth=1)
  326. my_x_ticks = np.arange(0, 1.01, 0.1)
  327. my_y_ticks = np.arange(0, 1.01, 0.1)
  328. plt.xticks(my_x_ticks, fontsize=5)
  329. plt.yticks(my_y_ticks, fontsize=5)
  330. for idx, catId in enumerate(catIds):
  331. pr_array = coco_eval.eval['precision'][0, :, idx, 0, 2]
  332. precision = pr_array[pr_array > -1]
  333. ap = np.mean(precision) if precision.size else float('nan')
  334. nm = coco_gt.loadCats(catId)[0]['name'] + ' AP={:0.2f}'.format(
  335. float(ap * 100))
  336. color = tuple(color_map[idx])
  337. color = [float(c) / 255 for c in color]
  338. color.append(0.75)
  339. plt.plot(x, pr_array, color=color, label=nm, linewidth=1)
  340. plt.legend(loc="lower left", fontsize=5)
  341. plt.subplot(1, 2, 2)
  342. plt.title(style + " score-recall IoU={}".format(iou_thresh))
  343. plt.xlabel('recall')
  344. plt.ylabel('score')
  345. plt.xlim(0, 1.01)
  346. plt.ylim(0, 1.01)
  347. plt.grid(linestyle='--', linewidth=1)
  348. plt.xticks(my_x_ticks, fontsize=5)
  349. plt.yticks(my_y_ticks, fontsize=5)
  350. for idx, catId in enumerate(catIds):
  351. nm = coco_gt.loadCats(catId)[0]['name']
  352. sr_array = coco_eval.eval['scores'][0, :, idx, 0, 2]
  353. color = tuple(color_map[idx])
  354. color = [float(c) / 255 for c in color]
  355. color.append(0.75)
  356. plt.plot(x, sr_array, color=color, label=nm, linewidth=1)
  357. plt.legend(loc="lower left", fontsize=5)
  358. plt.savefig(
  359. os.path.join(save_dir, "./{}_pr_curve(iou-{}).png".format(
  360. style, iou_thresh)),
  361. dpi=800)
  362. plt.close()
  363. if not os.path.exists(save_dir):
  364. os.makedirs(save_dir)
  365. cal_pr(coco, pred_bbox, iou_thresh, save_dir, style='bbox')
  366. if pred_mask is not None:
  367. cal_pr(coco, pred_mask, iou_thresh, save_dir, style='segm')