Переглянути джерело

override _compose_batch_transform in det.py

will-jl944 4 роки тому
батько
коміт
cc857bb773
2 змінених файлів з 96 додано та 10 видалено
  1. 95 9
      dygraph/paddlex/det.py
  2. 1 1
      dygraph/paddlex/seg.py

+ 95 - 9
dygraph/paddlex/det.py

@@ -11,11 +11,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import logging
 
+import copy
 from . import cv
 from .cv.models.utils.visualize import visualize_detection, draw_pr_curve
 from paddlex.cv.transforms import det_transforms
+from paddlex.cv.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH
+from paddlex.cv.transforms.batch_operators import BatchCompose, BatchRandomResize, BatchRandomResizeByShort, \
+    _BatchPadding, _Gt2YoloTarget
+import paddlex.utils.logging as logging
 
 transforms = det_transforms
 
@@ -105,14 +109,10 @@ class YOLOv3(cv.models.YOLOv3):
                  nms_keep_topk=100,
                  nms_iou_threshold=0.45,
                  label_smooth=False,
-                 train_random_shapes=None,
+                 train_random_shapes=[
+                     320, 352, 384, 416, 448, 480, 512, 544, 576, 608
+                 ],
                  input_channel=None):
-        if train_random_shapes is not None:
-            logging.warning(
-                "`train_random_shapes` is deprecated in PaddleX 2.0 and won't take effect. "
-                "To apply multi_scale training, please refer to paddlex.transforms.BatchRandomResize: "
-                "'https://github.com/PaddlePaddle/PaddleX/blob/develop/dygraph/paddlex/cv/transforms/batch_operators.py#L53'"
-            )
         if input_channel is not None:
             logging.warning(
                 "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
@@ -128,6 +128,48 @@ class YOLOv3(cv.models.YOLOv3):
             nms_keep_topk=nms_keep_topk,
             nms_iou_threshold=nms_iou_threshold,
             label_smooth=label_smooth)
+        self.train_random_shapes = train_random_shapes
+
+    def _compose_batch_transform(self, transforms, mode='train'):
+        if mode == 'train':
+            default_batch_transforms = [
+                _BatchPadding(pad_to_stride=-1), _NormalizeBox(),
+                _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
+                _Gt2YoloTarget(
+                    anchor_masks=self.anchor_masks,
+                    anchors=self.anchors,
+                    downsample_ratios=getattr(self, 'downsample_ratios',
+                                              [32, 16, 8]),
+                    num_classes=self.num_classes)
+            ]
+        else:
+            default_batch_transforms = [_BatchPadding(pad_to_stride=-1)]
+        if mode == 'eval' and self.metric == 'voc':
+            collate_batch = False
+        else:
+            collate_batch = True
+
+        custom_batch_transforms = []
+        for i, op in enumerate(transforms.transforms):
+            if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
+                if mode != 'train':
+                    raise Exception(
+                        "{} cannot be present in the {} transforms. ".format(
+                            op.__class__.__name__, mode) +
+                        "Please check the {} transforms.".format(mode))
+                custom_batch_transforms.insert(0, copy.deepcopy(op))
+                random_shape_defined = True
+        if not random_shape_defined:
+            default_batch_transforms.insert(
+                0,
+                BatchRandomResize(
+                    target_sizes=self.train_random_shapes, interp='RANDOM'))
+
+        batch_transforms = BatchCompose(
+            custom_batch_transforms + default_batch_transforms,
+            collate_batch=collate_batch)
+
+        return batch_transforms
 
 
 class PPYOLO(cv.models.PPYOLO):
@@ -154,7 +196,9 @@ class PPYOLO(cv.models.PPYOLO):
             nms_topk=1000,
             nms_keep_topk=100,
             nms_iou_threshold=0.45,
-            train_random_shapes=None,
+            train_random_shapes=[
+                320, 352, 384, 416, 448, 480, 512, 544, 576, 608
+            ],
             input_channel=None):
         if backbone == 'ResNet50_vd_ssld':
             backbone = 'ResNet50_vd_dcn'
@@ -192,3 +236,45 @@ class PPYOLO(cv.models.PPYOLO):
             nms_topk=nms_topk,
             nms_keep_topk=nms_keep_topk,
             nms_iou_threshold=nms_iou_threshold)
+        self.train_random_shapes = train_random_shapes
+
+    def _compose_batch_transform(self, transforms, mode='train'):
+        if mode == 'train':
+            default_batch_transforms = [
+                _BatchPadding(pad_to_stride=-1), _NormalizeBox(),
+                _PadBox(getattr(self, 'num_max_boxes', 50)), _BboxXYXY2XYWH(),
+                _Gt2YoloTarget(
+                    anchor_masks=self.anchor_masks,
+                    anchors=self.anchors,
+                    downsample_ratios=getattr(self, 'downsample_ratios',
+                                              [32, 16, 8]),
+                    num_classes=self.num_classes)
+            ]
+        else:
+            default_batch_transforms = [_BatchPadding(pad_to_stride=-1)]
+        if mode == 'eval' and self.metric == 'voc':
+            collate_batch = False
+        else:
+            collate_batch = True
+
+        custom_batch_transforms = []
+        for i, op in enumerate(transforms.transforms):
+            if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
+                if mode != 'train':
+                    raise Exception(
+                        "{} cannot be present in the {} transforms. ".format(
+                            op.__class__.__name__, mode) +
+                        "Please check the {} transforms.".format(mode))
+                custom_batch_transforms.insert(0, copy.deepcopy(op))
+                random_shape_defined = True
+        if not random_shape_defined:
+            default_batch_transforms.insert(
+                0,
+                BatchRandomResize(
+                    target_sizes=self.train_random_shapes, interp='RANDOM'))
+
+        batch_transforms = BatchCompose(
+            custom_batch_transforms + default_batch_transforms,
+            collate_batch=collate_batch)
+
+        return batch_transforms

+ 1 - 1
dygraph/paddlex/seg.py

@@ -11,11 +11,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import logging
 
 from . import cv
 from .cv.models.utils.visualize import visualize_segmentation
 from paddlex.cv.transforms import seg_transforms
+import paddlex.utils.logging as logging
 
 transforms = seg_transforms