Pārlūkot izejas kodu

modify ComposedTransforms

jiangjiajun 5 gadi atpakaļ
vecāks
revīzija
2934885a88

+ 18 - 13
paddlex/cv/transforms/cls_transforms.py

@@ -18,7 +18,6 @@ import random
 import os.path as osp
 import numpy as np
 from PIL import Image, ImageEnhance
-from .template import TemplateTransforms
 
 
 class ClsTransform:
@@ -93,6 +92,12 @@ class Compose(ClsTransform):
                     outputs = (im, label)
         return outputs
 
+    def add_augmenters(self, augmenters):
+        if not isinstance(augmenters, list):
+            raise Exception(
+                "augmenters should be list type in func add_augmenters()")
+        self.transforms = augmenters + self.transforms.transforms
+
 
 class RandomCrop(ClsTransform):
     """对图像进行随机剪裁,模型训练时的数据增强操作。
@@ -464,7 +469,7 @@ class ArrangeClassifier(ClsTransform):
         return outputs
 
 
-class BasicClsTransforms(TemplateTransforms):
+class ComposedClsTransforms(Compose):
     """ 分类模型的基础Transforms流程,具体如下
         训练阶段:
         1. 随机从图像中crop一块子图,并resize成crop_size大小
@@ -487,7 +492,6 @@ class BasicClsTransforms(TemplateTransforms):
                  crop_size=[224, 224],
                  mean=[0.485, 0.456, 0.406],
                  std=[0.229, 0.224, 0.225]):
-        super(TemplateClsTransforms, self).__init__(mode=mode)
         width = crop_size
         if isinstance(crop_size, list):
             if shape[0] != shape[1]:
@@ -499,18 +503,19 @@ class BasicClsTransforms(TemplateTransforms):
                 "In classifier model, width and height should be multiple of 32, e.g 224、256、320...."
             )
 
-        if self.mode == 'train':
+        if mode == 'train':
             # 训练时的transforms,包含数据增强
