Browse Source

default value of anchor and masks is None in 1.x yolo

will-jl944 4 years ago
parent
commit
8bdf091e84
1 changed files with 9 additions and 6 deletions
  1. 9 6
      dygraph/paddlex/det.py

+ 9 - 6
dygraph/paddlex/det.py

@@ -81,6 +81,8 @@ class FasterRCNN(cv.models.FasterRCNN):
             logging.warning(
                 "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
             )
+        if isinstance(anchor_sizes[0], int):
+            anchor_sizes = [[size] for size in anchor_sizes]
         super(FasterRCNN, self).__init__(
             num_classes=num_classes - 1,
             backbone=backbone,
@@ -117,6 +119,11 @@ class YOLOv3(cv.models.YOLOv3):
             logging.warning(
                 "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
             )
+        if anchors is None:
+            anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
+                       [59, 119], [116, 90], [156, 198], [373, 326]]
+        if anchor_masks is None:
+            anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
         super(YOLOv3, self).__init__(
             num_classes=num_classes,
             backbone=backbone,
@@ -150,6 +157,7 @@ class YOLOv3(cv.models.YOLOv3):
             collate_batch = True
 
         custom_batch_transforms = []
+        random_shape_defined = False
         for i, op in enumerate(transforms.transforms):
             if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
                 if mode != 'train':
@@ -208,12 +216,6 @@ class PPYOLO(cv.models.PPYOLO):
                 "To use backbone with deformable convolutional networks, "
                 "please specify in `backbone_name`. "
                 "Currently the only backbone with dcn is 'ResNet50_vd_dcn'.")
-        if train_random_shapes is not None:
-            logging.warning(
-                "`train_random_shapes` is deprecated in PaddleX 2.0 and won't take effect. "
-                "To apply multi_scale training, please refer to paddlex.transforms.BatchRandomResize: "
-                "'https://github.com/PaddlePaddle/PaddleX/blob/develop/dygraph/paddlex/cv/transforms/batch_operators.py#L53'"
-            )
         if input_channel is not None:
             logging.warning(
                 "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
@@ -258,6 +260,7 @@ class PPYOLO(cv.models.PPYOLO):
             collate_batch = True
 
         custom_batch_transforms = []
+        random_shape_defined = False
         for i, op in enumerate(transforms.transforms):
             if isinstance(op, (BatchRandomResize, BatchRandomResizeByShort)):
                 if mode != 'train':