瀏覽代碼

Decode and Normalize uint16/float32 image correctly; Reimplement randomdistort; Support multi-channel input for classification

FlyingQianMM 4 年之前
父節點
當前提交
ddee84d362

+ 25 - 24
docs/apis/models/classification.md

@@ -3,7 +3,7 @@
 ## paddlex.cls.ResNet50
 
 ```python
-paddlex.cls.ResNet50(num_classes=1000)
+paddlex.cls.ResNet50(num_classes=1000, input_channel=3)
 ```
 
 > 构建ResNet50分类器,并实现其训练、评估和预测。  
@@ -11,6 +11,7 @@ paddlex.cls.ResNet50(num_classes=1000)
 **参数**
 
 > - **num_classes** (int): 类别数。默认为1000。  
+> - **input_channel** (int): 输入图像的通道数量。默认为3。
 
 ### train
 
@@ -103,26 +104,26 @@ PaddleX提供了共计22种分类模型,所有分类模型均提供同`ResNet5
 
 | 模型              | 接口                    |
 | :---------------- | :---------------------- |
-| ResNet18          | paddlex.cls.ResNet18(num_classes=1000) |
-| ResNet34          | paddlex.cls.ResNet34(num_classes=1000) |
-| ResNet50          | paddlex.cls.ResNet50(num_classes=1000) |
-| ResNet50_vd       | paddlex.cls.ResNet50_vd(num_classes=1000) |
-| ResNet50_vd_ssld    | paddlex.cls.ResNet50_vd_ssld(num_classes=1000) |
-| ResNet101          | paddlex.cls.ResNet101(num_classes=1000) |
-| ResNet101_vd        | paddlex.cls.ResNet101_vd(num_classes=1000) |
-| ResNet101_vd_ssld      | paddlex.cls.ResNet101_vd_ssld(num_classes=1000) |
-| DarkNet53      | paddlex.cls.DarkNet53(num_classes=1000) |
-| MobileNetV1         | paddlex.cls.MobileNetV1(num_classes=1000) |
-| MobileNetV2       | paddlex.cls.MobileNetV2(num_classes=1000) |
-| MobileNetV3_small       | paddlex.cls.MobileNetV3_small(num_classes=1000) |
-| MobileNetV3_small_ssld  | paddlex.cls.MobileNetV3_small_ssld(num_classes=1000) |
-| MobileNetV3_large   | paddlex.cls.MobileNetV3_large(num_classes=1000) |
-| MobileNetV3_large_ssld | paddlex.cls.MobileNetV3_large_ssld(num_classes=1000) |
-| Xception65     | paddlex.cls.Xception65(num_classes=1000) |
-| Xception71     | paddlex.cls.Xception71(num_classes=1000) |
-| ShuffleNetV2     | paddlex.cls.ShuffleNetV2(num_classes=1000) |
-| DenseNet121      | paddlex.cls.DenseNet121(num_classes=1000) |
-| DenseNet161       | paddlex.cls.DenseNet161(num_classes=1000) |
-| DenseNet201       | paddlex.cls.DenseNet201(num_classes=1000) |
-| HRNet_W18       | paddlex.cls.HRNet_W18(num_classes=1000) |
-| AlexNet         | paddlex.cls.AlexNet(num_classes=1000) |
+| ResNet18          | paddlex.cls.ResNet18(num_classes=1000, input_channel=3) |
+| ResNet34          | paddlex.cls.ResNet34(num_classes=1000, input_channel=3) |
+| ResNet50          | paddlex.cls.ResNet50(num_classes=1000, input_channel=3) |
+| ResNet50_vd       | paddlex.cls.ResNet50_vd(num_classes=1000, input_channel=3) |
+| ResNet50_vd_ssld    | paddlex.cls.ResNet50_vd_ssld(num_classes=1000, input_channel=3) |
+| ResNet101          | paddlex.cls.ResNet101(num_classes=1000, input_channel=3) |
+| ResNet101_vd        | paddlex.cls.ResNet101_vd(num_classes=1000, input_channel=3) |
+| ResNet101_vd_ssld      | paddlex.cls.ResNet101_vd_ssld(num_classes=1000, input_channel=3) |
+| DarkNet53      | paddlex.cls.DarkNet53(num_classes=1000, input_channel=3) |
+| MobileNetV1         | paddlex.cls.MobileNetV1(num_classes=1000, input_channel=3) |
+| MobileNetV2       | paddlex.cls.MobileNetV2(num_classes=1000, input_channel=3) |
+| MobileNetV3_small       | paddlex.cls.MobileNetV3_small(num_classes=1000, input_channel=3) |
+| MobileNetV3_small_ssld  | paddlex.cls.MobileNetV3_small_ssld(num_classes=1000, input_channel=3) |
+| MobileNetV3_large   | paddlex.cls.MobileNetV3_large(num_classes=1000, input_channel=3) |
+| MobileNetV3_large_ssld | paddlex.cls.MobileNetV3_large_ssld(num_classes=1000, input_channel=3) |
+| Xception65     | paddlex.cls.Xception65(num_classes=1000, input_channel=3) |
+| Xception71     | paddlex.cls.Xception71(num_classes=1000, input_channel=3) |
+| ShuffleNetV2     | paddlex.cls.ShuffleNetV2(num_classes=1000, input_channel=3) |
+| DenseNet121      | paddlex.cls.DenseNet121(num_classes=1000, input_channel=3) |
+| DenseNet161       | paddlex.cls.DenseNet161(num_classes=1000, input_channel=3) |
+| DenseNet201       | paddlex.cls.DenseNet201(num_classes=1000, input_channel=3) |
+| HRNet_W18       | paddlex.cls.HRNet_W18(num_classes=1000, input_channel=3) |
+| AlexNet         | paddlex.cls.AlexNet(num_classes=1000, input_channel=3) |

+ 17 - 12
docs/apis/transforms/cls_transforms.md

@@ -14,16 +14,20 @@ paddlex.cls.transforms.Compose(transforms)
 
 ## Normalize
 ```python
-paddlex.cls.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+paddlex.cls.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0])
 ```
+对图像进行标准化。
 
-对图像进行标准化。  
-1. 对图像进行归一化到区间[0.0, 1.0]。  
-2. 对图像进行减均值除以标准差操作。
+1.像素值减去min_val
+2.像素值除以(max_val-min_val), 归一化到区间 [0.0, 1.0]。
+3.对图像进行减均值除以标准差操作。
 
 ### 参数
-* **mean** (list): 图像数据集的均值。默认为[0.485, 0.456, 0.406]。
-* **std** (list): 图像数据集的标准差。默认为[0.229, 0.224, 0.225]。
+* **mean** (list): 图像数据集的均值。默认为[0.485, 0.456, 0.406]。长度应与图像通道数量相同。
+* **std** (list): 图像数据集的标准差。默认为[0.229, 0.224, 0.225]。长度应与图像通道数量相同。
+* **min_val** (list): 图像数据集的最小值。默认值[0, 0, 0]。长度应与图像通道数量相同。
+* **max_val** (list): 图像数据集的最大值。默认值[255.0, 255.0, 255.0]。长度应与图像通道数量相同。
+
 
 ## ResizeByShort
 ```python
