analysis.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import numpy as np
  16. import os.path as osp
  17. import cv2
  18. from PIL import Image
  19. import pickle
  20. import threading
  21. import multiprocessing as mp
  22. import paddlex.utils.logging as logging
  23. from paddlex.utils import path_normalization
  24. from .dataset import get_encoding
  25. class Seg:
  26. def __init__(self, data_dir, file_list, label_list):
  27. self.data_dir = data_dir
  28. self.file_list_path = file_list
  29. self.file_list = list()
  30. self.labels = list()
  31. with open(label_list, encoding=get_encoding(label_list)) as f:
  32. for line in f:
  33. item = line.strip()
  34. self.labels.append(item)
  35. with open(file_list, encoding=get_encoding(file_list)) as f:
  36. for line in f:
  37. if line.count(" ") > 1:
  38. raise Exception(
  39. "A space is defined as the separator, but it exists in image or label name {}."
  40. .format(line))
  41. items = line.strip().split()
  42. items[0] = path_normalization(items[0])
  43. items[1] = path_normalization(items[1])
  44. full_path_im = osp.join(data_dir, items[0])
  45. full_path_label = osp.join(data_dir, items[1])
  46. if not osp.exists(full_path_im):
  47. raise IOError('The image file {} is not exist!'.format(
  48. full_path_im))
  49. if not osp.exists(full_path_label):
  50. raise IOError('The image file {} is not exist!'.format(
  51. full_path_label))
  52. self.file_list.append([full_path_im, full_path_label])
  53. self.num_samples = len(self.file_list)
  54. @staticmethod
  55. def decode_image(im, label):
  56. if isinstance(im, np.ndarray):
  57. if len(im.shape) != 3:
  58. raise Exception(
  59. "im should be 3-dimensions, but now is {}-dimensions".
  60. format(len(im.shape)))
  61. else:
  62. try:
  63. im = cv2.imread(im)
  64. except:
  65. raise ValueError('Can\'t read The image file {}!'.format(im))
  66. im = im.astype('float32')
  67. if label is not None:
  68. if isinstance(label, np.ndarray):
  69. if len(label.shape) != 2:
  70. raise Exception(
  71. "label should be 2-dimensions, but now is {}-dimensions".
  72. format(len(label.shape)))
  73. else:
  74. try:
  75. label = np.asarray(Image.open(label))
  76. except:
  77. ValueError('Can\'t read The label file {}!'.format(label))
  78. im_height, im_width, _ = im.shape
  79. label_height, label_width = label.shape
  80. if im_height != label_height or im_width != label_width:
  81. raise Exception(
  82. "The height or width of the image is not same as the label")
  83. return (im, label)
  84. def _get_shape(self):
  85. max_height = max(self.im_height_list)
  86. max_width = max(self.im_width_list)
  87. min_height = min(self.im_height_list)
  88. min_width = min(self.im_width_list)
  89. shape_info = {
  90. 'max_height': max_height,
  91. 'max_width': max_width,
  92. 'min_height': min_height,
  93. 'min_width': min_width,
  94. }
  95. return shape_info
  96. def _get_label_pixel_info(self):
  97. pixel_num = np.dot(self.im_height_list, self.im_width_list)
  98. label_pixel_info = dict()
  99. for label_value, label_value_num in zip(self.label_value_list,
  100. self.label_value_num_list):
  101. for v, n in zip(label_value, label_value_num):
  102. if v not in label_pixel_info.keys():
  103. label_pixel_info[v] = [n, float(n) / float(pixel_num)]
  104. else:
  105. label_pixel_info[v][0] += n
  106. label_pixel_info[v][1] += float(n) / float(pixel_num)
  107. return label_pixel_info
  108. def _get_image_pixel_info(self):
  109. channel = max([len(im_value) for im_value in self.im_value_list])
  110. im_pixel_info = [dict() for c in range(channel)]
  111. for im_value, im_value_num in zip(self.im_value_list,
  112. self.im_value_num_list):
  113. for c in range(channel):
  114. for v, n in zip(im_value[c], im_value_num[c]):
  115. if v not in im_pixel_info[c].keys():
  116. im_pixel_info[c][v] = n
  117. else:
  118. im_pixel_info[c][v] += n
  119. mode = osp.split(self.file_list_path)[-1].split('.')[0]
  120. with open(
  121. osp.join(self.data_dir,
  122. '{}_image_pixel_info.pkl'.format(mode)), 'wb') as f:
  123. pickle.dump(im_pixel_info, f)
  124. import matplotlib.pyplot as plt
  125. plot_id = (channel // 3 + 1) * 100 + 31
  126. for c in range(channel):
  127. if c > 8:
  128. continue
  129. plt.subplot(plot_id + c)
  130. plt.bar(im_pixel_info[c].keys(),
  131. im_pixel_info[c].values(),
  132. width=1,
  133. log=True)
  134. plt.xlabel('image pixel value')
  135. plt.ylabel('number')
  136. plt.title('channel={}'.format(c))
  137. plt.savefig(
  138. osp.join(self.data_dir, '{}_image_pixel_info.png'.format(mode)),
  139. dpi=800)
  140. plt.close()
  141. return im_pixel_info
  142. def _get_mean_std(self):
  143. im_mean = np.asarray(self.im_mean_list)
  144. im_mean = im_mean.sum(axis=0)
  145. im_mean = im_mean / len(self.file_list)
  146. im_mean /= 255.
  147. im_std = np.asarray(self.im_std_list)
  148. im_std = im_std.sum(axis=0)
  149. im_std = im_std / len(self.file_list)
  150. im_std /= 255.
  151. return (im_mean, im_std)
  152. def _get_image_info(self, start, end):
  153. for id in range(start, end):
  154. full_path_im, full_path_label = self.file_list[id]
  155. image, label = self.decode_image(full_path_im, full_path_label)
  156. height, width, channel = image.shape
  157. self.im_height_list[id] = height
  158. self.im_width_list[id] = width
  159. self.im_channel_list[id] = channel
  160. self.im_mean_list[
  161. id] = [np.mean(image[:, :, c]) for c in range(channel)]
  162. self.im_std_list[
  163. id] = [np.mean(image[:, :, c]) for c in range(channel)]
  164. for c in range(channel):
  165. unique, counts = np.unique(image[:, :, c], return_counts=True)
  166. self.im_value_list[id].extend([unique])
  167. self.im_value_num_list[id].extend([counts])
  168. unique, counts = np.unique(label, return_counts=True)
  169. self.label_value_list[id] = unique
  170. self.label_value_num_list[id] = counts
  171. def _get_clipped_mean_std(self, start, end, clip_min_value,
  172. clip_max_value):
  173. for id in range(start, end):
  174. full_path_im, full_path_label = self.file_list[id]
  175. image, label = self.decode_image(full_path_im, full_path_label)
  176. for c in range(self.channel_num):
  177. np.clip(
  178. image[:, :, c],
  179. clip_min_value[c],
  180. clip_max_value[c],
  181. out=image[:, :, c])
  182. image[:, :, c] -= clip_min_value[c]
  183. image[:, :, c] /= clip_max_value[c] - clip_min_value[c]
  184. self.clipped_im_mean_list[id] = [
  185. image[:, :, c].mean() for c in range(self.channel_num)
  186. ]
  187. self.clipped_im_std_list[
  188. id] = [image[:, :, c].std() for c in range(self.channel_num)]
  189. def analysis(self):
  190. self.im_mean_list = [[] for i in range(len(self.file_list))]
  191. self.im_std_list = [[] for i in range(len(self.file_list))]
  192. self.im_value_list = [[] for i in range(len(self.file_list))]
  193. self.im_value_num_list = [[] for i in range(len(self.file_list))]
  194. self.im_height_list = np.zeros(len(self.file_list), dtype='int32')
  195. self.im_width_list = np.zeros(len(self.file_list), dtype='int32')
  196. self.im_channel_list = np.zeros(len(self.file_list), dtype='int32')
  197. self.label_value_list = [[] for i in range(len(self.file_list))]
  198. self.label_value_num_list = [[] for i in range(len(self.file_list))]
  199. num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8
  200. num_workers = 6
  201. threads = []
  202. one_worker_file = len(self.file_list) // num_workers
  203. for i in range(num_workers):
  204. start = one_worker_file * i
  205. end = one_worker_file * (
  206. i + 1) if i < num_workers - 1 else len(self.file_list)
  207. t = threading.Thread(
  208. target=self._get_image_info, args=(start, end))
  209. print("====", len(self.file_list), start, end)
  210. #t.daemon = True
  211. threads.append(t)
  212. for t in threads:
  213. t.start()
  214. for t in threads:
  215. t.join()
  216. print('ok')
  217. import time
  218. import sys
  219. sys.exit(0)
  220. time.sleep(1000000)
  221. return
  222. #self._get_image_info(0, len(self.file_list))
  223. unique, counts = np.unique(self.im_channel_list, return_counts=True)
  224. print('==== unique')
  225. if len(unique) > 1:
  226. raise Exception("There are {} kinds of image channels: {}.".format(
  227. len(unique), unique[:]))
  228. self.channel_num = unique[0]
  229. shape_info = self._get_shape()
  230. print('==== shape_info')
  231. self.max_height = shape_info['max_height']
  232. self.max_width = shape_info['max_width']
  233. self.min_height = shape_info['min_height']
  234. self.min_width = shape_info['min_width']
  235. self.label_pixel_info = self._get_label_pixel_info()
  236. print('==== label_pixel_info')
  237. self.im_pixel_info = self._get_image_pixel_info()
  238. print('==== im_pixel_info')
  239. im_mean, im_std = self._get_mean_std()
  240. print('==== get_mean_std')
  241. max_im_value = list()
  242. min_im_value = list()
  243. for c in range(self.channel_num):
  244. max_im_value.append(max(self.im_pixel_info[c].keys()))
  245. min_im_value.append(min(self.im_pixel_info[c].keys()))
  246. self.max_im_value = np.asarray(max_im_value)
  247. self.min_im_value = np.asarray(min_im_value)
  248. logging.info(
  249. "############## The analysis results are as follows ##############\n"
  250. )
  251. logging.info("{} samples in file {}\n".format(
  252. len(self.file_list), self.file_list_path))
  253. logging.info("Maximal image height: {} Maximal image width: {}.\n".
  254. format(self.max_height, self.max_width))
  255. logging.info("Minimal image height: {} Minimal image width: {}.\n".
  256. format(self.min_height, self.min_width))
  257. logging.info("Image channel is {}.\n".format(self.channel_num))
  258. logging.info(
  259. "Image mean value: {} Image standard deviation: {} (normalized by 255, sorted by a BGR format).\n".
  260. format(im_mean, im_std))
  261. logging.info(
  262. "Label pixel information is shown in a format of (label_id, the number of label_id, the ratio of label_id):"
  263. )
  264. for v, (n, r) in self.label_pixel_info.items():
  265. logging.info("({}, {}, {})".format(v, n, r))
  266. mode = osp.split(self.file_list_path)[-1].split('.')[0]
  267. saved_pkl_file = osp.join(self.data_dir,
  268. '{}_image_pixel_info.pkl'.format(mode))
  269. saved_png_file = osp.join(self.data_dir,
  270. '{}_image_pixel_info.png'.format(mode))
  271. logging.info(
  272. "Image pixel information is saved in the file '{}' and shown in the file '{}'".
  273. format(saved_pkl_file, saved_png_file))
  274. def cal_clipvalue_ratio(self, clip_min_value, clip_max_value):
  275. if len(clip_min_value) != self.channel_num or len(
  276. clip_max_value) != self.channel_num:
  277. raise Exception(
  278. "The length of clip_min_value or clip_max_value should be equal to the number of image channel {}."
  279. .format(self.channle_num))
  280. for c in range(self.channel_num):
  281. if clip_min_value[c] < self.min_im_value[c] or clip_min_value[
  282. c] > self.max_im_value[c]:
  283. raise Exception(
  284. "Clip_min_value of the channel {} is not in [{}, {}]".
  285. format(c, self.min_im_value[c], self.max_im_value[c]))
  286. if clip_max_value[c] < self.min_im_value[c] or clip_max_value[
  287. c] > self.max_im_value[c]:
  288. raise Exception(
  289. "Clip_max_value of the channel {} is not in [{}, {}]".
  290. format(c, self.min_im_value[c], self.max_im_value[c]))
  291. clip_pixel_num = 0
  292. pixel_num = sum(self.im_pixel_info[c].values())
  293. for v, n in self.im_pixel_info[c].items():
  294. if v < clip_min_value[c] or v > clip_max_value[c]:
  295. clip_pixel_num += n
  296. logging.info("Channel {}, the ratio of pixels to be clipped = {}".
  297. format(c, clip_pixel_num / pixel_num))
  298. def cal_clipped_mean_std(self, clip_min_value, clip_max_value):
  299. for c in range(self.channel_num):
  300. if clip_min_value[c] < self.min_im_value[c] or clip_min_value[
  301. c] > self.max_im_value[c]:
  302. raise Exception(
  303. "Clip_min_value of the channel {} is not in [{}, {}]".
  304. format(c, self.min_im_value[c], self.max_im_value[c]))
  305. if clip_max_value[c] < self.min_im_value[c] or clip_max_value[
  306. c] > self.max_im_value[c]:
  307. raise Exception(
  308. "Clip_max_value of the channel {} is not in [{}, {}]".
  309. format(c, self.min_im_value[c], self.max_im_value[c]))
  310. self.clipped_im_mean_list = [[] for i in range(len(self.file_list))]
  311. self.clipped_im_std_list = [[] for i in range(len(self.file_list))]
  312. num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8
  313. threads = []
  314. one_worker_file = len(self.file_list) // num_workers
  315. for i in range(num_workers):
  316. start = one_worker_file * i
  317. end = one_worker_file * (
  318. i + 1) if i < num_workers - 1 else len(self.file_list)
  319. t = threading.Thread(
  320. target=self._get_clipped_mean_std,
  321. args=(start, end, clip_min_value, clip_max_value))
  322. threads.append(t)
  323. for t in threads:
  324. t.setDaemon(True)
  325. t.start()
  326. t.join()
  327. im_mean = np.asarray(self.clipped_im_mean_list)
  328. im_mean = im_mean.sum(axis=0)
  329. im_mean = im_mean / len(self.file_list)
  330. im_std = np.asarray(self.clipped_im_std_list)
  331. im_std = im_std.sum(axis=0)
  332. im_std = im_std / len(self.file_list)
  333. logging.info(
  334. "Image mean value: {} Image standard deviation: {} (normalized by (clip_max_value - clip_min_value)).\n".
  335. format(im_mean, im_std))