pyramidal_embedding.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from paddle.nn.initializer import Normal, Constant
  21. from paddle import ParamAttr
  22. from .resnet import *
  23. from paddlex.ppdet.core.workspace import register
  24. __all__ = ['PCBPyramid']
  25. @register
  26. class PCBPyramid(nn.Layer):
  27. """
  28. PCB (Part-based Convolutional Baseline), see https://arxiv.org/abs/1711.09349,
  29. Pyramidal Person Re-IDentification, see https://arxiv.org/abs/1810.12193
  30. Args:
  31. input_ch (int): Number of channels of the input feature.
  32. num_stripes (int): Number of sub-parts.
  33. used_levels (tuple): Whether the level is used, 1 means used.
  34. num_classes (int): Number of classes for identities.
  35. last_conv_stride (int): Stride of the last conv.
  36. last_conv_dilation (int): Dilation of the last conv.
  37. num_conv_out_channels (int): Number of channels of conv feature.
  38. """
  39. def __init__(self,
  40. input_ch=2048,
  41. num_stripes=6,
  42. used_levels=(1, 1, 1, 1, 1, 1),
  43. num_classes=751,
  44. last_conv_stride=1,
  45. last_conv_dilation=1,
  46. num_conv_out_channels=128):
  47. super(PCBPyramid, self).__init__()
  48. self.num_stripes = num_stripes
  49. self.used_levels = used_levels
  50. self.num_classes = num_classes
  51. self.num_in_each_level = [i for i in range(self.num_stripes, 0, -1)]
  52. self.num_branches = sum(self.num_in_each_level)
  53. self.base = ResNet101(
  54. lr_mult=0.1,
  55. last_conv_stride=last_conv_stride,
  56. last_conv_dilation=last_conv_dilation)
  57. self.dropout_layer = nn.Dropout(p=0.2)
  58. self.pyramid_conv_list0, self.pyramid_fc_list0 = self.basic_branch(
  59. num_conv_out_channels, input_ch)
  60. def basic_branch(self, num_conv_out_channels, input_ch):
  61. # the level indexes are defined from fine to coarse,
  62. # the branch will contain one more part than that of its previous level
  63. # the sliding step is set to 1
  64. pyramid_conv_list = nn.LayerList()
  65. pyramid_fc_list = nn.LayerList()
  66. idx_levels = 0
  67. for idx_branches in range(self.num_branches):
  68. if idx_branches >= sum(self.num_in_each_level[0:idx_levels + 1]):
  69. idx_levels += 1
  70. pyramid_conv_list.append(
  71. nn.Sequential(
  72. nn.Conv2D(input_ch, num_conv_out_channels, 1),
  73. nn.BatchNorm2D(num_conv_out_channels), nn.ReLU()))
  74. idx_levels = 0
  75. for idx_branches in range(self.num_branches):
  76. if idx_branches >= sum(self.num_in_each_level[0:idx_levels + 1]):
  77. idx_levels += 1
  78. name = "Linear_branch_id_{}".format(idx_branches)
  79. fc = nn.Linear(
  80. in_features=num_conv_out_channels,
  81. out_features=self.num_classes,
  82. weight_attr=ParamAttr(
  83. name=name + "_weights",
  84. initializer=Normal(
  85. mean=0., std=0.001)),
  86. bias_attr=ParamAttr(
  87. name=name + "_bias", initializer=Constant(value=0.)))
  88. pyramid_fc_list.append(fc)
  89. return pyramid_conv_list, pyramid_fc_list
  90. def pyramid_forward(self, feat):
  91. each_stripe_size = int(feat.shape[2] / self.num_stripes)
  92. feat_list, logits_list = [], []
  93. idx_levels = 0
  94. used_branches = 0
  95. for idx_branches in range(self.num_branches):
  96. if idx_branches >= sum(self.num_in_each_level[0:idx_levels + 1]):
  97. idx_levels += 1
  98. idx_in_each_level = idx_branches - sum(self.num_in_each_level[
  99. 0:idx_levels])
  100. stripe_size_in_each_level = each_stripe_size * (idx_levels + 1)
  101. start = idx_in_each_level * each_stripe_size
  102. end = start + stripe_size_in_each_level
  103. k = feat.shape[-1]
  104. local_feat_avgpool = F.avg_pool2d(
  105. feat[:, :, start:end, :],
  106. kernel_size=(stripe_size_in_each_level, k))
  107. local_feat_maxpool = F.max_pool2d(
  108. feat[:, :, start:end, :],
  109. kernel_size=(stripe_size_in_each_level, k))
  110. local_feat = local_feat_avgpool + local_feat_maxpool
  111. local_feat = self.pyramid_conv_list0[used_branches](local_feat)
  112. local_feat = paddle.reshape(
  113. local_feat, shape=[local_feat.shape[0], -1])
  114. feat_list.append(local_feat)
  115. local_logits = self.pyramid_fc_list0[used_branches](
  116. self.dropout_layer(local_feat))
  117. logits_list.append(local_logits)
  118. used_branches += 1
  119. return feat_list, logits_list
  120. def forward(self, x):
  121. feat = self.base(x)
  122. assert feat.shape[2] % self.num_stripes == 0
  123. feat_list, logits_list = self.pyramid_forward(feat)
  124. feat_out = paddle.concat(feat_list, axis=-1)
  125. return feat_out