bbox_head.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from collections import OrderedDict
  18. from paddle import fluid
  19. from paddle.fluid.param_attr import ParamAttr
  20. from paddle.fluid.initializer import Normal, Xavier
  21. from paddle.fluid.regularizer import L2Decay
  22. from paddle.fluid.initializer import MSRA
  23. from .loss.diou_loss import DiouLoss
  24. from .ops import MultiClassNMS, MatrixNMS, MultiClassSoftNMS, MultiClassDiouNMS
  25. __all__ = ['BBoxHead', 'TwoFCHead']
  26. class TwoFCHead(object):
  27. """
  28. RCNN head with two Fully Connected layers
  29. Args:
  30. mlp_dim (int): num of filters for the fc layers
  31. """
  32. def __init__(self, mlp_dim=1024):
  33. super(TwoFCHead, self).__init__()
  34. self.mlp_dim = mlp_dim
  35. def __call__(self, roi_feat):
  36. fan = roi_feat.shape[1] * roi_feat.shape[2] * roi_feat.shape[3]
  37. fc6 = fluid.layers.fc(input=roi_feat,
  38. size=self.mlp_dim,
  39. act='relu',
  40. name='fc6',
  41. param_attr=ParamAttr(
  42. name='fc6_w',
  43. initializer=Xavier(fan_out=fan)),
  44. bias_attr=ParamAttr(
  45. name='fc6_b',
  46. learning_rate=2.,
  47. regularizer=L2Decay(0.)))
  48. head_feat = fluid.layers.fc(input=fc6,
  49. size=self.mlp_dim,
  50. act='relu',
  51. name='fc7',
  52. param_attr=ParamAttr(
  53. name='fc7_w', initializer=Xavier()),
  54. bias_attr=ParamAttr(
  55. name='fc7_b',
  56. learning_rate=2.,
  57. regularizer=L2Decay(0.)))
  58. return head_feat
  59. class BBoxHead(object):
  60. def __init__(
  61. self,
  62. head,
  63. #box_coder
  64. prior_box_var=[0.1, 0.1, 0.2, 0.2],
  65. code_type='decode_center_size',
  66. box_normalized=False,
  67. axis=1,
  68. #MultiClassNMS
  69. rcnn_nms='MultiClassNMS',
  70. score_threshold=.05,
  71. nms_top_k=-1,
  72. keep_top_k=100,
  73. nms_threshold=.5,
  74. normalized=False,
  75. nms_eta=1.0,
  76. background_label=0,
  77. post_threshold=.05,
  78. softnms_sigma=0.5,
  79. #bbox_loss
  80. sigma=1.0,
  81. num_classes=81,
  82. rcnn_bbox_loss='SmoothL1Loss',
  83. diouloss_weight=10.0,
  84. diouloss_is_cls_agnostic=False,
  85. diouloss_use_complete_iou_loss=True):
  86. super(BBoxHead, self).__init__()
  87. self.head = head
  88. self.prior_box_var = prior_box_var
  89. self.code_type = code_type
  90. self.box_normalized = box_normalized
  91. self.axis = axis
  92. self.sigma = sigma
  93. self.num_classes = num_classes
  94. self.head_feat = None
  95. self.rcnn_bbox_loss = rcnn_bbox_loss
  96. self.diouloss_weight = diouloss_weight
  97. self.diouloss_is_cls_agnostic = diouloss_is_cls_agnostic
  98. self.diouloss_use_complete_iou_loss = diouloss_use_complete_iou_loss
  99. if self.rcnn_bbox_loss == 'DIoULoss':
  100. self.diou_loss = DiouLoss(
  101. loss_weight=self.diouloss_weight,
  102. is_cls_agnostic=self.diouloss_is_cls_agnostic,
  103. num_classes=num_classes,
  104. use_complete_iou_loss=self.diouloss_use_complete_iou_loss)
  105. if rcnn_nms == 'MultiClassNMS':
  106. self.nms = MultiClassNMS(
  107. score_threshold=score_threshold,
  108. keep_top_k=keep_top_k,
  109. nms_threshold=nms_threshold,
  110. normalized=normalized,
  111. nms_eta=nms_eta,
  112. background_label=background_label)
  113. elif rcnn_nms == 'MultiClassSoftNMS':
  114. self.nms = MultiClassSoftNMS(
  115. score_threshold=score_threshold,
  116. keep_top_k=keep_top_k,
  117. softnms_sigma=softnms_sigma,
  118. normalized=normalized,
  119. background_label=background_label)
  120. elif rcnn_nms == 'MatrixNMS':
  121. self.nms = MatrixNMS(
  122. score_threshold=score_threshold,
  123. post_threshold=post_threshold,
  124. keep_top_k=keep_top_k,
  125. normalized=normalized,
  126. background_label=background_label)
  127. elif rcnn_nms == 'MultiClassCiouNMS':
  128. self.nms = MultiClassDiouNMS(
  129. score_threshold=score_threshold,
  130. keep_top_k=keep_top_k,
  131. nms_threshold=nms_threshold,
  132. normalized=normalized,
  133. background_label=background_label)
  134. def get_head_feat(self, input=None):
  135. """
  136. Get the bbox head feature map.
  137. """
  138. if input is not None:
  139. feat = self.head(input)
  140. if isinstance(feat, OrderedDict):
  141. feat = list(feat.values())[0]
  142. self.head_feat = feat
  143. return self.head_feat
  144. def _get_output(self, roi_feat):
  145. """
  146. Get bbox head output.
  147. Args:
  148. roi_feat (Variable): RoI feature from RoIExtractor.
  149. Returns:
  150. cls_score(Variable): Output of rpn head with shape of
  151. [N, num_anchors, H, W].
  152. bbox_pred(Variable): Output of rpn head with shape of
  153. [N, num_anchors * 4, H, W].
  154. """
  155. head_feat = self.get_head_feat(roi_feat)
  156. # when ResNetC5 output a single feature map
  157. if not isinstance(self.head, TwoFCHead):
  158. head_feat = fluid.layers.pool2d(
  159. head_feat, pool_type='avg', global_pooling=True)
  160. cls_score = fluid.layers.fc(input=head_feat,
  161. size=self.num_classes,
  162. act=None,
  163. name='cls_score',
  164. param_attr=ParamAttr(
  165. name='cls_score_w',
  166. initializer=Normal(
  167. loc=0.0, scale=0.01)),
  168. bias_attr=ParamAttr(
  169. name='cls_score_b',
  170. learning_rate=2.,
  171. regularizer=L2Decay(0.)))
  172. bbox_pred = fluid.layers.fc(input=head_feat,
  173. size=4 * self.num_classes,
  174. act=None,
  175. name='bbox_pred',
  176. param_attr=ParamAttr(
  177. name='bbox_pred_w',
  178. initializer=Normal(
  179. loc=0.0, scale=0.001)),
  180. bias_attr=ParamAttr(
  181. name='bbox_pred_b',
  182. learning_rate=2.,
  183. regularizer=L2Decay(0.)))
  184. return cls_score, bbox_pred
  185. def get_loss(self, roi_feat, labels_int32, bbox_targets,
  186. bbox_inside_weights, bbox_outside_weights):
  187. """
  188. Get bbox_head loss.
  189. Args:
  190. roi_feat (Variable): RoI feature from RoIExtractor.
  191. labels_int32(Variable): Class label of a RoI with shape [P, 1].
  192. P is the number of RoI.
  193. bbox_targets(Variable): Box label of a RoI with shape
  194. [P, 4 * class_nums].
  195. bbox_inside_weights(Variable): Indicates whether a box should
  196. contribute to loss. Same shape as bbox_targets.
  197. bbox_outside_weights(Variable): Indicates whether a box should
  198. contribute to loss. Same shape as bbox_targets.
  199. Return:
  200. Type: Dict
  201. loss_cls(Variable): bbox_head loss.
  202. loss_bbox(Variable): bbox_head loss.
  203. """
  204. cls_score, bbox_pred = self._get_output(roi_feat)
  205. labels_int64 = fluid.layers.cast(x=labels_int32, dtype='int64')
  206. labels_int64.stop_gradient = True
  207. loss_cls = fluid.layers.softmax_with_cross_entropy(
  208. logits=cls_score, label=labels_int64, numeric_stable_mode=True)
  209. loss_cls = fluid.layers.reduce_mean(loss_cls)
  210. if self.rcnn_bbox_loss == 'SmoothL1Loss':
  211. loss_bbox = fluid.layers.smooth_l1(
  212. x=bbox_pred,
  213. y=bbox_targets,
  214. inside_weight=bbox_inside_weights,
  215. outside_weight=bbox_outside_weights,
  216. sigma=self.sigma)
  217. elif self.rcnn_bbox_loss == 'DIoULoss':
  218. loss_bbox = self.diou_loss(
  219. x=bbox_pred,
  220. y=bbox_targets,
  221. inside_weight=bbox_inside_weights,
  222. outside_weight=bbox_outside_weights)
  223. loss_bbox = fluid.layers.reduce_mean(loss_bbox)
  224. return {'loss_cls': loss_cls, 'loss_bbox': loss_bbox}
  225. def get_prediction(self,
  226. roi_feat,
  227. rois,
  228. im_info,
  229. im_shape,
  230. return_box_score=False):
  231. """
  232. Get prediction bounding box in test stage.
  233. Args:
  234. roi_feat (Variable): RoI feature from RoIExtractor.
  235. rois (Variable): Output of generate_proposals in rpn head.
  236. im_info (Variable): A 2-D LoDTensor with shape [B, 3]. B is the
  237. number of input images, each element consists of im_height,
  238. im_width, im_scale.
  239. im_shape (Variable): Actual shape of original image with shape
  240. [B, 3]. B is the number of images, each element consists of
  241. original_height, original_width, 1
  242. Returns:
  243. pred_result(Variable): Prediction result with shape [N, 6]. Each
  244. row has 6 values: [label, confidence, xmin, ymin, xmax, ymax].
  245. N is the total number of prediction.
  246. """
  247. cls_score, bbox_pred = self._get_output(roi_feat)
  248. im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3])
  249. im_scale = fluid.layers.sequence_expand(im_scale, rois)
  250. boxes = rois / im_scale
  251. cls_prob = fluid.layers.softmax(cls_score, use_cudnn=False)
  252. bbox_pred = fluid.layers.reshape(bbox_pred, (-1, self.num_classes, 4))
  253. decoded_box = fluid.layers.box_coder(
  254. prior_box=boxes,
  255. target_box=bbox_pred,
  256. prior_box_var=self.prior_box_var,
  257. code_type=self.code_type,
  258. box_normalized=self.box_normalized,
  259. axis=self.axis)
  260. cliped_box = fluid.layers.box_clip(input=decoded_box, im_info=im_shape)
  261. if return_box_score:
  262. return {'bbox': cliped_box, 'score': cls_prob}
  263. pred_result = self.nms(bboxes=cliped_box, scores=cls_prob)
  264. return {'bbox': pred_result}