mask_head.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from paddle import fluid
  18. from paddle.fluid.param_attr import ParamAttr
  19. from paddle.fluid.initializer import MSRA
  20. from paddle.fluid.regularizer import L2Decay
  21. from .fpn import ConvNorm
  22. __all__ = ['MaskHead']
  23. class MaskHead(object):
  24. """
  25. RCNN mask head
  26. Args:
  27. num_convs (int): num of convolutions, 4 for FPN, 1 otherwise
  28. conv_dim (int): num of channels after first convolution
  29. resolution (int): size of the output mask
  30. dilation (int): dilation rate
  31. num_classes (int): number of output classes
  32. """
  33. def __init__(self,
  34. num_convs=0,
  35. conv_dim=256,
  36. resolution=14,
  37. dilation=1,
  38. num_classes=81,
  39. norm_type=None):
  40. super(MaskHead, self).__init__()
  41. self.num_convs = num_convs
  42. self.conv_dim = conv_dim
  43. self.resolution = resolution
  44. self.dilation = dilation
  45. self.num_classes = num_classes
  46. self.norm_type = norm_type
  47. def _mask_conv_head(self, roi_feat, num_convs, norm_type):
  48. if norm_type == 'gn':
  49. for i in range(num_convs):
  50. layer_name = "mask_inter_feat_" + str(i + 1)
  51. fan = self.conv_dim * 3 * 3
  52. initializer = MSRA(uniform=False, fan_in=fan)
  53. roi_feat = ConvNorm(
  54. roi_feat,
  55. self.conv_dim,
  56. 3,
  57. act='relu',
  58. dilation=self.dilation,
  59. initializer=initializer,
  60. norm_type=self.norm_type,
  61. name=layer_name,
  62. norm_name=layer_name)
  63. else:
  64. for i in range(num_convs):
  65. layer_name = "mask_inter_feat_" + str(i + 1)
  66. fan = self.conv_dim * 3 * 3
  67. initializer = MSRA(uniform=False, fan_in=fan)
  68. roi_feat = fluid.layers.conv2d(
  69. input=roi_feat,
  70. num_filters=self.conv_dim,
  71. filter_size=3,
  72. padding=1 * self.dilation,
  73. act='relu',
  74. stride=1,
  75. dilation=self.dilation,
  76. name=layer_name,
  77. param_attr=ParamAttr(
  78. name=layer_name + '_w', initializer=initializer),
  79. bias_attr=ParamAttr(
  80. name=layer_name + '_b',
  81. learning_rate=2.,
  82. regularizer=L2Decay(0.)))
  83. fan = roi_feat.shape[1] * 2 * 2
  84. feat = fluid.layers.conv2d_transpose(
  85. input=roi_feat,
  86. num_filters=self.conv_dim,
  87. filter_size=2,
  88. stride=2,
  89. act='relu',
  90. param_attr=ParamAttr(
  91. name='conv5_mask_w',
  92. initializer=MSRA(uniform=False, fan_in=fan)),
  93. bias_attr=ParamAttr(
  94. name='conv5_mask_b', learning_rate=2.,
  95. regularizer=L2Decay(0.)))
  96. return feat
  97. def _get_output(self, roi_feat):
  98. class_num = self.num_classes
  99. # configure the conv number for FPN if necessary
  100. head_feat = self._mask_conv_head(roi_feat, self.num_convs,
  101. self.norm_type)
  102. fan = class_num
  103. mask_logits = fluid.layers.conv2d(
  104. input=head_feat,
  105. num_filters=class_num,
  106. filter_size=1,
  107. act=None,
  108. param_attr=ParamAttr(
  109. name='mask_fcn_logits_w',
  110. initializer=MSRA(uniform=False, fan_in=fan)),
  111. bias_attr=ParamAttr(
  112. name="mask_fcn_logits_b",
  113. learning_rate=2.,
  114. regularizer=L2Decay(0.)))
  115. return mask_logits
  116. def get_loss(self, roi_feat, mask_int32):
  117. mask_logits = self._get_output(roi_feat)
  118. num_classes = self.num_classes
  119. resolution = self.resolution
  120. dim = num_classes * resolution * resolution
  121. mask_logits = fluid.layers.reshape(mask_logits, (-1, dim))
  122. mask_label = fluid.layers.cast(x=mask_int32, dtype='float32')
  123. mask_label.stop_gradient = True
  124. loss_mask = fluid.layers.sigmoid_cross_entropy_with_logits(
  125. x=mask_logits, label=mask_label, ignore_index=-1, normalize=True)
  126. loss_mask = fluid.layers.reduce_sum(loss_mask, name='loss_mask')
  127. return {'loss_mask': loss_mask}
  128. def get_prediction(self, roi_feat, bbox_pred):
  129. """
  130. Get prediction mask in test stage.
  131. Args:
  132. roi_feat (Variable): RoI feature from RoIExtractor.
  133. bbox_pred (Variable): predicted bbox.
  134. Returns:
  135. mask_pred (Variable): Prediction mask with shape
  136. [N, num_classes, resolution, resolution].
  137. """
  138. mask_logits = self._get_output(roi_feat)
  139. mask_prob = fluid.layers.sigmoid(mask_logits)
  140. mask_prob = fluid.layers.lod_reset(mask_prob, bbox_pred)
  141. return mask_prob