Browse Source

Merge pull request #44 from PaddlePaddle/develop_imgaug

add imgaug support
Jason 5 years ago
parent
commit
9c2f0fe42e

+ 6 - 2
paddlex/cv/datasets/coco.py

@@ -100,7 +100,7 @@ class CocoDetection(VOCDetection):
             gt_score = np.ones((num_bbox, 1), dtype=np.float32)
             gt_score = np.ones((num_bbox, 1), dtype=np.float32)
             is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
             is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
             difficult = np.zeros((num_bbox, 1), dtype=np.int32)
             difficult = np.zeros((num_bbox, 1), dtype=np.int32)
-            gt_poly = [None] * num_bbox
+            gt_poly = None
 
 
             for i, box in enumerate(bboxes):
             for i, box in enumerate(bboxes):
                 catid = box['category_id']
                 catid = box['category_id']
@@ -108,6 +108,8 @@ class CocoDetection(VOCDetection):
                 gt_bbox[i, :] = box['clean_bbox']
                 gt_bbox[i, :] = box['clean_bbox']
                 is_crowd[i][0] = box['iscrowd']
                 is_crowd[i][0] = box['iscrowd']
                 if 'segmentation' in box:
                 if 'segmentation' in box:
+                    if gt_poly is None:
+                        gt_poly = [None] * num_bbox
                     gt_poly[i] = box['segmentation']
                     gt_poly[i] = box['segmentation']
 
 
             im_info = {
             im_info = {
@@ -119,9 +121,11 @@ class CocoDetection(VOCDetection):
                 'gt_class': gt_class,
                 'gt_class': gt_class,
                 'gt_bbox': gt_bbox,
                 'gt_bbox': gt_bbox,
                 'gt_score': gt_score,
                 'gt_score': gt_score,
-                'gt_poly': gt_poly,
                 'difficult': difficult
                 'difficult': difficult
             }
             }
+            if gt_poly is not None:
+                label_info['gt_poly'] = gt_poly
+
             coco_rec = (im_info, label_info)
             coco_rec = (im_info, label_info)
             self.file_list.append([im_fname, coco_rec])
             self.file_list.append([im_fname, coco_rec])
 
 

+ 0 - 1
paddlex/cv/datasets/voc.py

@@ -153,7 +153,6 @@ class VOCDetection(Dataset):
                     'gt_class': gt_class,
                     'gt_class': gt_class,
                     'gt_bbox': gt_bbox,
                     'gt_bbox': gt_bbox,
                     'gt_score': gt_score,
                     'gt_score': gt_score,
-                    'gt_poly': [],
                     'difficult': difficult
                     'difficult': difficult
                 }
                 }
                 voc_rec = (im_info, label_info)
                 voc_rec = (im_info, label_info)

+ 14 - 4
paddlex/cv/models/utils/visualize.py

@@ -16,6 +16,7 @@ import os
 import cv2
 import cv2
 import colorsys
 import colorsys
 import numpy as np
 import numpy as np
+import time
 import paddlex.utils.logging as logging
 import paddlex.utils.logging as logging
 from .detection_eval import fixed_linspace, backup_linspace, loadRes
 from .detection_eval import fixed_linspace, backup_linspace, loadRes
 
 
@@ -25,8 +26,12 @@ def visualize_detection(image, result, threshold=0.5, save_dir='./'):
         Visualize bbox and mask results
         Visualize bbox and mask results
     """
     """
 
 
-    image_name = os.path.split(image)[-1]
-    image = cv2.imread(image)
+    if isinstance(image, np.ndarray):
+        image_name = str(int(time.time())) + '.jpg'
+    else:
+        image = cv2.imread(image)
+        image_name = os.path.split(image)[-1]
+
     image = draw_bbox_mask(image, result, threshold=threshold)
     image = draw_bbox_mask(image, result, threshold=threshold)
     if save_dir is not None:
     if save_dir is not None:
         if not os.path.exists(save_dir):
         if not os.path.exists(save_dir):
@@ -56,13 +61,18 @@ def visualize_segmentation(image, result, weight=0.6, save_dir='./'):
     c3 = cv2.LUT(label_map, color_map[:, 2])
     c3 = cv2.LUT(label_map, color_map[:, 2])
     pseudo_img = np.dstack((c1, c2, c3))
     pseudo_img = np.dstack((c1, c2, c3))
 
 
