visualize.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. import os
  15. import cv2
  16. import numpy as np
  17. from PIL import Image, ImageDraw
  18. import paddlex.utils.logging as logging
  19. from .detection_eval import fixed_linspace, backup_linspace, loadRes
  20. def visualize_detection(image, result, threshold=0.5, save_dir='./'):
  21. """
  22. Visualize bbox and mask results
  23. """
  24. image_name = os.path.split(image)[-1]
  25. image = Image.open(image).convert('RGB')
  26. image = draw_bbox_mask(image, result, threshold=threshold)
  27. if save_dir is not None:
  28. if not os.path.exists(save_dir):
  29. os.makedirs(save_dir)
  30. out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
  31. image.save(out_path, quality=95)
  32. logging.info('The visualized result is saved as {}'.format(out_path))
  33. else:
  34. return image
  35. def visualize_segmentation(image, result, weight=0.6, save_dir='./'):
  36. """
  37. Convert segment result to color image, and save added image.
  38. Args:
  39. image: the path of origin image
  40. result: the predict result of image
  41. weight: the image weight of visual image, and the result weight is (1 - weight)
  42. save_dir: the directory for saving visual image
  43. """
  44. label_map = result['label_map']
  45. color_map = get_color_map_list(256)
  46. color_map = np.array(color_map).astype("uint8")
  47. # Use OpenCV LUT for color mapping
  48. c1 = cv2.LUT(label_map, color_map[:, 0])
  49. c2 = cv2.LUT(label_map, color_map[:, 1])
  50. c3 = cv2.LUT(label_map, color_map[:, 2])
  51. pseudo_img = np.dstack((c1, c2, c3))
  52. im = cv2.imread(image)
  53. vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
  54. if save_dir is not None:
  55. if not os.path.exists(save_dir):
  56. os.makedirs(save_dir)
  57. image_name = os.path.split(image)[-1]
  58. out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
  59. cv2.imwrite(out_path, vis_result)
  60. logging.info('The visualized result is saved as {}'.format(out_path))
  61. else:
  62. return vis_result
  63. def get_color_map_list(num_classes):
  64. """ Returns the color map for visualizing the segmentation mask,
  65. which can support arbitrary number of classes.
  66. Args:
  67. num_classes: Number of classes
  68. Returns:
  69. The color map
  70. """
  71. color_map = num_classes * [0, 0, 0]
  72. for i in range(0, num_classes):
  73. j = 0
  74. lab = i
  75. while lab:
  76. color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
  77. color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
  78. color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
  79. j += 1
  80. lab >>= 3
  81. color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
  82. return color_map
  83. # expand an array of boxes by a given scale.
  84. def expand_boxes(boxes, scale):
  85. """
  86. """
  87. w_half = (boxes[:, 2] - boxes[:, 0]) * .5
  88. h_half = (boxes[:, 3] - boxes[:, 1]) * .5
  89. x_c = (boxes[:, 2] + boxes[:, 0]) * .5
  90. y_c = (boxes[:, 3] + boxes[:, 1]) * .5
  91. w_half *= scale
  92. h_half *= scale
  93. boxes_exp = np.zeros(boxes.shape)
  94. boxes_exp[:, 0] = x_c - w_half
  95. boxes_exp[:, 2] = x_c + w_half
  96. boxes_exp[:, 1] = y_c - h_half
  97. boxes_exp[:, 3] = y_c + h_half
  98. return boxes_exp
  99. def clip_bbox(bbox):
  100. xmin = max(min(bbox[0], 1.), 0.)
  101. ymin = max(min(bbox[1], 1.), 0.)
  102. xmax = max(min(bbox[2], 1.), 0.)
  103. ymax = max(min(bbox[3], 1.), 0.)
  104. return xmin, ymin, xmax, ymax
  105. def draw_bbox_mask(image, results, threshold=0.5, alpha=0.7):
  106. labels = list()
  107. for dt in np.array(results):
  108. if dt['category'] not in labels:
  109. labels.append(dt['category'])
  110. color_map = get_color_map_list(len(labels))
  111. for dt in np.array(results):
  112. cname, bbox, score = dt['category'], dt['bbox'], dt['score']
  113. if score < threshold:
  114. continue
  115. xmin, ymin, w, h = bbox
  116. xmax = xmin + w
  117. ymax = ymin + h
  118. color = tuple(color_map[labels.index(cname)])
  119. # draw bbox
  120. draw = ImageDraw.Draw(image)
  121. draw.line([(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
  122. (xmin, ymin)],
  123. width=2,
  124. fill=color)
  125. # draw label
  126. text = "{} {:.2f}".format(cname, score)
  127. tw, th = draw.textsize(text)
  128. draw.rectangle([(xmin + 1, ymin - th), (xmin + tw + 1, ymin)],
  129. fill=color)
  130. draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
  131. # draw mask
  132. if 'mask' in dt:
  133. mask = dt['mask']
  134. color_mask = np.array(color_map[labels.index(
  135. dt['category'])]).astype('float32')
  136. img_array = np.array(image).astype('float32')
  137. idx = np.nonzero(mask)
  138. img_array[idx[0], idx[1], :] *= 1.0 - alpha
  139. img_array[idx[0], idx[1], :] += alpha * color_mask
  140. image = Image.fromarray(img_array.astype('uint8'))
  141. return image
  142. def draw_pr_curve(eval_details_file=None,
  143. gt=None,
  144. pred_bbox=None,
  145. pred_mask=None,
  146. iou_thresh=0.5,
  147. save_dir='./'):
  148. if eval_details_file is not None:
  149. import json
  150. with open(eval_details_file, 'r') as f:
  151. eval_details = json.load(f)
  152. pred_bbox = eval_details['bbox']
  153. if 'mask' in eval_details:
  154. pred_mask = eval_details['mask']
  155. gt = eval_details['gt']
  156. if gt is None or pred_bbox is None:
  157. raise Exception(
  158. "gt/pred_bbox/pred_mask is None now, please set right eval_details_file or gt/pred_bbox/pred_mask."
  159. )
  160. if pred_bbox is not None and len(pred_bbox) == 0:
  161. raise Exception("There is no predicted bbox.")
  162. if pred_mask is not None and len(pred_mask) == 0:
  163. raise Exception("There is no predicted mask.")
  164. from pycocotools.coco import COCO
  165. from pycocotools.cocoeval import COCOeval
  166. coco = COCO()
  167. coco.dataset = gt
  168. coco.createIndex()
  169. def _summarize(coco_gt, ap=1, iouThr=None, areaRng='all', maxDets=100):
  170. p = coco_gt.params
  171. aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
  172. mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
  173. if ap == 1:
  174. # dimension of precision: [TxRxKxAxM]
  175. s = coco_gt.eval['precision']
  176. # IoU
  177. if iouThr is not None:
  178. t = np.where(iouThr == p.iouThrs)[0]
  179. s = s[t]
  180. s = s[:, :, :, aind, mind]
  181. else:
  182. # dimension of recall: [TxKxAxM]
  183. s = coco_gt.eval['recall']
  184. if iouThr is not None:
  185. t = np.where(iouThr == p.iouThrs)[0]
  186. s = s[t]
  187. s = s[:, :, aind, mind]
  188. if len(s[s > -1]) == 0:
  189. mean_s = -1
  190. else:
  191. mean_s = np.mean(s[s > -1])
  192. return mean_s
  193. def cal_pr(coco_gt, coco_dt, iou_thresh, save_dir, style='bbox'):
  194. import matplotlib.pyplot as plt
  195. from pycocotools.cocoeval import COCOeval
  196. coco_dt = loadRes(coco_gt, coco_dt)
  197. np.linspace = fixed_linspace
  198. coco_eval = COCOeval(coco_gt, coco_dt, style)
  199. coco_eval.params.iouThrs = np.linspace(
  200. iou_thresh, iou_thresh, 1, endpoint=True)
  201. np.linspace = backup_linspace
  202. coco_eval.evaluate()
  203. coco_eval.accumulate()
  204. stats = _summarize(coco_eval, iouThr=iou_thresh)
  205. catIds = coco_gt.getCatIds()
  206. if len(catIds) != coco_eval.eval['precision'].shape[2]:
  207. raise Exception(
  208. "The category number must be same as the third dimension of precisions."
  209. )
  210. x = np.arange(0.0, 1.01, 0.01)
  211. color_map = get_color_map_list(256)[1:256]
  212. plt.subplot(1, 2, 1)
  213. plt.title(style + " precision-recall IoU={}".format(iou_thresh))
  214. plt.xlabel("recall")
  215. plt.ylabel("precision")
  216. plt.xlim(0, 1.01)
  217. plt.ylim(0, 1.01)
  218. plt.grid(linestyle='--', linewidth=1)
  219. plt.plot([0, 1], [0, 1], 'r--', linewidth=1)
  220. my_x_ticks = np.arange(0, 1.01, 0.1)
  221. my_y_ticks = np.arange(0, 1.01, 0.1)
  222. plt.xticks(my_x_ticks, fontsize=5)
  223. plt.yticks(my_y_ticks, fontsize=5)
  224. for idx, catId in enumerate(catIds):
  225. pr_array = coco_eval.eval['precision'][0, :, idx, 0, 2]
  226. precision = pr_array[pr_array > -1]
  227. ap = np.mean(precision) if precision.size else float('nan')
  228. nm = coco_gt.loadCats(catId)[0]['name'] + ' AP={:0.2f}'.format(
  229. float(ap * 100))
  230. color = tuple(color_map[idx])
  231. color = [float(c) / 255 for c in color]
  232. color.append(0.75)
  233. plt.plot(x, pr_array, color=color, label=nm, linewidth=1)
  234. plt.legend(loc="lower left", fontsize=5)
  235. plt.subplot(1, 2, 2)
  236. plt.title(style + " score-recall IoU={}".format(iou_thresh))
  237. plt.xlabel('recall')
  238. plt.ylabel('score')
  239. plt.xlim(0, 1.01)
  240. plt.ylim(0, 1.01)
  241. plt.grid(linestyle='--', linewidth=1)
  242. plt.xticks(my_x_ticks, fontsize=5)
  243. plt.yticks(my_y_ticks, fontsize=5)
  244. for idx, catId in enumerate(catIds):
  245. nm = coco_gt.loadCats(catId)[0]['name']
  246. sr_array = coco_eval.eval['scores'][0, :, idx, 0, 2]
  247. color = tuple(color_map[idx])
  248. color = [float(c) / 255 for c in color]
  249. color.append(0.75)
  250. plt.plot(x, sr_array, color=color, label=nm, linewidth=1)
  251. plt.legend(loc="lower left", fontsize=5)
  252. plt.savefig(
  253. os.path.join(save_dir, "./{}_pr_curve(iou-{}).png".format(
  254. style, iou_thresh)),
  255. dpi=800)
  256. plt.close()
  257. if not os.path.exists(save_dir):
  258. os.makedirs(save_dir)
  259. cal_pr(coco, pred_bbox, iou_thresh, save_dir, style='bbox')
  260. if pred_mask is not None:
  261. cal_pr(coco, pred_mask, iou_thresh, save_dir, style='segm')