rpn_head.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle.nn.initializer import Normal
  18. from paddlex.ppdet.core.workspace import register
  19. from .anchor_generator import AnchorGenerator
  20. from .target_layer import RPNTargetAssign
  21. from .proposal_generator import ProposalGenerator
  22. class RPNFeat(nn.Layer):
  23. """
  24. Feature extraction in RPN head
  25. Args:
  26. in_channel (int): Input channel
  27. out_channel (int): Output channel
  28. """
  29. def __init__(self, in_channel=1024, out_channel=1024):
  30. super(RPNFeat, self).__init__()
  31. # rpn feat is shared with each level
  32. self.rpn_conv = nn.Conv2D(
  33. in_channels=in_channel,
  34. out_channels=out_channel,
  35. kernel_size=3,
  36. padding=1,
  37. weight_attr=paddle.ParamAttr(initializer=Normal(
  38. mean=0., std=0.01)))
  39. self.rpn_conv.skip_quant = True
  40. def forward(self, feats):
  41. rpn_feats = []
  42. for feat in feats:
  43. rpn_feats.append(F.relu(self.rpn_conv(feat)))
  44. return rpn_feats
  45. @register
  46. class RPNHead(nn.Layer):
  47. """
  48. Region Proposal Network
  49. Args:
  50. anchor_generator (dict): configure of anchor generation
  51. rpn_target_assign (dict): configure of rpn targets assignment
  52. train_proposal (dict): configure of proposals generation
  53. at the stage of training
  54. test_proposal (dict): configure of proposals generation
  55. at the stage of prediction
  56. in_channel (int): channel of input feature maps which can be
  57. derived by from_config
  58. """
  59. def __init__(self,
  60. anchor_generator=AnchorGenerator().__dict__,
  61. rpn_target_assign=RPNTargetAssign().__dict__,
  62. train_proposal=ProposalGenerator(12000, 2000).__dict__,
  63. test_proposal=ProposalGenerator().__dict__,
  64. in_channel=1024):
  65. super(RPNHead, self).__init__()
  66. self.anchor_generator = anchor_generator
  67. self.rpn_target_assign = rpn_target_assign
  68. self.train_proposal = train_proposal
  69. self.test_proposal = test_proposal
  70. if isinstance(anchor_generator, dict):
  71. self.anchor_generator = AnchorGenerator(**anchor_generator)
  72. if isinstance(rpn_target_assign, dict):
  73. self.rpn_target_assign = RPNTargetAssign(**rpn_target_assign)
  74. if isinstance(train_proposal, dict):
  75. self.train_proposal = ProposalGenerator(**train_proposal)
  76. if isinstance(test_proposal, dict):
  77. self.test_proposal = ProposalGenerator(**test_proposal)
  78. num_anchors = self.anchor_generator.num_anchors
  79. self.rpn_feat = RPNFeat(in_channel, in_channel)
  80. # rpn head is shared with each level
  81. # rpn roi classification scores
  82. self.rpn_rois_score = nn.Conv2D(
  83. in_channels=in_channel,
  84. out_channels=num_anchors,
  85. kernel_size=1,
  86. padding=0,
  87. weight_attr=paddle.ParamAttr(initializer=Normal(
  88. mean=0., std=0.01)))
  89. self.rpn_rois_score.skip_quant = True
  90. # rpn roi bbox regression deltas
  91. self.rpn_rois_delta = nn.Conv2D(
  92. in_channels=in_channel,
  93. out_channels=4 * num_anchors,
  94. kernel_size=1,
  95. padding=0,
  96. weight_attr=paddle.ParamAttr(initializer=Normal(
  97. mean=0., std=0.01)))
  98. self.rpn_rois_delta.skip_quant = True
  99. @classmethod
  100. def from_config(cls, cfg, input_shape):
  101. # FPN share same rpn head
  102. if isinstance(input_shape, (list, tuple)):
  103. input_shape = input_shape[0]
  104. return {'in_channel': input_shape.channels}
  105. def forward(self, feats, inputs):
  106. rpn_feats = self.rpn_feat(feats)
  107. scores = []
  108. deltas = []
  109. for rpn_feat in rpn_feats:
  110. rrs = self.rpn_rois_score(rpn_feat)
  111. rrd = self.rpn_rois_delta(rpn_feat)
  112. scores.append(rrs)
  113. deltas.append(rrd)
  114. anchors = self.anchor_generator(rpn_feats)
  115. rois, rois_num = self._gen_proposal(scores, deltas, anchors, inputs)
  116. if self.training:
  117. loss = self.get_loss(scores, deltas, anchors, inputs)
  118. return rois, rois_num, loss
  119. else:
  120. return rois, rois_num, None
  121. def _gen_proposal(self, scores, bbox_deltas, anchors, inputs):
  122. """
  123. scores (list[Tensor]): Multi-level scores prediction
  124. bbox_deltas (list[Tensor]): Multi-level deltas prediction
  125. anchors (list[Tensor]): Multi-level anchors
  126. inputs (dict): ground truth info
  127. """
  128. prop_gen = self.train_proposal if self.training else self.test_proposal
  129. im_shape = inputs['im_shape']
  130. # Collect multi-level proposals for each batch
  131. # Get 'topk' of them as final output
  132. bs_rois_collect = []
  133. bs_rois_num_collect = []
  134. batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1])
  135. # Generate proposals for each level and each batch.
  136. # Discard batch-computing to avoid sorting bbox cross different batches.
  137. for i in range(batch_size):
  138. rpn_rois_list = []
  139. rpn_prob_list = []
  140. rpn_rois_num_list = []
  141. for rpn_score, rpn_delta, anchor in zip(scores, bbox_deltas,
  142. anchors):
  143. rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n = prop_gen(
  144. scores=rpn_score[i:i + 1],
  145. bbox_deltas=rpn_delta[i:i + 1],
  146. anchors=anchor,
  147. im_shape=im_shape[i:i + 1])
  148. if rpn_rois.shape[0] > 0:
  149. rpn_rois_list.append(rpn_rois)
  150. rpn_prob_list.append(rpn_rois_prob)
  151. rpn_rois_num_list.append(rpn_rois_num)
  152. if len(scores) > 1:
  153. rpn_rois = paddle.concat(rpn_rois_list)
  154. rpn_prob = paddle.concat(rpn_prob_list).flatten()
  155. if rpn_prob.shape[0] > post_nms_top_n:
  156. topk_prob, topk_inds = paddle.topk(rpn_prob,
  157. post_nms_top_n)
  158. topk_rois = paddle.gather(rpn_rois, topk_inds)
  159. else:
  160. topk_rois = rpn_rois
  161. topk_prob = rpn_prob
  162. else:
  163. topk_rois = rpn_rois_list[0]
  164. topk_prob = rpn_prob_list[0].flatten()
  165. bs_rois_collect.append(topk_rois)
  166. bs_rois_num_collect.append(paddle.shape(topk_rois)[0])
  167. bs_rois_num_collect = paddle.concat(bs_rois_num_collect)
  168. return bs_rois_collect, bs_rois_num_collect
  169. def get_loss(self, pred_scores, pred_deltas, anchors, inputs):
  170. """
  171. pred_scores (list[Tensor]): Multi-level scores prediction
  172. pred_deltas (list[Tensor]): Multi-level deltas prediction
  173. anchors (list[Tensor]): Multi-level anchors
  174. inputs (dict): ground truth info, including im, gt_bbox, gt_score
  175. """
  176. anchors = [paddle.reshape(a, shape=(-1, 4)) for a in anchors]
  177. anchors = paddle.concat(anchors)
  178. scores = [
  179. paddle.reshape(
  180. paddle.transpose(
  181. v, perm=[0, 2, 3, 1]),
  182. shape=(v.shape[0], -1, 1)) for v in pred_scores
  183. ]
  184. scores = paddle.concat(scores, axis=1)
  185. deltas = [
  186. paddle.reshape(
  187. paddle.transpose(
  188. v, perm=[0, 2, 3, 1]),
  189. shape=(v.shape[0], -1, 4)) for v in pred_deltas
  190. ]
  191. deltas = paddle.concat(deltas, axis=1)
  192. score_tgt, bbox_tgt, loc_tgt, norm = self.rpn_target_assign(inputs,
  193. anchors)
  194. scores = paddle.reshape(x=scores, shape=(-1, ))
  195. deltas = paddle.reshape(x=deltas, shape=(-1, 4))
  196. score_tgt = paddle.concat(score_tgt)
  197. score_tgt.stop_gradient = True
  198. pos_mask = score_tgt == 1
  199. pos_ind = paddle.nonzero(pos_mask)
  200. valid_mask = score_tgt >= 0
  201. valid_ind = paddle.nonzero(valid_mask)
  202. # cls loss
  203. if valid_ind.shape[0] == 0:
  204. loss_rpn_cls = paddle.zeros([1], dtype='float32')
  205. else:
  206. score_pred = paddle.gather(scores, valid_ind)
  207. score_label = paddle.gather(score_tgt, valid_ind).cast('float32')
  208. score_label.stop_gradient = True
  209. loss_rpn_cls = F.binary_cross_entropy_with_logits(
  210. logit=score_pred, label=score_label, reduction="sum")
  211. # reg loss
  212. if pos_ind.shape[0] == 0:
  213. loss_rpn_reg = paddle.zeros([1], dtype='float32')
  214. else:
  215. loc_pred = paddle.gather(deltas, pos_ind)
  216. loc_tgt = paddle.concat(loc_tgt)
  217. loc_tgt = paddle.gather(loc_tgt, pos_ind)
  218. loc_tgt.stop_gradient = True
  219. loss_rpn_reg = paddle.abs(loc_pred - loc_tgt).sum()
  220. return {
  221. 'loss_rpn_cls': loss_rpn_cls / norm,
  222. 'loss_rpn_reg': loss_rpn_reg / norm
  223. }