-    im = cv2.imread(image)
+    if isinstance(image, np.ndarray):
+        im = image
+        image_name = str(int(time.time())) + '.jpg'
+    else:
+        image = cv2.imread(image)
+        image_name = os.path.split(image)[-1]
+
     vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
     vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
 
 
     if save_dir is not None:
     if save_dir is not None:
         if not os.path.exists(save_dir):
         if not os.path.exists(save_dir):
             os.makedirs(save_dir)
             os.makedirs(save_dir)
-        image_name = os.path.split(image)[-1]
         out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
         out_path = os.path.join(save_dir, 'visualize_{}'.format(image_name))
         cv2.imwrite(out_path, vis_result)
         cv2.imwrite(out_path, vis_result)
         logging.info('The visualized result is saved as {}'.format(out_path))
         logging.info('The visualized result is saved as {}'.format(out_path))

+ 50 - 18
paddlex/cv/transforms/cls_transforms.py

@@ -13,13 +13,22 @@
 # limitations under the License.
 # limitations under the License.
 
 
 from .ops import *
 from .ops import *
+from .imgaug_support import execute_imgaug
 import random
 import random
 import os.path as osp
 import os.path as osp
 import numpy as np
 import numpy as np
 from PIL import Image, ImageEnhance
 from PIL import Image, ImageEnhance
 
 
 
 
-class Compose:
+class ClsTransform:
+    """分类Transform的基类
+    """
+
+    def __init__(self):
+        pass
+
+
+class Compose(ClsTransform):
     """根据数据预处理/增强算子对输入数据进行操作。
     """根据数据预处理/增强算子对输入数据进行操作。
        所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
        所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
 
 
@@ -39,6 +48,15 @@ class Compose:
                             'must be equal or larger than 1!')
                             'must be equal or larger than 1!')
         self.transforms = transforms
         self.transforms = transforms
 
 
+        # 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
+        for op in self.transforms:
+            if not isinstance(op, ClsTransform):
+                import imgaug.augmenters as iaa
+                if not isinstance(op, iaa.Augmenter):
+                    raise Exception(
+                        "Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
+                    )
+
     def __call__(self, im, label=None):
     def __call__(self, im, label=None):
         """
         """
         Args:
         Args:
@@ -48,20 +66,34 @@ class Compose:
             tuple: 根据网络所需字段所组成的tuple;
             tuple: 根据网络所需字段所组成的tuple;
                 字段由transforms中的最后一个数据预处理操作决定。
                 字段由transforms中的最后一个数据预处理操作决定。
         """
         """
-        try:
-            im = cv2.imread(im).astype('float32')
-        except:
-            raise TypeError('Can\'t read The image file {}!'.format(im))
+        if isinstance(im, np.ndarray):
+            if len(im.shape) != 3:
+                raise Exception(
+                    "im should be 3-dimension, but now is {}-dimensions".
+                    format(len(im.shape)))
+        else:
+            try:
+                im = cv2.imread(im).astype('float32')
+            except:
+                raise TypeError('Can\'t read The image file {}!'.format(im))
         im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
         im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
         for op in self.transforms:
         for op in self.transforms:
-            outputs = op(im, label)
-            im = outputs[0]
-            if len(outputs) == 2:
-                label = outputs[1]
+            if isinstance(op, ClsTransform):
+                outputs = op(im, label)
+                im = outputs[0]
+                if len(outputs) == 2:
+                    label = outputs[1]
+            else:
+                import imgaug.augmenters as iaa
+                if isinstance(op, iaa.Augmenter):
+                    im, = execute_imgaug(op, im)
+                output = (im, )
+                if label is not None:
+                    output = (im, label)
         return outputs
         return outputs
 
 
 
 
-class RandomCrop:
+class RandomCrop(ClsTransform):
     """对图像进行随机剪裁,模型训练时的数据增强操作。
     """对图像进行随机剪裁,模型训练时的数据增强操作。
 
 
     1. 根据lower_scale、lower_ratio、upper_ratio计算随机剪裁的高、宽。
     1. 根据lower_scale、lower_ratio、upper_ratio计算随机剪裁的高、宽。
@@ -104,7 +136,7 @@ class RandomCrop:
             return (im, label)
             return (im, label)
 
 
 
 