@@ -108,20 +112,21 @@ paddlex.cls.transforms.RandomDistort(brightness_range=0.9, brightness_prob=0.5,
 
 以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。  
 1. 对变换的操作顺序进行随机化操作。
-2. 按照1中的顺序以一定的概率对图像在范围[-range, range]内进行随机像素内容变换。  
+2. 按照1中的顺序以一定的概率对图像进行随机像素内容变换。  
 
-【注意】该数据增强必须在数据增强Normalize之前使用。
+【注意】如果输入是uint8/uint16的RGB图像,该数据增强必须在数据增强Normalize之前使用。
 
 ### 参数
-* **brightness_range** (float): 明亮度因子的范围。默认为0.9。
+* **brightness_range** (float): 明亮度的缩放系数范围。从[1-`brightness_range`, 1+`brightness_range`]中随机取值作为明亮度缩放因子`scale`,按照公式`image = image * scale`调整图像明亮度。默认为0.9。
 * **brightness_prob** (float): 随机调整明亮度的概率。默认为0.5。
-* **contrast_range** (float): 对比度因子的范围。默认为0.9。
+* **contrast_range** (float): 对比度的缩放系数范围。从[1-`contrast_range`, 1+`contrast_range`]中随机取值作为对比度缩放因子`scale`,按照公式`image = image * scale + (image_mean + 0.5) * (1 - scale)`调整图像对比度。默认为0.9。
 * **contrast_prob** (float): 随机调整对比度的概率。默认为0.5。
-* **saturation_range** (float): 饱和度因子的范围。默认为0.9。
+* **saturation_range** (float): 饱和度的缩放系数范围。从[1-`saturation_range`, 1+`saturation_range`]中随机取值作为饱和度缩放因子`scale`,按照公式`image = gray * (1 - scale) + image * scale`,其中`gray = R * 299/1000 + G * 587/1000+ B * 114/1000`。默认为0.9。
 * **saturation_prob** (float): 随机调整饱和度的概率。默认为0.5。
-* **hue_range** (int): 色调因子的范围。默认为18
+* **hue_range** (int): 调整色相角度的差值取值范围。从[-`hue_range`, `hue_range`]中随机取值作为色相角度调整差值`delta`,按照公式`hue = hue + delta`调整色相角度 。默认为18,取值范围[0, 360]
 * **hue_prob** (float): 随机调整色调的概率。默认为0.5。
 
+
 <!--
 ## ComposedClsTransforms
 ```python

+ 17 - 11
docs/apis/transforms/det_transforms.md

@@ -14,16 +14,21 @@ paddlex.det.transforms.Compose(transforms)
 
 ## Normalize
 ```python
-paddlex.det.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+paddlex.det.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], min_val=[0., 0., 0.], max_val=[255., 255., 255.])
 ```
 
 对图像进行标准化。  
-1. 归一化图像到到区间[0.0, 1.0]。  
-2. 对图像进行减均值除以标准差操作。
+1.像素值减去min_val
+2.像素值除以(max_val-min_val), 归一化到区间 [0.0, 1.0]。
+3.对图像进行减均值除以标准差操作。
+
 
 ### 参数
-* **mean** (list): 图像数据集的均值。默认为[0.485, 0.456, 0.406]。
-* **std** (list): 图像数据集的标准差。默认为[0.229, 0.224, 0.225]。
+* **mean** (list): 图像数据集的均值。默认为[0.485, 0.456, 0.406]。长度应与图像通道数量相同。
+* **std** (list): 图像数据集的标准差。默认为[0.229, 0.224, 0.225]。长度应与图像通道数量相同。
+* **min_val** (list): 图像数据集的最小值。默认值[0, 0, 0]。长度应与图像通道数量相同。
+* **max_val** (list): 图像数据集的最大值。默认值[255.0, 255.0, 255.0]。长度应与图像通道数量相同。
+
 
 ## ResizeByShort
 ```python
