visualize.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. import os
  15. import cv2
  16. import numpy as np
  17. import matplotlib.pyplot as plt
  18. from PIL import Image, ImageDraw
  19. import paddlex.utils.logging as logging
  20. from .detection_eval import fixed_linspace, backup_linspace, loadRes
  21. def visualize_detection(image, result, threshold=0.5, save_dir='./'):
  22. """
  23. Visualize bbox and mask results
  24. """
  25. image_name = os.path.split(image)[-1]
  26. image = Image.open(image).convert('RGB')
  27. image = draw_bbox_mask(image, result, threshold=threshold)
  28. if save_dir is not None:
  29. if not os.path.exists(save_dir):
  30. os.makedirs(save_dir)
  31. out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
  32. image.save(out_path, quality=95)
  33. logging.info('The visualized result is saved as {}'.format(out_path))
  34. else:
  35. return image
  36. def visualize_segmentation(image, result, weight=0.6, save_dir='./'):
  37. """
  38. Convert segment result to color image, and save added image.
  39. Args:
  40. image: the path of origin image
  41. result: the predict result of image
  42. weight: the image weight of visual image, and the result weight is (1 - weight)
  43. save_dir: the directory for saving visual image
  44. """
  45. label_map = result['label_map']
  46. color_map = get_color_map_list(256)
  47. color_map = np.array(color_map).astype("uint8")
  48. # Use OpenCV LUT for color mapping
  49. c1 = cv2.LUT(label_map, color_map[:, 0])
  50. c2 = cv2.LUT(label_map, color_map[:, 1])
  51. c3 = cv2.LUT(label_map, color_map[:, 2])
  52. pseudo_img = np.dstack((c1, c2, c3))
  53. im = cv2.imread(image)
  54. vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
  55. if save_dir is not None:
  56. if not os.path.exists(save_dir):
  57. os.makedirs(save_dir)
  58. image_name = os.path.split(image)[-1]
  59. out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
  60. cv2.imwrite(out_path, vis_result)
  61. logging.info('The visualized result is saved as {}'.format(out_path))
  62. else:
  63. return vis_result
  64. def get_color_map_list(num_classes):
  65. """ Returns the color map for visualizing the segmentation mask,
  66. which can support arbitrary number of classes.
  67. Args:
  68. num_classes: Number of classes
  69. Returns:
  70. The color map
  71. """
  72. color_map = num_classes * [0, 0, 0]
  73. for i in range(0, num_classes):
  74. j = 0
  75. lab = i
  76. while lab:
  77. color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
  78. color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
  79. color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
  80. j += 1
  81. lab >>= 3
  82. color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
  83. return color_map
  84. # expand an array of boxes by a given scale.
  85. def expand_boxes(boxes, scale):
  86. """
  87. """
  88. w_half = (boxes[:, 2] - boxes[:, 0]) * .5
  89. h_half = (boxes[:, 3] - boxes[:, 1]) * .5
  90. x_c = (boxes[:, 2] + boxes[:, 0]) * .5
  91. y_c = (boxes[:, 3] + boxes[:, 1]) * .5
  92. w_half *= scale
  93. h_half *= scale
  94. boxes_exp = np.zeros(boxes.shape)
  95. boxes_exp[:, 0] = x_c - w_half
  96. boxes_exp[:, 2] = x_c + w_half
  97. boxes_exp[:, 1] = y_c - h_half
  98. boxes_exp[:, 3] = y_c + h_half
  99. return boxes_exp
  100. def clip_bbox(bbox):
  101. xmin = max(min(bbox[0], 1.), 0.)
  102. ymin = max(min(bbox[1], 1.), 0.)
  103. xmax = max(min(bbox[2], 1.), 0.)
  104. ymax = max(min(bbox[3], 1.), 0.)
  105. return xmin, ymin, xmax, ymax
  106. def draw_bbox_mask(image, results, threshold=0.5, alpha=0.7):
  107. labels = list()
  108. for dt in np.array(results):
  109. if dt['category'] not in labels:
  110. labels.append(dt['category'])
  111. color_map = get_color_map_list(len(labels))
  112. for dt in np.array(results):
  113. cname, bbox, score = dt['category'], dt['bbox'], dt['score']
  114. if score < threshold:
  115. continue
  116. xmin, ymin, w, h = bbox
  117. xmax = xmin + w
  118. ymax = ymin + h
  119. color = tuple(color_map[labels.index(cname)])
  120. # draw bbox
  121. draw = ImageDraw.Draw(image)
  122. draw.line([(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
  123. (xmin, ymin)],
  124. width=2,
  125. fill=color)
  126. # draw label
  127. text = "{} {:.2f}".format(cname, score)
  128. tw, th = draw.textsize(text)
  129. draw.rectangle([(xmin + 1, ymin - th), (xmin + tw + 1, ymin)],
  130. fill=color)
  131. draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
  132. # draw mask
  133. if 'mask' in dt:
  134. mask = dt['mask']
  135. color_mask = np.array(color_map[labels.index(
  136. dt['category'])]).astype('float32')
  137. img_array = np.array(image).astype('float32')
  138. idx = np.nonzero(mask)
  139. img_array[idx[0], idx[1], :] *= 1.0 - alpha
  140. img_array[idx[0], idx[1], :] += alpha * color_mask
  141. image = Image.fromarray(img_array.astype('uint8'))
  142. return image
  143. def draw_pr_curve(eval_details_file=None,
  144. gt=None,
  145. pred_bbox=None,
  146. pred_mask=None,
  147. iou_thresh=0.5,
  148. save_dir='./'):
  149. if eval_details_file is not None:
  150. import json
  151. with open(eval_details_file, 'r') as f:
  152. eval_details = json.load(f)
  153. pred_bbox = eval_details['bbox']
  154. if 'mask' in eval_details:
  155. pred_mask = eval_details['mask']
  156. gt = eval_details['gt']
  157. if gt is None or pred_bbox is None:
  158. raise Exception(
  159. "gt/pred_bbox/pred_mask is None now, please set right eval_details_file or gt/pred_bbox/pred_mask."
  160. )
  161. if pred_bbox is not None and len(pred_bbox) == 0:
  162. raise Exception("There is no predicted bbox.")
  163. if pred_mask is not None and len(pred_mask) == 0:
  164. raise Exception("There is no predicted mask.")
  165. from pycocotools.coco import COCO
  166. from pycocotools.cocoeval import COCOeval
  167. coco = COCO()
  168. coco.dataset = gt
  169. coco.createIndex()
  170. def _summarize(coco_gt, ap=1, iouThr=None, areaRng='all', maxDets=100):
  171. p = coco_gt.params
  172. aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
  173. mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
  174. if ap == 1:
  175. # dimension of precision: [TxRxKxAxM]
  176. s = coco_gt.eval['precision']
  177. # IoU
  178. if iouThr is not None:
  179. t = np.where(iouThr == p.iouThrs)[0]
  180. s = s[t]
  181. s = s[:, :, :, aind, mind]
  182. else:
  183. # dimension of recall: [TxKxAxM]
  184. s = coco_gt.eval['recall']
  185. if iouThr is not None:
  186. t = np.where(iouThr == p.iouThrs)[0]
  187. s = s[t]
  188. s = s[:, :, aind, mind]
  189. if len(s[s > -1]) == 0:
  190. mean_s = -1
  191. else:
  192. mean_s = np.mean(s[s > -1])
  193. return mean_s
  194. def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'):
  195. from pycocotools.cocoeval import COCOeval
  196. coco_dt = loadRes(coco_gt, coco_dt)
  197. np.linspace = fixed_linspace
  198. coco_eval = COCOeval(coco_gt, coco_dt, style)
  199. coco_eval.params.iouThrs = np.linspace(
  200. iou_thresh, iou_thresh, 1, endpoint=True)
  201. np.linspace = backup_linspace
  202. coco_eval.evaluate()
  203. coco_eval.accumulate()
  204. stats = _summarize(coco_eval, iouThr=iou_thresh)
  205. catIds = coco_gt.getCatIds()
  206. if len(catIds) != coco_eval.eval['precision'].shape[2]:
  207. raise Exception(
  208. "The category number must be same as the third dimension of precisions."
  209. )
  210. x = np.arange(0.0, 1.01, 0.01)
  211. color_map = get_color_map_list(256)[1:256]
  212. plt.subplot(1, 2, 1)
  213. plt.title(style + " precision-recall IoU={}".format(iou_thresh))
  214. plt.xlabel("recall")
  215. plt.ylabel("precision")
  216. plt.xlim(0, 1.01)
  217. plt.ylim(0, 1.01)
  218. plt.grid(linestyle='--', linewidth=1)
  219. plt.plot([0, 1], [0, 1], 'r--', linewidth=1)
  220. my_x_ticks = np.arange(0, 1.01, 0.1)
  221. my_y_ticks = np.arange(0, 1.01, 0.1)
  222. plt.xticks(my_x_ticks, fontsize=5)
  223. plt.yticks(my_y_ticks, fontsize=5)
  224. for idx, catId in enumerate(catIds):
  225. pr_array = coco_eval.eval['precision'][0, :, idx, 0, 2]
  226. precision = pr_array[pr_array > -1]
  227. ap = np.mean(precision) if precision.size else float('nan')
  228. nm = coco_gt.loadCats(catId)[0]['name'] + ' AP={:0.2f}'.format(
  229. float(ap * 100))
  230. color = tuple(color_map[idx])
  231. color = [float(c) / 255 for c in color]
  232. color.append(0.75)
  233. plt.plot(x, pr_array, color=color, label=nm, linewidth=1)
  234. plt.legend(loc="lower left", fontsize=5)
  235. plt.subplot(1, 2, 2)
  236. plt.title(style + " score-recall IoU={}".format(iou_thresh))
  237. plt.xlabel('recall')
  238. plt.ylabel('score')
  239. plt.xlim(0, 1.01)
  240. plt.ylim(0, 1.01)
  241. plt.grid(linestyle='--', linewidth=1)
  242. plt.xticks(my_x_ticks, fontsize=5)
  243. plt.yticks(my_y_ticks, fontsize=5)
  244. for idx, catId in enumerate(catIds):
  245. nm = coco_gt.loadCats(catId)[0]['name']
  246. sr_array = coco_eval.eval['scores'][0, :, idx, 0, 2]
  247. color = tuple(color_map[idx])
  248. color = [float(c) / 255 for c in color]
  249. color.append(0.75)
  250. plt.plot(x, sr_array, color=color, label=nm, linewidth=1)
  251. plt.legend(loc="lower left", fontsize=5)
  252. plt.savefig(
  253. os.path.join(save_dir, "./{}_pr_curve(iou-{}).png".format(
  254. style, iou_thresh)),
  255. dpi=800)
  256. plt.close()
  257. if not os.path.exists(save_dir):
  258. os.makedirs(save_dir)
  259. cal_pr(coco, pred_bbox, iou_thresh, save_dir, style='bbox')
  260. if pred_mask is not None:
  261. cal_pr(coco, pred_mask, iou_thresh, save_dir, style='segm')