-class RandomHorizontalFlip:
+class RandomHorizontalFlip(ClsTransform):
     """以一定的概率对图像进行随机水平翻转,模型训练时的数据增强操作。
     """以一定的概率对图像进行随机水平翻转,模型训练时的数据增强操作。
 
 
     Args:
     Args:
@@ -132,7 +164,7 @@ class RandomHorizontalFlip:
             return (im, label)
             return (im, label)
 
 
 
 
-class RandomVerticalFlip:
+class RandomVerticalFlip(ClsTransform):
     """以一定的概率对图像进行随机垂直翻转,模型训练时的数据增强操作。
     """以一定的概率对图像进行随机垂直翻转,模型训练时的数据增强操作。
 
 
     Args:
     Args:
@@ -160,7 +192,7 @@ class RandomVerticalFlip:
             return (im, label)
             return (im, label)
 
 
 
 
-class Normalize:
+class Normalize(ClsTransform):
     """对图像进行标准化。
     """对图像进行标准化。
 
 
     1. 对图像进行归一化到区间[0.0, 1.0]。
     1. 对图像进行归一化到区间[0.0, 1.0]。
@@ -195,7 +227,7 @@ class Normalize:
             return (im, label)
             return (im, label)
 
 
 
 
-class ResizeByShort:
+class ResizeByShort(ClsTransform):
     """根据图像短边对图像重新调整大小(resize)。
     """根据图像短边对图像重新调整大小(resize)。
 
 
     1. 获取图像的长边和短边长度。
     1. 获取图像的长边和短边长度。
@@ -242,7 +274,7 @@ class ResizeByShort:
             return (im, label)
             return (im, label)
 
 
 
 
-class CenterCrop:
+class CenterCrop(ClsTransform):
     """以图像中心点扩散裁剪长宽为`crop_size`的正方形
     """以图像中心点扩散裁剪长宽为`crop_size`的正方形
 
 
     1. 计算剪裁的起始点。
     1. 计算剪裁的起始点。
@@ -272,7 +304,7 @@ class CenterCrop:
             return (im, label)
             return (im, label)
 
 
 
 
-class RandomRotate:
+class RandomRotate(ClsTransform):
     def __init__(self, rotate_range=30, prob=0.5):
     def __init__(self, rotate_range=30, prob=0.5):
         """以一定的概率对图像在[-rotate_range, rotaterange]角度范围内进行旋转,模型训练时的数据增强操作。
         """以一定的概率对图像在[-rotate_range, rotaterange]角度范围内进行旋转,模型训练时的数据增强操作。
 
 
@@ -306,7 +338,7 @@ class RandomRotate:
             return (im, label)
             return (im, label)
 
 
 
 
-class RandomDistort:
+class RandomDistort(ClsTransform):
     """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。
     """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作。
 
 
     1. 对变换的操作顺序进行随机化操作。
     1. 对变换的操作顺序进行随机化操作。
@@ -397,7 +429,7 @@ class RandomDistort:
             return (im, label)
             return (im, label)
 
 
 
 
-class ArrangeClassifier:
+class ArrangeClassifier(ClsTransform):
     """获取训练/验证/预测所需信息。注意:此操作不需用户自己显示调用
     """获取训练/验证/预测所需信息。注意:此操作不需用户自己显示调用
 
 
     Args:
     Args:

+ 61 - 21
paddlex/cv/transforms/det_transforms.py

@@ -24,11 +24,20 @@ import numpy as np
 import cv2
 import cv2
 from PIL import Image, ImageEnhance
 from PIL import Image, ImageEnhance
 
 
+from .imgaug_support import execute_imgaug
 from .ops import *
 from .ops import *
 from .box_utils import *
 from .box_utils import *
 
 
 
 
-class Compose:
+class DetTransform:
+    """检测数据处理基类
+    """
+
+    def __init__(self):
+        pass
+
+
+class Compose(DetTransform):
     """根据数据预处理/增强列表对输入数据进行操作。
     """根据数据预处理/增强列表对输入数据进行操作。
        所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
        所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
 
 
@@ -49,8 +58,16 @@ class Compose:
         self.transforms = transforms
         self.transforms = transforms
         self.use_mixup = False
         self.use_mixup = False
         for t in self.transforms:
         for t in self.transforms:
