Explorar o código

add dcn to rcnn

FlyingQianMM %!s(int64=4) %!d(string=hai) anos
pai
achega
79f3625894
Modificáronse 1 ficheiros con 30 adicións e 7 borrados
  1. 30 7
      dygraph/paddlex/cv/models/detector.py

+ 30 - 7
dygraph/paddlex/cv/models/detector.py

@@ -720,6 +720,7 @@ class FasterRCNN(BaseDetector):
                  num_classes=80,
                  backbone='ResNet50',
                  with_fpn=True,
+                 with_dcn=False,
                  aspect_ratios=[0.5, 1.0, 2.0],
                  anchor_sizes=[[32], [64], [128], [256], [512]],
                  keep_top_k=100,
@@ -740,12 +741,17 @@ class FasterRCNN(BaseDetector):
                 "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
                 "'ResNet101', 'ResNet101_vd', 'HRNet_W18')".format(backbone))
         self.backbone_name = backbone
+        dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
         if backbone == 'HRNet_W18':
             if not with_fpn:
                 logging.warning(
                     "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
                     format(backbone))
                 with_fpn = True
+            if with_dcn:
+                logging.warning(
+                    "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
+                    format(backbone))
             backbone = self._get_backbone(
                 'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
         elif backbone == 'ResNet50_vd_ssld':
@@ -761,7 +767,8 @@ class FasterRCNN(BaseDetector):
                 freeze_at=0,
                 return_idx=[0, 1, 2, 3],
                 num_stages=4,
-                lr_mult_list=[0.05, 0.05, 0.1, 0.15])
+                lr_mult_list=[0.05, 0.05, 0.1, 0.15],
+                dcn_v2_stages=dcn_v2_stages)
         elif 'ResNet50' in backbone:
             if with_fpn:
                 backbone = self._get_backbone(
@@ -770,8 +777,13 @@ class FasterRCNN(BaseDetector):
                     norm_type='bn',
                     freeze_at=0,
                     return_idx=[0, 1, 2, 3],
-                    num_stages=4)
+                    num_stages=4,
+                    dcn_v2_stages=dcn_v2_stages)
             else:
+                if with_dcn:
+                    logging.warning(
+                        "Backbone {} without fpn should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
+                        format(backbone))
                 backbone = self._get_backbone(
                     'ResNet',
                     variant='d' if '_vd' in backbone else 'b',
@@ -792,7 +804,8 @@ class FasterRCNN(BaseDetector):
                 norm_type='bn',
                 freeze_at=0,
                 return_idx=[0, 1, 2, 3],
-                num_stages=4)
+                num_stages=4,
+                dcn_v2_stages=dcn_v2_stages)
         else:
             if not with_fpn:
                 logging.warning(
@@ -806,7 +819,8 @@ class FasterRCNN(BaseDetector):
                 norm_type='bn',
                 freeze_at=0,
                 return_idx=[0, 1, 2, 3],
-                num_stages=4)
+                num_stages=4,
+                dcn_v2_stages=dcn_v2_stages)
 
         rpn_in_channel = backbone.out_shape[0].channels
 
@@ -1417,6 +1431,7 @@ class MaskRCNN(BaseDetector):
                  num_classes=80,
                  backbone='ResNet50_vd',
                  with_fpn=True,
+                 with_dcn=False,
                  aspect_ratios=[0.5, 1.0, 2.0],
                  anchor_sizes=[[32], [64], [128], [256], [512]],
                  keep_top_k=100,
@@ -1438,6 +1453,7 @@ class MaskRCNN(BaseDetector):
                 format(backbone))
 
         self.backbone_name = backbone + '_fpn' if with_fpn else backbone
+        dcn_v2_stages = [1, 2, 3] if with_dcn else [-1]
 
         if backbone == 'ResNet50':
             if with_fpn:
@@ -1446,8 +1462,13 @@ class MaskRCNN(BaseDetector):
                     norm_type='bn',
                     freeze_at=0,
                     return_idx=[0, 1, 2, 3],
-                    num_stages=4)
+                    num_stages=4,
+                    dcn_v2_stages=dcn_v2_stages)
             else:
+                if with_dcn:
+                    logging.warning(
+                        "Backbone {} should be used along with dcn disabled, 'with_dcn' is forcibly set to False".
+                        format(backbone))
                 backbone = self._get_backbone(
                     'ResNet',
                     norm_type='bn',
@@ -1469,7 +1490,8 @@ class MaskRCNN(BaseDetector):
                 return_idx=[0, 1, 2, 3],
                 num_stages=4,
                 lr_mult_list=[0.05, 0.05, 0.1, 0.15]
-                if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0])
+                if '_ssld' in backbone else [1.0, 1.0, 1.0, 1.0],
+                dcn_v2_stages=dcn_v2_stages)
 
         else:
             if not with_fpn:
@@ -1484,7 +1506,8 @@ class MaskRCNN(BaseDetector):
                 norm_type='bn',
                 freeze_at=0,
                 return_idx=[0, 1, 2, 3],
-                num_stages=4)
+                num_stages=4,
+                dcn_v2_stages=dcn_v2_stages)
 
         rpn_in_channel = backbone.out_shape[0].channels