visualize.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # copytrue (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #! /usr/bin/env python
  15. # -*- coding: utf-8 -*-
  16. import numpy as np
  17. import cv2
  18. import math
  19. import xml.etree.ElementTree as ET
  20. from PIL import Image
  21. def resize_img(img):
  22. """ 调整图片尺寸
  23. Args:
  24. img: 图片信息
  25. """
  26. h, w = img.shape[:2]
  27. min_size = 580
  28. if w >= h and w > min_size:
  29. new_w = min_size
  30. new_h = new_w * h / w
  31. elif h >= w and h > min_size:
  32. new_h = min_size
  33. new_w = new_h * w / h
  34. else:
  35. new_h = h
  36. new_w = w
  37. new_img = cv2.resize(
  38. img, (int(new_w), int(new_h)), interpolation=cv2.INTER_CUBIC)
  39. scale_value = new_w / w
  40. return new_img, scale_value
  41. def plot_det_label(image, anno, labels):
  42. """ 目标检测类型生成标注图
  43. Args:
  44. image: 图片路径
  45. anno: 图片标注
  46. labels: 图片所属数据集的类别信息
  47. """
  48. catid2color = {}
  49. img = cv2.imread(image)
  50. img, scale_value = resize_img(img)
  51. tree = ET.parse(anno)
  52. objs = tree.findall('object')
  53. color_map = get_color_map_list(len(labels) + 1)
  54. for i, obj in enumerate(objs):
  55. cname = obj.find('name').text
  56. catid = labels.index(cname)
  57. if cname not in labels:
  58. continue
  59. xmin = int(float(obj.find('bndbox').find('xmin').text) * scale_value)
  60. ymin = int(float(obj.find('bndbox').find('ymin').text) * scale_value)
  61. xmax = int(float(obj.find('bndbox').find('xmax').text) * scale_value)
  62. ymax = int(float(obj.find('bndbox').find('ymax').text) * scale_value)
  63. if catid not in catid2color:
  64. catid2color[catid] = color_map[catid + 1]
  65. color = tuple(catid2color[catid])
  66. img = draw_rectangle_and_cname(img, xmin, ymin, xmax, ymax, cname,
  67. color)
  68. return img
  69. def plot_seg_label(anno):
  70. """ 语义分割类型生成标注图
  71. Args:
  72. anno: 图片标注
  73. """
  74. label = pil_imread(anno)
  75. pse_label = gray2pseudo(label)
  76. return pse_label
  77. def plot_insseg_label(image, anno, labels, alpha=0.7):
  78. """ 实例分割类型生成标注图
  79. Args:
  80. image: 图片路径
  81. anno: 图片标注
  82. labels: 图片所属数据集的类别信息
  83. """
  84. anno = np.load(anno, allow_pickle=True).tolist()
  85. catid2color = dict()
  86. img = cv2.imread(image)
  87. img, scale_value = resize_img(img)
  88. color_map = get_color_map_list(len(labels) + 1)
  89. img_h = anno['h']
  90. img_w = anno['w']
  91. gt_class = anno['gt_class']
  92. gt_bbox = anno['gt_bbox']
  93. gt_poly = anno['gt_poly']
  94. num_bbox = gt_bbox.shape[0]
  95. num_mask = len(gt_poly)
  96. # 描绘mask信息
  97. img_array = np.array(img).astype('float32')
  98. for i in range(num_mask):
  99. cname = gt_class[i]
  100. catid = labels.index(cname)
  101. if cname not in labels:
  102. continue
  103. if catid not in catid2color:
  104. catid2color[catid] = color_map[catid + 1]
  105. color = np.array(catid2color[catid]).astype('float32')
  106. import pycocotools.mask as mask_util
  107. for x in range(len(gt_poly[i])):
  108. for y in range(len(gt_poly[i][x])):
  109. gt_poly[i][x][y] = int(float(gt_poly[i][x][y]) * scale_value)
  110. poly = gt_poly[i]
  111. rles = mask_util.frPyObjects(poly,
  112. int(float(img_h) * scale_value),
  113. int(float(img_w) * scale_value))
  114. rle = mask_util.merge(rles)
  115. mask = mask_util.decode(rle) * 255
  116. idx = np.nonzero(mask)
  117. img_array[idx[0], idx[1], :] *= 1.0 - alpha
  118. img_array[idx[0], idx[1], :] += alpha * color
  119. img = img_array.astype('uint8')
  120. for i in range(num_bbox):
  121. cname = gt_class[i]
  122. catid = labels.index(cname)
  123. if cname not in labels:
  124. continue
  125. if catid not in catid2color:
  126. catid2color[catid] = color_map[catid]
  127. color = tuple(catid2color[catid])
  128. xmin, ymin, xmax, ymax = gt_bbox[i]
  129. img = draw_rectangle_and_cname(img,
  130. int(float(xmin) * scale_value),
  131. int(float(ymin) * scale_value),
  132. int(float(xmax) * scale_value),
  133. int(float(ymax) * scale_value), cname,
  134. color)
  135. return img
  136. def draw_rectangle_and_cname(img, xmin, ymin, xmax, ymax, cname, color):
  137. """ 根据提供的标注信息,给图片描绘框体和类别显示
  138. Args:
  139. img: 图片路径
  140. xmin: 检测框最小的x坐标
  141. ymin: 检测框最小的y坐标
  142. xmax: 检测框最大的x坐标
  143. ymax: 检测框最大的y坐标
  144. cname: 类别信息
  145. color: 类别与颜色的对应信息
  146. """
  147. # 描绘检测框
  148. line_width = math.ceil(2 * max(img.shape[0:2]) / 600)
  149. cv2.rectangle(
  150. img,
  151. pt1=(xmin, ymin),
  152. pt2=(xmax, ymax),
  153. color=color,
  154. thickness=line_width)
  155. # 计算并描绘类别信息
  156. text_thickness = math.ceil(2 * max(img.shape[0:2]) / 1200)
  157. fontscale = math.ceil(0.5 * max(img.shape[0:2]) / 600)
  158. tw, th = cv2.getTextSize(
  159. cname, 0, fontScale=fontscale, thickness=text_thickness)[0]
  160. cv2.rectangle(
  161. img,
  162. pt1=(xmin + 1, ymin - th),
  163. pt2=(xmin + int(0.7 * tw) + 1, ymin),
  164. color=color,
  165. thickness=-1)
  166. cv2.putText(
  167. img,
  168. cname, (int(xmin) + 3, int(ymin) - 5),
  169. 0,
  170. 0.6 * fontscale, (255, 255, 255),
  171. lineType=cv2.LINE_AA,
  172. thickness=text_thickness)
  173. return img
  174. def pil_imread(file_path):
  175. """ 将图片读成np格式数据
  176. Args:
  177. file_path: 图片路径
  178. """
  179. img = Image.open(file_path)
  180. return np.asarray(img)
  181. def get_color_map_list(num_classes):
  182. """ 为类别信息生成对应的颜色列表
  183. Args:
  184. num_classes: 类别数量
  185. """
  186. color_map = num_classes * [0, 0, 0]
  187. for i in range(0, num_classes):
  188. j = 0
  189. lab = i
  190. while lab:
  191. color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
  192. color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
  193. color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
  194. j += 1
  195. lab >>= 3
  196. color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
  197. return color_map
  198. def gray2pseudo(gray_image):
  199. """ 将分割的结果映射到图片
  200. Args:
  201. gray_image: 灰度图
  202. """
  203. color_map = get_color_map_list(256)
  204. color_map = np.array(color_map).astype("uint8")
  205. # 用OpenCV进行色彩映射
  206. c1 = cv2.LUT(gray_image, color_map[:, 0])
  207. c2 = cv2.LUT(gray_image, color_map[:, 1])
  208. c3 = cv2.LUT(gray_image, color_map[:, 2])
  209. pseudo_img = np.dstack((c1, c2, c3))
  210. return pseudo_img