-            if t.__class__.__name__ == 'MixupImage':
+            if type(t).__name__ == 'MixupImage':
                 self.use_mixup = True
                 self.use_mixup = True
+        # 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
+        for op in self.transforms:
+            if not isinstance(op, DetTransform):
+                import imgaug.augmenters as iaa
+                if not isinstance(op, iaa.Augmenter):
+                    raise Exception(
+                        "Elements in transforms should be defined in 'paddlex.det.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
+                    )
 
 
     def __call__(self, im, im_info=None, label_info=None):
     def __call__(self, im, im_info=None, label_info=None):
         """
         """
@@ -84,11 +101,18 @@ class Compose:
         def decode_image(im_file, im_info, label_info):
         def decode_image(im_file, im_info, label_info):
             if im_info is None:
             if im_info is None:
                 im_info = dict()
                 im_info = dict()
-            try:
-                im = cv2.imread(im_file).astype('float32')
-            except:
-                raise TypeError(
-                    'Can\'t read The image file {}!'.format(im_file))
+            if isinstance(im_file, np.ndarray):
+                if len(im_file.shape) != 3:
+                    raise Exception(
+                        "im should be 3-dimensions, but now is {}-dimensions".
+                        format(len(im_file.shape)))
+                im = im_file
+            else:
+                try:
+                    im = cv2.imread(im_file).astype('float32')
+                except:
+                    raise TypeError(
+                        'Can\'t read The image file {}!'.format(im_file))
             im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
             im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
             # make default im_info with [h, w, 1]
             # make default im_info with [h, w, 1]
             im_info['im_resize_info'] = np.array(
             im_info['im_resize_info'] = np.array(
@@ -117,12 +141,28 @@ class Compose:
         for op in self.transforms:
         for op in self.transforms:
             if im is None:
             if im is None:
                 return None
                 return None
-            outputs = op(im, im_info, label_info)
-            im = outputs[0]
+            if isinstance(op, DetTransform):
+                outputs = op(im, im_info, label_info)
+                im = outputs[0]
+            else:
+                if label_info is not None:
+                    gt_poly = label_info.get('gt_poly', None)
+                    gt_bbox = label_info['gt_bbox']
+                    if gt_poly is None:
+                        im, aug_bbox = execute_imgaug(op, im, bboxes=gt_bbox)
+                    else:
+                        im, aug_bbox, aug_poly = execute_imgaug(
+                            op, im, bboxes=gt_bbox, polygons=gt_poly)
+                        label_info['gt_poly'] = aug_poly
+                    label_info['gt_bbox'] = aug_bbox
+                    outputs = (im, im_info, label_info)
+                else:
+                    im, = execute_imgaug(op, im)
+                    outputs = (im, im_info)
         return outputs
         return outputs
 
 
 
 
-class ResizeByShort:
+class ResizeByShort(DetTransform):
     """根据图像的短边调整图像大小(resize)。
     """根据图像的短边调整图像大小(resize)。
 
 
     1. 获取图像的长边和短边长度。
     1. 获取图像的长边和短边长度。
@@ -194,7 +234,7 @@ class ResizeByShort:
             return (im, im_info, label_info)
             return (im, im_info, label_info)
 
 
 
 
-class Padding:
+class Padding(DetTransform):
     """1.将图像的长和宽padding至coarsest_stride的倍数。如输入图像为[300, 640],
     """1.将图像的长和宽padding至coarsest_stride的倍数。如输入图像为[300, 640],
        `coarest_stride`为32,则由于300不为32的倍数,因此在图像最右和最下使用0值
        `coarest_stride`为32,则由于300不为32的倍数,因此在图像最右和最下使用0值
        进行padding,最终输出图像为[320, 640]。
        进行padding,最终输出图像为[320, 640]。
@@ -290,7 +330,7 @@ class Padding:
             return (padding_im, im_info, label_info)
             return (padding_im, im_info, label_info)
 
 
 
 
-class Resize:
+class Resize(DetTransform):
     """调整图像大小(resize)。
     """调整图像大小(resize)。
 
 
     - 当目标大小(target_size)类型为int时,根据插值方式,
     - 当目标大小(target_size)类型为int时,根据插值方式,
@@ -369,7 +409,7 @@ class Resize:
             return (im, im_info, label_info)
             return (im, im_info, label_info)
 
 
 
 
-class RandomHorizontalFlip:
+class RandomHorizontalFlip(DetTransform):
     """随机翻转图像、标注框、分割信息,模型训练时的数据增强操作。
     """随机翻转图像、标注框、分割信息,模型训练时的数据增强操作。
 
 
     1. 随机采样一个0-1之间的小数,当小数小于水平翻转概率时,
     1. 随机采样一个0-1之间的小数,当小数小于水平翻转概率时,
@@ -447,7 +487,7 @@ class RandomHorizontalFlip:
             return (im, im_info, label_info)
             return (im, im_info, label_info)
 
 
 
 
-class Normalize:
+class Normalize(DetTransform):
     """对图像进行标准化。
     """对图像进行标准化。
 
 
     1. 归一化图像到到区间[0.0, 1.0]。
     1. 归一化图像到到区间[0.0, 1.0]。
@@ -491,7 +531,7 @@ class Normalize:
             return (im, im_info, label_info)
             return (im, im_info, label_info)
 
 
 
 
-class RandomDistort:
+class RandomDistort(DetTransform):
     """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作
     """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作
 
 
     1. 对变换的操作顺序进行随机化操作。
     1. 对变换的操作顺序进行随机化操作。
@@ -585,7 +625,7 @@ class RandomDistort:
             return (im, im_info, label_info)
             return (im, im_info, label_info)
 
 
 
 
-class MixupImage:
+class MixupImage(DetTransform):
     """对图像进行mixup操作,模型训练时的数据增强操作,目前仅YOLOv3模型支持该transform。
     """对图像进行mixup操作,模型训练时的数据增强操作,目前仅YOLOv3模型支持该transform。
 
 
     当label_info中不存在mixup字段时,直接返回,否则进行下述操作:
     当label_info中不存在mixup字段时,直接返回,否则进行下述操作:
@@ -714,7 +754,7 @@ class MixupImage:
             return (im, im_info, label_info)
             return (im, im_info, label_info)
 
 
 
 
-class RandomExpand:
+class RandomExpand(DetTransform):
     """随机扩张图像,模型训练时的数据增强操作。
     """随机扩张图像,模型训练时的数据增强操作。
     1. 随机选取扩张比例(扩张比例大于1时才进行扩张)。
     1. 随机选取扩张比例(扩张比例大于1时才进行扩张)。
     2. 计算扩张后图像大小。
     2. 计算扩张后图像大小。
@@ -796,7 +836,7 @@ class RandomExpand:
         return (canvas, im_info, label_info)
         return (canvas, im_info, label_info)
 
 
 
 
-class RandomCrop:
+class RandomCrop(DetTransform):
     """随机裁剪图像。
     """随机裁剪图像。
     1. 若allow_no_crop为True,则在thresholds加入’no_crop’。
     1. 若allow_no_crop为True,则在thresholds加入’no_crop’。
     2. 随机打乱thresholds。
     2. 随机打乱thresholds。
@@ -944,7 +984,7 @@ class RandomCrop:
         return (im, im_info, label_info)
         return (im, im_info, label_info)
 
 
 
 
-class ArrangeFasterRCNN:
+class ArrangeFasterRCNN(DetTransform):
     """获取FasterRCNN模型训练/验证/预测所需信息。
     """获取FasterRCNN模型训练/验证/预测所需信息。
 
 
     Args:
     Args:
@@ -1019,7 +1059,7 @@ class ArrangeFasterRCNN:
         return outputs
         return outputs
 
 
 
 
-class ArrangeMaskRCNN:
+class ArrangeMaskRCNN(DetTransform):
     """获取MaskRCNN模型训练/验证/预测所需信息。
     """获取MaskRCNN模型训练/验证/预测所需信息。
 
 
     Args:
     Args:
@@ -1103,7 +1143,7 @@ class ArrangeMaskRCNN:
         return outputs
         return outputs
 
 
 
 
-class ArrangeYOLOv3:
+class ArrangeYOLOv3(DetTransform):
     """获取YOLOv3模型训练/验证/预测所需信息。
     """获取YOLOv3模型训练/验证/预测所需信息。
 
 
     Args:
     Args:

+ 131 - 0
paddlex/cv/transforms/imgaug_support.py

@@ -0,0 +1,131 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+
+def execute_imgaug(augmenter, im, bboxes=None, polygons=None,
+                   segment_map=None):
+    # 预处理,将bboxes, polygons转换成imgaug格式
+    import imgaug.augmentables.polys as polys
+    import imgaug.augmentables.bbs as bbs
+
+    aug_im = im.astype('uint8')
+
+    aug_bboxes = None
+    if bboxes is not None:
+        aug_bboxes = list()
+        for i in range(len(bboxes)):
+            x1 = bboxes[i, 0] - 1
+            y1 = bboxes[i, 1]
+            x2 = bboxes[i, 2]
+            y2 = bboxes[i, 3]
+            aug_bboxes.append(bbs.BoundingBox(x1, y1, x2, y2))
+
+    aug_polygons = None
+    lod_info = list()
+    if polygons is not None:
+        aug_polygons = list()
+        for i in range(len(polygons)):
+            num = len(polygons[i])
+            lod_info.append(num)
+            for j in range(num):
+                points = np.reshape(polygons[i][j], (-1, 2))
+                aug_polygons.append(polys.Polygon(points))
+
+    aug_segment_map = None
+    if segment_map is not None:
+        if len(segment_map.shape) == 2:
+            h, w = segment_map.shape
+            aug_segment_map = np.reshape(segment_map, (1, h, w, 1))
+        elif len(segment_map.shape) == 3:
+            h, w, c = segment_map.shape
+            aug_segment_map = np.reshape(segment_map, (1, h, w, c))
+        else:
+            raise Exception(
+                "Only support 2-dimensions for 3-dimensions for segment_map")
+
+    aug_im, aug_bboxes, aug_polygons, aug_seg_map = augmenter.augment(
+        image=aug_im,
+        bounding_boxes=aug_bboxes,
+        polygons=aug_polygons,
+        segmentation_maps=aug_segment_map)
+
+    aug_im = aug_im.astype('float32')
+
+    if aug_polygons is not None:
+        assert len(aug_bboxes) == len(
+            lod_info
+        ), "Number of aug_bboxes should be equal to number of aug_polygons"
+
+    if aug_bboxes is not None:
+        # 裁剪掉在图像之外的bbox和polygon
+        for i in range(len(aug_bboxes)):
+            aug_bboxes[i] = aug_bboxes[i].clip_out_of_image(aug_im)
+        if aug_polygons is not None:
+            for i in range(len(aug_polygons)):
+                aug_polygons[i] = aug_polygons[i].clip_out_of_image(aug_im)
+
+        # 过滤掉无效的bbox和polygon,并转换为训练数据格式
+        converted_bboxes = list()
+        converted_polygons = list()
+        poly_index = 0
+        for i in range(len(aug_bboxes)):
+            # 过滤width或height不足1像素的框
+            if aug_bboxes[i].width < 1 or aug_bboxes[i].height < 1:
+                continue
+            if aug_polygons is None:
+                converted_bboxes.append([
+                    aug_bboxes[i].x1, aug_bboxes[i].y1, aug_bboxes[i].x2,
+                    aug_bboxes[i].y2
+                ])
+                continue
+
+            # 如若有polygons,将会继续执行下面代码
+            polygons_this_box = list()
+            for ps in aug_polygons[poly_index:poly_index + lod_info[i]]:
+                if len(ps) == 0:
+                    continue
+                for p in ps:
+                    # 没有3个point的polygon被过滤
+                    if len(p.exterior) < 3:
+                        continue
+                    polygons_this_box.append(p.exterior.flatten().tolist())
+            poly_index += lod_info[i]
+
+            if len(polygons_this_box) == 0:
+                continue
+            converted_bboxes.append([
+                aug_bboxes[i].x1, aug_bboxes[i].y1, aug_bboxes[i].x2,
+                aug_bboxes[i].y2
+            ])
+            converted_polygons.append(polygons_this_box)
+        if len(converted_bboxes) == 0:
+            aug_im = im
+            converted_bboxes = bboxes
+            converted_polygons = polygons
+
+    result = [aug_im]
+    if bboxes is not None:
+        result.append(np.array(converted_bboxes))
+    if polygons is not None:
+        result.append(converted_polygons)
+    if segment_map is not None:
+        n, h, w, c = aug_seg_map.shape
+        if len(segment_map.shape) == 2:
+            aug_seg_map = np.reshape(aug_seg_map, (h, w))
+        elif len(segment_map.shape) == 3:
+            aug_seg_map = np.reshape(aug_seg_map, (h, w, c))
+        result.append(aug_seg_map)
+    return result

+ 57 - 26
paddlex/cv/transforms/seg_transforms.py

@@ -14,6 +14,7 @@
 # limitations under the License.
 # limitations under the License.
 
 
 from .ops import *
 from .ops import *
+from .imgaug_support import execute_imgaug
 import random
 import random
 import os.path as osp
 import os.path as osp
 import numpy as np
 import numpy as np
@@ -22,7 +23,15 @@ import cv2
 from collections import OrderedDict
 from collections import OrderedDict
 
 
 
 
-class Compose:
+class SegTransform:
+    """ 分割transform基类
+    """
+
+    def __init__(self):
+        pass
+
+
+class Compose(SegTransform):
     """根据数据预处理/增强算子对输入数据进行操作。
     """根据数据预处理/增强算子对输入数据进行操作。
        所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
        所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
 
 
@@ -43,6 +52,14 @@ class Compose:
                             'must be equal or larger than 1!')
                             'must be equal or larger than 1!')
         self.transforms = transforms
         self.transforms = transforms
         self.to_rgb = False
         self.to_rgb = False