-            self.transforms = transforms.Compose([
-                transforms.RandomCrop(crop_size=width),
-                transforms.RandomHorizontalFlip(prob=0.5),
-                transforms.Normalize(
+            transforms = [
+                RandomCrop(crop_size=width), RandomHorizontalFlip(prob=0.5),
+                Normalize(
                     mean=mean, std=std)
-            ])
+            ]
         else:
             # 验证/预测时的transforms
-            self.transforms = transforms.Compose([
-                transforms.ReiszeByShort(short_size=int(width * 1.14)),
-                transforms.CenterCrop(crop_size=width), transforms.Normalize(
+            transforms = [
+                ReiszeByShort(short_size=int(width * 1.14)),
+                CenterCrop(crop_size=width), Normalize(
                     mean=mean, std=std)
-            ])
+            ]
+
+        super(ComposedClsTransforms, self).__init__(transforms)

+ 35 - 30
paddlex/cv/transforms/det_transforms.py

@@ -27,7 +27,6 @@ from PIL import Image, ImageEnhance
 from .imgaug_support import execute_imgaug
 from .ops import *
 from .box_utils import *
-from .template import TemplateTransforms
 
 
 class DetTransform:
@@ -153,6 +152,13 @@ class Compose(DetTransform):
                     outputs = (im, im_info)
         return outputs
 
+    def add_augmenters(self, augmenters):
+        if not isinstance(augmenters, list):
+            raise Exception(
+                "augmenters should be list type in func add_augmenters()")
+        assert mode == 'train', "There should be exists augmenters while on train mode"
+        self.transforms = augmenters + self.transforms.transforms
+
 
 class ResizeByShort(DetTransform):
     """根据图像的短边调整图像大小(resize)。
@@ -1230,7 +1236,7 @@ class ArrangeYOLOv3(DetTransform):
         return outputs
 
 
-class BasicRCNNTransforms(TemplateTransforms):
+class ComposedRCNNTransforms(Compose):
     """ RCNN模型(faster-rcnn/mask-rcnn)图像处理流程,具体如下,
         训练阶段:
         1. 随机以0.5的概率将图像水平翻转
@@ -1257,27 +1263,27 @@ class BasicRCNNTransforms(TemplateTransforms):
                  min_max_size=[800, 1333],
                  mean=[0.485, 0.456, 0.406],
                  std=[0.229, 0.224, 0.225]):
-        super(RCNNTransforms, self).__init__(mode=mode)
-        if self.mode == 'train':
+        if mode == 'train':
             # 训练时的transforms,包含数据增强
-            self.transforms = transforms.Compose([
-                transforms.RandomHorizontalFlip(prob=0.5),
-                transforms.Normalize(
-                    mean=mean, std=std), transforms.ResizeByShort(
+            transforms = [
+                RandomHorizontalFlip(prob=0.5), Normalize(
+                    mean=mean, std=std), ResizeByShort(
                         short_size=min_max_size[0], max_size=min_max_size[1]),
-                transforms.Padding(coarsest_stride=32)
-            ])
+                Padding(coarsest_stride=32)
+            ]
         else:
             # 验证/预测时的transforms
-            self.transforms = transforms.Compose([
-                transforms.Normalize(
-                    mean=mean, std=std), transforms.ResizeByShort(
+            transforms = [
+                Normalize(
+                    mean=mean, std=std), ResizeByShort(
                         short_size=min_max_size[0], max_size=min_max_size[1]),
-                transforms.Padding(coarsest_stride=32)
-            ])
+                Padding(coarsest_stride=32)
+            ]
 
+        super(RCNNTransforms, self).__init__(transforms)
 
-class BasicYOLOTransforms(TemplateTransforms):
+
+class ComposedYOLOTransforms(Compose):
     """YOLOv3模型的图像预处理流程,具体如下,
         训练阶段:
         1. 在前mixup_epoch轮迭代中,使用MixupImage策略,见https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#mixupimage
@@ -1305,7 +1311,6 @@ class BasicYOLOTransforms(TemplateTransforms):
                  mixup_epoch=250,
                  mean=[0.485, 0.456, 0.406],
                  std=[0.229, 0.224, 0.225]):
-        super(YOLOTransforms, self).__init__(mode=mode)
         width = shape
         if isinstance(shape, list):
             if shape[0] != shape[1]:
@@ -1317,20 +1322,20 @@ class BasicYOLOTransforms(TemplateTransforms):
                 "In YOLOv3 model, width and height should be multiple of 32, e.g 224、256、320...."
             )
 
-        if self.mode == 'train':
+        if mode == 'train':
             # 训练时的transforms,包含数据增强
-            self.transforms = transforms.Compose([
-                transforms.MixupImage(mixup_epoch=mixup_epoch),
-                transforms.RandomDistort(), transforms.RandomExpand(),
-                transforms.RandomCrop(), transforms.Resize(
-                    target_size=width, interp='RANDOM'),
-                transforms.RandomHorizontalFlip(), transforms.Normalize(
-                    mean=mean, std=std)
-            ])
+            transforms = [
+                MixupImage(mixup_epoch=mixup_epoch), RandomDistort(),
+                RandomExpand(), RandomCrop(), Resize(
+                    target_size=width,
+                    interp='RANDOM'), RandomHorizontalFlip(), Normalize(
+                        mean=mean, std=std)
+            ]
         else:
             # 验证/预测时的transforms
-            self.transforms = transforms.Compose([
-                transforms.Resize(
-                    target_size=width, interp='CUBIC'), transforms.Normalize(
+            transforms = [
+                Resize(
+                    target_size=width, interp='CUBIC'), Normalize(
                         mean=mean, std=std)
-            ])
+            ]
+        super(YOLOTransforms, self).__init__(transforms)

+ 8 - 11
paddlex/cv/transforms/seg_transforms.py

@@ -1091,7 +1091,7 @@ class ArrangeSegmenter(SegTransform):
             return (im, )
 
 
-class BasicSegTransforms(TemplateTransforms):
+class ComposedTransforms(Compose):
     """ 语义分割模型(UNet/DeepLabv3p)的图像处理流程,具体如下
         训练阶段:
         1. 随机对图像以0.5的概率水平翻转
@@ -1113,18 +1113,15 @@ class BasicSegTransforms(TemplateTransforms):
                  train_crop_size=[769, 769],
                  mean=[0.5, 0.5, 0.5],
                  std=[0.5, 0.5, 0.5]):
-        super(TemplateSegTransforms, self).__init__(mode=mode)
         if self.mode == 'train':
             # 训练时的transforms,包含数据增强
-            self.transforms = transforms.Compose([
-                transforms.RandomHorizontalFlip(),
-                transforms.ResizeStepScaling(),
-                transforms.RandomPaddingCrop(crop_size=train_crop_size),
-                transforms.Normalize(
+            transforms = [
+                RandomHorizontalFlip(prob=0.5), ResizeStepScaling(),
+                RandomPaddingCrop(crop_size=train_crop_size), Normalize(
                     mean=mean, std=std)
-            ])
+            ]
         else:
             # 验证/预测时的transforms
-            self.transforms = transforms.Compose(
-                [transforms.Normalize(
-                    mean=mean, std=std)])
+            transforms = [transforms.Normalize(mean=mean, std=std)]
+
+        super(ComposedSegTransforms, self).__init__(transforms)