@@ -85,20 +90,21 @@ paddlex.det.transforms.RandomDistort(brightness_range=0.5, brightness_prob=0.5,
 
 以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。  
 1. 对变换的操作顺序进行随机化操作。
-2. 按照1中的顺序以一定的概率对图像在范围[-range, range]内进行随机像素内容变换。  
+2. 按照1中的顺序以一定的概率对图像进行随机像素内容变换。  
 
-【注意】该数据增强必须在数据增强Normalize之前使用。
+【注意】如果输入是uint8/uint16的RGB图像,该数据增强必须在数据增强Normalize之前使用。
 
 ### 参数
-* **brightness_range** (float): 明亮度因子的范围。默认为0.5。
+* **brightness_range** (float): 明亮度的缩放系数范围。从[1-`brightness_range`, 1+`brightness_range`]中随机取值作为明亮度缩放因子`scale`,按照公式`image = image * scale`调整图像明亮度。默认为0.5。
 * **brightness_prob** (float): 随机调整明亮度的概率。默认为0.5。
-* **contrast_range** (float): 对比度因子的范围。默认为0.5。
+* **contrast_range** (float): 对比度的缩放系数范围。从[1-`contrast_range`, 1+`contrast_range`]中随机取值作为对比度缩放因子`scale`,按照公式`image = image * scale + (image_mean + 0.5) * (1 - scale)`调整图像对比度。默认为0.5。
 * **contrast_prob** (float): 随机调整对比度的概率。默认为0.5。
-* **saturation_range** (float): 饱和度因子的范围。默认为0.5。
+* **saturation_range** (float): 饱和度的缩放系数范围。从[1-`saturation_range`, 1+`saturation_range`]中随机取值作为饱和度缩放因子`scale`,按照公式`image = gray * (1 - scale) + image * scale`,其中`gray = R * 299/1000 + G * 587/1000+ B * 114/1000`。默认为0.5。
 * **saturation_prob** (float): 随机调整饱和度的概率。默认为0.5。
-* **hue_range** (int): 色调因子的范围。默认为18
+* **hue_range** (int): 调整色相角度的差值取值范围。从[-`hue_range`, `hue_range`]中随机取值作为色相角度调整差值`delta`,按照公式`hue = hue + delta`调整色相角度 。默认为18,取值范围[0, 360]
 * **hue_prob** (float): 随机调整色调的概率。默认为0.5。
 
+
 ## MixupImage
 ```python
 paddlex.det.transforms.MixupImage(alpha=1.5, beta=1.5, mixup_epoch=-1)

+ 9 - 8
docs/apis/transforms/seg_transforms.md

@@ -153,23 +153,24 @@ paddlex.seg.transforms.RandomScaleAspect(min_scale=0.5, aspect_ratio=0.33)
 ```python
 paddlex.seg.transforms.RandomDistort(brightness_range=0.5, brightness_prob=0.5, contrast_range=0.5, contrast_prob=0.5, saturation_range=0.5, saturation_prob=0.5, hue_range=18, hue_prob=0.5)
 ```
-以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。目前支持多通道的RGB图像,例如支持多张RGB图像沿通道轴做concatenate后的图像数据,不支持通道数量不是3的倍数的图像数据。
+以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。
 
-1.对变换的操作顺序进行随机化操作。
-2.按照1中的顺序以一定的概率对图像在范围[-range, range]内进行随机像素内容变换。  
+【注意】如果输入是uint8/uint16的RGB图像,该数据增强必须在数据增强Normalize之前使用。如果输入是由多张RGB图像数据沿通道方向做拼接而成的图像数据,则会把每3个通道数据视为一张RGB图像数据,依次对每3个通道数据做随机像素内容变化。
 
-【注意】该数据增强必须在数据增强Normalize之前使用。
+1. 对变换的操作顺序进行随机化操作。
+2. 按照1中的顺序以一定的概率对图像进行随机像素内容变换。  
 
 ### 参数
-* **brightness_range** (float): 明亮度因子的范围。默认为0.5。
+* **brightness_range** (float): 明亮度的缩放系数范围。从[1-`brightness_range`, 1+`brightness_range`]中随机取值作为明亮度缩放因子`scale`,按照公式`image = image * scale`调整图像明亮度。默认为0.5。
 * **brightness_prob** (float): 随机调整明亮度的概率。默认为0.5。
-* **contrast_range** (float): 对比度因子的范围。默认为0.5。
+* **contrast_range** (float): 对比度的缩放系数范围。从[1-`contrast_range`, 1+`contrast_range`]中随机取值作为对比度缩放因子`scale`,按照公式`image = image * scale + (image_mean + 0.5) * (1 - scale)`调整图像对比度。默认为0.5。
 * **contrast_prob** (float): 随机调整对比度的概率。默认为0.5。
-* **saturation_range** (float): 饱和度因子的范围。默认为0.5。
+* **saturation_range** (float): 饱和度的缩放系数范围。从[1-`saturation_range`, 1+`saturation_range`]中随机取值作为饱和度缩放因子`scale`,按照公式`image = gray * (1 - scale) + image * scale`,其中`gray = R * 299/1000 + G * 587/1000+ B * 114/1000`。默认为0.5。
 * **saturation_prob** (float): 随机调整饱和度的概率。默认为0.5。
-* **hue_range** (int): 色调因子的范围。默认为18
+* **hue_range** (int): 调整色相角度的差值取值范围。从[-`hue_range`, `hue_range`]中随机取值作为色相角度调整差值`delta`,按照公式`hue = hue + delta`调整色相角度 。默认为18,取值范围[0, 360]
 * **hue_prob** (float): 随机调整色调的概率。默认为0.5。
 
+
 ## Clip
 ```python
 paddlex.seg.transforms.Clip(min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0])

+ 104 - 46
paddlex/cv/models/classifier.py

@@ -37,9 +37,13 @@ class BaseClassifier(BaseAPI):
                           'MobileNetV1', 'MobileNetV2', 'Xception41',
                           'Xception65', 'Xception71']。默认为'ResNet50'。
         num_classes (int): 类别数。默认为1000。
+        input_channel (int): 输入图像的通道数量。默认为3。
     """
 
-    def __init__(self, model_name='ResNet50', num_classes=1000):
+    def __init__(self,
+                 model_name='ResNet50',
+                 num_classes=1000,
+                 input_channel=3):
         self.init_params = locals()
         super(BaseClassifier, self).__init__('classifier')
         if not hasattr(paddlex.cv.nets, str.lower(model_name)):
@@ -49,19 +53,23 @@ class BaseClassifier(BaseAPI):
         self.labels = None
         self.num_classes = num_classes
         self.fixed_input_shape = None
+        self.input_channel = input_channel
 
     def build_net(self, mode='train'):
         if self.__class__.__name__ == "AlexNet":
             assert self.fixed_input_shape is not None, "In AlexNet, input_shape should be defined, e.g. model = paddlex.cls.AlexNet(num_classes=1000, input_shape=[224, 224])"
         if self.fixed_input_shape is not None:
             input_shape = [
-                None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
+                None, self.input_channel, self.fixed_input_shape[1],
+                self.fixed_input_shape[0]
             ]
             image = fluid.data(
                 dtype='float32', shape=input_shape, name='image')
         else:
             image = fluid.data(
-                dtype='float32', shape=[None, 3, None, None], name='image')
+                dtype='float32',
+                shape=[None, self.input_channel, None, None],
+                name='image')
         if mode != 'test':
             label = fluid.data(dtype='int64', shape=[None, 1], name='label')
         model = getattr(paddlex.cv.nets, str.lower(self.model_name))
@@ -223,11 +231,13 @@ class BaseClassifier(BaseAPI):
           tuple (metrics, eval_details): 当return_details为True时,增加返回dict,
               包含关键字:'true_labels'、'pred_scores',分别代表真实类别id、每个类别的预测得分。
         """
+        input_channel = getattr(self, 'input_channel', 3)
         arrange_transforms(
             model_type=self.model_type,
             class_name=self.__class__.__name__,
             transforms=eval_dataset.transforms,
-            mode='eval')
+            mode='eval',
+            input_channel=input_channel)
         data_generator = eval_dataset.generator(
             batch_size=batch_size, drop_last=False)
         k = min(5, self.num_classes)
@@ -283,12 +293,14 @@ class BaseClassifier(BaseAPI):
                     transforms,
                     model_type,
                     class_name,
-                    thread_pool=None):
+                    thread_pool=None,
+                    input_channel=3):
         arrange_transforms(
             model_type=model_type,
             class_name=class_name,
             transforms=transforms,
-            mode='test')
+            mode='test',
+            input_channel=input_channel)
         if thread_pool is not None:
             batch_data = thread_pool.map(transforms, images)
         else:
@@ -334,8 +346,13 @@ class BaseClassifier(BaseAPI):
 
         if transforms is None:
             transforms = self.test_transforms