+        # 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
+        for op in self.transforms:
+            if not isinstance(op, SegTransform):
+                import imgaug.augmenters as iaa
+                if not isinstance(op, iaa.Augmenter):
+                    raise Exception(
+                        "Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
+                    )
 
 
     def __call__(self, im, im_info=None, label=None):
     def __call__(self, im, im_info=None, label=None):
         """
         """
@@ -60,26 +77,40 @@ class Compose:
 
 
         if im_info is None:
         if im_info is None:
             im_info = list()
             im_info = list()
-        try:
-            im = cv2.imread(im).astype('float32')
-        except:
-            raise ValueError('Can\'t read The image file {}!'.format(im))
+        if isinstance(im, np.ndarray):
+            if len(im.shape) != 3:
+                raise Exception(
+                    "im should be 3-dimensions, but now is {}-dimensions".
+                    format(len(im.shape)))
+        else:
+            try:
+                im = cv2.imread(im).astype('float32')
+            except:
+                raise ValueError('Can\'t read The image file {}!'.format(im))
         if self.to_rgb:
         if self.to_rgb:
             im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
             im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
         if label is not None:
         if label is not None:
             if not isinstance(label, np.ndarray):
             if not isinstance(label, np.ndarray):
                 label = np.asarray(Image.open(label))
                 label = np.asarray(Image.open(label))
         for op in self.transforms:
         for op in self.transforms:
