| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506 |
- # coding: utf8
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import numpy as np
- import os
- import os.path as osp
- import sys
- import argparse
- from PIL import Image
- from tqdm import tqdm
- import imghdr
- import logging
- import pickle
- import gdal
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Data analyse and data check before training.')
- parser.add_argument(
- '--data_dir',
- dest='data_dir',
- help='Dataset directory',
- default=None,
- type=str)
- parser.add_argument(
- '--num_classes',
- dest='num_classes',
- help='Number of classes',
- default=None,
- type=int)
- parser.add_argument(
- '--separator',
- dest='separator',
- help='file list separator',
- default=" ",
- type=str)
- parser.add_argument(
- '--ignore_index',
- dest='ignore_index',
- help='Ignored class index',
- default=255,
- type=int)
- if len(sys.argv) == 1:
- parser.print_help()
- sys.exit(1)
- return parser.parse_args()
- def read_img(img_path):
- img_format = imghdr.what(img_path)
- name, ext = osp.splitext(img_path)
- if img_format == 'tiff' or ext == '.img':
- dataset = gdal.Open(img_path)
- if dataset == None:
- raise Exception('Can not open', img_path)
- im_data = dataset.ReadAsArray()
- return im_data.transpose((1, 2, 0))
- elif ext == '.npy':
- return np.load(img_path)
- else:
- raise Exception('Not support {} image format!'.format(ext))
- def img_pixel_statistics(img, img_value_num, img_min_value, img_max_value):
- channel = img.shape[2]
- means = np.zeros(channel)
- stds = np.zeros(channel)
- for k in range(channel):
- img_k = img[:, :, k]
- # count mean, std
- means[k] = np.mean(img_k)
- stds[k] = np.std(img_k)
- # count min, max
- min_value = np.min(img_k)
- max_value = np.max(img_k)
- if img_max_value[k] < max_value:
- img_max_value[k] = max_value
- if img_min_value[k] > min_value:
- img_min_value[k] = min_value
- # count the distribution of image value, value number
- unique, counts = np.unique(img_k, return_counts=True)
- add_num = []
- max_unique = np.max(unique)
- add_len = max_unique - len(img_value_num[k]) + 1
- if add_len > 0:
- img_value_num[k] += ([0] * add_len)
- for i in range(len(unique)):
- value = unique[i]
- img_value_num[k][value] += counts[i]
- img_value_num[k] += add_num
- return means, stds, img_min_value, img_max_value, img_value_num
- def data_distribution_statistics(data_dir, img_value_num, logger):
- """count the distribution of image value, value number
- """
- logger.info(
- "\n-----------------------------\nThe whole dataset statistics...")
- if not img_value_num:
- return
- logger.info("\nImage pixel statistics:")
- total_ratio = []
- [total_ratio.append([]) for i in range(len(img_value_num))]
- for k in range(len(img_value_num)):
- total_num = sum(img_value_num[k])
- total_ratio[k] = [i / total_num for i in img_value_num[k]]
- total_ratio[k] = np.around(total_ratio[k], decimals=4)
- with open(os.path.join(data_dir, 'img_pixel_statistics.pkl'), 'wb') as f:
- pickle.dump([total_ratio, img_value_num], f)
- def data_range_statistics(img_min_value, img_max_value, logger):
- """print min value, max value
- """
- logger.info("value range: \nimg_min_value = {} \nimg_max_value = {}".
- format(img_min_value, img_max_value))
- def cal_normalize_coefficient(total_means, total_stds, total_img_num, logger):
- """count mean, std
- """
- total_means = total_means / total_img_num
- total_stds = total_stds / total_img_num
- logger.info("\nCount the channel-by-channel mean and std of the image:\n"
- "mean = {}\nstd = {}".format(total_means, total_stds))
- def error_print(str):
- return "".join(["\nNOT PASS ", str])
- def correct_print(str):
- return "".join(["\nPASS ", str])
- def pil_imread(file_path):
- """read pseudo-color label"""
- im = Image.open(file_path)
- return np.asarray(im)
- def get_img_shape_range(img, max_width, max_height, min_width, min_height):
- """获取图片最大和最小宽高"""
- img_shape = img.shape
- height, width = img_shape[0], img_shape[1]
- max_height = max(height, max_height)
- max_width = max(width, max_width)
- min_height = min(height, min_height)
- min_width = min(width, min_width)
- return max_width, max_height, min_width, min_height
- def get_img_channel_num(img, img_channels):
- """获取图像的通道数"""
- img_shape = img.shape
- if img_shape[-1] not in img_channels:
- img_channels.append(img_shape[-1])
- return img_channels
- def is_label_single_channel(label):
- """判断标签是否为灰度图"""
- label_shape = label.shape
- if len(label_shape) == 2:
- return True
- else:
- return False
- def image_label_shape_check(img, label):
- """
- 验证图像和标注的大小是否匹配
- """
- flag = True
- img_height = img.shape[0]
- img_width = img.shape[1]
- label_height = label.shape[0]
- label_width = label.shape[1]
- if img_height != label_height or img_width != label_width:
- flag = False
- return flag
- def ground_truth_check(label, label_path):
- """
- 验证标注图像的格式
- 统计标注图类别和像素数
- params:
- label: 标注图
- label_path: 标注图路径
- return:
- png_format: 返回是否是png格式图片
- unique: 返回标注类别
- counts: 返回标注的像素数
- """
- if imghdr.what(label_path) == "png":
- png_format = True
- else:
- png_format = False
- unique, counts = np.unique(label, return_counts=True)
- return png_format, unique, counts
- def sum_label_check(label_classes, num_of_each_class, ignore_index,
- num_classes, total_label_classes, total_num_of_each_class):
- """
- 统计所有标注图上的类别和每个类别的像素数
- params:
- label_classes: 标注类别
- num_of_each_class: 各个类别的像素数目
- """
- is_label_correct = True
- if ignore_index in label_classes:
- label_classes2 = np.delete(label_classes,
- np.where(label_classes == ignore_index))
- else:
- label_classes2 = label_classes
- if min(label_classes2) < 0 or max(label_classes2) > num_classes - 1:
- is_label_correct = False
- add_class = []
- add_num = []
- for i in range(len(label_classes)):
- gi = label_classes[i]
- if gi in total_label_classes:
- j = total_label_classes.index(gi)
- total_num_of_each_class[j] += num_of_each_class[i]
- else:
- add_class.append(gi)
- add_num.append(num_of_each_class[i])
- total_num_of_each_class += add_num
- total_label_classes += add_class
- return is_label_correct, total_num_of_each_class, total_label_classes
- def label_class_check(num_classes, total_label_classes,
- total_num_of_each_class, wrong_labels, logger):
- """
- 检查实际标注类别是否和配置参数`num_classes`,`ignore_index`匹配。
- **NOTE:**
- 标注图像类别数值必须在[0~(`num_classes`-1)]范围内或者为`ignore_index`。
- 标注类别最好从0开始,否则可能影响精度。
- """
- total_ratio = total_num_of_each_class / sum(total_num_of_each_class)
- total_ratio = np.around(total_ratio, decimals=4)
- total_nc = sorted(
- zip(total_label_classes, total_ratio, total_num_of_each_class))
- if len(wrong_labels) == 0 and not total_nc[0][0]:
- logger.info(correct_print("label class check!"))
- else:
- logger.info(error_print("label class check!"))
- if total_nc[0][0]:
- logger.info("Warning: label classes should start from 0")
- if len(wrong_labels) > 0:
- logger.info("fatal error: label class is out of range [0, {}]".
- format(num_classes - 1))
- for i in wrong_labels:
- logger.debug(i)
- return total_nc
- def label_class_statistics(total_nc, logger):
- """
- 对标注图像进行校验,输出校验结果
- """
- logger.info("\nLabel class statistics:\n"
- "(label class, percentage, total pixel number) = {} ".format(
- total_nc))
- def shape_check(shape_unequal_image, logger):
- """输出shape校验结果"""
- if len(shape_unequal_image) == 0:
- logger.info(correct_print("shape check"))
- logger.info("All images are the same shape as the labels")
- else:
- logger.info(error_print("shape check"))
- logger.info(
- "Some images are not the same shape as the labels as follow: ")
- for i in shape_unequal_image:
- logger.debug(i)
- def separator_check(wrong_lines, file_list, separator, logger):
- """检查分割符是否复合要求"""
- if len(wrong_lines) == 0:
- logger.info(
- correct_print(
- file_list.split(os.sep)[-1] + " DATASET.separator check"))
- else:
- logger.info(
- error_print(
- file_list.split(os.sep)[-1] + " DATASET.separator check"))
- logger.info("The following list is not separated by {}".format(
- separator))
- for i in wrong_lines:
- logger.debug(i)
- def imread_check(imread_failed, logger):
- if len(imread_failed) == 0:
- logger.info(correct_print("dataset reading check"))
- logger.info("All images can be read successfully")
- else:
- logger.info(error_print("dataset reading check"))
- logger.info("Failed to read {} images".format(len(imread_failed)))
- for i in imread_failed:
- logger.debug(i)
- def single_channel_label_check(label_not_single_channel, logger):
- if len(label_not_single_channel) == 0:
- logger.info(correct_print("label single_channel check"))
- logger.info("All label images are single_channel")
- else:
- logger.info(error_print("label single_channel check"))
- logger.info(
- "{} label images are not single_channel\nLabel pixel statistics may be insignificant"
- .format(len(label_not_single_channel)))
- for i in label_not_single_channel:
- logger.debug(i)
- def img_shape_range_statistics(max_width, min_width, max_height, min_height,
- logger):
- logger.info("\nImage size statistics:")
- logger.info(
- "max width = {} min width = {} max height = {} min height = {}".
- format(max_width, min_width, max_height, min_height))
- def img_channels_statistics(img_channels, logger):
- logger.info("\nImage channels statistics\nImage channels = {}".format(
- np.unique(img_channels)))
- def data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
- logger):
- train_file_list = osp.join(data_dir, 'train.txt')
- val_file_list = osp.join(data_dir, 'val.txt')
- test_file_list = osp.join(data_dir, 'test.txt')
- total_img_num = 0
- has_label = False
- for file_list in [train_file_list, val_file_list, test_file_list]:
- # initialization
- imread_failed = []
- max_width = 0
- max_height = 0
- min_width = sys.float_info.max
- min_height = sys.float_info.max
- label_not_single_channel = []
- shape_unequal_image = []
- wrong_labels = []
- wrong_lines = []
- total_label_classes = []
- total_num_of_each_class = []
- img_channels = []
- with open(file_list, 'r') as fid:
- logger.info("\n-----------------------------\nCheck {}...".format(
- file_list))
- lines = fid.readlines()
- if not lines:
- logger.info("File list is empty!")
- continue
- for line in tqdm(lines):
- line = line.strip()
- parts = line.split(separator)
- if len(parts) == 1:
- if file_list == train_file_list or file_list == val_file_list:
- logger.info("Train or val list must have labels!")
- break
- img_name = parts
- img_path = os.path.join(data_dir, img_name[0])
- try:
- img = read_img(img_path)
- except Exception as e:
- imread_failed.append((line, str(e)))
- continue
- elif len(parts) == 2:
- has_label = True
- img_name, label_name = parts[0], parts[1]
- img_path = os.path.join(data_dir, img_name)
- label_path = os.path.join(data_dir, label_name)
- try:
- img = read_img(img_path)
- label = pil_imread(label_path)
- except Exception as e:
- imread_failed.append((line, str(e)))
- continue
- is_single_channel = is_label_single_channel(label)
- if not is_single_channel:
- label_not_single_channel.append(line)
- continue
- is_equal_img_label_shape = image_label_shape_check(img,
- label)
- if not is_equal_img_label_shape:
- shape_unequal_image.append(line)
- png_format, label_classes, num_of_each_class = ground_truth_check(
- label, label_path)
- is_label_correct, total_num_of_each_class, total_label_classes = sum_label_check(
- label_classes, num_of_each_class, ignore_index,
- num_classes, total_label_classes,
- total_num_of_each_class)
- if not is_label_correct:
- wrong_labels.append(line)
- else:
- wrong_lines.append(lines)
- continue
- if total_img_num == 0:
- channel = img.shape[2]
- total_means = np.zeros(channel)
- total_stds = np.zeros(channel)
- img_min_value = [sys.float_info.max] * channel
- img_max_value = [0] * channel
- img_value_num = []
- [img_value_num.append([]) for i in range(channel)]
- means, stds, img_min_value, img_max_value, img_value_num = img_pixel_statistics(
- img, img_value_num, img_min_value, img_max_value)
- total_means += means
- total_stds += stds
- max_width, max_height, min_width, min_height = get_img_shape_range(
- img, max_width, max_height, min_width, min_height)
- img_channels = get_img_channel_num(img, img_channels)
- total_img_num += 1
- # data check
- separator_check(wrong_lines, file_list, separator, logger)
- imread_check(imread_failed, logger)
- if has_label:
- single_channel_label_check(label_not_single_channel, logger)
- shape_check(shape_unequal_image, logger)
- total_nc = label_class_check(num_classes, total_label_classes,
- total_num_of_each_class,
- wrong_labels, logger)
- # data analyse on train, validation, test set.
- img_channels_statistics(img_channels, logger)
- img_shape_range_statistics(max_width, min_width, max_height,
- min_height, logger)
- if has_label:
- label_class_statistics(total_nc, logger)
- # data analyse on the whole dataset.
- data_range_statistics(img_min_value, img_max_value, logger)
- data_distribution_statistics(data_dir, img_value_num, logger)
- cal_normalize_coefficient(total_means, total_stds, total_img_num, logger)
- def main():
- args = parse_args()
- data_dir = args.data_dir
- ignore_index = args.ignore_index
- num_classes = args.num_classes
- separator = args.separator
- logger = logging.getLogger()
- logger.setLevel('DEBUG')
- BASIC_FORMAT = "%(message)s"
- formatter = logging.Formatter(BASIC_FORMAT)
- sh = logging.StreamHandler()
- sh.setFormatter(formatter)
- sh.setLevel('INFO')
- th = logging.FileHandler(
- os.path.join(data_dir, 'data_analyse_and_check.log'), 'w')
- th.setFormatter(formatter)
- logger.addHandler(sh)
- logger.addHandler(th)
- data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
- logger)
- print("\nDetailed error information can be viewed in {}.".format(
- os.path.join(data_dir, 'data_analyse_and_check.log')))
- if __name__ == "__main__":
- main()
|