-        im = BaseClassifier._preprocess(images, transforms, self.model_type,
-                                        self.__class__.__name__)
+        input_channel = getattr(self, 'input_channel', 3)
+        im = BaseClassifier._preprocess(
+            images,
+            transforms,
+            self.model_type,
+            self.__class__.__name__,
+            input_channel=input_channel)
 
         with fluid.scope_guard(self.scope):
             result = self.exe.run(self.test_prog,
@@ -366,9 +383,14 @@ class BaseClassifier(BaseAPI):
 
         if transforms is None:
             transforms = self.test_transforms
+        input_channel = getattr(self, 'input_channel', 3)
         im = BaseClassifier._preprocess(
-            img_file_list, transforms, self.model_type,
-            self.__class__.__name__, self.thread_pool)
+            img_file_list,
+            transforms,
+            self.model_type,
+            self.__class__.__name__,
+            self.thread_pool,
+            input_channel=input_channel)
 
         with fluid.scope_guard(self.scope):
             result = self.exe.run(self.test_prog,
@@ -470,109 +492,145 @@ class ResNet50_vd(BaseClassifier):
 
 
 class ResNet101_vd(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(ResNet101_vd, self).__init__(
-            model_name='ResNet101_vd', num_classes=num_classes)
+            model_name='ResNet101_vd',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class ResNet50_vd_ssld(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(ResNet50_vd_ssld, self).__init__(
-            model_name='ResNet50_vd_ssld', num_classes=num_classes)
+            model_name='ResNet50_vd_ssld',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class ResNet101_vd_ssld(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(ResNet101_vd_ssld, self).__init__(
-            model_name='ResNet101_vd_ssld', num_classes=num_classes)
+            model_name='ResNet101_vd_ssld',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class DarkNet53(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(DarkNet53, self).__init__(
-            model_name='DarkNet53', num_classes=num_classes)
+            model_name='DarkNet53',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class MobileNetV1(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(MobileNetV1, self).__init__(
-            model_name='MobileNetV1', num_classes=num_classes)
+            model_name='MobileNetV1',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class MobileNetV2(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(MobileNetV2, self).__init__(
-            model_name='MobileNetV2', num_classes=num_classes)
+            model_name='MobileNetV2',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class MobileNetV3_small(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(MobileNetV3_small, self).__init__(
-            model_name='MobileNetV3_small', num_classes=num_classes)
+            model_name='MobileNetV3_small',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class MobileNetV3_large(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(MobileNetV3_large, self).__init__(
-            model_name='MobileNetV3_large', num_classes=num_classes)
+            model_name='MobileNetV3_large',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class MobileNetV3_small_ssld(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(MobileNetV3_small_ssld, self).__init__(
-            model_name='MobileNetV3_small_ssld', num_classes=num_classes)
+            model_name='MobileNetV3_small_ssld',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class MobileNetV3_large_ssld(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(MobileNetV3_large_ssld, self).__init__(
-            model_name='MobileNetV3_large_ssld', num_classes=num_classes)
+            model_name='MobileNetV3_large_ssld',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class Xception65(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(Xception65, self).__init__(
-            model_name='Xception65', num_classes=num_classes)
+            model_name='Xception65',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class Xception41(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(Xception41, self).__init__(
-            model_name='Xception41', num_classes=num_classes)
+            model_name='Xception41',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class DenseNet121(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(DenseNet121, self).__init__(
-            model_name='DenseNet121', num_classes=num_classes)
+            model_name='DenseNet121',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class DenseNet161(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(DenseNet161, self).__init__(
-            model_name='DenseNet161', num_classes=num_classes)
+            model_name='DenseNet161',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class DenseNet201(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(DenseNet201, self).__init__(
-            model_name='DenseNet201', num_classes=num_classes)
+            model_name='DenseNet201',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class ShuffleNetV2(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(ShuffleNetV2, self).__init__(
-            model_name='ShuffleNetV2', num_classes=num_classes)
+            model_name='ShuffleNetV2',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class HRNet_W18(BaseClassifier):
-    def __init__(self, num_classes=1000):
+    def __init__(self, num_classes=1000, input_channel=3):
         super(HRNet_W18, self).__init__(
-            model_name='HRNet_W18', num_classes=num_classes)
+            model_name='HRNet_W18',
+            num_classes=num_classes,
+            input_channel=input_channel)
 
 
 class AlexNet(BaseClassifier):
-    def __init__(self, num_classes=1000, input_shape=None):
+    def __init__(self, num_classes=1000, input_shape=None, input_channel=3):
         super(AlexNet, self).__init__(
-            model_name='AlexNet', num_classes=num_classes)
+            model_name='AlexNet',
+            num_classes=num_classes,
+            input_channel=input_channel)
         self.fixed_input_shape = input_shape

+ 96 - 30
paddlex/cv/transforms/cls_transforms.py

@@ -47,6 +47,8 @@ class Compose(ClsTransform):
                             'must be equal or larger than 1!')
         self.transforms = transforms
         self.batch_transforms = None
+        self.data_type = np.uint8
+        self.to_rgb = True
         # 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
         for op in self.transforms:
             if not isinstance(op, ClsTransform):
@@ -56,7 +58,7 @@ class Compose(ClsTransform):
                         "Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
                     )
 
-    def __call__(self, im, label=None):
+    def __call__(self, im_file, label=None):
         """
         Args:
             im (str/np.ndarray): 图像路径/图像np.ndarray数据。
@@ -65,28 +67,43 @@ class Compose(ClsTransform):
             tuple: 根据网络所需字段所组成的tuple;
                 字段由transforms中的最后一个数据预处理操作决定。
         """
-        if isinstance(im, np.ndarray):
-            if len(im.shape) != 3:
+        input_channel = getattr(self, 'input_channel', 3)
+        if isinstance(im_file, np.ndarray):
+            if len(im_file.shape) != 3:
                 raise Exception(
-                    "im should be 3-dimension, but now is {}-dimensions".format(
-                        len(im.shape)))
+                    "im should be 3-dimension, but now is {}-dimensions".
+                    format(len(im_file.shape)))
         else:
             try:
-                im_path = im
-                im = cv2.imread(im).astype('float32')
+                if input_channel == 3:
+                    im = cv2.imread(im_file, cv2.IMREAD_ANYDEPTH |
+                                    cv2.IMREAD_ANYCOLOR)
+                else:
+                    im = cv2.imread(im_file, cv2.IMREAD_UNCHANGED)
+                    if im.ndim < 3:
+                        im = np.expand_dims(im, axis=-1)
             except:
                 raise TypeError('Can\'t read The image file {}!'.format(
-                    im_path))
+                    im_file))
+        self.data_type = im.dtype
         im = im.astype('float32')
-        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+        if input_channel == 3 and self.to_rgb:
+            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
         for op in self.transforms:
             if isinstance(op, ClsTransform):
+                if op.__class__.__name__ == 'RandomDistort':
+                    op.to_rgb = self.to_rgb
+                    op.data_type = self.data_type
                 outputs = op(im, label)
                 im = outputs[0]
                 if len(outputs) == 2:
                     label = outputs[1]
             else:
                 import imgaug.augmenters as iaa
+                if im.shape[-1] != 3:
+                    raise Exception(
+                        "Only the 3-channel RGB image is supported in the imgaug operator, but recieved image channel is {}".
+                        format(im.shape[-1]))
                 if isinstance(op, iaa.Augmenter):
                     im = execute_imgaug(op, im)
                 outputs = (im, )
@@ -142,8 +159,8 @@ class RandomCrop(ClsTransform):
             tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
                    当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
         """
-        im = random_crop(im, self.crop_size, self.lower_scale, self.lower_ratio,
-                         self.upper_ratio)
+        im = random_crop(im, self.crop_size, self.lower_scale,
+                         self.lower_ratio, self.upper_ratio)
         if label is None:
             return (im, )
         else:
@@ -208,19 +225,39 @@ class RandomVerticalFlip(ClsTransform):
 
 class Normalize(ClsTransform):
     """对图像进行标准化。
-
-    1. 对图像进行归一化到区间[0.0, 1.0]。
-    2. 对图像进行减均值除以标准差操作。
+    1.像素值减去min_val
+    2.像素值除以(max_val-min_val)
+    3.对图像进行减均值除以标准差操作。
 
     Args:
-        mean (list): 图像数据集的均值。默认为[0.485, 0.456, 0.406]。
-        std (list): 图像数据集的标准差。默认为[0.229, 0.224, 0.225]。
+        mean (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。
+        std (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。
+        min_val (list): 图像数据集的最小值。默认值[0, 0, 0]。
+        max_val (list): 图像数据集的最大值。默认值[255.0, 255.0, 255.0]。
 
+    Raises:
+        ValueError: mean或std不是list对象。std包含0。
     """
 
-    def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
+    def __init__(self,
+                 mean=[0.485, 0.456, 0.406],
+                 std=[0.229, 0.224, 0.225],
+                 min_val=[0, 0, 0],
+                 max_val=[255.0, 255.0, 255.0]):
         self.mean = mean
         self.std = std
+        self.min_val = min_val
+        self.max_val = max_val
+
+        if not (isinstance(self.mean, list) and isinstance(self.std, list)):
+            raise ValueError("{}: input type is invalid.".format(self))
+        if not (isinstance(self.min_val, list) and isinstance(self.max_val,
+                                                              list)):
+            raise ValueError("{}: input type is invalid.".format(self))
+
+        from functools import reduce
+        if reduce(lambda x, y: x * y, self.std) == 0:
+            raise ValueError('{}: std is invalid!'.format(self))
 
     def __call__(self, im, label=None):
         """
@@ -234,7 +271,7 @@ class Normalize(ClsTransform):
         """
         mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
         std = np.array(self.std)[np.newaxis, np.newaxis, :]
-        im = normalize(im, mean, std)
+        im = normalize(im, mean, std, self.min_val, self.max_val)
         if label is None:
             return (im, )
         else:
@@ -273,12 +310,14 @@ class ResizeByShort(ClsTransform):
         im_short_size = min(im.shape[0], im.shape[1])
         im_long_size = max(im.shape[0], im.shape[1])
         scale = float(self.short_size) / im_short_size
-        if self.max_size > 0 and np.round(scale * im_long_size) > self.max_size:
+        if self.max_size > 0 and np.round(scale *
+                                          im_long_size) > self.max_size:
             scale = float(self.max_size) / float(im_long_size)
         resized_width = int(round(im.shape[1] * scale))
         resized_height = int(round(im.shape[0] * scale))
         im = cv2.resize(
-            im, (resized_width, resized_height), interpolation=cv2.INTER_LINEAR)
+            im, (resized_width, resized_height),
+            interpolation=cv2.INTER_LINEAR)
 
         if label is None:
             return (im, )
@@ -351,19 +390,30 @@ class RandomRotate(ClsTransform):
 
 
 class RandomDistort(ClsTransform):
-    """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作
+    """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作
 
     1. 对变换的操作顺序进行随机化操作。
-    2. 按照1中的顺序以一定的概率对图像在范围[-range, range]内进行随机像素内容变换。
+    2. 按照1中的顺序以一定的概率对图像进行随机像素内容变换。
+
+    【注意】如果输入是uint8/uint16的RGB图像,该数据增强必须在数据增强Normalize之前使用。
 
     Args:
-        brightness_range (float): 明亮度因子的范围。默认为0.9。
+        brightness_range (float): 明亮度的缩放系数范围。
+            从[1-`brightness_range`, 1+`brightness_range`]中随机取值作为明亮度缩放因子`scale`,
+            按照公式`image = image * scale`调整图像明亮度。默认值为0.9。
         brightness_prob (float): 随机调整明亮度的概率。默认为0.5。
-        contrast_range (float): 对比度因子的范围。默认为0.9。
+        contrast_range (float): 对比度的缩放系数范围。
+            从[1-`contrast_range`, 1+`contrast_range`]中随机取值作为对比度缩放因子`scale`,
+            按照公式`image = image * scale + (image_mean + 0.5) * (1 - scale)`调整图像对比度。默认为0.9。
         contrast_prob (float): 随机调整对比度的概率。默认为0.5。
-        saturation_range (float): 饱和度因子的范围。默认为0.9。
+        saturation_range (float): 饱和度的缩放系数范围。
+            从[1-`saturation_range`, 1+`saturation_range`]中随机取值作为饱和度缩放因子`scale`,
+            按照公式`image = gray * (1 - scale) + image * scale`,
+            其中`gray = R * 299/1000 + G * 587/1000+ B * 114/1000`。默认为0.9。
         saturation_prob (float): 随机调整饱和度的概率。默认为0.5。
-        hue_range (int): 色调因子的范围。默认为18。
+        hue_range (int): 调整色相角度的差值取值范围。
+            从[-`hue_range`, `hue_range`]中随机取值作为色相角度调整差值`delta`,
+            按照公式`hue = hue + delta`调整色相角度 。默认为18,取值范围[0, 360]。
         hue_prob (float): 随机调整色调的概率。默认为0.5。
     """
 
@@ -395,6 +445,16 @@ class RandomDistort(ClsTransform):
             tuple: 当label为空时,返回的tuple为(im, ),对应图像np.ndarray数据;
                    当label不为空时,返回的tuple为(im, label),分别对应图像np.ndarray数据、图像类别id。
         """
+        if im.shape[-1] != 3:
+            raise Exception(
+                "Only the 3-channel RGB image is supported in the RandomDistort operator, but recieved image channel is {}".
+                format(im.shape[-1]))
+
+        if self.data_type not in [np.uint8, np.uint16, np.float32]:
+            raise Exception(
+                "Only the uint8/uint16/float32 RGB image is supported in the RandomDistort operator, but recieved image data type is {}".
+                format(self.data_type))
+
         brightness_lower = 1 - self.brightness_range
         brightness_upper = 1 + self.brightness_range
         contrast_lower = 1 - self.contrast_range
@@ -408,19 +468,25 @@ class RandomDistort(ClsTransform):
         params_dict = {
             'brightness': {
                 'brightness_lower': brightness_lower,
-                'brightness_upper': brightness_upper
+                'brightness_upper': brightness_upper,
+                'dtype': self.data_type
             },
             'contrast': {
                 'contrast_lower': contrast_lower,
-                'contrast_upper': contrast_upper
+                'contrast_upper': contrast_upper,
+                'dtype': self.data_type
             },
             'saturation': {
                 'saturation_lower': saturation_lower,
-                'saturation_upper': saturation_upper
+                'saturation_upper': saturation_upper,
+                'is_rgb': self.to_rgb,
+                'dtype': self.data_type
             },
             'hue': {
                 'hue_lower': hue_lower,
-                'hue_upper': hue_upper
+                'hue_upper': hue_upper,
+                'is_rgb': self.to_rgb,
+                'dtype': self.data_type
             }
         }
         prob_dict = {

+ 67 - 24
paddlex/cv/transforms/det_transforms.py

@@ -22,7 +22,6 @@ import os.path as osp
 import numpy as np
 
 import cv2
-from PIL import Image, ImageEnhance
 
 from .imgaug_support import execute_imgaug
 from .ops import *
@@ -57,6 +56,8 @@ class Compose(DetTransform):
         self.transforms = transforms
         self.batch_transforms = None
         self.use_mixup = False
+        self.data_type = np.uint8
+        self.to_rgb = True
         for t in self.transforms:
             if type(t).__name__ == 'MixupImage':
                 self.use_mixup = True
@@ -110,17 +111,18 @@ class Compose(DetTransform):
             else:
                 try:
                     if input_channel == 3:
-                        im = cv2.imread(im_file).astype('float32')
+                        im = cv2.imread(im_file, cv2.IMREAD_ANYDEPTH |
+                                        cv2.IMREAD_ANYCOLOR)
                     else:
-                        im = cv2.imread(im_file,
-                                        cv2.IMREAD_UNCHANGED).astype('float32')
+                        im = cv2.imread(im_file, cv2.IMREAD_UNCHANGED)
                         if im.ndim < 3:
                             im = np.expand_dims(im, axis=-1)
                 except:
                     raise TypeError('Can\'t read The image file {}!'.format(
                         im_file))
+            self.data_type = im.dtype
             im = im.astype('float32')
-            if input_channel == 3:
+            if input_channel == 3 and self.to_rgb:
                 im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
             # make default im_info with [h, w, 1]
             im_info['im_resize_info'] = np.array(
@@ -135,7 +137,8 @@ class Compose(DetTransform):
                 im_info['mixup'] = \
                   decode_image(im_info['mixup'][0],
                                im_info['mixup'][1],
-                               im_info['mixup'][2])
+                               im_info['mixup'][2],
+                               input_channel)
             if label_info is None:
                 return (im, im_info)
             else:
@@ -151,14 +154,19 @@ class Compose(DetTransform):
             if im is None:
                 return None
             if isinstance(op, DetTransform):
+                if op.__class__.__name__ == 'RandomDistort':
+                    op.to_rgb = self.to_rgb
+                    op.data_type = self.data_type
                 outputs = op(im, im_info, label_info)
                 im = outputs[0]
             else:
+                import imgaug.augmenters as iaa
                 if im.shape[-1] != 3:
                     raise Exception(
                         "Only the 3-channel RGB image is supported in the imgaug operator, but recieved image channel is {}".
                         format(im.shape[-1]))
-                im = execute_imgaug(op, im)
+                if isinstance(op, iaa.Augmenter):
+                    im = execute_imgaug(op, im)
                 if label_info is not None:
                     outputs = (im, im_info, label_info)
                 else:
@@ -515,22 +523,37 @@ class RandomHorizontalFlip(DetTransform):
 class Normalize(DetTransform):
     """对图像进行标准化。
 
-    1. 归一化图像到到区间[0.0, 1.0]。
-    2. 对图像进行减均值除以标准差操作。
+    1.像素值减去min_val
+    2.像素值除以(max_val-min_val)
+    3.对图像进行减均值除以标准差操作。
 
     Args:
-        mean (list): 图像数据集的均值。默认为[0.485, 0.456, 0.406]。
-        std (list): 图像数据集的标准差。默认为[0.229, 0.224, 0.225]。
+        mean (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。
+        std (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。
+        min_val (list): 图像数据集的最小值。默认值[0, 0, 0]。
+        max_val (list): 图像数据集的最大值。默认值[255.0, 255.0, 255.0]。
 
     Raises:
         TypeError: 形参数据类型不满足需求。
     """
 
-    def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
+    def __init__(self,
+                 mean=[0.485, 0.456, 0.406],
+                 std=[0.229, 0.224, 0.225],
+                 min_val=[0, 0, 0],
+                 max_val=[255.0, 255.0, 255.0]):
         self.mean = mean
         self.std = std
+        self.min_val = min_val
+        self.max_val = max_val
+
         if not (isinstance(self.mean, list) and isinstance(self.std, list)):
             raise TypeError("NormalizeImage: input type is invalid.")
+
+        if not (isinstance(self.min_val, list) and isinstance(self.max_val,
+                                                              list)):
+            raise ValueError("{}: input type is invalid.".format(self))
+
         from functools import reduce
         if reduce(lambda x, y: x * y, self.std) == 0:
             raise TypeError('NormalizeImage: std is invalid!')
@@ -549,9 +572,7 @@ class Normalize(DetTransform):
         """
         mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
         std = np.array(self.std)[np.newaxis, np.newaxis, :]
-        min_val = [0] * im.shape[-1]
-        max_val = [255] * im.shape[-1]
-        im = normalize(im, mean, std, min_val, max_val)
+        im = normalize(im, mean, std, self.min_val, self.max_val)
         if label_info is None:
             return (im, im_info)
         else:
@@ -562,16 +583,27 @@ class RandomDistort(DetTransform):
     """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作
 
     1. 对变换的操作顺序进行随机化操作。
-    2. 按照1中的顺序以一定的概率在范围[-range, range]对图像进行随机像素内容变换。
+    2. 按照1中的顺序以一定的概率对图像进行随机像素内容变换。
+
+    【注意】如果输入是uint8/uint16的RGB图像,该数据增强必须在数据增强Normalize之前使用。
 
     Args:
-        brightness_range (float): 明亮度因子的范围。默认为0.5。
+        brightness_range (float): 明亮度的缩放系数范围。
+            从[1-`brightness_range`, 1+`brightness_range`]中随机取值作为明亮度缩放因子`scale`,
+            按照公式`image = image * scale`调整图像明亮度。默认值为0.5。
         brightness_prob (float): 随机调整明亮度的概率。默认为0.5。
-        contrast_range (float): 对比度因子的范围。默认为0.5。
+        contrast_range (float): 对比度的缩放系数范围。
+            从[1-`contrast_range`, 1+`contrast_range`]中随机取值作为对比度缩放因子`scale`,
+            按照公式`image = image * scale + (image_mean + 0.5) * (1 - scale)`调整图像对比度。默认为0.5。
         contrast_prob (float): 随机调整对比度的概率。默认为0.5。
-        saturation_range (float): 饱和度因子的范围。默认为0.5。
+        saturation_range (float): 饱和度的缩放系数范围。
+            从[1-`saturation_range`, 1+`saturation_range`]中随机取值作为饱和度缩放因子`scale`,
+            按照公式`image = gray * (1 - scale) + image * scale`,
+            其中`gray = R * 299/1000 + G * 587/1000+ B * 114/1000`。默认为0.5。
         saturation_prob (float): 随机调整饱和度的概率。默认为0.5。
-        hue_range (int): 色调因子的范围。默认为18。
+        hue_range (int): 调整色相角度的差值取值范围。
+            从[-`hue_range`, `hue_range`]中随机取值作为色相角度调整差值`delta`,
+            按照公式`hue = hue + delta`调整色相角度 。默认为18,取值范围[0, 360]。
         hue_prob (float): 随机调整色调的概率。默认为0.5。
     """
 
@@ -610,6 +642,11 @@ class RandomDistort(DetTransform):
                 "Only the 3-channel RGB image is supported in the RandomDistort operator, but recieved image channel is {}".
                 format(im.shape[-1]))
 
+        if self.data_type not in [np.uint8, np.uint16, np.float32]:
+            raise Exception(
+                "Only the uint8/uint16/float32 RGB image is supported in the RandomDistort operator, but recieved image data type is {}".
+                format(self.data_type))
+
         brightness_lower = 1 - self.brightness_range
         brightness_upper = 1 + self.brightness_range
         contrast_lower = 1 - self.contrast_range
@@ -623,19 +660,25 @@ class RandomDistort(DetTransform):
         params_dict = {
             'brightness': {
                 'brightness_lower': brightness_lower,
-                'brightness_upper': brightness_upper
+                'brightness_upper': brightness_upper,
+                'dtype': self.data_type
             },
             'contrast': {
                 'contrast_lower': contrast_lower,
-                'contrast_upper': contrast_upper
+                'contrast_upper': contrast_upper,
+                'dtype': self.data_type
             },
             'saturation': {
                 'saturation_lower': saturation_lower,
-                'saturation_upper': saturation_upper
+                'saturation_upper': saturation_upper,
+                'is_rgb': self.to_rgb,
+                'dtype': self.data_type
             },
             'hue': {
                 'hue_lower': hue_lower,
-                'hue_upper': hue_upper
+                'hue_upper': hue_upper,
+                'is_rgb': self.to_rgb,
+                'dtype': self.data_type
             }
         }
         prob_dict = {

+ 52 - 15
paddlex/cv/transforms/ops.py

@@ -120,39 +120,76 @@ def bgr2rgb(im):
     return im[:, :, ::-1]
 
 
-def hue(im, hue_lower, hue_upper):
+def hue(im, hue_lower, hue_upper, is_rgb=False, dtype=np.uint8):
     delta = np.random.uniform(hue_lower, hue_upper)
-    u = np.cos(delta * np.pi)
-    w = np.sin(delta * np.pi)
-    bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
-    tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321],
-                     [0.211, -0.523, 0.311]])
-    ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647],
-                      [1.0, -1.107, 1.705]])
-    t = np.dot(np.dot(ityiq, bt), tyiq).T
-    im = np.dot(im, t)
+    if is_rgb:
+        im = cv2.cvtColor(im, cv2.COLOR_RGB2HSV)
+    else:
+        im = cv2.cvtColor(im, cv2.COLOR_BGR2HSV)
+    im[:, :, 0] = im[:, :, 0] + delta
+    im[:, :, 0] = np.clip(im[:, :, 0], 0, 360.)
+    if is_rgb:
+        im = cv2.cvtColor(im, cv2.COLOR_HSV2RGB)
+    else:
+        im = cv2.cvtColor(im, cv2.COLOR_HSV2BGR)
+    if dtype == np.uint8:
+        im = np.clip(im, 0., 255.)
+    elif dtype == np.uint16:
+        im = np.clip(im, 0., 65535.)
+    elif dtype == np.float32:
+        im = np.clip(im, 0., 1.)
     return im
 
 
-def saturation(im, saturation_lower, saturation_upper):
+def saturation(im,
+               saturation_lower,
+               saturation_upper,
+               is_rgb=False,
+               dtype=np.uint8):
+    if is_rgb:
+        gray_scale = np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32)
+    else:
+        gray_scale = np.array([[[0.114, 0.587, 0.299]]], dtype=np.float32)
     delta = np.random.uniform(saturation_lower, saturation_upper)
-    gray = im * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32)
+    gray = im * gray_scale
     gray = gray.sum(axis=2, keepdims=True)
     gray *= (1.0 - delta)
     im *= delta
     im += gray
+    if dtype == np.uint8:
+        im = np.clip(im, 0., 255.)
+    elif dtype == np.uint16:
+        im = np.clip(im, 0., 65535.)
+    elif dtype == np.float32:
+        im = np.clip(im, 0., 1.)
     return im
 
 
-def contrast(im, contrast_lower, contrast_upper):
+def contrast(im, contrast_lower, contrast_upper, dtype=np.uint8):
     delta = np.random.uniform(contrast_lower, contrast_upper)
+    im_mean = im.mean() + 0.5
+    im1 = np.full_like(im, im_mean)
     im *= delta
+    im += im1 * (1 - delta)
+    if dtype == np.uint8:
+        im = np.clip(im, 0., 255.)
+    elif dtype == np.uint16:
+        im = np.clip(im, 0., 65535.)
+    elif dtype == np.float32:
+        im = np.clip(im, 0., 1.)
     return im
 
 
-def brightness(im, brightness_lower, brightness_upper):
+def brightness(im, brightness_lower, brightness_upper, dtype=np.uint8):
     delta = np.random.uniform(brightness_lower, brightness_upper)
-    im += delta
+    im *= delta
+    if dtype == np.uint8:
+        im = np.clip(im, 0., 255.)
+    elif dtype == np.uint16:
+        im = np.clip(im, 0., 65535.)
+    elif dtype == np.float32:
+        im = np.clip(im, 0., 1.)
+
     return im
 
 

+ 55 - 13
paddlex/cv/transforms/seg_transforms.py

@@ -54,6 +54,7 @@ class Compose(SegTransform):
                             'must be equal or larger than 1!')
         self.transforms = transforms
         self.batch_transforms = None
+        self.data_type = np.uint8
         self.to_rgb = False
         # 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
         for op in self.transforms:
@@ -84,7 +85,8 @@ class Compose(SegTransform):
             return im_data.transpose((1, 2, 0))
         elif img_format in ['jpeg', 'bmp', 'png']:
             if input_channel == 3:
-                return cv2.imread(img_path)
+                return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
+                                  cv2.IMREAD_ANYCOLOR)
             else:
                 return cv2.imread(im_file, cv2.IMREAD_UNCHANGED)
         elif ext == '.npy':
@@ -102,11 +104,10 @@ class Compose(SegTransform):
             im = im_path
         else:
             try:
-                im = Compose.read_img(im_path, input_channel).astype('float32')
+                im = Compose.read_img(im_path, input_channel)
             except:
                 raise ValueError('Can\'t read The image file {}!'.format(
                     im_path))
-        im = im.astype('float32')
         if label is not None:
             if isinstance(label, np.ndarray):
                 if len(label.shape) != 2:
@@ -145,6 +146,8 @@ class Compose(SegTransform):
 
         input_channel = getattr(self, 'input_channel', 3)
         im, label = self.decode_image(im, label, input_channel)
+        self.data_type = im.dtype
+        im = im.astype('float32')
         if self.to_rgb and input_channel == 3:
             im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
         if im_info is None:
@@ -153,6 +156,9 @@ class Compose(SegTransform):
             origin_label = label.copy()
         for op in self.transforms:
             if isinstance(op, SegTransform):
+                if op.__class__.__name__ == 'RandomDistort':
+                    op.to_rgb = self.to_rgb
+                    op.data_type = self.data_type
                 outputs = op(im, im_info, label)
                 im = outputs[0]
                 if len(outputs) >= 2:
@@ -160,7 +166,13 @@ class Compose(SegTransform):
                 if len(outputs) == 3:
                     label = outputs[2]
             else:
-                im = execute_imgaug(op, im)
+                import imgaug.augmenters as iaa
+                if im.shape[-1] != 3:
+                    raise Exception(
+                        "Only the 3-channel RGB image is supported in the imgaug operator, but recieved image channel is {}".
+                        format(im.shape[-1]))
+                if isinstance(op, iaa.Augmenter):
+                    im = execute_imgaug(op, im)
                 if label is not None:
                     outputs = (im, im_info, label)
                 else:
@@ -1059,19 +1071,33 @@ class RandomScaleAspect(SegTransform):
 
 
 class RandomDistort(SegTransform):
-    """对图像进行随机失真。
+    """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作
 
     1. 对变换的操作顺序进行随机化操作。
     2. 按照1中的顺序以一定的概率对图像进行随机像素内容变换。
 
+    【注意】如果输入是uint8/uint16的RGB图像,该数据增强必须在数据增强Normalize之前使用。
+    如果输入是由多张RGB图像数据沿通道方向做拼接而成的图像数据,则会把每3个通道数据视为一张RGB图像数据,
+    依次对每3个通道数据做随机像素内容变化。
+
+
     Args:
-        brightness_range (float): 明亮度因子的范围。默认为0.5。
+        brightness_range (float): 明亮度的缩放系数范围。
+            从[1-`brightness_range`, 1+`brightness_range`]中随机取值作为明亮度缩放因子`scale`,
+            按照公式`image = image * scale`调整图像明亮度。默认值为0.5。
         brightness_prob (float): 随机调整明亮度的概率。默认为0.5。
-        contrast_range (float): 对比度因子的范围。默认为0.5。
+        contrast_range (float): 对比度的缩放系数范围。
+            从[1-`contrast_range`, 1+`contrast_range`]中随机取值作为对比度缩放因子`scale`,
+            按照公式`image = image * scale + (image_mean + 0.5) * (1 - scale)`调整图像对比度。默认为0.5。
         contrast_prob (float): 随机调整对比度的概率。默认为0.5。
-        saturation_range (float): 饱和度因子的范围。默认为0.5。
+        saturation_range (float): 饱和度的缩放系数范围。
+            从[1-`saturation_range`, 1+`saturation_range`]中随机取值作为饱和度缩放因子`scale`,
+            按照公式`image = gray * (1 - scale) + image * scale`,
+            其中`gray = R * 299/1000 + G * 587/1000+ B * 114/1000`。默认为0.5。
         saturation_prob (float): 随机调整饱和度的概率。默认为0.5。
-        hue_range (int): 色调因子的范围。默认为18。
+        hue_range (int): 调整色相角度的差值取值范围。
+            从[-`hue_range`, `hue_range`]中随机取值作为色相角度调整差值`delta`,
+            按照公式`hue = hue + delta`调整色相角度 。默认为18,取值范围[0, 360]。
         hue_prob (float): 随机调整色调的概率。默认为0.5。
     """
 
@@ -1108,6 +1134,16 @@ class RandomDistort(SegTransform):
                 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
                 存储与图像相关信息的字典和标注图像np.ndarray数据。
         """
+        if im.shape[-1] % 3 != 0:
+            raise Exception(
+                "Only the 3-channel RGB image or the image array composed by concatenating many 3-channel RGB images along the channel axis is supported in the RandomDistort operator, but recieved image channel is {} which cannot be divided by 3.".
+                format(im.shape[-1]))
+
+        if self.data_type not in [np.uint8, np.uint16, np.float32]:
+            raise Exception(
+                "Only the uint8/uint16/float32 RGB image is supported in the RandomDistort operator, but recieved image data type is {}".
+                format(self.data_type))
+
         brightness_lower = 1 - self.brightness_range
         brightness_upper = 1 + self.brightness_range
         contrast_lower = 1 - self.contrast_range
@@ -1121,19 +1157,25 @@ class RandomDistort(SegTransform):
         params_dict = {
             'brightness': {
                 'brightness_lower': brightness_lower,
-                'brightness_upper': brightness_upper
+                'brightness_upper': brightness_upper,
+                'dtype': self.data_type
             },
             'contrast': {
                 'contrast_lower': contrast_lower,
-                'contrast_upper': contrast_upper
+                'contrast_upper': contrast_upper,
+                'dtype': self.data_type
             },
             'saturation': {
                 'saturation_lower': saturation_lower,
-                'saturation_upper': saturation_upper
+                'saturation_upper': saturation_upper,
+                'is_rgb': self.to_rgb,
+                'dtype': self.data_type
             },
             'hue': {
                 'hue_lower': hue_lower,
-                'hue_upper': hue_upper
+                'hue_upper': hue_upper,
+                'is_rgb': self.to_rgb,
+                'dtype': self.data_type
             }
         }
         prob_dict = {