analyse.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. # coding: utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import numpy as np
  19. import os
  20. import os.path as osp
  21. import sys
  22. import argparse
  23. from PIL import Image
  24. from tqdm import tqdm
  25. import imghdr
  26. import logging
  27. import pickle
  28. import gdal
  29. def parse_args():
  30. parser = argparse.ArgumentParser(
  31. description='Data analyse and data check before training.')
  32. parser.add_argument(
  33. '--data_dir',
  34. dest='data_dir',
  35. help='Dataset directory',
  36. default=None,
  37. type=str)
  38. parser.add_argument(
  39. '--num_classes',
  40. dest='num_classes',
  41. help='Number of classes',
  42. default=None,
  43. type=int)
  44. parser.add_argument(
  45. '--separator',
  46. dest='separator',
  47. help='file list separator',
  48. default=" ",
  49. type=str)
  50. parser.add_argument(
  51. '--ignore_index',
  52. dest='ignore_index',
  53. help='Ignored class index',
  54. default=255,
  55. type=int)
  56. if len(sys.argv) == 1:
  57. parser.print_help()
  58. sys.exit(1)
  59. return parser.parse_args()
  60. def read_img(img_path):
  61. img_format = imghdr.what(img_path)
  62. name, ext = osp.splitext(img_path)
  63. if img_format == 'tiff' or ext == '.img':
  64. dataset = gdal.Open(img_path)
  65. if dataset == None:
  66. raise Exception('Can not open', img_path)
  67. im_data = dataset.ReadAsArray()
  68. return im_data.transpose((1, 2, 0))
  69. elif ext == '.npy':
  70. return np.load(img_path)
  71. else:
  72. raise Exception('Not support {} image format!'.format(ext))
  73. def img_pixel_statistics(img, img_value_num, img_min_value, img_max_value):
  74. channel = img.shape[2]
  75. means = np.zeros(channel)
  76. stds = np.zeros(channel)
  77. for k in range(channel):
  78. img_k = img[:, :, k]
  79. # count mean, std
  80. means[k] = np.mean(img_k)
  81. stds[k] = np.std(img_k)
  82. # count min, max
  83. min_value = np.min(img_k)
  84. max_value = np.max(img_k)
  85. if img_max_value[k] < max_value:
  86. img_max_value[k] = max_value
  87. if img_min_value[k] > min_value:
  88. img_min_value[k] = min_value
  89. # count the distribution of image value, value number
  90. unique, counts = np.unique(img_k, return_counts=True)
  91. add_num = []
  92. max_unique = np.max(unique)
  93. add_len = max_unique - len(img_value_num[k]) + 1
  94. if add_len > 0:
  95. img_value_num[k] += ([0] * add_len)
  96. for i in range(len(unique)):
  97. value = unique[i]
  98. img_value_num[k][value] += counts[i]
  99. img_value_num[k] += add_num
  100. return means, stds, img_min_value, img_max_value, img_value_num
  101. def data_distribution_statistics(data_dir, img_value_num, logger):
  102. """count the distribution of image value, value number
  103. """
  104. logger.info(
  105. "\n-----------------------------\nThe whole dataset statistics...")
  106. if not img_value_num:
  107. return
  108. logger.info("\nImage pixel statistics:")
  109. total_ratio = []
  110. [total_ratio.append([]) for i in range(len(img_value_num))]
  111. for k in range(len(img_value_num)):
  112. total_num = sum(img_value_num[k])
  113. total_ratio[k] = [i / total_num for i in img_value_num[k]]
  114. total_ratio[k] = np.around(total_ratio[k], decimals=4)
  115. with open(os.path.join(data_dir, 'img_pixel_statistics.pkl'), 'wb') as f:
  116. pickle.dump([total_ratio, img_value_num], f)
  117. def data_range_statistics(img_min_value, img_max_value, logger):
  118. """print min value, max value
  119. """
  120. logger.info("value range: \nimg_min_value = {} \nimg_max_value = {}".
  121. format(img_min_value, img_max_value))
  122. def cal_normalize_coefficient(total_means, total_stds, total_img_num, logger):
  123. """count mean, std
  124. """
  125. total_means = total_means / total_img_num
  126. total_stds = total_stds / total_img_num
  127. logger.info("\nCount the channel-by-channel mean and std of the image:\n"
  128. "mean = {}\nstd = {}".format(total_means, total_stds))
  129. def error_print(str):
  130. return "".join(["\nNOT PASS ", str])
  131. def correct_print(str):
  132. return "".join(["\nPASS ", str])
  133. def pil_imread(file_path):
  134. """read pseudo-color label"""
  135. im = Image.open(file_path)
  136. return np.asarray(im)
  137. def get_img_shape_range(img, max_width, max_height, min_width, min_height):
  138. """获取图片最大和最小宽高"""
  139. img_shape = img.shape
  140. height, width = img_shape[0], img_shape[1]
  141. max_height = max(height, max_height)
  142. max_width = max(width, max_width)
  143. min_height = min(height, min_height)
  144. min_width = min(width, min_width)
  145. return max_width, max_height, min_width, min_height
  146. def get_img_channel_num(img, img_channels):
  147. """获取图像的通道数"""
  148. img_shape = img.shape
  149. if img_shape[-1] not in img_channels:
  150. img_channels.append(img_shape[-1])
  151. return img_channels
  152. def is_label_single_channel(label):
  153. """判断标签是否为灰度图"""
  154. label_shape = label.shape
  155. if len(label_shape) == 2:
  156. return True
  157. else:
  158. return False
  159. def image_label_shape_check(img, label):
  160. """
  161. 验证图像和标注的大小是否匹配
  162. """
  163. flag = True
  164. img_height = img.shape[0]
  165. img_width = img.shape[1]
  166. label_height = label.shape[0]
  167. label_width = label.shape[1]
  168. if img_height != label_height or img_width != label_width:
  169. flag = False
  170. return flag
  171. def ground_truth_check(label, label_path):
  172. """
  173. 验证标注图像的格式
  174. 统计标注图类别和像素数
  175. params:
  176. label: 标注图
  177. label_path: 标注图路径
  178. return:
  179. png_format: 返回是否是png格式图片
  180. unique: 返回标注类别
  181. counts: 返回标注的像素数
  182. """
  183. if imghdr.what(label_path) == "png":
  184. png_format = True
  185. else:
  186. png_format = False
  187. unique, counts = np.unique(label, return_counts=True)
  188. return png_format, unique, counts
  189. def sum_label_check(label_classes, num_of_each_class, ignore_index,
  190. num_classes, total_label_classes, total_num_of_each_class):
  191. """
  192. 统计所有标注图上的类别和每个类别的像素数
  193. params:
  194. label_classes: 标注类别
  195. num_of_each_class: 各个类别的像素数目
  196. """
  197. is_label_correct = True
  198. if ignore_index in label_classes:
  199. label_classes2 = np.delete(label_classes,
  200. np.where(label_classes == ignore_index))
  201. else:
  202. label_classes2 = label_classes
  203. if min(label_classes2) < 0 or max(label_classes2) > num_classes - 1:
  204. is_label_correct = False
  205. add_class = []
  206. add_num = []
  207. for i in range(len(label_classes)):
  208. gi = label_classes[i]
  209. if gi in total_label_classes:
  210. j = total_label_classes.index(gi)
  211. total_num_of_each_class[j] += num_of_each_class[i]
  212. else:
  213. add_class.append(gi)
  214. add_num.append(num_of_each_class[i])
  215. total_num_of_each_class += add_num
  216. total_label_classes += add_class
  217. return is_label_correct, total_num_of_each_class, total_label_classes
  218. def label_class_check(num_classes, total_label_classes,
  219. total_num_of_each_class, wrong_labels, logger):
  220. """
  221. 检查实际标注类别是否和配置参数`num_classes`,`ignore_index`匹配。
  222. **NOTE:**
  223. 标注图像类别数值必须在[0~(`num_classes`-1)]范围内或者为`ignore_index`。
  224. 标注类别最好从0开始,否则可能影响精度。
  225. """
  226. total_ratio = total_num_of_each_class / sum(total_num_of_each_class)
  227. total_ratio = np.around(total_ratio, decimals=4)
  228. total_nc = sorted(
  229. zip(total_label_classes, total_ratio, total_num_of_each_class))
  230. if len(wrong_labels) == 0 and not total_nc[0][0]:
  231. logger.info(correct_print("label class check!"))
  232. else:
  233. logger.info(error_print("label class check!"))
  234. if total_nc[0][0]:
  235. logger.info("Warning: label classes should start from 0")
  236. if len(wrong_labels) > 0:
  237. logger.info("fatal error: label class is out of range [0, {}]".
  238. format(num_classes - 1))
  239. for i in wrong_labels:
  240. logger.debug(i)
  241. return total_nc
  242. def label_class_statistics(total_nc, logger):
  243. """
  244. 对标注图像进行校验,输出校验结果
  245. """
  246. logger.info("\nLabel class statistics:\n"
  247. "(label class, percentage, total pixel number) = {} ".format(
  248. total_nc))
  249. def shape_check(shape_unequal_image, logger):
  250. """输出shape校验结果"""
  251. if len(shape_unequal_image) == 0:
  252. logger.info(correct_print("shape check"))
  253. logger.info("All images are the same shape as the labels")
  254. else:
  255. logger.info(error_print("shape check"))
  256. logger.info(
  257. "Some images are not the same shape as the labels as follow: ")
  258. for i in shape_unequal_image:
  259. logger.debug(i)
  260. def separator_check(wrong_lines, file_list, separator, logger):
  261. """检查分割符是否复合要求"""
  262. if len(wrong_lines) == 0:
  263. logger.info(
  264. correct_print(
  265. file_list.split(os.sep)[-1] + " DATASET.separator check"))
  266. else:
  267. logger.info(
  268. error_print(
  269. file_list.split(os.sep)[-1] + " DATASET.separator check"))
  270. logger.info("The following list is not separated by {}".format(
  271. separator))
  272. for i in wrong_lines:
  273. logger.debug(i)
  274. def imread_check(imread_failed, logger):
  275. if len(imread_failed) == 0:
  276. logger.info(correct_print("dataset reading check"))
  277. logger.info("All images can be read successfully")
  278. else:
  279. logger.info(error_print("dataset reading check"))
  280. logger.info("Failed to read {} images".format(len(imread_failed)))
  281. for i in imread_failed:
  282. logger.debug(i)
  283. def single_channel_label_check(label_not_single_channel, logger):
  284. if len(label_not_single_channel) == 0:
  285. logger.info(correct_print("label single_channel check"))
  286. logger.info("All label images are single_channel")
  287. else:
  288. logger.info(error_print("label single_channel check"))
  289. logger.info(
  290. "{} label images are not single_channel\nLabel pixel statistics may be insignificant"
  291. .format(len(label_not_single_channel)))
  292. for i in label_not_single_channel:
  293. logger.debug(i)
  294. def img_shape_range_statistics(max_width, min_width, max_height, min_height,
  295. logger):
  296. logger.info("\nImage size statistics:")
  297. logger.info(
  298. "max width = {} min width = {} max height = {} min height = {}".
  299. format(max_width, min_width, max_height, min_height))
  300. def img_channels_statistics(img_channels, logger):
  301. logger.info("\nImage channels statistics\nImage channels = {}".format(
  302. np.unique(img_channels)))
  303. def data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
  304. logger):
  305. train_file_list = osp.join(data_dir, 'train.txt')
  306. val_file_list = osp.join(data_dir, 'val.txt')
  307. test_file_list = osp.join(data_dir, 'test.txt')
  308. total_img_num = 0
  309. has_label = False
  310. for file_list in [train_file_list, val_file_list, test_file_list]:
  311. # initialization
  312. imread_failed = []
  313. max_width = 0
  314. max_height = 0
  315. min_width = sys.float_info.max
  316. min_height = sys.float_info.max
  317. label_not_single_channel = []
  318. shape_unequal_image = []
  319. wrong_labels = []
  320. wrong_lines = []
  321. total_label_classes = []
  322. total_num_of_each_class = []
  323. img_channels = []
  324. with open(file_list, 'r') as fid:
  325. logger.info("\n-----------------------------\nCheck {}...".format(
  326. file_list))
  327. lines = fid.readlines()
  328. if not lines:
  329. logger.info("File list is empty!")
  330. continue
  331. for line in tqdm(lines):
  332. line = line.strip()
  333. parts = line.split(separator)
  334. if len(parts) == 1:
  335. if file_list == train_file_list or file_list == val_file_list:
  336. logger.info("Train or val list must have labels!")
  337. break
  338. img_name = parts
  339. img_path = os.path.join(data_dir, img_name[0])
  340. try:
  341. img = read_img(img_path)
  342. except Exception as e:
  343. imread_failed.append((line, str(e)))
  344. continue
  345. elif len(parts) == 2:
  346. has_label = True
  347. img_name, label_name = parts[0], parts[1]
  348. img_path = os.path.join(data_dir, img_name)
  349. label_path = os.path.join(data_dir, label_name)
  350. try:
  351. img = read_img(img_path)
  352. label = pil_imread(label_path)
  353. except Exception as e:
  354. imread_failed.append((line, str(e)))
  355. continue
  356. is_single_channel = is_label_single_channel(label)
  357. if not is_single_channel:
  358. label_not_single_channel.append(line)
  359. continue
  360. is_equal_img_label_shape = image_label_shape_check(img,
  361. label)
  362. if not is_equal_img_label_shape:
  363. shape_unequal_image.append(line)
  364. png_format, label_classes, num_of_each_class = ground_truth_check(
  365. label, label_path)
  366. is_label_correct, total_num_of_each_class, total_label_classes = sum_label_check(
  367. label_classes, num_of_each_class, ignore_index,
  368. num_classes, total_label_classes,
  369. total_num_of_each_class)
  370. if not is_label_correct:
  371. wrong_labels.append(line)
  372. else:
  373. wrong_lines.append(lines)
  374. continue
  375. if total_img_num == 0:
  376. channel = img.shape[2]
  377. total_means = np.zeros(channel)
  378. total_stds = np.zeros(channel)
  379. img_min_value = [sys.float_info.max] * channel
  380. img_max_value = [0] * channel
  381. img_value_num = []
  382. [img_value_num.append([]) for i in range(channel)]
  383. means, stds, img_min_value, img_max_value, img_value_num = img_pixel_statistics(
  384. img, img_value_num, img_min_value, img_max_value)
  385. total_means += means
  386. total_stds += stds
  387. max_width, max_height, min_width, min_height = get_img_shape_range(
  388. img, max_width, max_height, min_width, min_height)
  389. img_channels = get_img_channel_num(img, img_channels)
  390. total_img_num += 1
  391. # data check
  392. separator_check(wrong_lines, file_list, separator, logger)
  393. imread_check(imread_failed, logger)
  394. if has_label:
  395. single_channel_label_check(label_not_single_channel, logger)
  396. shape_check(shape_unequal_image, logger)
  397. total_nc = label_class_check(num_classes, total_label_classes,
  398. total_num_of_each_class,
  399. wrong_labels, logger)
  400. # data analyse on train, validation, test set.
  401. img_channels_statistics(img_channels, logger)
  402. img_shape_range_statistics(max_width, min_width, max_height,
  403. min_height, logger)
  404. if has_label:
  405. label_class_statistics(total_nc, logger)
  406. # data analyse on the whole dataset.
  407. data_range_statistics(img_min_value, img_max_value, logger)
  408. data_distribution_statistics(data_dir, img_value_num, logger)
  409. cal_normalize_coefficient(total_means, total_stds, total_img_num, logger)
  410. def main():
  411. args = parse_args()
  412. data_dir = args.data_dir
  413. ignore_index = args.ignore_index
  414. num_classes = args.num_classes
  415. separator = args.separator
  416. logger = logging.getLogger()
  417. logger.setLevel('DEBUG')
  418. BASIC_FORMAT = "%(message)s"
  419. formatter = logging.Formatter(BASIC_FORMAT)
  420. sh = logging.StreamHandler()
  421. sh.setFormatter(formatter)
  422. sh.setLevel('INFO')
  423. th = logging.FileHandler(
  424. os.path.join(data_dir, 'data_analyse_and_check.log'), 'w')
  425. th.setFormatter(formatter)
  426. logger.addHandler(sh)
  427. logger.addHandler(th)
  428. data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
  429. logger)
  430. print("\nDetailed error information can be viewed in {}.".format(
  431. os.path.join(data_dir, 'data_analyse_and_check.log')))
  432. if __name__ == "__main__":
  433. main()