Browse Source

Merge pull request #867 from FlyingQianMM/develop_add

add dcn to rcnn
FlyingQianMM 4 years ago
parent
commit
5a659133f8
1 changed files with 30 additions and 7 deletions
  1. 30 7
      dygraph/paddlex/cv/models/detector.py

+ 30 - 7
dygraph/paddlex/cv/models/detector.py

@@ -721,6 +721,7 @@ class FasterRCNN(BaseDetector):
                  num_classes=80,
                  num_classes=80,
                  backbone='ResNet50',
                  backbone='ResNet50',
                  with_fpn=True,
                  with_fpn=True,
+                 with_dcn=False,
                  aspect_ratios=[0.5, 1.0, 2.0],
                  aspect_ratios=[0.5, 1.0, 2.0],
                  anchor_sizes=[[32], [64], [128], [256], [512]],
                  anchor_sizes=[[32], [64], [128], [256], [512]],
                  keep_top_k=100,
                  keep_top_k=100,
@@ -741,12 +742,17 @@ class FasterRCNN(BaseDetector):
                 "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
                 "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
                 "'ResNet101', 'ResNet101_vd', 'HRNet_W18')".format(backbone))
                 "'ResNet101', 'ResNet101_vd', 'HRNet_W18')".format(backbone))
         self.backbone_name = backbone
         self.backbone_name = backbone
+        dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
         if backbone == 'HRNet_W18':
         if backbone == 'HRNet_W18':
             if not with_fpn:
             if not with_fpn:
                 logging.warning(
                 logging.warning(
                     "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
                     "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
                     format(backbone))
                     format(backbone))
                 with_fpn = True
                 with_fpn = True
+            if with_dcn:
+                logging.warning(
+                    "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
+                    format(backbone))
             backbone = self._get_backbone(
             backbone = self._get_backbone(
                 'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
                 'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
         elif backbone == 'ResNet50_vd_ssld':
         elif backbone == 'ResNet50_vd_ssld':
@@ -762,7 +768,8 @@ class FasterRCNN(BaseDetector):
                 freeze_at=0,
                 freeze_at=0,
                 return_idx=[0, 1, 2, 3],
                 return_idx=[0, 1, 2, 3],
                 num_stages=4,
                 num_stages=4,
-                lr_mult_list=[0.05, 0.05, 0.1, 0.15])
+                lr_mult_list=[0.05, 0.05, 0.1, 0.15],
+                dcn_v2_stages=dcn_v2_stages)
         elif 'ResNet50' in backbone:
         elif 'ResNet50' in backbone:
             if with_fpn:
             if with_fpn:
                 backbone = self._get_backbone(
                 backbone = self._get_backbone(
@@ -771,8 +778,13 @@ class FasterRCNN(BaseDetector):
                     norm_type='bn',
                     norm_type='bn',
                     freeze_at=0,
                     freeze_at=0,
                     return_idx=[0, 1, 2, 3],
                     return_idx=[0, 1, 2, 3],
-                    num_stages=4)
+                    num_stages=4,
+                    dcn_v2_stages=dcn_v2_stages)
             else:
             else:
+                if with_dcn:
+                    logging.warning(
+                        "Backbone {} without fpn should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
+                        format(backbone))
                 backbone = self._get_backbone(
                 backbone = self._get_backbone(
                     'ResNet',
                     'ResNet',
                     variant='d' if '_vd' in backbone else 'b',
                     variant='d' if '_vd' in backbone else 'b',
@@ -793,7 +805,8 @@ class FasterRCNN(BaseDetector):
                 norm_type='bn',
                 norm_type='bn',
                 freeze_at=0,
                 freeze_at=0,
                 return_idx=[0, 1, 2, 3],
                 return_idx=[0, 1, 2, 3],
-                num_stages=4)
+                num_stages=4,
+                dcn_v2_stages=dcn_v2_stages)
         else:
         else:
             if not with_fpn:
             if not with_fpn:
                 logging.warning(
                 logging.warning(
@@ -807,7 +820,8 @@ class FasterRCNN(BaseDetector):
                 norm_type='bn',
                 norm_type='bn',
                 freeze_at=0,
                 freeze_at=0,
                 return_idx=[0, 1, 2, 3],
                 return_idx=[0, 1, 2, 3],
-                num_stages=4)
+                num_stages=4,
+                dcn_v2_stages=dcn_v2_stages)
 
 
         rpn_in_channel = backbone.out_shape[0].channels
         rpn_in_channel = backbone.out_shape[0].channels
 
 
@@ -1418,6 +1432,7 @@ class MaskRCNN(BaseDetector):
                  num_classes=80,
                  num_classes=80,
                  backbone='ResNet50_vd',
                  backbone='ResNet50_vd',
                  with_fpn=True,
                  with_fpn=True,
+                 with_dcn=False,
                  aspect_ratios=[0.5, 1.0, 2.0],
                  aspect_ratios=[0.5, 1.0, 2.0],
                  anchor_sizes=[[32], [64], [128], [256], [512]],
                  anchor_sizes=[[32], [64], [128], [256], [512]],
                  keep_top_k=100,
                  keep_top_k=100,
@@ -1439,6 +1454,7 @@ class MaskRCNN(BaseDetector):
                 format(backbone))
                 format(backbone))
 
 
         self.backbone_name = backbone + '_fpn' if with_fpn else backbone
         self.backbone_name = backbone + '_fpn' if with_fpn else backbone
+        dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
 
 
         if backbone == 'ResNet50':
         if backbone == 'ResNet50':
             if with_fpn:
             if with_fpn:
@@ -1447,8 +1463,13 @@ class MaskRCNN(BaseDetector):
                     norm_type='bn',
                     norm_type='bn',
                     freeze_at=0,
                     freeze_at=0,
                     return_idx=[0, 1, 2, 3],
                     return_idx=[0, 1, 2, 3],
-                    num_stages=4)
+                    num_stages=4,
+                    dcn_v2_stages=dcn_v2_stages)
             else:
             else:
+                if with_dcn:
+                    logging.warning(
+                        "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
+                        format(backbone))
                 backbone = self._get_backbone(
                 backbone = self._get_backbone(
                     'ResNet',
                     'ResNet',
                     norm_type='bn',
                     norm_type='bn',
@@ -1470,7 +1491,8 @@ class MaskRCNN(BaseDetector):
                 return_idx=[0, 1, 2, 3],
                 return_idx=[0, 1, 2, 3],
                 num_stages=4,
                 num_stages=4,
                 lr_mult_list=[0.05, 0.05, 0.1, 0.15]
                 lr_mult_list=[0.05, 0.05, 0.1, 0.15]
-                if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0])
+                if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0],
+                dcn_v2_stages=dcn_v2_stages)
 
 
         else:
         else:
             if not with_fpn:
             if not with_fpn:
@@ -1485,7 +1507,8 @@ class MaskRCNN(BaseDetector):
                 norm_type='bn',
                 norm_type='bn',
                 freeze_at=0,
                 freeze_at=0,
                 return_idx=[0, 1, 2, 3],
                 return_idx=[0, 1, 2, 3],
-                num_stages=4)
+                num_stages=4,
+                dcn_v2_stages=dcn_v2_stages)
 
 
         rpn_in_channel = backbone.out_shape[0].channels
         rpn_in_channel = backbone.out_shape[0].channels