-            outputs = op(im, im_info, label)
-            im = outputs[0]
-            if len(outputs) >= 2:
-                im_info = outputs[1]
-            if len(outputs) == 3:
-                label = outputs[2]
+            if isinstance(op, SegTransform):
+                outputs = op(im, im_info, label)
+                im = outputs[0]
+                if len(outputs) >= 2:
+                    im_info = outputs[1]
+                if len(outputs) == 3:
+                    label = outputs[2]
+            else:
+                if label is not None:
+                    im, label = execute_imgaug(op, im, segment_map=label)
+                    outputs = (im, im_info, label)
+                else:
+                    im, = execute_imgaug(op, im)
+                    outputs = (im, im_info)
         return outputs
         return outputs
 
 
 
 
-class RandomHorizontalFlip:
+class RandomHorizontalFlip(SegTransform):
     """以一定的概率对图像进行水平翻转。当存在标注图像时,则同步进行翻转。
     """以一定的概率对图像进行水平翻转。当存在标注图像时,则同步进行翻转。
 
 
     Args:
     Args:
@@ -115,7 +146,7 @@ class RandomHorizontalFlip:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class RandomVerticalFlip:
+class RandomVerticalFlip(SegTransform):
     """以一定的概率对图像进行垂直翻转。当存在标注图像时,则同步进行翻转。
     """以一定的概率对图像进行垂直翻转。当存在标注图像时,则同步进行翻转。
 
 
     Args:
     Args:
