瀏覽代碼

add analysis for seg dataset

FlyingQianMM 5 年之前
父節點
當前提交
746122f4dd

+ 1 - 1
paddlex/__init__.py

@@ -32,7 +32,7 @@ from . import slim
 from . import convertor
 from . import tools
 from . import deploy
-from . import RemoteSensing
+from . import remotesensing
 
 try:
     import pycocotools

+ 1 - 0
paddlex/cv/datasets/__init__.py

@@ -20,3 +20,4 @@ from .easydata_cls import EasyDataCls
 from .easydata_det import EasyDataDet
 from .easydata_seg import EasyDataSeg
 from .dataset import generate_minibatch
+from .analysis import Seg

+ 366 - 0
paddlex/cv/datasets/analysis.py

@@ -0,0 +1,366 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+import numpy as np
+import os.path as osp
+import cv2
+from PIL import Image
+import pickle
+import threading
+import multiprocessing as mp
+
+import paddlex.utils.logging as logging
+from paddlex.utils import path_normalization
+from .dataset import get_encoding
+
+
+class Seg:
+    def __init__(self, data_dir, file_list, label_list):
+        self.data_dir = data_dir
+        self.file_list_path = file_list
+        self.file_list = list()
+        self.labels = list()
+        with open(label_list, encoding=get_encoding(label_list)) as f:
+            for line in f:
+                item = line.strip()
+                self.labels.append(item)
+
+        with open(file_list, encoding=get_encoding(file_list)) as f:
+            for line in f:
+                if line.count(" ") > 1:
+                    raise Exception(
+                        "A space is defined as the separator, but it exists in image or label name {}."
+                        .format(line))
+                items = line.strip().split()
+                items[0] = path_normalization(items[0])
+                items[1] = path_normalization(items[1])
+                full_path_im = osp.join(data_dir, items[0])
+                full_path_label = osp.join(data_dir, items[1])
+                if not osp.exists(full_path_im):
+                    raise IOError('The image file {} is not exist!'.format(
+                        full_path_im))
+                if not osp.exists(full_path_label):
+                    raise IOError('The image file {} is not exist!'.format(
+                        full_path_label))
+                self.file_list.append([full_path_im, full_path_label])
+        self.num_samples = len(self.file_list)
+
+    @staticmethod
+    def decode_image(im, label):
+        if isinstance(im, np.ndarray):
+            if len(im.shape) != 3:
+                raise Exception(
+                    "im should be 3-dimensions, but now is {}-dimensions".
+                    format(len(im.shape)))
+        else:
+            try:
+                im = cv2.imread(im)
+            except:
+                raise ValueError('Can\'t read The image file {}!'.format(im))
+        im = im.astype('float32')
+        if label is not None:
+            if isinstance(label, np.ndarray):
+                if len(label.shape) != 2:
+                    raise Exception(
+                        "label should be 2-dimensions, but now is {}-dimensions".
+                        format(len(label.shape)))
+
+            else:
+                try:
+                    label = np.asarray(Image.open(label))
+                except:
+                    ValueError('Can\'t read The label file {}!'.format(label))
+        im_height, im_width, _ = im.shape
+        label_height, label_width = label.shape
+        if im_height != label_height or im_width != label_width:
+            raise Exception(
+                "The height or width of the image is not same as the label")
+        return (im, label)
+
+    def _get_shape(self):
+        max_height = max(self.im_height_list)
+        max_width = max(self.im_width_list)
+        min_height = min(self.im_height_list)
+        min_width = min(self.im_width_list)
+        shape_info = {
+            'max_height': max_height,
+            'max_width': max_width,
+            'min_height': min_height,
+            'min_width': min_width,
+        }
+        return shape_info
+
+    def _get_label_pixel_info(self):
+        pixel_num = np.dot(self.im_height_list, self.im_width_list)
+        label_pixel_info = dict()
+        for label_value, label_value_num in zip(self.label_value_list,
+                                                self.label_value_num_list):
+            for v, n in zip(label_value, label_value_num):
+                if v not in label_pixel_info.keys():
+                    label_pixel_info[v] = [n, float(n) / float(pixel_num)]
+                else:
+                    label_pixel_info[v][0] += n
+                    label_pixel_info[v][1] += float(n) / float(pixel_num)
+
+        return label_pixel_info
+
+    def _get_image_pixel_info(self):
+        channel = max([len(im_value) for im_value in self.im_value_list])
+        im_pixel_info = [dict() for c in range(channel)]
+        for im_value, im_value_num in zip(self.im_value_list,
+                                          self.im_value_num_list):
+            for c in range(channel):
+                for v, n in zip(im_value[c], im_value_num[c]):
+                    if v not in im_pixel_info[c].keys():
+                        im_pixel_info[c][v] = n
+                    else:
+                        im_pixel_info[c][v] += n
+        mode = osp.split(self.file_list_path)[-1].split('.')[0]
+        with open(
+                osp.join(self.data_dir,
+                         '{}_image_pixel_info.pkl'.format(mode)), 'wb') as f:
+            pickle.dump(im_pixel_info, f)
+
+        import matplotlib.pyplot as plt
+        plot_id = (channel // 3 + 1) * 100 + 31
+        for c in range(channel):
+            if c > 8:
+                continue
+            plt.subplot(plot_id + c)
+            plt.bar(im_pixel_info[c].keys(),
+                    im_pixel_info[c].values(),
+                    width=1,
+                    log=True)
+            plt.xlabel('image pixel value')
+            plt.ylabel('number')
+            plt.title('channel={}'.format(c))
+        plt.savefig(
+            osp.join(self.data_dir, '{}_image_pixel_info.png'.format(mode)),
+            dpi=800)
+        plt.close()
+        return im_pixel_info
+
+    def _get_mean_std(self):
+        im_mean = np.asarray(self.im_mean_list)
+        im_mean = im_mean.sum(axis=0)
+        im_mean = im_mean / len(self.file_list)
+        im_mean /= 255.
+
+        im_std = np.asarray(self.im_std_list)
+        im_std = im_std.sum(axis=0)
+        im_std = im_std / len(self.file_list)
+        im_std /= 255.
+
+        return (im_mean, im_std)
+
+    def _get_image_info(self, start, end):
+        for id in range(start, end):
+            full_path_im, full_path_label = self.file_list[id]
+            image, label = self.decode_image(full_path_im, full_path_label)
+
+            height, width, channel = image.shape
+            self.im_height_list[id] = height
+            self.im_width_list[id] = width
+            self.im_channel_list[id] = channel
+
+            self.im_mean_list[
+                id] = [np.mean(image[:, :, c]) for c in range(channel)]
+            self.im_std_list[
+                id] = [np.mean(image[:, :, c]) for c in range(channel)]
+            for c in range(channel):
+                unique, counts = np.unique(image[:, :, c], return_counts=True)
+                self.im_value_list[id].extend([unique])
+                self.im_value_num_list[id].extend([counts])
+
+            unique, counts = np.unique(label, return_counts=True)
+            self.label_value_list[id] = unique
+            self.label_value_num_list[id] = counts
+
+    def _get_clipped_mean_std(self, start, end, clip_min_value,
+                              clip_max_value):
+        for id in range(start, end):
+            full_path_im, full_path_label = self.file_list[id]
+            image, label = self.decode_image(full_path_im, full_path_label)
+            for c in range(self.channel_num):
+                np.clip(
+                    image[:, :, c],
+                    clip_min_value[c],
+                    clip_max_value[c],
+                    out=image[:, :, c])
+                image[:, :, c] -= clip_min_value[c]
+                image[:, :, c] /= clip_max_value[c] - clip_min_value[c]
+            self.clipped_im_mean_list[id] = [
+                image[:, :, c].mean() for c in range(self.channel_num)
+            ]
+            self.clipped_im_std_list[
+                id] = [image[:, :, c].std() for c in range(self.channel_num)]
+
+    def analysis(self):
+        self.im_mean_list = [[] for i in range(len(self.file_list))]
+        self.im_std_list = [[] for i in range(len(self.file_list))]
+        self.im_value_list = [[] for i in range(len(self.file_list))]
+        self.im_value_num_list = [[] for i in range(len(self.file_list))]
+        self.im_height_list = np.zeros(len(self.file_list), dtype='int32')
+        self.im_width_list = np.zeros(len(self.file_list), dtype='int32')
+        self.im_channel_list = np.zeros(len(self.file_list), dtype='int32')
+        self.label_value_list = [[] for i in range(len(self.file_list))]
+        self.label_value_num_list = [[] for i in range(len(self.file_list))]
+
+        num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8
+        num_workers = 6
+        threads = []
+        one_worker_file = len(self.file_list) // num_workers
+        for i in range(num_workers):
+            start = one_worker_file * i
+            end = one_worker_file * (
+                i + 1) if i < num_workers - 1 else len(self.file_list)
+            t = threading.Thread(
+                target=self._get_image_info, args=(start, end))
+            print("====", len(self.file_list), start, end)
+            #t.daemon = True
+            threads.append(t)
+        for t in threads:
+            t.start()
+        for t in threads:
+            t.join()
+        print('ok')
+        import time
+        import sys
+        sys.exit(0)
+        time.sleep(1000000)
+        return
+
+        #self._get_image_info(0, len(self.file_list))
+        unique, counts = np.unique(self.im_channel_list, return_counts=True)
+        print('==== unique')
+        if len(unique) > 1:
+            raise Exception("There are {} kinds of image channels: {}.".format(
+                len(unique), unique[:]))
+        self.channel_num = unique[0]
+        shape_info = self._get_shape()
+        print('==== shape_info')
+        self.max_height = shape_info['max_height']
+        self.max_width = shape_info['max_width']
+        self.min_height = shape_info['min_height']
+        self.min_width = shape_info['min_width']
+        self.label_pixel_info = self._get_label_pixel_info()
+        print('==== label_pixel_info')
+        self.im_pixel_info = self._get_image_pixel_info()
+        print('==== im_pixel_info')
+        im_mean, im_std = self._get_mean_std()
+        print('==== get_mean_std')
+        max_im_value = list()
+        min_im_value = list()
+        for c in range(self.channel_num):
+            max_im_value.append(max(self.im_pixel_info[c].keys()))
+            min_im_value.append(min(self.im_pixel_info[c].keys()))
+        self.max_im_value = np.asarray(max_im_value)
+        self.min_im_value = np.asarray(min_im_value)
+
+        logging.info(
+            "############## The analysis results are as follows ##############\n"
+        )
+        logging.info("{} samples in file {}\n".format(
+            len(self.file_list), self.file_list_path))
+        logging.info("Maximal image height: {} Maximal image width: {}.\n".
+                     format(self.max_height, self.max_width))
+        logging.info("Minimal image height: {} Minimal image width: {}.\n".
+                     format(self.min_height, self.min_width))
+        logging.info("Image channel is {}.\n".format(self.channel_num))
+        logging.info(
+            "Image mean value: {} Image standard deviation: {} (normalized by 255, sorted by a BGR format).\n".
+            format(im_mean, im_std))
+        logging.info(
+            "Label pixel information is shown in a format of (label_id, the number of label_id, the ratio of label_id):"
+        )
+        for v, (n, r) in self.label_pixel_info.items():
+            logging.info("({}, {}, {})".format(v, n, r))
+        mode = osp.split(self.file_list_path)[-1].split('.')[0]
+        saved_pkl_file = osp.join(self.data_dir,
+                                  '{}_image_pixel_info.pkl'.format(mode))
+        saved_png_file = osp.join(self.data_dir,
+                                  '{}_image_pixel_info.png'.format(mode))
+        logging.info(
+            "Image pixel information is saved in the file '{}' and shown in the file '{}'".
+            format(saved_pkl_file, saved_png_file))
+
+    def cal_clipvalue_ratio(self, clip_min_value, clip_max_value):
+        if len(clip_min_value) != self.channel_num or len(
+                clip_max_value) != self.channel_num:
+            raise Exception(
+                "The length of clip_min_value or clip_max_value should be equal to the number of image channel {}."
+                .format(self.channle_num))
+        for c in range(self.channel_num):
+            if clip_min_value[c] < self.min_im_value[c] or clip_min_value[
+                    c] > self.max_im_value[c]:
+                raise Exception(
+                    "Clip_min_value of the channel {} is not in [{}, {}]".
+                    format(c, self.min_im_value[c], self.max_im_value[c]))
+            if clip_max_value[c] < self.min_im_value[c] or clip_max_value[
+                    c] > self.max_im_value[c]:
+                raise Exception(
+                    "Clip_max_value of the channel {} is not in [{}, {}]".
+                    format(c, self.min_im_value[c], self.max_im_value[c]))
+            clip_pixel_num = 0
+            pixel_num = sum(self.im_pixel_info[c].values())
+            for v, n in self.im_pixel_info[c].items():
+                if v < clip_min_value[c] or v > clip_max_value[c]:
+                    clip_pixel_num += n
+            logging.info("Channel {}, the ratio of pixels to be clipped = {}".
+                         format(c, clip_pixel_num / pixel_num))
+
+    def cal_clipped_mean_std(self, clip_min_value, clip_max_value):
+        for c in range(self.channel_num):
+            if clip_min_value[c] < self.min_im_value[c] or clip_min_value[
+                    c] > self.max_im_value[c]:
+                raise Exception(
+                    "Clip_min_value of the channel {} is not in [{}, {}]".
+                    format(c, self.min_im_value[c], self.max_im_value[c]))
+            if clip_max_value[c] < self.min_im_value[c] or clip_max_value[
+                    c] > self.max_im_value[c]:
+                raise Exception(
+                    "Clip_max_value of the channel {} is not in [{}, {}]".
+                    format(c, self.min_im_value[c], self.max_im_value[c]))
+
+        self.clipped_im_mean_list = [[] for i in range(len(self.file_list))]
+        self.clipped_im_std_list = [[] for i in range(len(self.file_list))]
+
+        num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8
+        threads = []
+        one_worker_file = len(self.file_list) // num_workers
+        for i in range(num_workers):
+            start = one_worker_file * i
+            end = one_worker_file * (
+                i + 1) if i < num_workers - 1 else len(self.file_list)
+            t = threading.Thread(
+                target=self._get_clipped_mean_std,
+                args=(start, end, clip_min_value, clip_max_value))
+            threads.append(t)
+        for t in threads:
+            t.setDaemon(True)
+            t.start()
+        t.join()
+
+        im_mean = np.asarray(self.clipped_im_mean_list)
+        im_mean = im_mean.sum(axis=0)
+        im_mean = im_mean / len(self.file_list)
+
+        im_std = np.asarray(self.clipped_im_std_list)
+        im_std = im_std.sum(axis=0)
+        im_std = im_std / len(self.file_list)
+
+        logging.info(
+            "Image mean value: {} Image standard deviation: {} (normalized by (clip_max_value - clip_min_value)).\n".
+            format(im_mean, im_std))

+ 4 - 1
paddlex/cv/datasets/seg_dataset.py

@@ -20,7 +20,6 @@ import paddlex.utils.logging as logging
 from paddlex.utils import path_normalization
 from .dataset import Dataset
 from .dataset import get_encoding
-from .dataset import is_pic
 
 
 class SegDataset(Dataset):
@@ -64,6 +63,10 @@ class SegDataset(Dataset):
                     self.labels.append(item)
         with open(file_list, encoding=get_encoding(file_list)) as f:
             for line in f:
+                if line.count(" ") > 1:
+                    raise Exception(
+                        "A space is defined as the separator, but it exists in image or label name {}."
+                        .format(line))
                 items = line.strip().split()
                 items[0] = path_normalization(items[0])
                 items[1] = path_normalization(items[1])

+ 3 - 0
paddlex/cv/models/deeplabv3p.py

@@ -65,6 +65,7 @@ class DeepLabv3p(BaseAPI):
 
     def __init__(self,
                  num_classes=2,
+                 input_channel=3,
                  backbone='MobileNetV2_x1.0',
                  output_stride=16,
                  aspp_with_sep_conv=True,
@@ -114,6 +115,7 @@ class DeepLabv3p(BaseAPI):
 
         self.backbone = backbone
         self.num_classes = num_classes
+        self.input_channel = input_channel
         self.use_bce_loss = use_bce_loss
         self.use_dice_loss = use_dice_loss
         self.class_weight = class_weight
@@ -215,6 +217,7 @@ class DeepLabv3p(BaseAPI):
     def build_net(self, mode='train'):
         model = paddlex.cv.nets.segmentation.DeepLabv3p(
             self.num_classes,
+            input_channel=self.input_channel,
             mode=mode,
             backbone=self._get_backbone(self.backbone),
             output_stride=self.output_stride,

+ 3 - 0
paddlex/cv/models/fast_scnn.py

@@ -48,6 +48,7 @@ class FastSCNN(DeepLabv3p):
 
     def __init__(self,
                  num_classes=2,
+                 input_channel=3,
                  use_bce_loss=False,
                  use_dice_loss=False,
                  class_weight=None,
@@ -86,6 +87,7 @@ class FastSCNN(DeepLabv3p):
             )
 
         self.num_classes = num_classes
+        self.input_channel = input_channel
         self.use_bce_loss = use_bce_loss
         self.use_dice_loss = use_dice_loss
         self.class_weight = class_weight
@@ -97,6 +99,7 @@ class FastSCNN(DeepLabv3p):
     def build_net(self, mode='train'):
         model = paddlex.cv.nets.segmentation.FastSCNN(
             self.num_classes,
+            input_channel=self.input_channel,
             mode=mode,
             use_bce_loss=self.use_bce_loss,
             use_dice_loss=self.use_dice_loss,

+ 3 - 0
paddlex/cv/models/hrnet.py

@@ -44,6 +44,7 @@ class HRNet(DeepLabv3p):
 
     def __init__(self,
                  num_classes=2,
+                 input_channel=3,
                  width=18,
                  use_bce_loss=False,
                  use_dice_loss=False,
@@ -72,6 +73,7 @@ class HRNet(DeepLabv3p):
                     'Expect class_weight is a list or string but receive {}'.
                     format(type(class_weight)))
         self.num_classes = num_classes
+        self.input_channel = input_channel
         self.width = width
         self.use_bce_loss = use_bce_loss
         self.use_dice_loss = use_dice_loss
@@ -83,6 +85,7 @@ class HRNet(DeepLabv3p):
     def build_net(self, mode='train'):
         model = paddlex.cv.nets.segmentation.HRNet(
             self.num_classes,
+            input_channel=self.input_channel,
             width=self.width,
             mode=mode,
             use_bce_loss=self.use_bce_loss,

+ 1 - 1
paddlex/cv/models/ppyolo.py

@@ -36,7 +36,7 @@ class PPYOLO(BaseAPI):
 
     Args:
         num_classes (int): 类别数。默认为80。
-        backbone (str): PPYOLO的backbone网络,取值范围为['ResNet50_vd']。默认为'ResNet50_vd'。
+        backbone (str): PPYOLO的backbone网络,取值范围为['ResNet50_vd_ssld']。默认为'ResNet50_vd_ssld'。
         with_dcn_v2 (bool): Backbone是否使用DCNv2结构。默认为True。
         anchors (list|tuple): anchor框的宽度和高度,为None时表示使用默认值
                     [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],

+ 7 - 2
paddlex/cv/nets/segmentation/deeplabv3p.py

@@ -72,6 +72,7 @@ class DeepLabv3p(object):
     def __init__(self,
                  num_classes,
                  backbone,
+                 input_channel=3,
                  mode='train',
                  output_stride=16,
                  aspp_with_sep_conv=True,
@@ -115,6 +116,7 @@ class DeepLabv3p(object):
                     format(type(class_weight)))
 
         self.num_classes = num_classes
+        self.input_channel = input_channel
         self.backbone = backbone
         self.mode = mode
         self.use_bce_loss = use_bce_loss
@@ -402,13 +404,16 @@ class DeepLabv3p(object):
 
         if self.fixed_input_shape is not None:
             input_shape = [
-                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+                None, self.input_channel, self.fixed_input_shape[1],
+                self.fixed_input_shape[0]
             ]
             inputs['image'] = fluid.data(
                 dtype='float32', shape=input_shape, name='image')
         else:
             inputs['image'] = fluid.data(
-                dtype='float32', shape=[None, 3, None, None], name='image')
+                dtype='float32',
+                shape=[None, self.input_channel, None, None],
+                name='image')
         if self.mode == 'train':
             inputs['label'] = fluid.data(
                 dtype='int32', shape=[None, 1, None, None], name='label')

+ 7 - 2
paddlex/cv/nets/segmentation/fast_scnn.py

@@ -33,6 +33,7 @@ from .model_utils.loss import bce_loss
 class FastSCNN(object):
     def __init__(self,
                  num_classes,
+                 input_channel=3,
                  mode='train',
                  use_bce_loss=False,
                  use_dice_loss=False,
@@ -62,6 +63,7 @@ class FastSCNN(object):
                     format(type(class_weight)))
 
         self.num_classes = num_classes
+        self.input_channel = input_channel
         self.mode = mode
         self.use_bce_loss = use_bce_loss
         self.use_dice_loss = use_dice_loss
@@ -137,13 +139,16 @@ class FastSCNN(object):
         inputs = OrderedDict()
         if self.fixed_input_shape is not None:
             input_shape = [
-                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+                None, self.input_channel, self.fixed_input_shape[1],
+                self.fixed_input_shape[0]
             ]
             inputs['image'] = fluid.data(
                 dtype='float32', shape=input_shape, name='image')
         else:
             inputs['image'] = fluid.data(
-                dtype='float32', shape=[None, 3, None, None], name='image')
+                dtype='float32',
+                shape=[None, self.input_channel, None, None],
+                name='image')
         if self.mode == 'train':
             inputs['label'] = fluid.data(
                 dtype='int32', shape=[None, 1, None, None], name='label')

+ 7 - 2
paddlex/cv/nets/segmentation/hrnet.py

@@ -32,6 +32,7 @@ import paddlex
 class HRNet(object):
     def __init__(self,
                  num_classes,
+                 input_channel=3,
                  mode='train',
                  width=18,
                  use_bce_loss=False,
@@ -61,6 +62,7 @@ class HRNet(object):
                     format(type(class_weight)))
 
         self.num_classes = num_classes
+        self.input_channel = input_channel
         self.mode = mode
         self.use_bce_loss = use_bce_loss
         self.use_dice_loss = use_dice_loss
@@ -136,13 +138,16 @@ class HRNet(object):
 
         if self.fixed_input_shape is not None:
             input_shape = [
-                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+                None, self.input_channel, self.fixed_input_shape[1],
+                self.fixed_input_shape[0]
             ]
             inputs['image'] = fluid.data(
                 dtype='float32', shape=input_shape, name='image')
         else:
             inputs['image'] = fluid.data(
-                dtype='float32', shape=[None, 3, None, None], name='image')
+                dtype='float32',
+                shape=[None, self.input_channel, None, None],
+                name='image')
         if self.mode == 'train':
             inputs['label'] = fluid.data(
                 dtype='int32', shape=[None, 1, None, None], name='label')

+ 17 - 2
paddlex/cv/transforms/seg_transforms.py

@@ -74,8 +74,22 @@ class Compose(SegTransform):
                 raise ValueError('Can\'t read The image file {}!'.format(im))
         im = im.astype('float32')
         if label is not None:
-            if not isinstance(label, np.ndarray):
-                label = np.asarray(Image.open(label))
+            if isinstance(label, np.ndarray):
+                if len(label.shape) != 2:
+                    raise Exception(
+                        "label should be 2-dimensions, but now is {}-dimensions".
+                        format(len(label.shape)))
+
+            else:
+                try:
+                    label = np.asarray(Image.open(label))
+                except:
+                    ValueError('Can\'t read The label file {}!'.format(label))
+        im_height, im_width, _ = im.shape
+        label_height, label_width = label.shape
+        if im_height != label_height or im_width != label_width:
+            raise Exception(
+                "The height or width of the image is not same as the label")
         return (im, label)
 
     def __call__(self, im, im_info=None, label=None):
@@ -605,6 +619,7 @@ class Normalize(SegTransform):
         mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
         std = np.array(self.std)[np.newaxis, np.newaxis, :]
         im = normalize(im, mean, std, self.min_val, self.max_val)
+        im = im.astype('float32')
 
         if label is None:
             return (im, im_info)

+ 0 - 0
paddlex/RemoteSensing/__init__.py → paddlex/remotesensing/__init__.py


+ 5 - 5
paddlex/RemoteSensing/train_demo.py → paddlex/remotesensing/train_demo.py

@@ -16,7 +16,7 @@
 import os.path as osp
 import argparse
 from paddlex.seg import transforms
-import paddlex.RemoteSensing.transforms as custom_transforms
+import paddlex.remotesensing.transforms as rs_transforms
 import paddlex as pdx
 
 
@@ -110,22 +110,22 @@ train_transforms = transforms.Compose([
     transforms.RandomHorizontalFlip(0.5),
     transforms.ResizeStepScaling(0.5, 2.0, 0.25),
     transforms.RandomPaddingCrop(im_padding_value=[1000] * channel),
-    custom_transforms.Clip(
+    rs_transforms.Clip(
         min_val=clip_min_value, max_val=clip_max_value),
     transforms.Normalize(
         min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
 ])
 
-train_transforms.decode_image = custom_transforms.decode_image
+train_transforms.decode_image = rs_transforms.decode_image
 
 eval_transforms = transforms.Compose([
-    custom_transforms.Clip(
+    rs_transforms.Clip(
         min_val=clip_min_value, max_val=clip_max_value),
     transforms.Normalize(
         min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
 ])
 
-eval_transforms.decode_image = custom_transforms.decode_image
+eval_transforms.decode_image = rs_transforms.decode_image
 
 train_list = osp.join(data_dir, 'train.txt')
 val_list = osp.join(data_dir, 'val.txt')

+ 25 - 3
paddlex/RemoteSensing/transforms.py → paddlex/remotesensing/transforms.py

@@ -2,17 +2,23 @@ import os
 import os.path as osp
 import imghdr
 import gdal
+gdal.UseExceptions()
+gdal.PushErrorHandler('CPLQuietErrorHandler')
 import numpy as np
 from PIL import Image
 
 from paddlex.seg import transforms
+import paddlex.utils.logging as logging
 
 
 def read_img(img_path):
     img_format = imghdr.what(img_path)
     name, ext = osp.splitext(img_path)
     if img_format == 'tiff' or ext == '.img':
-        dataset = gdal.Open(img_path)
+        try:
+            dataset = gdal.Open(img_path)
+        except:
+            logging.error(gdal.GetLastErrorMsg())
         if dataset == None:
             raise Exception('Can not open', img_path)
         im_data = dataset.ReadAsArray()
@@ -36,9 +42,25 @@ def decode_image(im, label):
             im = read_img(im)
         except:
             raise ValueError('Can\'t read The image file {}!'.format(im))
+    im = im.astype('float32')
+
     if label is not None:
-        if not isinstance(label, np.ndarray):
-            label = read_img(label)
+        if isinstance(label, np.ndarray):
+            if len(label.shape) != 2:
+                raise Exception(
+                    "label should be 2-dimensions, but now is {}-dimensions".
+                    format(len(label.shape)))
+
+        else:
+            try:
+                label = np.asarray(Image.open(label))
+            except:
+                ValueError('Can\'t read The label file {}!'.format(label))
+    im_height, im_width, _ = im.shape
+    label_height, label_width = label.shape
+    if im_height != label_height or im_width != label_width:
+        raise Exception(
+            "The height or width of the image is not same as the label")
     return (im, label)
 
 

+ 0 - 0
paddlex/remotesensing/utils/__init__.py


+ 506 - 0
paddlex/remotesensing/utils/analyse.py

@@ -0,0 +1,506 @@
+# coding: utf8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import os
+import os.path as osp
+import sys
+import argparse
+from PIL import Image
+from tqdm import tqdm
+import imghdr
+import logging
+import pickle
+import gdal
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description='Data analyse and data check before training.')
+    parser.add_argument(
+        '--data_dir',
+        dest='data_dir',
+        help='Dataset directory',
+        default=None,
+        type=str)
+    parser.add_argument(
+        '--num_classes',
+        dest='num_classes',
+        help='Number of classes',
+        default=None,
+        type=int)
+    parser.add_argument(
+        '--separator',
+        dest='separator',
+        help='file list separator',
+        default=" ",
+        type=str)
+    parser.add_argument(
+        '--ignore_index',
+        dest='ignore_index',
+        help='Ignored class index',
+        default=255,
+        type=int)
+    if len(sys.argv) == 1:
+        parser.print_help()
+        sys.exit(1)
+    return parser.parse_args()
+
+
+def read_img(img_path):
+    img_format = imghdr.what(img_path)
+    name, ext = osp.splitext(img_path)
+    if img_format == 'tiff' or ext == '.img':
+        dataset = gdal.Open(img_path)
+        if dataset == None:
+            raise Exception('Can not open', img_path)
+        im_data = dataset.ReadAsArray()
+        return im_data.transpose((1, 2, 0))
+    elif ext == '.npy':
+        return np.load(img_path)
+    else:
+        raise Exception('Not support {} image format!'.format(ext))
+
+
+def img_pixel_statistics(img, img_value_num, img_min_value, img_max_value):
+    channel = img.shape[2]
+    means = np.zeros(channel)
+    stds = np.zeros(channel)
+    for k in range(channel):
+        img_k = img[:, :, k]
+
+        # count mean, std
+        means[k] = np.mean(img_k)
+        stds[k] = np.std(img_k)
+
+        # count min, max
+        min_value = np.min(img_k)
+        max_value = np.max(img_k)
+        if img_max_value[k] < max_value:
+            img_max_value[k] = max_value
+        if img_min_value[k] > min_value:
+            img_min_value[k] = min_value
+
+        # count the distribution of image value, value number
+        unique, counts = np.unique(img_k, return_counts=True)
+        add_num = []
+        max_unique = np.max(unique)
+        add_len = max_unique - len(img_value_num[k]) + 1
+        if add_len > 0:
+            img_value_num[k] += ([0] * add_len)
+        for i in range(len(unique)):
+            value = unique[i]
+            img_value_num[k][value] += counts[i]
+
+        img_value_num[k] += add_num
+    return means, stds, img_min_value, img_max_value, img_value_num
+
+
+def data_distribution_statistics(data_dir, img_value_num, logger):
+    """count the distribution of image value, value number
+    """
+    logger.info(
+        "\n-----------------------------\nThe whole dataset statistics...")
+
+    if not img_value_num:
+        return
+    logger.info("\nImage pixel statistics:")
+    total_ratio = []
+    [total_ratio.append([]) for i in range(len(img_value_num))]
+    for k in range(len(img_value_num)):
+        total_num = sum(img_value_num[k])
+        total_ratio[k] = [i / total_num for i in img_value_num[k]]
+        total_ratio[k] = np.around(total_ratio[k], decimals=4)
+    with open(os.path.join(data_dir, 'img_pixel_statistics.pkl'), 'wb') as f:
+        pickle.dump([total_ratio, img_value_num], f)
+
+
+def data_range_statistics(img_min_value, img_max_value, logger):
+    """print min value, max value
+    """
+    logger.info("value range: \nimg_min_value = {} \nimg_max_value = {}".
+                format(img_min_value, img_max_value))
+
+
+def cal_normalize_coefficient(total_means, total_stds, total_img_num, logger):
+    """count mean, std
+    """
+    total_means = total_means / total_img_num
+    total_stds = total_stds / total_img_num
+    logger.info("\nCount the channel-by-channel mean and std of the image:\n"
+                "mean = {}\nstd = {}".format(total_means, total_stds))
+
+
+def error_print(str):
+    return "".join(["\nNOT PASS ", str])
+
+
+def correct_print(str):
+    return "".join(["\nPASS ", str])
+
+
+def pil_imread(file_path):
+    """read pseudo-color label"""
+    im = Image.open(file_path)
+    return np.asarray(im)
+
+
+def get_img_shape_range(img, max_width, max_height, min_width, min_height):
+    """获取图片最大和最小宽高"""
+    img_shape = img.shape
+    height, width = img_shape[0], img_shape[1]
+    max_height = max(height, max_height)
+    max_width = max(width, max_width)
+    min_height = min(height, min_height)
+    min_width = min(width, min_width)
+    return max_width, max_height, min_width, min_height
+
+
+def get_img_channel_num(img, img_channels):
+    """获取图像的通道数"""
+    img_shape = img.shape
+    if img_shape[-1] not in img_channels:
+        img_channels.append(img_shape[-1])
+    return img_channels
+
+
+def is_label_single_channel(label):
+    """判断标签是否为灰度图"""
+    label_shape = label.shape
+    if len(label_shape) == 2:
+        return True
+    else:
+        return False
+
+
+def image_label_shape_check(img, label):
+    """
+    验证图像和标注的大小是否匹配
+    """
+
+    flag = True
+    img_height = img.shape[0]
+    img_width = img.shape[1]
+    label_height = label.shape[0]
+    label_width = label.shape[1]
+
+    if img_height != label_height or img_width != label_width:
+        flag = False
+    return flag
+
+
+def ground_truth_check(label, label_path):
+    """
+    验证标注图像的格式
+    统计标注图类别和像素数
+    params:
+        label: 标注图
+        label_path: 标注图路径
+    return:
+        png_format: 返回是否是png格式图片
+        unique: 返回标注类别
+        counts: 返回标注的像素数
+    """
+    if imghdr.what(label_path) == "png":
+        png_format = True
+    else:
+        png_format = False
+
+    unique, counts = np.unique(label, return_counts=True)
+
+    return png_format, unique, counts
+
+
+def sum_label_check(label_classes, num_of_each_class, ignore_index,
+                    num_classes, total_label_classes, total_num_of_each_class):
+    """
+    统计所有标注图上的类别和每个类别的像素数
+    params:
+        label_classes: 标注类别
+        num_of_each_class: 各个类别的像素数目
+    """
+    is_label_correct = True
+
+    if ignore_index in label_classes:
+        label_classes2 = np.delete(label_classes,
+                                   np.where(label_classes == ignore_index))
+    else:
+        label_classes2 = label_classes
+    if min(label_classes2) < 0 or max(label_classes2) > num_classes - 1:
+        is_label_correct = False
+    add_class = []
+    add_num = []
+    for i in range(len(label_classes)):
+        gi = label_classes[i]
+        if gi in total_label_classes:
+            j = total_label_classes.index(gi)
+            total_num_of_each_class[j] += num_of_each_class[i]
+        else:
+            add_class.append(gi)
+            add_num.append(num_of_each_class[i])
+    total_num_of_each_class += add_num
+    total_label_classes += add_class
+    return is_label_correct, total_num_of_each_class, total_label_classes
+
+
+def label_class_check(num_classes, total_label_classes,
+                      total_num_of_each_class, wrong_labels, logger):
+    """
+    检查实际标注类别是否和配置参数`num_classes`,`ignore_index`匹配。
+
+    **NOTE:**
+    标注图像类别数值必须在[0~(`num_classes`-1)]范围内或者为`ignore_index`。
+    标注类别最好从0开始,否则可能影响精度。
+    """
+    total_ratio = total_num_of_each_class / sum(total_num_of_each_class)
+    total_ratio = np.around(total_ratio, decimals=4)
+    total_nc = sorted(
+        zip(total_label_classes, total_ratio, total_num_of_each_class))
+    if len(wrong_labels) == 0 and not total_nc[0][0]:
+        logger.info(correct_print("label class check!"))
+    else:
+        logger.info(error_print("label class check!"))
+        if total_nc[0][0]:
+            logger.info("Warning: label classes should start from 0")
+        if len(wrong_labels) > 0:
+            logger.info("fatal error: label class is out of range [0, {}]".
+                        format(num_classes - 1))
+            for i in wrong_labels:
+                logger.debug(i)
+    return total_nc
+
+
+def label_class_statistics(total_nc, logger):
+    """
+    对标注图像进行校验,输出校验结果
+    """
+    logger.info("\nLabel class statistics:\n"
+                "(label class, percentage, total pixel number) = {} ".format(
+                    total_nc))
+
+
+def shape_check(shape_unequal_image, logger):
+    """输出shape校验结果"""
+    if len(shape_unequal_image) == 0:
+        logger.info(correct_print("shape check"))
+        logger.info("All images are the same shape as the labels")
+    else:
+        logger.info(error_print("shape check"))
+        logger.info(
+            "Some images are not the same shape as the labels as follow: ")
+        for i in shape_unequal_image:
+            logger.debug(i)
+
+
+def separator_check(wrong_lines, file_list, separator, logger):
+    """检查分割符是否复合要求"""
+    if len(wrong_lines) == 0:
+        logger.info(
+            correct_print(
+                file_list.split(os.sep)[-1] + " DATASET.separator check"))
+    else:
+        logger.info(
+            error_print(
+                file_list.split(os.sep)[-1] + " DATASET.separator check"))
+        logger.info("The following list is not separated by {}".format(
+            separator))
+        for i in wrong_lines:
+            logger.debug(i)
+
+
+def imread_check(imread_failed, logger):
+    if len(imread_failed) == 0:
+        logger.info(correct_print("dataset reading check"))
+        logger.info("All images can be read successfully")
+    else:
+        logger.info(error_print("dataset reading check"))
+        logger.info("Failed to read {} images".format(len(imread_failed)))
+        for i in imread_failed:
+            logger.debug(i)
+
+
+def single_channel_label_check(label_not_single_channel, logger):
+    if len(label_not_single_channel) == 0:
+        logger.info(correct_print("label single_channel check"))
+        logger.info("All label images are single_channel")
+    else:
+        logger.info(error_print("label single_channel check"))
+        logger.info(
+            "{} label images are not single_channel\nLabel pixel statistics may be insignificant"
+            .format(len(label_not_single_channel)))
+        for i in label_not_single_channel:
+            logger.debug(i)
+
+
+def img_shape_range_statistics(max_width, min_width, max_height, min_height,
+                               logger):
+    logger.info("\nImage size statistics:")
+    logger.info(
+        "max width = {}  min width = {}  max height = {}  min height = {}".
+        format(max_width, min_width, max_height, min_height))
+
+
+def img_channels_statistics(img_channels, logger):
+    logger.info("\nImage channels statistics\nImage channels = {}".format(
+        np.unique(img_channels)))
+
+
+def data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
+                           logger):
+    train_file_list = osp.join(data_dir, 'train.txt')
+    val_file_list = osp.join(data_dir, 'val.txt')
+    test_file_list = osp.join(data_dir, 'test.txt')
+    total_img_num = 0
+    has_label = False
+    for file_list in [train_file_list, val_file_list, test_file_list]:
+        # initialization
+        imread_failed = []
+        max_width = 0
+        max_height = 0
+        min_width = sys.float_info.max
+        min_height = sys.float_info.max
+        label_not_single_channel = []
+        shape_unequal_image = []
+        wrong_labels = []
+        wrong_lines = []
+        total_label_classes = []
+        total_num_of_each_class = []
+        img_channels = []
+
+        with open(file_list, 'r') as fid:
+            logger.info("\n-----------------------------\nCheck {}...".format(
+                file_list))
+            lines = fid.readlines()
+            if not lines:
+                logger.info("File list is empty!")
+                continue
+            for line in tqdm(lines):
+                line = line.strip()
+                parts = line.split(separator)
+                if len(parts) == 1:
+                    if file_list == train_file_list or file_list == val_file_list:
+                        logger.info("Train or val list must have labels!")
+                        break
+                    img_name = parts
+                    img_path = os.path.join(data_dir, img_name[0])
+                    try:
+                        img = read_img(img_path)
+                    except Exception as e:
+                        imread_failed.append((line, str(e)))
+                        continue
+                elif len(parts) == 2:
+                    has_label = True
+                    img_name, label_name = parts[0], parts[1]
+                    img_path = os.path.join(data_dir, img_name)
+                    label_path = os.path.join(data_dir, label_name)
+                    try:
+                        img = read_img(img_path)
+                        label = pil_imread(label_path)
+                    except Exception as e:
+                        imread_failed.append((line, str(e)))
+                        continue
+
+                    is_single_channel = is_label_single_channel(label)
+                    if not is_single_channel:
+                        label_not_single_channel.append(line)
+                        continue
+                    is_equal_img_label_shape = image_label_shape_check(img,
+                                                                       label)
+                    if not is_equal_img_label_shape:
+                        shape_unequal_image.append(line)
+                    png_format, label_classes, num_of_each_class = ground_truth_check(
+                        label, label_path)
+                    is_label_correct, total_num_of_each_class, total_label_classes = sum_label_check(
+                        label_classes, num_of_each_class, ignore_index,
+                        num_classes, total_label_classes,
+                        total_num_of_each_class)
+                    if not is_label_correct:
+                        wrong_labels.append(line)
+                else:
+                    wrong_lines.append(lines)
+                    continue
+
+                if total_img_num == 0:
+                    channel = img.shape[2]
+                    total_means = np.zeros(channel)
+                    total_stds = np.zeros(channel)
+                    img_min_value = [sys.float_info.max] * channel
+                    img_max_value = [0] * channel
+                    img_value_num = []
+                    [img_value_num.append([]) for i in range(channel)]
+                means, stds, img_min_value, img_max_value, img_value_num = img_pixel_statistics(
+                    img, img_value_num, img_min_value, img_max_value)
+                total_means += means
+                total_stds += stds
+                max_width, max_height, min_width, min_height = get_img_shape_range(
+                    img, max_width, max_height, min_width, min_height)
+                img_channels = get_img_channel_num(img, img_channels)
+                total_img_num += 1
+
+            # data check
+            separator_check(wrong_lines, file_list, separator, logger)
+            imread_check(imread_failed, logger)
+            if has_label:
+                single_channel_label_check(label_not_single_channel, logger)
+                shape_check(shape_unequal_image, logger)
+                total_nc = label_class_check(num_classes, total_label_classes,
+                                             total_num_of_each_class,
+                                             wrong_labels, logger)
+
+            # data analyse on train, validation, test set.
+            img_channels_statistics(img_channels, logger)
+            img_shape_range_statistics(max_width, min_width, max_height,
+                                       min_height, logger)
+            if has_label:
+                label_class_statistics(total_nc, logger)
+    # data analyse on the whole dataset.
+    data_range_statistics(img_min_value, img_max_value, logger)
+    data_distribution_statistics(data_dir, img_value_num, logger)
+    cal_normalize_coefficient(total_means, total_stds, total_img_num, logger)
+
+
+def main():
+    args = parse_args()
+    data_dir = args.data_dir
+    ignore_index = args.ignore_index
+    num_classes = args.num_classes
+    separator = args.separator
+
+    logger = logging.getLogger()
+    logger.setLevel('DEBUG')
+    BASIC_FORMAT = "%(message)s"
+    formatter = logging.Formatter(BASIC_FORMAT)
+    sh = logging.StreamHandler()
+    sh.setFormatter(formatter)
+    sh.setLevel('INFO')
+    th = logging.FileHandler(
+        os.path.join(data_dir, 'data_analyse_and_check.log'), 'w')
+    th.setFormatter(formatter)
+    logger.addHandler(sh)
+    logger.addHandler(th)
+
+    data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
+                           logger)
+
+    print("\nDetailed error information can be viewed in {}.".format(
+        os.path.join(data_dir, 'data_analyse_and_check.log')))
+
+
+if __name__ == "__main__":
+    main()