analysis.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import numpy as np
  16. import os.path as osp
  17. import cv2
  18. from PIL import Image
  19. import pickle
  20. import threading
  21. import multiprocessing as mp
  22. import paddlex.utils.logging as logging
  23. from paddlex.utils import path_normalization
  24. from paddlex.cv.transforms.seg_transforms import Compose
  25. from .dataset import get_encoding
  26. class Seg:
  27. def __init__(self, data_dir, file_list, label_list):
  28. self.data_dir = data_dir
  29. self.file_list_path = file_list
  30. self.file_list = list()
  31. self.labels = list()
  32. with open(label_list, encoding=get_encoding(label_list)) as f:
  33. for line in f:
  34. item = line.strip()
  35. self.labels.append(item)
  36. with open(file_list, encoding=get_encoding(file_list)) as f:
  37. for line in f:
  38. items = line.strip().split()
  39. if len(items) > 2:
  40. raise Exception(
  41. "A space is defined as the separator, but it exists in image or label name {}."
  42. .format(line))
  43. items[0] = path_normalization(items[0])
  44. items[1] = path_normalization(items[1])
  45. full_path_im = osp.join(data_dir, items[0])
  46. full_path_label = osp.join(data_dir, items[1])
  47. if not osp.exists(full_path_im):
  48. raise IOError('The image file {} is not exist!'.format(
  49. full_path_im))
  50. if not osp.exists(full_path_label):
  51. raise IOError('The image file {} is not exist!'.format(
  52. full_path_label))
  53. self.file_list.append([full_path_im, full_path_label])
  54. self.num_samples = len(self.file_list)
  55. def _get_shape(self):
  56. max_height = max(self.im_height_list)
  57. max_width = max(self.im_width_list)
  58. min_height = min(self.im_height_list)
  59. min_width = min(self.im_width_list)
  60. shape_info = {
  61. 'max_height': max_height,
  62. 'max_width': max_width,
  63. 'min_height': min_height,
  64. 'min_width': min_width,
  65. }
  66. return shape_info
  67. def _get_label_pixel_info(self):
  68. pixel_num = np.dot(self.im_height_list, self.im_width_list)
  69. label_pixel_info = dict()
  70. for label_value, label_value_num in zip(self.label_value_list,
  71. self.label_value_num_list):
  72. for v, n in zip(label_value, label_value_num):
  73. if v not in label_pixel_info.keys():
  74. label_pixel_info[v] = [n, float(n) / float(pixel_num)]
  75. else:
  76. label_pixel_info[v][0] += n
  77. label_pixel_info[v][1] += float(n) / float(pixel_num)
  78. return label_pixel_info
  79. def _get_image_pixel_info(self):
  80. channel = max([len(im_value) for im_value in self.im_value_list])
  81. im_pixel_info = [dict() for c in range(channel)]
  82. for im_value, im_value_num in zip(self.im_value_list,
  83. self.im_value_num_list):
  84. for c in range(channel):
  85. for v, n in zip(im_value[c], im_value_num[c]):
  86. if v not in im_pixel_info[c].keys():
  87. im_pixel_info[c][v] = n
  88. else:
  89. im_pixel_info[c][v] += n
  90. return im_pixel_info
  91. def _get_mean_std(self):
  92. im_mean = np.asarray(self.im_mean_list)
  93. im_mean = im_mean.sum(axis=0)
  94. im_mean = im_mean / len(self.file_list)
  95. im_mean /= self.max_im_value - self.min_im_value
  96. im_std = np.asarray(self.im_std_list)
  97. im_std = im_std.sum(axis=0)
  98. im_std = im_std / len(self.file_list)
  99. im_std /= self.max_im_value - self.min_im_value
  100. return (im_mean, im_std)
  101. def _get_image_info(self, start, end):
  102. for id in range(start, end):
  103. full_path_im, full_path_label = self.file_list[id]
  104. image, label = Compose.decode_image(full_path_im, full_path_label)
  105. height, width, channel = image.shape
  106. self.im_height_list[id] = height
  107. self.im_width_list[id] = width
  108. self.im_channel_list[id] = channel
  109. self.im_mean_list[
  110. id] = [image[:, :, c].mean() for c in range(channel)]
  111. self.im_std_list[
  112. id] = [image[:, :, c].std() for c in range(channel)]
  113. for c in range(channel):
  114. unique, counts = np.unique(image[:, :, c], return_counts=True)
  115. self.im_value_list[id].extend([unique])
  116. self.im_value_num_list[id].extend([counts])
  117. unique, counts = np.unique(label, return_counts=True)
  118. self.label_value_list[id] = unique
  119. self.label_value_num_list[id] = counts
  120. def _get_clipped_mean_std(self, start, end, clip_min_value,
  121. clip_max_value):
  122. for id in range(start, end):
  123. full_path_im, full_path_label = self.file_list[id]
  124. image, label = Compose.decode_image(full_path_im, full_path_label)
  125. for c in range(self.channel_num):
  126. np.clip(
  127. image[:, :, c],
  128. clip_min_value[c],
  129. clip_max_value[c],
  130. out=image[:, :, c])
  131. image[:, :, c] -= clip_min_value[c]
  132. image[:, :, c] /= clip_max_value[c] - clip_min_value[c]
  133. self.clipped_im_mean_list[id] = [
  134. image[:, :, c].mean() for c in range(self.channel_num)
  135. ]
  136. self.clipped_im_std_list[
  137. id] = [image[:, :, c].std() for c in range(self.channel_num)]
  138. def analysis(self):
  139. self.im_mean_list = [[] for i in range(len(self.file_list))]
  140. self.im_std_list = [[] for i in range(len(self.file_list))]
  141. self.im_value_list = [[] for i in range(len(self.file_list))]
  142. self.im_value_num_list = [[] for i in range(len(self.file_list))]
  143. self.im_height_list = np.zeros(len(self.file_list), dtype='int32')
  144. self.im_width_list = np.zeros(len(self.file_list), dtype='int32')
  145. self.im_channel_list = np.zeros(len(self.file_list), dtype='int32')
  146. self.label_value_list = [[] for i in range(len(self.file_list))]
  147. self.label_value_num_list = [[] for i in range(len(self.file_list))]
  148. num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8
  149. threads = []
  150. one_worker_file = len(self.file_list) // num_workers
  151. for i in range(num_workers):
  152. start = one_worker_file * i
  153. end = one_worker_file * (
  154. i + 1) if i < num_workers - 1 else len(self.file_list)
  155. t = threading.Thread(
  156. target=self._get_image_info, args=(start, end))
  157. threads.append(t)
  158. for t in threads:
  159. t.start()
  160. for t in threads:
  161. t.join()
  162. unique, counts = np.unique(self.im_channel_list, return_counts=True)
  163. if len(unique) > 1:
  164. raise Exception("There are {} kinds of image channels: {}.".format(
  165. len(unique), unique[:]))
  166. self.channel_num = unique[0]
  167. shape_info = self._get_shape()
  168. self.max_height = shape_info['max_height']
  169. self.max_width = shape_info['max_width']
  170. self.min_height = shape_info['min_height']
  171. self.min_width = shape_info['min_width']
  172. self.label_pixel_info = self._get_label_pixel_info()
  173. self.im_pixel_info = self._get_image_pixel_info()
  174. mode = osp.split(self.file_list_path)[-1].split('.')[0]
  175. import matplotlib.pyplot as plt
  176. for c in range(self.channel_num):
  177. plt.figure()
  178. plt.bar(self.im_pixel_info[c].keys(),
  179. self.im_pixel_info[c].values(),
  180. width=1,
  181. log=True)
  182. plt.xlabel('image pixel value')
  183. plt.ylabel('number')
  184. plt.title('channel={}'.format(c))
  185. plt.savefig(
  186. osp.join(self.data_dir,
  187. '{}_channel{}_distribute.png'.format(mode, c)),
  188. dpi=100)
  189. plt.close()
  190. max_im_value = list()
  191. min_im_value = list()
  192. for c in range(self.channel_num):
  193. max_im_value.append(max(self.im_pixel_info[c].keys()))
  194. min_im_value.append(min(self.im_pixel_info[c].keys()))
  195. self.max_im_value = np.asarray(max_im_value)
  196. self.min_im_value = np.asarray(min_im_value)
  197. im_mean, im_std = self._get_mean_std()
  198. info = {
  199. 'channel_num': self.channel_num,
  200. 'image_pixel': self.im_pixel_info,
  201. 'label_pixel': self.label_pixel_info,
  202. 'file_num': len(self.file_list),
  203. 'max_height': self.max_height,
  204. 'max_width': self.max_width,
  205. 'min_height': self.min_height,
  206. 'min_width': self.min_width,
  207. 'max_image_value': self.max_im_value,
  208. 'min_image_value': self.min_im_value
  209. }
  210. saved_pkl_file = osp.join(self.data_dir,
  211. '{}_infomation.pkl'.format(mode))
  212. with open(osp.join(saved_pkl_file), 'wb') as f:
  213. pickle.dump(info, f)
  214. logging.info(
  215. "############## The analysis results are as follows ##############\n"
  216. )
  217. logging.info("{} samples in file {}\n".format(
  218. len(self.file_list), self.file_list_path))
  219. logging.info("Minimal image height: {} Minimal image width: {}.\n".
  220. format(self.min_height, self.min_width))
  221. logging.info("Maximal image height: {} Maximal image width: {}.\n".
  222. format(self.max_height, self.max_width))
  223. logging.info("Image channel is {}.\n".format(self.channel_num))
  224. logging.info(
  225. "Minimal image value: {} Maximal image value: {} (arranged in 0-{} channel order) \n".
  226. format(self.min_im_value, self.max_im_value, self.channel_num))
  227. logging.info(
  228. "Image pixel distribution of each channel is saved with 'distribute.png' in the {}"
  229. .format(self.data_dir))
  230. logging.info(
  231. "Image mean value: {} Image standard deviation: {} (normalized by the (max_im_value - min_im_value), arranged in 0-{} channel order).\n".
  232. format(im_mean, im_std, self.channel_num))
  233. logging.info(
  234. "Label pixel information is shown in a format of (label_id, the number of label_id, the ratio of label_id):"
  235. )
  236. for v, (n, r) in self.label_pixel_info.items():
  237. logging.info("({}, {}, {})".format(v, n, r))
  238. logging.info("Dataset information is saved in {}".format(
  239. saved_pkl_file))
  240. def cal_clipped_mean_std(self, clip_min_value, clip_max_value,
  241. data_info_file):
  242. if not osp.exists(data_info_file):
  243. raise Exception("Dataset information file {} does not exist.".
  244. format(data_info_file))
  245. with open(data_info_file, 'rb') as f:
  246. im_info = pickle.load(f)
  247. channel_num = im_info['channel_num']
  248. min_im_value = im_info['min_image_value']
  249. max_im_value = im_info['max_image_value']
  250. im_pixel_info = im_info['image_pixel']
  251. if len(clip_min_value) != channel_num or len(
  252. clip_max_value) != channel_num:
  253. raise Exception(
  254. "The length of clip_min_value or clip_max_value should be equal to the number of image channel {}."
  255. .format(channle_num))
  256. for c in range(channel_num):
  257. if clip_min_value[c] < min_im_value[c] or clip_min_value[
  258. c] > max_im_value[c]:
  259. raise Exception(
  260. "Clip_min_value of the channel {} is not in [{}, {}]".
  261. format(c, min_im_value[c], max_im_value[c]))
  262. if clip_max_value[c] < min_im_value[c] or clip_max_value[
  263. c] > max_im_value[c]:
  264. raise Exception(
  265. "Clip_max_value of the channel {} is not in [{}, {}]".
  266. format(c, min_im_value[c], self.max_im_value[c]))
  267. self.clipped_im_mean_list = [[] for i in range(len(self.file_list))]
  268. self.clipped_im_std_list = [[] for i in range(len(self.file_list))]
  269. num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8
  270. threads = []
  271. one_worker_file = len(self.file_list) // num_workers
  272. self.channel_num = channel_num
  273. for i in range(num_workers):
  274. start = one_worker_file * i
  275. end = one_worker_file * (
  276. i + 1) if i < num_workers - 1 else len(self.file_list)
  277. t = threading.Thread(
  278. target=self._get_clipped_mean_std,
  279. args=(start, end, clip_min_value, clip_max_value))
  280. threads.append(t)
  281. for t in threads:
  282. t.start()
  283. for t in threads:
  284. t.join()
  285. im_mean = np.asarray(self.clipped_im_mean_list)
  286. im_mean = im_mean.sum(axis=0)
  287. im_mean = im_mean / len(self.file_list)
  288. im_std = np.asarray(self.clipped_im_std_list)
  289. im_std = im_std.sum(axis=0)
  290. im_std = im_std / len(self.file_list)
  291. for c in range(channel_num):
  292. clip_pixel_num = 0
  293. pixel_num = sum(im_pixel_info[c].values())
  294. for v, n in im_pixel_info[c].items():
  295. if v < clip_min_value[c] or v > clip_max_value[c]:
  296. clip_pixel_num += n
  297. logging.info("Channel {}, the ratio of pixels to be clipped = {}".
  298. format(c, clip_pixel_num / pixel_num))
  299. logging.info(
  300. "Image mean value: {} Image standard deviation: {} (normalized by (clip_max_value - clip_min_value), arranged in 0-{} channel order).\n".
  301. format(im_mean, im_std, self.channel_num))