@@ -150,7 +181,7 @@ class RandomVerticalFlip:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class Resize:
+class Resize(SegTransform):
     """调整图像大小(resize),当存在标注图像时,则同步进行处理。
     """调整图像大小(resize),当存在标注图像时,则同步进行处理。
 
 
     - 当目标大小(target_size)类型为int时,根据插值方式,
     - 当目标大小(target_size)类型为int时,根据插值方式,
@@ -260,7 +291,7 @@ class Resize:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class ResizeByLong:
+class ResizeByLong(SegTransform):
     """对图像长边resize到固定值,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
     """对图像长边resize到固定值,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
 
 
     Args:
     Args:
@@ -301,7 +332,7 @@ class ResizeByLong:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class ResizeByShort:
+class ResizeByShort(SegTransform):
     """根据图像的短边调整图像大小(resize)。
     """根据图像的短边调整图像大小(resize)。
 
 
     1. 获取图像的长边和短边长度。
     1. 获取图像的长边和短边长度。
@@ -378,7 +409,7 @@ class ResizeByShort:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class ResizeRangeScaling:
+class ResizeRangeScaling(SegTransform):
     """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
     """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
 
 
     Args:
     Args:
@@ -427,7 +458,7 @@ class ResizeRangeScaling:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class ResizeStepScaling:
+class ResizeStepScaling(SegTransform):
     """对图像按照某一个比例resize,这个比例以scale_step_size为步长
     """对图像按照某一个比例resize,这个比例以scale_step_size为步长
     在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。
     在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。
 
 
