Răsfoiți Sursa

add multichannel RemoteSensing

FlyingQianMM 5 ani în urmă
părinte
comite
564417a4be

+ 0 - 0
paddlex/RemoteSensing/__init__.py


+ 157 - 0
paddlex/RemoteSensing/train_demo.py

@@ -0,0 +1,157 @@
+# coding: utf8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path as osp
+import argparse
+from paddlex.seg import transforms
+import paddlex.RemoteSensing.transforms as custom_transforms
+import paddlex as pdx
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='RemoteSensing training')
+    parser.add_argument(
+        '--data_dir',
+        dest='data_dir',
+        help='dataset directory',
+        default=None,
+        type=str)
+    parser.add_argument(
+        '--save_dir',
+        dest='save_dir',
+        help='model save directory',
+        default=None,
+        type=str)
+    parser.add_argument(
+        '--num_classes',
+        dest='num_classes',
+        help='Number of classes',
+        default=None,
+        type=int)
+    parser.add_argument(
+        '--channel',
+        dest='channel',
+        help='number of data channel',
+        default=3,
+        type=int)
+    parser.add_argument(
+        '--clip_min_value',
+        dest='clip_min_value',
+        help='Min values for clipping data',
+        nargs='+',
+        default=None,
+        type=int)
+    parser.add_argument(
+        '--clip_max_value',
+        dest='clip_max_value',
+        help='Max values for clipping data',
+        nargs='+',
+        default=None,
+        type=int)
+    parser.add_argument(
+        '--mean',
+        dest='mean',
+        help='Data means',
+        nargs='+',
+        default=None,
+        type=float)
+    parser.add_argument(
+        '--std',
+        dest='std',
+        help='Data standard deviation',
+        nargs='+',
+        default=None,
+        type=float)
+    parser.add_argument(
+        '--num_epochs',
+        dest='num_epochs',
+        help='number of traing epochs',
+        default=100,
+        type=int)
+    parser.add_argument(
+        '--train_batch_size',
+        dest='train_batch_size',
+        help='training batch size',
+        default=4,
+        type=int)
+    parser.add_argument(
+        '--lr', dest='lr', help='learning rate', default=0.01, type=float)
+    return parser.parse_args()
+
+
+args = parse_args()
+data_dir = args.data_dir
+save_dir = args.save_dir
+num_classes = args.num_classes
+channel = args.channel
+clip_min_value = args.clip_min_value
+clip_max_value = args.clip_max_value
+mean = args.mean
+std = args.std
+num_epochs = args.num_epochs
+train_batch_size = args.train_batch_size
+lr = args.lr
+
+# 定义训练和验证时的transforms
+train_transforms = transforms.Compose([
+    transforms.RandomVerticalFlip(0.5),
+    transforms.RandomHorizontalFlip(0.5),
+    transforms.ResizeStepScaling(0.5, 2.0, 0.25),
+    transforms.RandomPaddingCrop(im_padding_value=[1000] * channel),
+    custom_transforms.Clip(
+        min_val=clip_min_value, max_val=clip_max_value),
+    transforms.Normalize(
+        min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
+])
+
+train_transforms.decode_image = custom_transforms.decode_image
+
+eval_transforms = transforms.Compose([
+    custom_transforms.Clip(
+        min_val=clip_min_value, max_val=clip_max_value),
+    transforms.Normalize(
+        min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
+])
+
+eval_transforms.decode_image = custom_transforms.decode_image
+
+train_list = osp.join(data_dir, 'train.txt')
+val_list = osp.join(data_dir, 'val.txt')
+label_list = osp.join(data_dir, 'labels.txt')
+
+train_dataset = pdx.datasets.SegDataset(
+    data_dir=data_dir,
+    file_list=train_list,
+    label_list=label_list,
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.SegDataset(
+    data_dir=data_dir,
+    file_list=val_list,
+    label_list=label_list,
+    transforms=eval_transforms)
+
+model = pdx.seg.UNet(num_classes=num_classes, input_channel=channel)
+
+model.train(
+    num_epochs=num_epochs,
+    train_dataset=train_dataset,
+    train_batch_size=train_batch_size,
+    eval_dataset=eval_dataset,
+    save_interval_epochs=5,
+    log_interval_steps=10,
+    save_dir=save_dir,
+    learning_rate=lr,
+    use_vdl=True)

+ 69 - 0
paddlex/RemoteSensing/transforms.py

@@ -0,0 +1,69 @@
+import os
+import os.path as osp
+import imghdr
+import gdal
+import numpy as np
+from PIL import Image
+
+from paddlex.seg import transforms
+
+
+def read_img(img_path):
+    img_format = imghdr.what(img_path)
+    name, ext = osp.splitext(img_path)
+    if img_format == 'tiff' or ext == '.img':
+        dataset = gdal.Open(img_path)
+        if dataset == None:
+            raise Exception('Can not open', img_path)
+        im_data = dataset.ReadAsArray()
+        return im_data.transpose((1, 2, 0))
+    elif img_format == 'png':
+        return np.asarray(Image.open(img_path))
+    elif ext == '.npy':
+        return np.load(img_path)
+    else:
+        raise Exception('Image format {} is not supported!'.format(ext))
+
+
+def decode_image(im, label):
+    if isinstance(im, np.ndarray):
+        if len(im.shape) != 3:
+            raise Exception(
+                "im should be 3-dimensions, but now is {}-dimensions".format(
+                    len(im.shape)))
+    else:
+        try:
+            im = read_img(im)
+        except:
+            raise ValueError('Can\'t read The image file {}!'.format(im))
+    if label is not None:
+        if not isinstance(label, np.ndarray):
+            label = read_img(label)
+    return (im, label)
+
+
+class Clip(transforms.SegTransform):
+    """
+    对图像上超出一定范围的数据进行裁剪。
+
+    Args:
+        min_val (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值0.
+        max_val (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值255.0.
+    """
+
+    def __init__(self, min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0]):
+        self.min_val = min_val
+        self.max_val = max_val
+        if not (isinstance(self.min_val, list) and isinstance(self.max_val,
+                                                              list)):
+            raise ValueError("{}: input type is invalid.".format(self))
+
+    def __call__(self, im, im_info=None, label=None):
+        for k in range(im.shape[2]):
+            np.clip(
+                im[:, :, k], self.min_val[k], self.max_val[k], out=im[:, :, k])
+
+        if label is None:
+            return (im, im_info)
+        else:
+            return (im, im_info, label)

+ 1 - 0
paddlex/__init__.py

@@ -32,6 +32,7 @@ from . import slim
 from . import convertor
 from . import tools
 from . import deploy
+from . import RemoteSensing
 
 try:
     import pycocotools

+ 0 - 2
paddlex/cv/datasets/seg_dataset.py

@@ -67,8 +67,6 @@ class SegDataset(Dataset):
                 items = line.strip().split()
                 items[0] = path_normalization(items[0])
                 items[1] = path_normalization(items[1])
-                if not is_pic(items[0]):
-                    continue
                 full_path_im = osp.join(data_dir, items[0])
                 full_path_label = osp.join(data_dir, items[1])
                 if not osp.exists(full_path_im):

+ 3 - 0
paddlex/cv/models/unet.py

@@ -43,6 +43,7 @@ class UNet(DeepLabv3p):
 
     def __init__(self,
                  num_classes=2,
+                 input_channel=3,
                  upsample_mode='bilinear',
                  use_bce_loss=False,
                  use_dice_loss=False,
@@ -71,6 +72,7 @@ class UNet(DeepLabv3p):
                     'Expect class_weight is a list or string but receive {}'.
                     format(type(class_weight)))
         self.num_classes = num_classes
+        self.input_channel = input_channel
         self.upsample_mode = upsample_mode
         self.use_bce_loss = use_bce_loss
         self.use_dice_loss = use_dice_loss
@@ -82,6 +84,7 @@ class UNet(DeepLabv3p):
     def build_net(self, mode='train'):
         model = paddlex.cv.nets.segmentation.UNet(
             self.num_classes,
+            input_channel=self.input_channel,
             mode=mode,
             upsample_mode=self.upsample_mode,
             use_bce_loss=self.use_bce_loss,

+ 7 - 2
paddlex/cv/nets/segmentation/unet.py

@@ -64,6 +64,7 @@ class UNet(object):
 
     def __init__(self,
                  num_classes,
+                 input_channel=3,
                  mode='train',
                  upsample_mode='bilinear',
                  use_bce_loss=False,
@@ -92,6 +93,7 @@ class UNet(object):
                     'Expect class_weight is a list or string but receive {}'.
                     format(type(class_weight)))
         self.num_classes = num_classes
+        self.input_channel = input_channel
         self.mode = mode
         self.upsample_mode = upsample_mode
         self.use_bce_loss = use_bce_loss
@@ -232,13 +234,16 @@ class UNet(object):
 
         if self.fixed_input_shape is not None:
             input_shape = [
-                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+                None, self.input_channel, self.fixed_input_shape[1],
+                self.fixed_input_shape[0]
             ]
             inputs['image'] = fluid.data(
                 dtype='float32', shape=input_shape, name='image')
         else:
             inputs['image'] = fluid.data(
-                dtype='float32', shape=[None, 3, None, None], name='image')
+                dtype='float32',
+                shape=[None, self.input_channel, None, None],
+                name='image')
         if self.mode == 'train':
             inputs['label'] = fluid.data(
                 dtype='int32', shape=[None, 1, None, None], name='label')

+ 6 - 2
paddlex/cv/transforms/ops.py

@@ -18,8 +18,12 @@ import numpy as np
 from PIL import Image, ImageEnhance
 
 
-def normalize(im, mean, std):
-    im = im / 255.0
+def normalize(im, mean, std, min_value, max_value):
+    # Rescaling (min-max normalization)
+    range_value = [max_value[i] - min_value[i] for i in range(len(max_value))]
+    im = (im - min_value) / range_value
+
+    # Standardization (Z-score Normalization)
     im -= mean
     im /= std
     return im

+ 52 - 33
paddlex/cv/transforms/seg_transforms.py

@@ -60,6 +60,24 @@ class Compose(SegTransform):
                         "Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
                     )
 
+    @staticmethod
+    def decode_image(im, label):
+        if isinstance(im, np.ndarray):
+            if len(im.shape) != 3:
+                raise Exception(
+                    "im should be 3-dimensions, but now is {}-dimensions".
+                    format(len(im.shape)))
+        else:
+            try:
+                im = cv2.imread(im)
+            except:
+                raise ValueError('Can\'t read The image file {}!'.format(im))
+        im = im.astype('float32')
+        if label is not None:
+            if not isinstance(label, np.ndarray):
+                label = np.asarray(Image.open(label))
+        return (im, label)
+
     def __call__(self, im, im_info=None, label=None):
         """
         Args:
@@ -73,24 +91,12 @@ class Compose(SegTransform):
             tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。
         """
 
-        if isinstance(im, np.ndarray):
-            if len(im.shape) != 3:
-                raise Exception(
-                    "im should be 3-dimensions, but now is {}-dimensions".
-                    format(len(im.shape)))
-        else:
-            try:
-                im = cv2.imread(im)
-            except:
-                raise ValueError('Can\'t read The image file {}!'.format(im))
-        im = im.astype('float32')
-        if im_info is None:
-            im_info = [('origin_shape', im.shape[0:2])]
+        im, label = self.decode_image(im, label)
         if self.to_rgb:
             im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+        if im_info is None:
+            im_info = [('origin_shape', im.shape[0:2])]
         if label is not None:
-            if not isinstance(label, np.ndarray):
-                label = np.asarray(Image.open(label))
             origin_label = label.copy()
         for op in self.transforms:
             if isinstance(op, SegTransform):
@@ -561,11 +567,21 @@ class Normalize(SegTransform):
         ValueError: mean或std不是list对象。std包含0。
     """
 
-    def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
+    def __init__(self,
+                 mean=[0.5, 0.5, 0.5],
+                 std=[0.5, 0.5, 0.5],
+                 min_val=[0, 0, 0],
+                 max_val=[255.0, 255.0, 255.0]):
+        self.min_val = min_val
+        self.max_val = max_val
         self.mean = mean
         self.std = std
         if not (isinstance(self.mean, list) and isinstance(self.std, list)):
             raise ValueError("{}: input type is invalid.".format(self))
+        if not (isinstance(self.min_val, list) and isinstance(self.max_val,
+                                                              list)):
+            raise ValueError("{}: input type is invalid.".format(self))
+
         from functools import reduce
         if reduce(lambda x, y: x * y, self.std) == 0:
             raise ValueError('{}: std is invalid!'.format(self))
@@ -588,7 +604,7 @@ class Normalize(SegTransform):
 
         mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
         std = np.array(self.std)[np.newaxis, np.newaxis, :]
-        im = normalize(im, mean, std)
+        im = normalize(im, mean, std, self.min_val, self.max_val)
 
         if label is None:
             return (im, im_info)
@@ -752,23 +768,26 @@ class RandomPaddingCrop(SegTransform):
             pad_height = max(crop_height - img_height, 0)
             pad_width = max(crop_width - img_width, 0)
             if (pad_height > 0 or pad_width > 0):
-                im = cv2.copyMakeBorder(
-                    im,
-                    0,
-                    pad_height,
-                    0,
-                    pad_width,
-                    cv2.BORDER_CONSTANT,
-                    value=self.im_padding_value)
+                img_channel = im.shape[2]
+                import copy
+                orig_im = copy.deepcopy(im)
+                im = np.zeros((img_height + pad_height, img_width + pad_width,
+                               img_channel)).astype(orig_im.dtype)
+                for i in range(img_channel):
+                    im[:, :, i] = np.pad(
+                        orig_im[:, :, i],
+                        pad_width=((0, pad_height), (0, pad_width)),
+                        mode='constant',
+                        constant_values=(self.im_padding_value[i],
+                                         self.im_padding_value[i]))
+
                 if label is not None:
-                    label = cv2.copyMakeBorder(
-                        label,
-                        0,
-                        pad_height,
-                        0,
-                        pad_width,
-                        cv2.BORDER_CONSTANT,
-                        value=self.label_padding_value)
+                    label = np.pad(label,
+                                   pad_width=((0, pad_height), (0, pad_width)),
+                                   mode='constant',
+                                   constant_values=(self.label_padding_value,
+                                                    self.label_padding_value))
+
                 img_height = im.shape[0]
                 img_width = im.shape[1]