浏览代码

change ResNetXvd to ResNetX_vd

FlyingQianMM 5 年之前
父节点
当前提交
200eb174db
共有 2 个文件被更改,包括 6 次插入6 次删除
  1. 4 4
      paddlex/cv/models/faster_rcnn.py
  2. 2 2
      paddlex/cv/models/mask_rcnn.py

+ 4 - 4
paddlex/cv/models/faster_rcnn.py

@@ -32,7 +32,7 @@ class FasterRCNN(BaseAPI):
     Args:
         num_classes (int): 包含了背景类的类别数。默认为81。
         backbone (str): FasterRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50',
-            'ResNet50vd', 'ResNet101', 'ResNet101vd']。默认为'ResNet50'。
+            'ResNet50_vd', 'ResNet101', 'ResNet101_vd']。默认为'ResNet50'。
         with_fpn (bool): 是否使用FPN结构。默认为True。
         aspect_ratios (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
         anchor_sizes (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
@@ -47,7 +47,7 @@ class FasterRCNN(BaseAPI):
         self.init_params = locals()
         super(FasterRCNN, self).__init__('detector')
         backbones = [
-            'ResNet18', 'ResNet50', 'ResNet50vd', 'ResNet101', 'ResNet101vd'
+            'ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd'
         ]
         assert backbone in backbones, "backbone should be one of {}".format(
             backbones)
@@ -66,7 +66,7 @@ class FasterRCNN(BaseAPI):
         elif backbone_name == 'ResNet50':
             layers = 50
             variant = 'b'
-        elif backbone_name == 'ResNet50vd':
+        elif backbone_name == 'ResNet50_vd':
             layers = 50
             variant = 'd'
             norm_type = 'affine_channel'
@@ -74,7 +74,7 @@ class FasterRCNN(BaseAPI):
             layers = 101
             variant = 'b'
             norm_type = 'affine_channel'
-        elif backbone_name == 'ResNet101vd':
+        elif backbone_name == 'ResNet101_vd':
             layers = 101
             variant = 'd'
             norm_type = 'affine_channel'

+ 2 - 2
paddlex/cv/models/mask_rcnn.py

@@ -32,7 +32,7 @@ class MaskRCNN(FasterRCNN):
     Args:
         num_classes (int): 包含了背景类的类别数。默认为81。
         backbone (str): MaskRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50',
-            'ResNet50vd', 'ResNet101', 'ResNet101vd']。默认为'ResNet50'。
+            'ResNet50_vd', 'ResNet101', 'ResNet101_vd']。默认为'ResNet50'。
         with_fpn (bool): 是否使用FPN结构。默认为True。
         aspect_ratios (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
         anchor_sizes (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
@@ -46,7 +46,7 @@ class MaskRCNN(FasterRCNN):
                  anchor_sizes=[32, 64, 128, 256, 512]):
         self.init_params = locals()
         backbones = [
-            'ResNet18', 'ResNet50', 'ResNet50vd', 'ResNet101', 'ResNet101vd'
+            'ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd'
         ]
         assert backbone in backbones, "backbone should be one of {}".format(
             backbones)