repvgg.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. import paddle.nn as nn
  2. import paddle
  3. import numpy as np
  4. __all__ = [
  5. 'RepVGG',
  6. 'RepVGG_A0',
  7. 'RepVGG_A1',
  8. 'RepVGG_A2',
  9. 'RepVGG_B0',
  10. 'RepVGG_B1',
  11. 'RepVGG_B2',
  12. 'RepVGG_B3',
  13. 'RepVGG_B1g2',
  14. 'RepVGG_B1g4',
  15. 'RepVGG_B2g2',
  16. 'RepVGG_B2g4',
  17. 'RepVGG_B3g2',
  18. 'RepVGG_B3g4',
  19. ]
  20. class ConvBN(nn.Layer):
  21. def __init__(self,
  22. in_channels,
  23. out_channels,
  24. kernel_size,
  25. stride,
  26. padding,
  27. groups=1):
  28. super(ConvBN, self).__init__()
  29. self.conv = nn.Conv2D(
  30. in_channels=in_channels,
  31. out_channels=out_channels,
  32. kernel_size=kernel_size,
  33. stride=stride,
  34. padding=padding,
  35. groups=groups,
  36. bias_attr=False)
  37. self.bn = nn.BatchNorm2D(num_features=out_channels)
  38. def forward(self, x):
  39. y = self.conv(x)
  40. y = self.bn(y)
  41. return y
  42. class RepVGGBlock(nn.Layer):
  43. def __init__(self,
  44. in_channels,
  45. out_channels,
  46. kernel_size,
  47. stride=1,
  48. padding=0,
  49. dilation=1,
  50. groups=1,
  51. padding_mode='zeros'):
  52. super(RepVGGBlock, self).__init__()
  53. self.in_channels = in_channels
  54. self.out_channels = out_channels
  55. self.kernel_size = kernel_size
  56. self.stride = stride
  57. self.padding = padding
  58. self.dilation = dilation
  59. self.groups = groups
  60. self.padding_mode = padding_mode
  61. assert kernel_size == 3
  62. assert padding == 1
  63. padding_11 = padding - kernel_size // 2
  64. self.nonlinearity = nn.ReLU()
  65. self.rbr_identity = nn.BatchNorm2D(
  66. num_features=in_channels
  67. ) if out_channels == in_channels and stride == 1 else None
  68. self.rbr_dense = ConvBN(
  69. in_channels=in_channels,
  70. out_channels=out_channels,
  71. kernel_size=kernel_size,
  72. stride=stride,
  73. padding=padding,
  74. groups=groups)
  75. self.rbr_1x1 = ConvBN(
  76. in_channels=in_channels,
  77. out_channels=out_channels,
  78. kernel_size=1,
  79. stride=stride,
  80. padding=padding_11,
  81. groups=groups)
  82. def forward(self, inputs):
  83. if not self.training:
  84. return self.nonlinearity(self.rbr_reparam(inputs))
  85. if self.rbr_identity is None:
  86. id_out = 0
  87. else:
  88. id_out = self.rbr_identity(inputs)
  89. return self.nonlinearity(
  90. self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
  91. def eval(self):
  92. if not hasattr(self, 'rbr_reparam'):
  93. self.rbr_reparam = nn.Conv2D(
  94. in_channels=self.in_channels,
  95. out_channels=self.out_channels,
  96. kernel_size=self.kernel_size,
  97. stride=self.stride,
  98. padding=self.padding,
  99. dilation=self.dilation,
  100. groups=self.groups,
  101. padding_mode=self.padding_mode)
  102. self.training = False
  103. kernel, bias = self.get_equivalent_kernel_bias()
  104. self.rbr_reparam.weight.set_value(kernel)
  105. self.rbr_reparam.bias.set_value(bias)
  106. for layer in self.sublayers():
  107. layer.eval()
  108. def get_equivalent_kernel_bias(self):
  109. kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
  110. kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
  111. kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
  112. return kernel3x3 + self._pad_1x1_to_3x3_tensor(
  113. kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
  114. def _pad_1x1_to_3x3_tensor(self, kernel1x1):
  115. if kernel1x1 is None:
  116. return 0
  117. else:
  118. return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
  119. def _fuse_bn_tensor(self, branch):
  120. if branch is None:
  121. return 0, 0
  122. if isinstance(branch, ConvBN):
  123. kernel = branch.conv.weight
  124. running_mean = branch.bn._mean
  125. running_var = branch.bn._variance
  126. gamma = branch.bn.weight
  127. beta = branch.bn.bias
  128. eps = branch.bn._epsilon
  129. else:
  130. assert isinstance(branch, nn.BatchNorm2D)
  131. if not hasattr(self, 'id_tensor'):
  132. input_dim = self.in_channels // self.groups
  133. kernel_value = np.zeros(
  134. (self.in_channels, input_dim, 3, 3), dtype=np.float32)
  135. for i in range(self.in_channels):
  136. kernel_value[i, i % input_dim, 1, 1] = 1
  137. self.id_tensor = paddle.to_tensor(kernel_value)
  138. kernel = self.id_tensor
  139. running_mean = branch._mean
  140. running_var = branch._variance
  141. gamma = branch.weight
  142. beta = branch.bias
  143. eps = branch._epsilon
  144. std = (running_var + eps).sqrt()
  145. t = (gamma / std).reshape((-1, 1, 1, 1))
  146. return kernel * t, beta - running_mean * gamma / std
  147. class RepVGG(nn.Layer):
  148. def __init__(self,
  149. num_blocks,
  150. width_multiplier=None,
  151. override_groups_map=None,
  152. class_dim=1000):
  153. super(RepVGG, self).__init__()
  154. assert len(width_multiplier) == 4
  155. self.override_groups_map = override_groups_map or dict()
  156. assert 0 not in self.override_groups_map
  157. self.in_planes = min(64, int(64 * width_multiplier[0]))
  158. self.stage0 = RepVGGBlock(
  159. in_channels=3,
  160. out_channels=self.in_planes,
  161. kernel_size=3,
  162. stride=2,
  163. padding=1)
  164. self.cur_layer_idx = 1
  165. self.stage1 = self._make_stage(
  166. int(64 * width_multiplier[0]), num_blocks[0], stride=2)
  167. self.stage2 = self._make_stage(
  168. int(128 * width_multiplier[1]), num_blocks[1], stride=2)
  169. self.stage3 = self._make_stage(
  170. int(256 * width_multiplier[2]), num_blocks[2], stride=2)
  171. self.stage4 = self._make_stage(
  172. int(512 * width_multiplier[3]), num_blocks[3], stride=2)
  173. self.gap = nn.AdaptiveAvgPool2D(output_size=1)
  174. self.linear = nn.Linear(int(512 * width_multiplier[3]), class_dim)
  175. def _make_stage(self, planes, num_blocks, stride):
  176. strides = [stride] + [1] * (num_blocks - 1)
  177. blocks = []
  178. for stride in strides:
  179. cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
  180. blocks.append(
  181. RepVGGBlock(
  182. in_channels=self.in_planes,
  183. out_channels=planes,
  184. kernel_size=3,
  185. stride=stride,
  186. padding=1,
  187. groups=cur_groups))
  188. self.in_planes = planes
  189. self.cur_layer_idx += 1
  190. return nn.Sequential(*blocks)
  191. def eval(self):
  192. self.training = False
  193. for layer in self.sublayers():
  194. layer.training = False
  195. layer.eval()
  196. def forward(self, x):
  197. out = self.stage0(x)
  198. out = self.stage1(out)
  199. out = self.stage2(out)
  200. out = self.stage3(out)
  201. out = self.stage4(out)
  202. out = self.gap(out)
  203. out = paddle.flatten(out, start_axis=1)
  204. out = self.linear(out)
  205. return out
  206. optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
  207. g2_map = {l: 2 for l in optional_groupwise_layers}
  208. g4_map = {l: 4 for l in optional_groupwise_layers}
  209. def RepVGG_A0(**kwargs):
  210. return RepVGG(
  211. num_blocks=[2, 4, 14, 1],
  212. width_multiplier=[0.75, 0.75, 0.75, 2.5],
  213. override_groups_map=None,
  214. **kwargs)
  215. def RepVGG_A1(**kwargs):
  216. return RepVGG(
  217. num_blocks=[2, 4, 14, 1],
  218. width_multiplier=[1, 1, 1, 2.5],
  219. override_groups_map=None,
  220. **kwargs)
  221. def RepVGG_A2(**kwargs):
  222. return RepVGG(
  223. num_blocks=[2, 4, 14, 1],
  224. width_multiplier=[1.5, 1.5, 1.5, 2.75],
  225. override_groups_map=None,
  226. **kwargs)
  227. def RepVGG_B0(**kwargs):
  228. return RepVGG(
  229. num_blocks=[4, 6, 16, 1],
  230. width_multiplier=[1, 1, 1, 2.5],
  231. override_groups_map=None,
  232. **kwargs)
  233. def RepVGG_B1(**kwargs):
  234. return RepVGG(
  235. num_blocks=[4, 6, 16, 1],
  236. width_multiplier=[2, 2, 2, 4],
  237. override_groups_map=None,
  238. **kwargs)
  239. def RepVGG_B1g2(**kwargs):
  240. return RepVGG(
  241. num_blocks=[4, 6, 16, 1],
  242. width_multiplier=[2, 2, 2, 4],
  243. override_groups_map=g2_map,
  244. **kwargs)
  245. def RepVGG_B1g4(**kwargs):
  246. return RepVGG(
  247. num_blocks=[4, 6, 16, 1],
  248. width_multiplier=[2, 2, 2, 4],
  249. override_groups_map=g4_map,
  250. **kwargs)
  251. def RepVGG_B2(**kwargs):
  252. return RepVGG(
  253. num_blocks=[4, 6, 16, 1],
  254. width_multiplier=[2.5, 2.5, 2.5, 5],
  255. override_groups_map=None,
  256. **kwargs)
  257. def RepVGG_B2g2(**kwargs):
  258. return RepVGG(
  259. num_blocks=[4, 6, 16, 1],
  260. width_multiplier=[2.5, 2.5, 2.5, 5],
  261. override_groups_map=g2_map,
  262. **kwargs)
  263. def RepVGG_B2g4(**kwargs):
  264. return RepVGG(
  265. num_blocks=[4, 6, 16, 1],
  266. width_multiplier=[2.5, 2.5, 2.5, 5],
  267. override_groups_map=g4_map,
  268. **kwargs)
  269. def RepVGG_B3(**kwargs):
  270. return RepVGG(
  271. num_blocks=[4, 6, 16, 1],
  272. width_multiplier=[3, 3, 3, 5],
  273. override_groups_map=None,
  274. **kwargs)
  275. def RepVGG_B3g2(**kwargs):
  276. return RepVGG(
  277. num_blocks=[4, 6, 16, 1],
  278. width_multiplier=[3, 3, 3, 5],
  279. override_groups_map=g2_map,
  280. **kwargs)
  281. def RepVGG_B3g4(**kwargs):
  282. return RepVGG(
  283. num_blocks=[4, 6, 16, 1],
  284. width_multiplier=[3, 3, 3, 5],
  285. override_groups_map=g4_map,
  286. **kwargs)