|
|
@@ -699,14 +699,22 @@ class FasterRCNN(BaseDetector):
|
|
|
self.init_params = locals()
|
|
|
if backbone not in [
|
|
|
'ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34',
|
|
|
- 'ResNet34_vd', 'ResNet101', 'ResNet101_vd'
|
|
|
+ 'ResNet34_vd', 'ResNet101', 'ResNet101_vd', 'HRNet'
|
|
|
]:
|
|
|
raise ValueError(
|
|
|
"backbone: {} is not supported. Please choose one of "
|
|
|
"('ResNet50', 'ResNet50_vd', 'ResNet50_vd_ssld', 'ResNet34', 'ResNet34_vd', "
|
|
|
- "'ResNet101', 'ResNet101_vd')".format(backbone))
|
|
|
- self.backbone_name = backbone + '_fpn' if with_fpn else backbone
|
|
|
- if backbone == 'ResNet50_vd_ssld':
|
|
|
+ "'ResNet101', 'ResNet101_vd', 'HRNet')".format(backbone))
|
|
|
+ self.backbone_name = backbone
|
|
|
+ if backbone == 'HRNet':
|
|
|
+ if not with_fpn:
|
|
|
+ logging.warning(
|
|
|
+ "Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
|
|
|
+ format(backbone))
|
|
|
+ with_fpn = True
|
|
|
+ backbone = self._get_backbone(
|
|
|
+ 'HRNet', width=18, freeze_at=0, return_idx=[0, 1, 2, 3])
|
|
|
+ elif backbone == 'ResNet50_vd_ssld':
|
|
|
if not with_fpn:
|
|
|
logging.warning(
|
|
|
"Backbone {} should be used along with fpn enabled, 'with_fpn' is forcibly set to True".
|
|
|
@@ -769,10 +777,23 @@ class FasterRCNN(BaseDetector):
|
|
|
rpn_in_channel = backbone.out_shape[0].channels
|
|
|
|
|
|
if with_fpn:
|
|
|
- neck = ppdet.modeling.FPN(
|
|
|
- in_channels=[i.channels for i in backbone.out_shape],
|
|
|
- out_channel=fpn_num_channels,
|
|
|
- spatial_scales=[1.0 / i.stride for i in backbone.out_shape])
|
|
|
+ self.backbone_name = self.backbone_name + '_fpn'
|
|
|
+
|
|
|
+ if 'HRNet' in self.backbone_name:
|
|
|
+ neck = ppdet.modeling.HRFPN(
|
|
|
+ in_channel=[i.channels for i in backbone.out_shape],
|
|
|
+ out_channel=fpn_num_channels,
|
|
|
+ spatial_scales=[
|
|
|
+ 1.0 / i.stride for i in backbone.out_shape
|
|
|
+ ],
|
|
|
+ share_conv=False)
|
|
|
+ else:
|
|
|
+ neck = ppdet.modeling.FPN(
|
|
|
+ in_channels=[i.channels for i in backbone.out_shape],
|
|
|
+ out_channel=fpn_num_channels,
|
|
|
+ spatial_scales=[
|
|
|
+ 1.0 / i.stride for i in backbone.out_shape
|
|
|
+ ])
|
|
|
rpn_in_channel = neck.out_shape[0].channels
|
|
|
anchor_generator_cfg = {
|
|
|
'aspect_ratios': aspect_ratios,
|