visualize.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import cv2
  17. from PIL import Image
  18. import numpy as np
  19. import math
  20. from .imgaug_support import execute_imgaug
  21. from .cls_transforms import ClsTransform
  22. from .det_transforms import DetTransform
  23. from .seg_transforms import SegTransform
  24. import paddlex as pdx
  25. from paddlex.cv.models.utils.visualize import get_color_map_list
  26. def _draw_rectangle_and_cname(img, xmin, ymin, xmax, ymax, cname, color):
  27. """ 根据提供的标注信息,给图片描绘框体和类别显示
  28. Args:
  29. img: 图片路径
  30. xmin: 检测框最小的x坐标
  31. ymin: 检测框最小的y坐标
  32. xmax: 检测框最大的x坐标
  33. ymax: 检测框最大的y坐标
  34. cname: 类别信息
  35. color: 类别与颜色的对应信息
  36. """
  37. # 描绘检测框
  38. line_width = math.ceil(2 * max(img.shape[0:2]) / 600)
  39. cv2.rectangle(
  40. img,
  41. pt1=(xmin, ymin),
  42. pt2=(xmax, ymax),
  43. color=color,
  44. thickness=line_width)
  45. return img
  46. def cls_compose(im, label=None, transforms=None, vdl_writer=None, step=0):
  47. """
  48. Args:
  49. im (str/np.ndarray): 图像路径/图像np.ndarray数据。
  50. label (int): 每张图像所对应的类别序号。
  51. vdl_writer (visualdl.LogWriter): VisualDL存储器,日志信息将保存在其中。
  52. 当为None时,不对日志进行保存。默认为None。
  53. step (int): 数据预处理的轮数,当vdl_writer不为None时有效。默认为0。
  54. Returns:
  55. tuple: 根据网络所需字段所组成的tuple;
  56. 字段由transforms中的最后一个数据预处理操作决定。
  57. """
  58. if isinstance(im, np.ndarray):
  59. if len(im.shape) != 3:
  60. raise Exception(
  61. "im should be 3-dimension, but now is {}-dimensions".
  62. format(len(im.shape)))
  63. else:
  64. try:
  65. im = cv2.imread(im).astype('float32')
  66. except:
  67. raise TypeError('Can\'t read The image file {}!'.format(im))
  68. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  69. if vdl_writer is not None:
  70. vdl_writer.add_image(tag='0. OriginalImange/' + str(step),
  71. img=im,
  72. step=0)
  73. op_id = 1
  74. for op in transforms:
  75. if isinstance(op, ClsTransform):
  76. if vdl_writer is not None and hasattr(op, 'prob'):
  77. op.prob = 1.0
  78. outputs = op(im, label)
  79. im = outputs[0]
  80. if len(outputs) == 2:
  81. label = outputs[1]
  82. else:
  83. import imgaug.augmenters as iaa
  84. if isinstance(op, iaa.Augmenter):
  85. im = execute_imgaug(op, im)
  86. outputs = (im, )
  87. if label is not None:
  88. outputs = (im, label)
  89. if vdl_writer is not None:
  90. tag = str(op_id) + '. ' + op.__class__.__name__ + '/' + str(step)
  91. vdl_writer.add_image(tag=tag,
  92. img=im,
  93. step=0)
  94. op_id += 1
  95. def det_compose(im, im_info=None, label_info=None, transforms=None, vdl_writer=None, step=0,
  96. labels=[], catid2color=None):
  97. def decode_image(im_file, im_info, label_info):
  98. if im_info is None:
  99. im_info = dict()
  100. if isinstance(im_file, np.ndarray):
  101. if len(im_file.shape) != 3:
  102. raise Exception(
  103. "im should be 3-dimensions, but now is {}-dimensions".
  104. format(len(im_file.shape)))
  105. im = im_file
  106. else:
  107. try:
  108. im = cv2.imread(im_file).astype('float32')
  109. except:
  110. raise TypeError('Can\'t read The image file {}!'.format(
  111. im_file))
  112. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  113. # make default im_info with [h, w, 1]
  114. im_info['im_resize_info'] = np.array(
  115. [im.shape[0], im.shape[1], 1.], dtype=np.float32)
  116. im_info['image_shape'] = np.array([im.shape[0],
  117. im.shape[1]]).astype('int32')
  118. use_mixup = False
  119. for t in transforms:
  120. if type(t).__name__ == 'MixupImage':
  121. use_mixup = True
  122. if not use_mixup:
  123. if 'mixup' in im_info:
  124. del im_info['mixup']
  125. # decode mixup image
  126. if 'mixup' in im_info:
  127. im_info['mixup'] = \
  128. decode_image(im_info['mixup'][0],
  129. im_info['mixup'][1],
  130. im_info['mixup'][2])
  131. if label_info is None:
  132. return (im, im_info)
  133. else:
  134. return (im, im_info, label_info)
  135. outputs = decode_image(im, im_info, label_info)
  136. im = outputs[0]
  137. im_info = outputs[1]
  138. if len(outputs) == 3:
  139. label_info = outputs[2]
  140. if vdl_writer is not None:
  141. vdl_writer.add_image(tag='0. OriginalImange/' + str(step),
  142. img=im,
  143. step=0)
  144. op_id = 1
  145. bboxes = label_info['gt_bbox']
  146. transforms = [None] + transforms
  147. for op in transforms:
  148. if im is None:
  149. return None
  150. if isinstance(op, DetTransform) or op is None:
  151. if vdl_writer is not None and hasattr(op, 'prob'):
  152. op.prob = 1.0
  153. if op is not None:
  154. outputs = op(im, im_info, label_info)
  155. else:
  156. outputs = (im, im_info, label_info)
  157. im = outputs[0]
  158. vdl_im = im
  159. if vdl_writer is not None:
  160. if isinstance(op, pdx.cv.transforms.det_transforms.ResizeByShort):
  161. scale = outputs[1]['im_resize_info'][2]
  162. bboxes = bboxes * scale
  163. elif isinstance(op, pdx.cv.transforms.det_transforms.Resize):
  164. h = outputs[1]['image_shape'][0]
  165. w = outputs[1]['image_shape'][1]
  166. target_size = op.target_size
  167. if isinstance(target_size, int):
  168. h_scale = float(target_size) / h
  169. w_scale = float(target_size) / w
  170. else:
  171. h_scale = float(target_size[0]) / h
  172. w_scale = float(target_size[1]) / w
  173. bboxes[:,0] = bboxes[:,0] * w_scale
  174. bboxes[:,1] = bboxes[:,1] * h_scale
  175. bboxes[:,2] = bboxes[:,2] * w_scale
  176. bboxes[:,3] = bboxes[:,3] * h_scale
  177. else:
  178. bboxes = outputs[2]['gt_bbox']
  179. if not isinstance(op, pdx.cv.transforms.det_transforms.RandomHorizontalFlip):
  180. for i in range(bboxes.shape[0]):
  181. bbox = bboxes[i]
  182. cname = labels[outputs[2]['gt_class'][i][0]-1]
  183. vdl_im = _draw_rectangle_and_cname(vdl_im,
  184. int(bbox[0]),
  185. int(bbox[1]),
  186. int(bbox[2]),
  187. int(bbox[3]),
  188. cname,
  189. catid2color[outputs[2]['gt_class'][i][0]-1])
  190. if isinstance(op, pdx.cv.transforms.det_transforms.Normalize):
  191. vdl_im = im
  192. else:
  193. im = execute_imgaug(op, im)
  194. if label_info is not None:
  195. outputs = (im, im_info, label_info)
  196. else:
  197. outputs = (im, im_info)
  198. vdl_im = im
  199. if vdl_writer is not None:
  200. tag = str(op_id) + '. ' + op.__class__.__name__ + '/' + str(step)
  201. if op is None:
  202. tag = str(op_id) + '. OriginalImangeWithGTBox/' + str(step)
  203. vdl_writer.add_image(tag=tag,
  204. img=vdl_im,
  205. step=0)
  206. op_id += 1
  207. def seg_compose(im, im_info=None, label=None, transforms=None, vdl_writer=None, step=0):
  208. if im_info is None:
  209. im_info = list()
  210. if isinstance(im, np.ndarray):
  211. if len(im.shape) != 3:
  212. raise Exception(
  213. "im should be 3-dimensions, but now is {}-dimensions".
  214. format(len(im.shape)))
  215. else:
  216. try:
  217. im = cv2.imread(im).astype('float32')
  218. except:
  219. raise ValueError('Can\'t read The image file {}!'.format(im))
  220. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  221. if label is not None:
  222. if not isinstance(label, np.ndarray):
  223. label = np.asarray(Image.open(label))
  224. if vdl_writer is not None:
  225. vdl_writer.add_image(tag='0. OriginalImange' + '/' + str(step),
  226. img=im,
  227. step=0)
  228. op_id = 1
  229. for op in transforms:
  230. if isinstance(op, SegTransform):
  231. outputs = op(im, im_info, label)
  232. im = outputs[0]
  233. if len(outputs) >= 2:
  234. im_info = outputs[1]
  235. if len(outputs) == 3:
  236. label = outputs[2]
  237. else:
  238. im = execute_imgaug(op, im)
  239. if label is not None:
  240. outputs = (im, im_info, label)
  241. else:
  242. outputs = (im, im_info)
  243. if vdl_writer is not None:
  244. tag = str(op_id) + '. ' + op.__class__.__name__ + '/' + str(step)
  245. vdl_writer.add_image(tag=tag,
  246. img=im,
  247. step=0)
  248. op_id += 1
  249. def visualize(dataset, img_count=3, save_dir='vdl_output'):
  250. '''对数据预处理/增强中间结果进行可视化。
  251. 可使用VisualDL查看中间结果:
  252. 1. VisualDL启动方式: visualdl --logdir vdl_output --port 8001
  253. 2. 浏览器打开 https://0.0.0.0:8001即可,
  254. 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
  255. Args:
  256. dataset (paddlex.datasets): 数据集读取器。
  257. img_count (int): 需要进行数据预处理/增强的图像数目。默认为3。
  258. save_dir (str): 日志保存的路径。默认为'vdl_output'。
  259. '''
  260. if dataset.num_samples < img_count:
  261. img_count = dataset.num_samples
  262. transforms = dataset.transforms
  263. if not osp.isdir(save_dir):
  264. if osp.exists(save_dir):
  265. os.remove(save_dir)
  266. os.makedirs(save_dir)
  267. from visualdl import LogWriter
  268. vdl_save_dir = osp.join(save_dir, 'image_transforms')
  269. vdl_writer = LogWriter(vdl_save_dir)
  270. for i, data in enumerate(dataset.iterator()):
  271. if i == img_count:
  272. break
  273. data.append(transforms.transforms)
  274. data.append(vdl_writer)
  275. data.append(i)
  276. if isinstance(transforms, ClsTransform):
  277. cls_compose(*data)
  278. elif isinstance(transforms, DetTransform):
  279. labels = dataset.labels
  280. color_map = get_color_map_list(len(labels) + 1)
  281. catid2color = {}
  282. for catid in range(len(labels)):
  283. catid2color[catid] = color_map[catid + 1]
  284. data.append(labels)
  285. data.append(catid2color)
  286. det_compose(*data)
  287. elif isinstance(transforms, SegTransform):
  288. seg_compose(*data)
  289. else:
  290. raise Exception('The transform must the subclass of \
  291. ClsTransform or DetTransform or SegTransform!')