@@ -502,7 +533,7 @@ class ResizeStepScaling:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class Normalize:
+class Normalize(SegTransform):
     """对图像进行标准化。
     """对图像进行标准化。
     1.尺度缩放到 [0,1]。
     1.尺度缩放到 [0,1]。
     2.对图像进行减均值除以标准差操作。
     2.对图像进行减均值除以标准差操作。
@@ -550,7 +581,7 @@ class Normalize:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class Padding:
+class Padding(SegTransform):
     """对图像或标注图像进行padding,padding方向为右和下。
     """对图像或标注图像进行padding,padding方向为右和下。
     根据提供的值对图像或标注图像进行padding操作。
     根据提供的值对图像或标注图像进行padding操作。
 
 
@@ -642,7 +673,7 @@ class Padding:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class RandomPaddingCrop:
+class RandomPaddingCrop(SegTransform):
     """对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作。
     """对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作。
 
 
     Args:
     Args:
@@ -741,7 +772,7 @@ class RandomPaddingCrop:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class RandomBlur:
+class RandomBlur(SegTransform):
     """以一定的概率对图像进行高斯模糊。
     """以一定的概率对图像进行高斯模糊。
 
 
     Args:
     Args:
@@ -787,7 +818,7 @@ class RandomBlur:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class RandomRotate:
+class RandomRotate(SegTransform):
     """对图像进行随机旋转, 模型训练时的数据增强操作。
     """对图像进行随机旋转, 模型训练时的数据增强操作。
     在旋转区间[-rotate_range, rotate_range]内,对图像进行随机旋转,当存在标注图像时,同步进行,
     在旋转区间[-rotate_range, rotate_range]内,对图像进行随机旋转,当存在标注图像时,同步进行,
     并对旋转后的图像和标注图像进行相应的padding。
     并对旋转后的图像和标注图像进行相应的padding。
@@ -859,7 +890,7 @@ class RandomRotate:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class RandomScaleAspect:
+class RandomScaleAspect(SegTransform):
     """裁剪并resize回原始尺寸的图像和标注图像。
     """裁剪并resize回原始尺寸的图像和标注图像。
     按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。
     按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。
 
 
@@ -922,7 +953,7 @@ class RandomScaleAspect:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class RandomDistort:
+class RandomDistort(SegTransform):
     """对图像进行随机失真。
     """对图像进行随机失真。
 
 
     1. 对变换的操作顺序进行随机化操作。
     1. 对变换的操作顺序进行随机化操作。
@@ -1018,7 +1049,7 @@ class RandomDistort:
             return (im, im_info, label)
             return (im, im_info, label)
 
 
 
 
-class ArrangeSegmenter:
+class ArrangeSegmenter(SegTransform):
     """获取训练/验证/预测所需的信息。
     """获取训练/验证/预测所需的信息。
 
 
     Args:
     Args: