visualize.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # -*- coding: utf-8 -*
  15. import os
  16. import cv2
  17. import colorsys
  18. import numpy as np
  19. import time
  20. import paddlex.utils.logging as logging
  21. from .detection_eval import fixed_linspace, backup_linspace, loadRes
  22. from paddlex.cv.datasets.dataset import is_pic
  23. def visualize_detection(image, result, threshold=0.5, save_dir='./'):
  24. """
  25. Visualize bbox and mask results
  26. """
  27. if isinstance(image, np.ndarray):
  28. image_name = str(int(time.time() * 1000)) + '.jpg'
  29. else:
  30. image_name = os.path.split(image)[-1]
  31. image = cv2.imread(image)
  32. image = draw_bbox_mask(image, result, threshold=threshold)
  33. if save_dir is not None:
  34. if not os.path.exists(save_dir):
  35. os.makedirs(save_dir)
  36. out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
  37. cv2.imwrite(out_path, image)
  38. logging.info('The visualized result is saved as {}'.format(out_path))
  39. else:
  40. return image
  41. def visualize_segmentation(image,
  42. result,
  43. weight=0.6,
  44. save_dir='./',
  45. color=None):
  46. """
  47. Convert segment result to color image, and save added image.
  48. Args:
  49. image: the path of origin image
  50. result: the predict result of image
  51. weight: the image weight of visual image, and the result weight is (1 - weight)
  52. save_dir: the directory for saving visual image
  53. color: the list of a BGR-mode color for each label.
  54. """
  55. label_map = result['label_map']
  56. color_map = get_color_map_list(256)
  57. if color is not None:
  58. for i in range(len(color) // 3):
  59. color_map[i] = color[i * 3:(i + 1) * 3]
  60. color_map = np.array(color_map).astype("uint8")
  61. # Use OpenCV LUT for color mapping
  62. c1 = cv2.LUT(label_map, color_map[:, 0])
  63. c2 = cv2.LUT(label_map, color_map[:, 1])
  64. c3 = cv2.LUT(label_map, color_map[:, 2])
  65. pseudo_img = np.dstack((c1, c2, c3))
  66. if isinstance(image, np.ndarray):
  67. im = image
  68. image_name = str(int(time.time() * 1000)) + '.jpg'
  69. if image.shape[2] != 3:
  70. logging.info(
  71. "The image is not 3-channel array, so predicted label map is shown as a pseudo color image."
  72. )
  73. weight = 0.
  74. else:
  75. image_name = os.path.split(image)[-1]
  76. if not is_pic(image):
  77. logging.info(
  78. "The image cannot be opened by opencv, so predicted label map is shown as a pseudo color image."
  79. )
  80. image_name = image_name.split('.')[0] + '.jpg'
  81. weight = 0.
  82. else:
  83. im = cv2.imread(image)
  84. if abs(weight) < 1e-5:
  85. vis_result = pseudo_img
  86. else:
  87. vis_result = cv2.addWeighted(im, weight,
  88. pseudo_img.astype(im.dtype), 1 - weight,
  89. 0)
  90. if save_dir is not None:
  91. if not os.path.exists(save_dir):
  92. os.makedirs(save_dir)
  93. out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
  94. cv2.imwrite(out_path, vis_result)
  95. logging.info('The visualized result is saved as {}'.format(out_path))
  96. else:
  97. return vis_result
  98. def get_color_map_list(num_classes):
  99. """ Returns the color map for visualizing the segmentation mask,
  100. which can support arbitrary number of classes.
  101. Args:
  102. num_classes: Number of classes
  103. Returns:
  104. The color map
  105. """
  106. color_map = num_classes * [0, 0, 0]
  107. for i in range(0, num_classes):
  108. j = 0
  109. lab = i
  110. while lab:
  111. color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
  112. color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
  113. color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
  114. j += 1
  115. lab >>= 3
  116. color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
  117. return color_map
  118. # expand an array of boxes by a given scale.
  119. def expand_boxes(boxes, scale):
  120. """
  121. """
  122. w_half = (boxes[:, 2] - boxes[:, 0]) * .5
  123. h_half = (boxes[:, 3] - boxes[:, 1]) * .5
  124. x_c = (boxes[:, 2] + boxes[:, 0]) * .5
  125. y_c = (boxes[:, 3] + boxes[:, 1]) * .5
  126. w_half *= scale
  127. h_half *= scale
  128. boxes_exp = np.zeros(boxes.shape)
  129. boxes_exp[:, 0] = x_c - w_half
  130. boxes_exp[:, 2] = x_c + w_half
  131. boxes_exp[:, 1] = y_c - h_half
  132. boxes_exp[:, 3] = y_c + h_half
  133. return boxes_exp
  134. def clip_bbox(bbox):
  135. xmin = max(min(bbox[0], 1.), 0.)
  136. ymin = max(min(bbox[1], 1.), 0.)
  137. xmax = max(min(bbox[2], 1.), 0.)
  138. ymax = max(min(bbox[3], 1.), 0.)
  139. return xmin, ymin, xmax, ymax
  140. def draw_bbox_mask(image, results, threshold=0.5):
  141. import matplotlib
  142. matplotlib.use('Agg')
  143. import matplotlib as mpl
  144. import matplotlib.figure as mplfigure
  145. import matplotlib.colors as mplc
  146. from matplotlib.backends.backend_agg import FigureCanvasAgg
  147. # refer to https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/visualizer.py
  148. def _change_color_brightness(color, brightness_factor):
  149. assert brightness_factor >= -1.0 and brightness_factor <= 1.0
  150. color = mplc.to_rgb(color)
  151. polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
  152. modified_lightness = polygon_color[1] + (brightness_factor *
  153. polygon_color[1])
  154. modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
  155. modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
  156. modified_color = colorsys.hls_to_rgb(
  157. polygon_color[0], modified_lightness, polygon_color[2])
  158. return modified_color
  159. _SMALL_OBJECT_AREA_THRESH = 1000
  160. # setup figure
  161. width, height = image.shape[1], image.shape[0]
  162. scale = 1
  163. fig = mplfigure.Figure(frameon=False)
  164. dpi = fig.get_dpi()
  165. fig.set_size_inches(
  166. (width * scale + 1e-2) / dpi,
  167. (height * scale + 1e-2) / dpi, )
  168. canvas = FigureCanvasAgg(fig)
  169. ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
  170. ax.axis("off")
  171. ax.set_xlim(0.0, width)
  172. ax.set_ylim(height)
  173. default_font_size = max(np.sqrt(height * width) // 90, 10 // scale)
  174. linewidth = max(default_font_size / 4, 1)
  175. labels = list()
  176. for dt in np.array(results):
  177. if dt['category'] not in labels:
  178. labels.append(dt['category'])
  179. color_map = get_color_map_list(256)
  180. keep_results = []
  181. areas = []
  182. for dt in np.array(results):
  183. cname, bbox, score = dt['category'], dt['bbox'], dt['score']
  184. if score < threshold:
  185. continue
  186. keep_results.append(dt)
  187. areas.append(bbox[2] * bbox[3])
  188. areas = np.asarray(areas)
  189. sorted_idxs = np.argsort(-areas).tolist()
  190. keep_results = [keep_results[k]
  191. for k in sorted_idxs] if len(keep_results) > 0 else []
  192. for dt in np.array(keep_results):
  193. cname, bbox, score = dt['category'], dt['bbox'], dt['score']
  194. xmin, ymin, w, h = bbox
  195. xmax = xmin + w
  196. ymax = ymin + h
  197. color = tuple(color_map[labels.index(cname) + 2])
  198. color = [c / 255. for c in color]
  199. # draw bbox
  200. ax.add_patch(
  201. mpl.patches.Rectangle(
  202. (xmin, ymin),
  203. w,
  204. h,
  205. fill=False,
  206. edgecolor=color,
  207. linewidth=linewidth * scale,
  208. alpha=0.8,
  209. linestyle="-", ))
  210. # draw mask
  211. if 'mask' in dt:
  212. mask = dt['mask']
  213. mask = np.ascontiguousarray(mask)
  214. res = cv2.findContours(
  215. mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
  216. hierarchy = res[-1]
  217. alpha = 0.5
  218. if hierarchy is not None:
  219. has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
  220. res = res[-2]
  221. res = [x.flatten() for x in res]
  222. res = [x for x in res if len(x) >= 6]
  223. for segment in res:
  224. segment = segment.reshape(-1, 2)
  225. edge_color = mplc.to_rgb(color) + (1, )
  226. polygon = mpl.patches.Polygon(
  227. segment,
  228. fill=True,
  229. facecolor=mplc.to_rgb(color) + (alpha, ),
  230. edgecolor=edge_color,
  231. linewidth=max(default_font_size // 15 * scale, 1), )
  232. ax.add_patch(polygon)
  233. # draw label
  234. text_pos = (xmin, ymin)
  235. horiz_align = "left"
  236. instance_area = w * h
  237. if (instance_area < _SMALL_OBJECT_AREA_THRESH * scale or
  238. h < 40 * scale):
  239. if ymin >= height - 5:
  240. text_pos = (xmin, ymin)
  241. else:
  242. text_pos = (xmin, ymax)
  243. height_ratio = h / np.sqrt(height * width)
  244. font_size = (np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2,
  245. 2) * 0.5 * default_font_size)
  246. text = "{} {:.2f}".format(cname, score)
  247. color = np.maximum(list(mplc.to_rgb(color)), 0.2)
  248. color[np.argmax(color)] = max(0.8, np.max(color))
  249. color = _change_color_brightness(color, brightness_factor=0.7)
  250. ax.text(
  251. text_pos[0],
  252. text_pos[1],
  253. text,
  254. size=font_size * scale,
  255. family="sans-serif",
  256. bbox={
  257. "facecolor": "black",
  258. "alpha": 0.8,
  259. "pad": 0.7,
  260. "edgecolor": "none"
  261. },
  262. verticalalignment="top",
  263. horizontalalignment=horiz_align,
  264. color=color,
  265. zorder=10,
  266. rotation=0, )
  267. s, (width, height) = canvas.print_to_buffer()
  268. buffer = np.frombuffer(s, dtype="uint8")
  269. img_rgba = buffer.reshape(height, width, 4)
  270. rgb, alpha = np.split(img_rgba, [3], axis=2)
  271. try:
  272. import numexpr as ne
  273. visualized_image = ne.evaluate(
  274. "image * (1 - alpha / 255.0) + rgb * (alpha / 255.0)")
  275. except ImportError:
  276. alpha = alpha.astype("float32") / 255.0
  277. visualized_image = image * (1 - alpha) + rgb * alpha
  278. visualized_image = visualized_image.astype("uint8")
  279. return visualized_image
  280. def draw_pr_curve(eval_details_file=None,
  281. gt=None,
  282. pred_bbox=None,
  283. pred_mask=None,
  284. iou_thresh=0.5,
  285. save_dir='./'):
  286. if eval_details_file is not None:
  287. import json
  288. with open(eval_details_file, 'r') as f:
  289. eval_details = json.load(f)
  290. pred_bbox = eval_details['bbox']
  291. if 'mask' in eval_details:
  292. pred_mask = eval_details['mask']
  293. gt = eval_details['gt']
  294. if gt is None or pred_bbox is None:
  295. raise Exception(
  296. "gt/pred_bbox/pred_mask is None now, please set right eval_details_file or gt/pred_bbox/pred_mask."
  297. )
  298. if pred_bbox is not None and len(pred_bbox) == 0:
  299. raise Exception("There is no predicted bbox.")
  300. if pred_mask is not None and len(pred_mask) == 0:
  301. raise Exception("There is no predicted mask.")
  302. import matplotlib
  303. matplotlib.use('Agg')
  304. import matplotlib.pyplot as plt
  305. from pycocotools.coco import COCO
  306. from pycocotools.cocoeval import COCOeval
  307. coco = COCO()
  308. coco.dataset = gt
  309. coco.createIndex()
  310. def _summarize(coco_gt, ap=1, iouThr=None, areaRng='all', maxDets=100):
  311. p = coco_gt.params
  312. aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
  313. mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
  314. if ap == 1:
  315. # dimension of precision: [TxRxKxAxM]
  316. s = coco_gt.eval['precision']
  317. # IoU
  318. if iouThr is not None:
  319. t = np.where(iouThr == p.iouThrs)[0]
  320. s = s[t]
  321. s = s[:, :, :, aind, mind]
  322. else:
  323. # dimension of recall: [TxKxAxM]
  324. s = coco_gt.eval['recall']
  325. if iouThr is not None:
  326. t = np.where(iouThr == p.iouThrs)[0]
  327. s = s[t]
  328. s = s[:, :, aind, mind]
  329. if len(s[s > -1]) == 0:
  330. mean_s = -1
  331. else:
  332. mean_s = np.mean(s[s > -1])
  333. return mean_s
  334. def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'):
  335. from pycocotools.cocoeval import COCOeval
  336. coco_dt = loadRes(coco_gt, coco_dt)
  337. np.linspace = fixed_linspace
  338. coco_eval = COCOeval(coco_gt, coco_dt, style)
  339. coco_eval.params.iouThrs = np.linspace(
  340. iou_thresh, iou_thresh, 1, endpoint=True)
  341. np.linspace = backup_linspace
  342. coco_eval.evaluate()
  343. coco_eval.accumulate()
  344. stats = _summarize(coco_eval, iouThr=iou_thresh)
  345. catIds = coco_gt.getCatIds()
  346. if len(catIds) != coco_eval.eval['precision'].shape[2]:
  347. raise Exception(
  348. "The category number must be same as the third dimension of precisions."
  349. )
  350. x = np.arange(0.0, 1.01, 0.01)
  351. color_map = get_color_map_list(256)[1:256]
  352. plt.subplot(1, 2, 1)
  353. plt.title(style + " precision-recall IoU={}".format(iou_thresh))
  354. plt.xlabel("recall")
  355. plt.ylabel("precision")
  356. plt.xlim(0, 1.01)
  357. plt.ylim(0, 1.01)
  358. plt.grid(linestyle='--', linewidth=1)
  359. plt.plot([0, 1], [0, 1], 'r--', linewidth=1)
  360. my_x_ticks = np.arange(0, 1.01, 0.1)
  361. my_y_ticks = np.arange(0, 1.01, 0.1)
  362. plt.xticks(my_x_ticks, fontsize=5)
  363. plt.yticks(my_y_ticks, fontsize=5)
  364. for idx, catId in enumerate(catIds):
  365. pr_array = coco_eval.eval['precision'][0, :, idx, 0, 2]
  366. precision = pr_array[pr_array > -1]
  367. ap = np.mean(precision) if precision.size else float('nan')
  368. nm = coco_gt.loadCats(catId)[0]['name'] + ' AP={:0.2f}'.format(
  369. float(ap * 100))
  370. color = tuple(color_map[idx])
  371. color = [float(c) / 255 for c in color]
  372. color.append(0.75)
  373. plt.plot(x, pr_array, color=color, label=nm, linewidth=1)
  374. plt.legend(loc="lower left", fontsize=5)
  375. plt.subplot(1, 2, 2)
  376. plt.title(style + " score-recall IoU={}".format(iou_thresh))
  377. plt.xlabel('recall')
  378. plt.ylabel('score')
  379. plt.xlim(0, 1.01)
  380. plt.ylim(0, 1.01)
  381. plt.grid(linestyle='--', linewidth=1)
  382. plt.xticks(my_x_ticks, fontsize=5)
  383. plt.yticks(my_y_ticks, fontsize=5)
  384. for idx, catId in enumerate(catIds):
  385. nm = coco_gt.loadCats(catId)[0]['name']
  386. sr_array = coco_eval.eval['scores'][0, :, idx, 0, 2]
  387. color = tuple(color_map[idx])
  388. color = [float(c) / 255 for c in color]
  389. color.append(0.75)
  390. plt.plot(x, sr_array, color=color, label=nm, linewidth=1)
  391. plt.legend(loc="lower left", fontsize=5)
  392. plt.savefig(
  393. os.path.join(
  394. save_dir,
  395. "./{}_pr_curve(iou-{}).png".format(style, iou_thresh)),
  396. dpi=800)
  397. plt.close()
  398. if not os.path.exists(save_dir):
  399. os.makedirs(save_dir)
  400. cal_pr(coco, pred_bbox, iou_thresh, save_dir, style='bbox')
  401. if pred_mask is not None:
  402. cal_pr(coco, pred_mask, iou_thresh, save_dir, style='segm')