Эх сурвалжийг харах

add hrnet backbone for fasterrcnn

will-jl944 4 жил өмнө
parent
commit
30ea279905

+ 1 - 1
dygraph/PaddleDetection

@@ -1 +1 @@
-Subproject commit 57e9f917dd8191834bcfc98a5df239d7fb42c941
+Subproject commit 66d7eefab9aca8243ddf49a52b748b786b80ffb5

+ 29 - 8
dygraph/paddlex/cv/models/detector.py

@@ -699,14 +699,22 @@ class FasterRCNN(BaseDetector):
         self.init_params = locals()
         if backbone not in [
                 'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
-                'ResNet34_vd', 'ResNet101', 'ResNet101_vd'
+                'ResNet34_vd', 'ResNet101', 'ResNet101_vd', 'HRNet'
         ]:
             raise ValueError(
                 "backbone: {} is not supported. Please choose one of "
                 "('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
-                "'ResNet101', 'ResNet101_vd')".format(backbone))
-        self.backbone_name = backbone + '_fpn' if with_fpn else backbone
-        if backbone == 'ResNet50_vd_ssld':
+                "'ResNet101', 'ResNet101_vd', 'HRNet')".format(backbone))
+        self.backbone_name = backbone
+        if backbone == 'HRNet':
+            if not with_fpn:
+                logging.warning(
+                    "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
+                    format(backbone))
+                with_fpn = True
+            backbone = self._get_backbone(
+                'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
+        elif backbone == 'ResNet50_vd_ssld':
             if not with_fpn:
                 logging.warning(
                     "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
@@ -769,10 +777,23 @@ class FasterRCNN(BaseDetector):
         rpn_in_channel = backbone.out_shape[0].channels
 
         if with_fpn:
-            neck = ppdet.modeling.FPN(
-                in_channels=[i.channels for i in backbone.out_shape],
-                out_channel=fpn_num_channels,
-                spatial_scales=[1.0 / i.stride for i in backbone.out_shape])
+            self.backbone_name = self.backbone_name + '_fpn'
+
+            if 'HRNet' in self.backbone_name:
+                neck = ppdet.modeling.HRFPN(
+                    in_channel=[i.channels for i in backbone.out_shape],
+                    out_channel=fpn_num_channels,
+                    spatial_scales=[
+                        1.0 / i.stride for i in backbone.out_shape
+                    ],
+                    share_conv=False)
+            else:
+                neck = ppdet.modeling.FPN(
+                    in_channels=[i.channels for i in backbone.out_shape],
+                    out_channel=fpn_num_channels,
+                    spatial_scales=[
+                        1.0 / i.stride for i in backbone.out_shape
+                    ])
             rpn_in_channel = neck.out_shape[0].channels
             anchor_generator_cfg = {
                 'aspect_ratios': aspect_ratios,