bbox_head.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from collections import OrderedDict
  18. from paddle import fluid
  19. from paddle.fluid.param_attr import ParamAttr
  20. from paddle.fluid.initializer import Normal, Xavier
  21. from paddle.fluid.regularizer import L2Decay
  22. from paddle.fluid.initializer import MSRA
  23. __all__ = ['BBoxHead', 'TwoFCHead']
  24. class TwoFCHead(object):
  25. """
  26. RCNN head with two Fully Connected layers
  27. Args:
  28. mlp_dim (int): num of filters for the fc layers
  29. """
  30. def __init__(self, mlp_dim=1024):
  31. super(TwoFCHead, self).__init__()
  32. self.mlp_dim = mlp_dim
  33. def __call__(self, roi_feat):
  34. fan = roi_feat.shape[1] * roi_feat.shape[2] * roi_feat.shape[3]
  35. fc6 = fluid.layers.fc(
  36. input=roi_feat,
  37. size=self.mlp_dim,
  38. act='relu',
  39. name='fc6',
  40. param_attr=ParamAttr(
  41. name='fc6_w', initializer=Xavier(fan_out=fan)),
  42. bias_attr=ParamAttr(
  43. name='fc6_b', learning_rate=2., regularizer=L2Decay(0.)))
  44. head_feat = fluid.layers.fc(
  45. input=fc6,
  46. size=self.mlp_dim,
  47. act='relu',
  48. name='fc7',
  49. param_attr=ParamAttr(name='fc7_w', initializer=Xavier()),
  50. bias_attr=ParamAttr(
  51. name='fc7_b', learning_rate=2., regularizer=L2Decay(0.)))
  52. return head_feat
  53. class BBoxHead(object):
  54. def __init__(
  55. self,
  56. head,
  57. #box_coder
  58. prior_box_var=[0.1, 0.1, 0.2, 0.2],
  59. code_type='decode_center_size',
  60. box_normalized=False,
  61. axis=1,
  62. #MultiClassNMS
  63. score_threshold=.05,
  64. nms_top_k=-1,
  65. keep_top_k=100,
  66. nms_threshold=.5,
  67. normalized=False,
  68. nms_eta=1.0,
  69. background_label=0,
  70. #bbox_loss
  71. sigma=1.0,
  72. num_classes=81):
  73. super(BBoxHead, self).__init__()
  74. self.head = head
  75. self.prior_box_var = prior_box_var
  76. self.code_type = code_type
  77. self.box_normalized = box_normalized
  78. self.axis = axis
  79. self.score_threshold = score_threshold
  80. self.nms_top_k = nms_top_k
  81. self.keep_top_k = keep_top_k
  82. self.nms_threshold = nms_threshold
  83. self.normalized = normalized
  84. self.nms_eta = nms_eta
  85. self.background_label = background_label
  86. self.sigma = sigma
  87. self.num_classes = num_classes
  88. self.head_feat = None
  89. def get_head_feat(self, input=None):
  90. """
  91. Get the bbox head feature map.
  92. """
  93. if input is not None:
  94. feat = self.head(input)
  95. if isinstance(feat, OrderedDict):
  96. feat = list(feat.values())[0]
  97. self.head_feat = feat
  98. return self.head_feat
  99. def _get_output(self, roi_feat):
  100. """
  101. Get bbox head output.
  102. Args:
  103. roi_feat (Variable): RoI feature from RoIExtractor.
  104. Returns:
  105. cls_score(Variable): Output of rpn head with shape of
  106. [N, num_anchors, H, W].
  107. bbox_pred(Variable): Output of rpn head with shape of
  108. [N, num_anchors * 4, H, W].
  109. """
  110. head_feat = self.get_head_feat(roi_feat)
  111. # when ResNetC5 output a single feature map
  112. if not isinstance(self.head, TwoFCHead):
  113. head_feat = fluid.layers.pool2d(
  114. head_feat, pool_type='avg', global_pooling=True)
  115. cls_score = fluid.layers.fc(
  116. input=head_feat,
  117. size=self.num_classes,
  118. act=None,
  119. name='cls_score',
  120. param_attr=ParamAttr(
  121. name='cls_score_w', initializer=Normal(loc=0.0, scale=0.01)),
  122. bias_attr=ParamAttr(
  123. name='cls_score_b', learning_rate=2., regularizer=L2Decay(0.)))
  124. bbox_pred = fluid.layers.fc(
  125. input=head_feat,
  126. size=4 * self.num_classes,
  127. act=None,
  128. name='bbox_pred',
  129. param_attr=ParamAttr(
  130. name='bbox_pred_w', initializer=Normal(loc=0.0, scale=0.001)),
  131. bias_attr=ParamAttr(
  132. name='bbox_pred_b', learning_rate=2., regularizer=L2Decay(0.)))
  133. return cls_score, bbox_pred
  134. def get_loss(self, roi_feat, labels_int32, bbox_targets,
  135. bbox_inside_weights, bbox_outside_weights):
  136. """
  137. Get bbox_head loss.
  138. Args:
  139. roi_feat (Variable): RoI feature from RoIExtractor.
  140. labels_int32(Variable): Class label of a RoI with shape [P, 1].
  141. P is the number of RoI.
  142. bbox_targets(Variable): Box label of a RoI with shape
  143. [P, 4 * class_nums].
  144. bbox_inside_weights(Variable): Indicates whether a box should
  145. contribute to loss. Same shape as bbox_targets.
  146. bbox_outside_weights(Variable): Indicates whether a box should
  147. contribute to loss. Same shape as bbox_targets.
  148. Return:
  149. Type: Dict
  150. loss_cls(Variable): bbox_head loss.
  151. loss_bbox(Variable): bbox_head loss.
  152. """
  153. cls_score, bbox_pred = self._get_output(roi_feat)
  154. labels_int64 = fluid.layers.cast(x=labels_int32, dtype='int64')
  155. labels_int64.stop_gradient = True
  156. loss_cls = fluid.layers.softmax_with_cross_entropy(
  157. logits=cls_score, label=labels_int64, numeric_stable_mode=True)
  158. loss_cls = fluid.layers.reduce_mean(loss_cls)
  159. loss_bbox = fluid.layers.smooth_l1(
  160. x=bbox_pred,
  161. y=bbox_targets,
  162. inside_weight=bbox_inside_weights,
  163. outside_weight=bbox_outside_weights,
  164. sigma=self.sigma)
  165. loss_bbox = fluid.layers.reduce_mean(loss_bbox)
  166. return {'loss_cls': loss_cls, 'loss_bbox': loss_bbox}
  167. def get_prediction(self,
  168. roi_feat,
  169. rois,
  170. im_info,
  171. im_shape,
  172. return_box_score=False):
  173. """
  174. Get prediction bounding box in test stage.
  175. Args:
  176. roi_feat (Variable): RoI feature from RoIExtractor.
  177. rois (Variable): Output of generate_proposals in rpn head.
  178. im_info (Variable): A 2-D LoDTensor with shape [B, 3]. B is the
  179. number of input images, each element consists of im_height,
  180. im_width, im_scale.
  181. im_shape (Variable): Actual shape of original image with shape
  182. [B, 3]. B is the number of images, each element consists of
  183. original_height, original_width, 1
  184. Returns:
  185. pred_result(Variable): Prediction result with shape [N, 6]. Each
  186. row has 6 values: [label, confidence, xmin, ymin, xmax, ymax].
  187. N is the total number of prediction.
  188. """
  189. cls_score, bbox_pred = self._get_output(roi_feat)
  190. im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3])
  191. im_scale = fluid.layers.sequence_expand(im_scale, rois)
  192. boxes = rois / im_scale
  193. cls_prob = fluid.layers.softmax(cls_score, use_cudnn=False)
  194. bbox_pred = fluid.layers.reshape(bbox_pred, (-1, self.num_classes, 4))
  195. decoded_box = fluid.layers.box_coder(
  196. prior_box=boxes,
  197. target_box=bbox_pred,
  198. prior_box_var=self.prior_box_var,
  199. code_type=self.code_type,
  200. box_normalized=self.box_normalized,
  201. axis=self.axis)
  202. cliped_box = fluid.layers.box_clip(input=decoded_box, im_info=im_shape)
  203. if return_box_score:
  204. return {'bbox': cliped_box, 'score': cls_prob}
  205. pred_result = fluid.layers.multiclass_nms(
  206. bboxes=cliped_box,
  207. scores=cls_prob,
  208. score_threshold=self.score_threshold,
  209. nms_top_k=self.nms_top_k,
  210. keep_top_k=self.keep_top_k,
  211. nms_threshold=self.nms_threshold,
  212. normalized=self.normalized,
  213. nms_eta=self.nms_eta,
  214. background_label=self.background_label)
  215. return {'bbox': pred_result}