|
@@ -848,7 +848,8 @@ class FasterRCNN(BaseDetector):
|
|
|
if test_pre_nms_top_n is None else test_pre_nms_top_n,
|
|
if test_pre_nms_top_n is None else test_pre_nms_top_n,
|
|
|
'post_nms_top_n': test_post_nms_top_n
|
|
'post_nms_top_n': test_post_nms_top_n
|
|
|
}
|
|
}
|
|
|
- head = ppdet.modeling.TwoFCHead(out_channel=1024)
|
|
|
|
|
|
|
+ head = ppdet.modeling.TwoFCHead(
|
|
|
|
|
+ in_channel=neck.out_shape[0].channels, out_channel=1024)
|
|
|
roi_extractor_cfg = {
|
|
roi_extractor_cfg = {
|
|
|
'resolution': 7,
|
|
'resolution': 7,
|
|
|
'spatial_scale': [1. / i.stride for i in neck.out_shape],
|
|
'spatial_scale': [1. / i.stride for i in neck.out_shape],
|