visualize.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import cv2
  17. from PIL import Image
  18. import numpy as np
  19. from .imgaug_support import execute_imgaug
  20. from .cls_transforms import ClsTransform
  21. from .det_transforms import DetTransform
  22. from .seg_transforms import SegTransform
  23. import paddlex
  24. from paddlex.cv.models.utils.visualize import get_color_map_list
  25. def _draw_rectangle_and_cname(img, xmin, ymin, xmax, ymax, cname, color):
  26. """ 根据提供的标注信息,给图片描绘框体和类别显示
  27. Args:
  28. img: 图片路径
  29. xmin: 检测框最小的x坐标
  30. ymin: 检测框最小的y坐标
  31. xmax: 检测框最大的x坐标
  32. ymax: 检测框最大的y坐标
  33. cname: 类别信息
  34. color: 类别与颜色的对应信息
  35. """
  36. # 描绘检测框
  37. line_width = math.ceil(2 * max(img.shape[0:2]) / 600)
  38. cv2.rectangle(
  39. img,
  40. pt1=(xmin, ymin),
  41. pt2=(xmax, ymax),
  42. color=color,
  43. thickness=line_width)
  44. return img
  45. def cls_compose(im, label=None, transforms=None, vdl_writer=None, step=0):
  46. """
  47. Args:
  48. im (str/np.ndarray): 图像路径/图像np.ndarray数据。
  49. label (int): 每张图像所对应的类别序号。
  50. vdl_writer (visualdl.LogWriter): VisualDL存储器,日志信息将保存在其中。
  51. 当为None时,不对日志进行保存。默认为None。
  52. step (int): 数据预处理的轮数,当vdl_writer不为None时有效。默认为0。
  53. Returns:
  54. tuple: 根据网络所需字段所组成的tuple;
  55. 字段由transforms中的最后一个数据预处理操作决定。
  56. """
  57. if isinstance(im, np.ndarray):
  58. if len(im.shape) != 3:
  59. raise Exception(
  60. "im should be 3-dimension, but now is {}-dimensions".
  61. format(len(im.shape)))
  62. else:
  63. try:
  64. im = cv2.imread(im).astype('float32')
  65. except:
  66. raise TypeError('Can\'t read The image file {}!'.format(im))
  67. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  68. if vdl_writer is not None:
  69. vdl_writer.add_image(tag='0. OriginalImange/' + str(step),
  70. img=im,
  71. step=0)
  72. op_id = 1
  73. for op in transforms:
  74. if isinstance(op, ClsTransform):
  75. if vdl_writer is not None and hasattr(op, 'prob'):
  76. op.prob = 1.0
  77. outputs = op(im, label)
  78. im = outputs[0]
  79. if len(outputs) == 2:
  80. label = outputs[1]
  81. else:
  82. import imgaug.augmenters as iaa
  83. if isinstance(op, iaa.Augmenter):
  84. im = execute_imgaug(op, im)
  85. outputs = (im, )
  86. if label is not None:
  87. outputs = (im, label)
  88. if vdl_writer is not None:
  89. tag = str(op_id) + '. ' + op.__class__.__name__ + '/' + str(step)
  90. vdl_writer.add_image(tag=tag,
  91. img=im,
  92. step=0)
  93. op_id += 1
  94. def det_compose(im, im_info=None, label_info=None, transforms=None, vdl_writer=None, step=0,
  95. labels=[], catid2color=None):
  96. def decode_image(im_file, im_info, label_info):
  97. if im_info is None:
  98. im_info = dict()
  99. if isinstance(im_file, np.ndarray):
  100. if len(im_file.shape) != 3:
  101. raise Exception(
  102. "im should be 3-dimensions, but now is {}-dimensions".
  103. format(len(im_file.shape)))
  104. im = im_file
  105. else:
  106. try:
  107. im = cv2.imread(im_file).astype('float32')
  108. except:
  109. raise TypeError('Can\'t read The image file {}!'.format(
  110. im_file))
  111. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  112. # make default im_info with [h, w, 1]
  113. im_info['im_resize_info'] = np.array(
  114. [im.shape[0], im.shape[1], 1.], dtype=np.float32)
  115. im_info['image_shape'] = np.array([im.shape[0],
  116. im.shape[1]]).astype('int32')
  117. use_mixup = False
  118. for t in transforms:
  119. if type(t).__name__ == 'MixupImage':
  120. use_mixup = True
  121. if not use_mixup:
  122. if 'mixup' in im_info:
  123. del im_info['mixup']
  124. # decode mixup image
  125. if 'mixup' in im_info:
  126. im_info['mixup'] = \
  127. decode_image(im_info['mixup'][0],
  128. im_info['mixup'][1],
  129. im_info['mixup'][2])
  130. if label_info is None:
  131. return (im, im_info)
  132. else:
  133. return (im, im_info, label_info)
  134. outputs = decode_image(im, im_info, label_info)
  135. im = outputs[0]
  136. im_info = outputs[1]
  137. if len(outputs) == 3:
  138. label_info = outputs[2]
  139. if vdl_writer is not None:
  140. vdl_writer.add_image(tag='0. OriginalImange/' + str(step),
  141. img=im,
  142. step=0)
  143. op_id = 1
  144. bboxes = label_info['gt_bbox']
  145. transforms = [None] + transforms
  146. for op in transforms:
  147. if im is None:
  148. return None
  149. if isinstance(op, DetTransform) or op is None:
  150. if vdl_writer is not None and hasattr(op, 'prob'):
  151. op.prob = 1.0
  152. if op is not None:
  153. outputs = op(im, im_info, label_info)
  154. else:
  155. outputs = (im, im_info, label_info)
  156. im = outputs[0]
  157. vdl_im = im
  158. if vdl_writer is not None:
  159. if isinstance(op, paddlex.cv.transforms.det_transforms.ResizeByShort):
  160. scale = outputs[1]['im_resize_info'][2]
  161. bboxes = bboxes * scale
  162. elif isinstance(op, paddlex.cv.transforms.det_transforms.Resize):
  163. h = outputs[1]['image_shape'][0]
  164. w = outputs[1]['image_shape'][1]
  165. target_size = op.target_size
  166. if isinstance(target_size, int):
  167. h_scale = float(target_size) / h
  168. w_scale = float(target_size) / w
  169. else:
  170. h_scale = float(target_size[0]) / h
  171. w_scale = float(target_size[1]) / w
  172. bboxes[:,0] = bboxes[:,0] * w_scale
  173. bboxes[:,1] = bboxes[:,1] * h_scale
  174. bboxes[:,2] = bboxes[:,2] * w_scale
  175. bboxes[:,3] = bboxes[:,3] * h_scale
  176. else:
  177. bboxes = outputs[2]['gt_bbox']
  178. if not isinstance(op, paddlex.cv.transforms.det_transforms.RandomHorizontalFlip):
  179. for i in range(bboxes.shape[0]):
  180. bbox = bboxes[i]
  181. cname = labels[outputs[2]['gt_class'][i][0]-1]
  182. vdl_im = _draw_rectangle_and_cname(vdl_im,
  183. int(bbox[0]),
  184. int(bbox[1]),
  185. int(bbox[2]),
  186. int(bbox[3]),
  187. cname,
  188. catid2color[outputs[2]['gt_class'][i][0]-1])
  189. if isinstance(op, paddlex.cv.transforms.det_transforms.Normalize):
  190. vdl_im = im
  191. else:
  192. im = execute_imgaug(op, im)
  193. if label_info is not None:
  194. outputs = (im, im_info, label_info)
  195. else:
  196. outputs = (im, im_info)
  197. vdl_im = im
  198. if vdl_writer is not None:
  199. tag = str(op_id) + '. ' + op.__class__.__name__ + '/' + str(step)
  200. if op is None:
  201. tag = str(op_id) + '. OriginalImangeWithBbox/' + str(step)
  202. vdl_writer.add_image(tag=tag,
  203. img=vdl_im,
  204. step=0)
  205. op_id += 1
  206. def seg_compose(im, im_info=None, label=None, transforms=None, vdl_writer=None, step=0):
  207. if im_info is None:
  208. im_info = list()
  209. if isinstance(im, np.ndarray):
  210. if len(im.shape) != 3:
  211. raise Exception(
  212. "im should be 3-dimensions, but now is {}-dimensions".
  213. format(len(im.shape)))
  214. else:
  215. try:
  216. im = cv2.imread(im).astype('float32')
  217. except:
  218. raise ValueError('Can\'t read The image file {}!'.format(im))
  219. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  220. if label is not None:
  221. if not isinstance(label, np.ndarray):
  222. label = np.asarray(Image.open(label))
  223. if vdl_writer is not None:
  224. vdl_writer.add_image(tag='0. OriginalImange' + '/' + str(step),
  225. img=im,
  226. step=0)
  227. op_id = 1
  228. for op in transforms:
  229. if isinstance(op, SegTransform):
  230. outputs = op(im, im_info, label)
  231. im = outputs[0]
  232. if len(outputs) >= 2:
  233. im_info = outputs[1]
  234. if len(outputs) == 3:
  235. label = outputs[2]
  236. else:
  237. im = execute_imgaug(op, im)
  238. if label is not None:
  239. outputs = (im, im_info, label)
  240. else:
  241. outputs = (im, im_info)
  242. if vdl_writer is not None:
  243. tag = str(op_id) + '. ' + op.__class__.__name__ + '/' + str(step)
  244. vdl_writer.add_image(tag=tag,
  245. img=im,
  246. step=0)
  247. op_id += 1
  248. def visualize(dataset, img_count=3, save_dir='vdl_output'):
  249. '''对数据预处理/增强中间结果进行可视化。
  250. 可使用VisualDL查看中间结果:
  251. 1. VisualDL启动方式: visualdl --logdir vdl_output --port 8001
  252. 2. 浏览器打开 https://0.0.0.0:8001即可,
  253. 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
  254. Args:
  255. dataset (paddlex.datasets): 数据集读取器。
  256. img_count (int): 需要进行数据预处理/增强的图像数目。默认为3。
  257. save_dir (str): 日志保存的路径。默认为'vdl_output'。
  258. '''
  259. if dataset.num_samples < img_count:
  260. img_count = dataset.num_samples
  261. transforms = dataset.transforms
  262. if not osp.isdir(save_dir):
  263. if osp.exists(save_dir):
  264. os.remove(save_dir)
  265. os.makedirs(save_dir)
  266. from visualdl import LogWriter
  267. vdl_save_dir = osp.join(save_dir, 'image_transforms')
  268. vdl_writer = LogWriter(vdl_save_dir)
  269. for i, data in enumerate(dataset.iterator()):
  270. if i == img_count:
  271. break
  272. data.append(transforms.transforms)
  273. data.append(vdl_writer)
  274. data.append(i)
  275. if isinstance(transforms, ClsTransform):
  276. cls_compose(*data)
  277. elif isinstance(transforms, DetTransform):
  278. labels = dataset.labels
  279. color_map = get_color_map_list(len(labels) + 1)
  280. catid2color = {}
  281. for catid in range(len(labels)):
  282. catid2color[catid] = color_map[catid + 1]
  283. data.append(labels)
  284. data.append(catid2color)
  285. det_compose(*data)
  286. elif isinstance(transforms, SegTransform):
  287. seg_compose(*data)
  288. else:
  289. raise Exception('The transform must the subclass of \
  290. ClsTransform or DetTransform or SegTransform!')