resnet_vd.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import ParamAttr
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from paddle.nn import Conv2D, BatchNorm, Linear
  22. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  23. from paddle.nn.initializer import Uniform
  24. import math
  25. __all__ = [
  26. "ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet50_vd_ssld",
  27. "ResNet101_vd", "ResNet101_vd_ssld", "ResNet152_vd", "ResNet200_vd"
  28. ]
  29. class ConvBNLayer(nn.Layer):
  30. def __init__(self,
  31. num_channels,
  32. num_filters,
  33. filter_size,
  34. stride=1,
  35. groups=1,
  36. is_vd_mode=False,
  37. act=None,
  38. lr_mult=1.0,
  39. name=None):
  40. super(ConvBNLayer, self).__init__()
  41. self.is_vd_mode = is_vd_mode
  42. self._pool2d_avg = AvgPool2D(
  43. kernel_size=2, stride=2, padding=0, ceil_mode=True)
  44. self._conv = Conv2D(
  45. in_channels=num_channels,
  46. out_channels=num_filters,
  47. kernel_size=filter_size,
  48. stride=stride,
  49. padding=(filter_size - 1) // 2,
  50. groups=groups,
  51. weight_attr=ParamAttr(
  52. name=name + "_weights", learning_rate=lr_mult),
  53. bias_attr=False)
  54. if name == "conv1":
  55. bn_name = "bn_" + name
  56. else:
  57. bn_name = "bn" + name[3:]
  58. self._batch_norm = BatchNorm(
  59. num_filters,
  60. act=act,
  61. param_attr=ParamAttr(
  62. name=bn_name + '_scale', learning_rate=lr_mult),
  63. bias_attr=ParamAttr(
  64. bn_name + '_offset', learning_rate=lr_mult),
  65. moving_mean_name=bn_name + '_mean',
  66. moving_variance_name=bn_name + '_variance')
  67. def forward(self, inputs):
  68. if self.is_vd_mode:
  69. inputs = self._pool2d_avg(inputs)
  70. y = self._conv(inputs)
  71. y = self._batch_norm(y)
  72. return y
  73. class BottleneckBlock(nn.Layer):
  74. def __init__(self,
  75. num_channels,
  76. num_filters,
  77. stride,
  78. shortcut=True,
  79. if_first=False,
  80. lr_mult=1.0,
  81. name=None):
  82. super(BottleneckBlock, self).__init__()
  83. self.conv0 = ConvBNLayer(
  84. num_channels=num_channels,
  85. num_filters=num_filters,
  86. filter_size=1,
  87. act='relu',
  88. lr_mult=lr_mult,
  89. name=name + "_branch2a")
  90. self.conv1 = ConvBNLayer(
  91. num_channels=num_filters,
  92. num_filters=num_filters,
  93. filter_size=3,
  94. stride=stride,
  95. act='relu',
  96. lr_mult=lr_mult,
  97. name=name + "_branch2b")
  98. self.conv2 = ConvBNLayer(
  99. num_channels=num_filters,
  100. num_filters=num_filters * 4,
  101. filter_size=1,
  102. act=None,
  103. lr_mult=lr_mult,
  104. name=name + "_branch2c")
  105. if not shortcut:
  106. self.short = ConvBNLayer(
  107. num_channels=num_channels,
  108. num_filters=num_filters * 4,
  109. filter_size=1,
  110. stride=1,
  111. is_vd_mode=False if if_first else True,
  112. lr_mult=lr_mult,
  113. name=name + "_branch1")
  114. self.shortcut = shortcut
  115. def forward(self, inputs):
  116. y = self.conv0(inputs)
  117. conv1 = self.conv1(y)
  118. conv2 = self.conv2(conv1)
  119. if self.shortcut:
  120. short = inputs
  121. else:
  122. short = self.short(inputs)
  123. y = paddle.add(x=short, y=conv2)
  124. y = F.relu(y)
  125. return y
  126. class BasicBlock(nn.Layer):
  127. def __init__(self,
  128. num_channels,
  129. num_filters,
  130. stride,
  131. shortcut=True,
  132. if_first=False,
  133. lr_mult=1.0,
  134. name=None):
  135. super(BasicBlock, self).__init__()
  136. self.stride = stride
  137. self.conv0 = ConvBNLayer(
  138. num_channels=num_channels,
  139. num_filters=num_filters,
  140. filter_size=3,
  141. stride=stride,
  142. act='relu',
  143. lr_mult=lr_mult,
  144. name=name + "_branch2a")
  145. self.conv1 = ConvBNLayer(
  146. num_channels=num_filters,
  147. num_filters=num_filters,
  148. filter_size=3,
  149. act=None,
  150. lr_mult=lr_mult,
  151. name=name + "_branch2b")
  152. if not shortcut:
  153. self.short = ConvBNLayer(
  154. num_channels=num_channels,
  155. num_filters=num_filters,
  156. filter_size=1,
  157. stride=1,
  158. is_vd_mode=False if if_first else True,
  159. lr_mult=lr_mult,
  160. name=name + "_branch1")
  161. self.shortcut = shortcut
  162. def forward(self, inputs):
  163. y = self.conv0(inputs)
  164. conv1 = self.conv1(y)
  165. if self.shortcut:
  166. short = inputs
  167. else:
  168. short = self.short(inputs)
  169. y = paddle.add(x=short, y=conv1)
  170. y = F.relu(y)
  171. return y
  172. class ResNet_vd(nn.Layer):
  173. def __init__(self,
  174. layers=50,
  175. class_dim=1000,
  176. lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]):
  177. super(ResNet_vd, self).__init__()
  178. self.layers = layers
  179. supported_layers = [18, 34, 50, 101, 152, 200]
  180. assert layers in supported_layers, \
  181. "supported layers are {} but input layer is {}".format(
  182. supported_layers, layers)
  183. self.lr_mult_list = lr_mult_list
  184. assert isinstance(self.lr_mult_list, (
  185. list, tuple
  186. )), "lr_mult_list should be in (list, tuple) but got {}".format(
  187. type(self.lr_mult_list))
  188. assert len(
  189. self.lr_mult_list
  190. ) == 5, "lr_mult_list length should should be 5 but got {}".format(
  191. len(self.lr_mult_list))
  192. if layers == 18:
  193. depth = [2, 2, 2, 2]
  194. elif layers == 34 or layers == 50:
  195. depth = [3, 4, 6, 3]
  196. elif layers == 101:
  197. depth = [3, 4, 23, 3]
  198. elif layers == 152:
  199. depth = [3, 8, 36, 3]
  200. elif layers == 200:
  201. depth = [3, 12, 48, 3]
  202. num_channels = [64, 256, 512,
  203. 1024] if layers >= 50 else [64, 64, 128, 256]
  204. num_filters = [64, 128, 256, 512]
  205. self.conv1_1 = ConvBNLayer(
  206. num_channels=3,
  207. num_filters=32,
  208. filter_size=3,
  209. stride=2,
  210. act='relu',
  211. lr_mult=self.lr_mult_list[0],
  212. name="conv1_1")
  213. self.conv1_2 = ConvBNLayer(
  214. num_channels=32,
  215. num_filters=32,
  216. filter_size=3,
  217. stride=1,
  218. act='relu',
  219. lr_mult=self.lr_mult_list[0],
  220. name="conv1_2")
  221. self.conv1_3 = ConvBNLayer(
  222. num_channels=32,
  223. num_filters=64,
  224. filter_size=3,
  225. stride=1,
  226. act='relu',
  227. lr_mult=self.lr_mult_list[0],
  228. name="conv1_3")
  229. self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1)
  230. self.block_list = []
  231. if layers >= 50:
  232. for block in range(len(depth)):
  233. shortcut = False
  234. for i in range(depth[block]):
  235. if layers in [101, 152, 200] and block == 2:
  236. if i == 0:
  237. conv_name = "res" + str(block + 2) + "a"
  238. else:
  239. conv_name = "res" + str(block + 2) + "b" + str(i)
  240. else:
  241. conv_name = "res" + str(block + 2) + chr(97 + i)
  242. bottleneck_block = self.add_sublayer(
  243. 'bb_%d_%d' % (block, i),
  244. BottleneckBlock(
  245. num_channels=num_channels[block]
  246. if i == 0 else num_filters[block] * 4,
  247. num_filters=num_filters[block],
  248. stride=2 if i == 0 and block != 0 else 1,
  249. shortcut=shortcut,
  250. if_first=block == i == 0,
  251. lr_mult=self.lr_mult_list[block + 1],
  252. name=conv_name))
  253. self.block_list.append(bottleneck_block)
  254. shortcut = True
  255. else:
  256. for block in range(len(depth)):
  257. shortcut = False
  258. for i in range(depth[block]):
  259. conv_name = "res" + str(block + 2) + chr(97 + i)
  260. basic_block = self.add_sublayer(
  261. 'bb_%d_%d' % (block, i),
  262. BasicBlock(
  263. num_channels=num_channels[block]
  264. if i == 0 else num_filters[block],
  265. num_filters=num_filters[block],
  266. stride=2 if i == 0 and block != 0 else 1,
  267. shortcut=shortcut,
  268. if_first=block == i == 0,
  269. name=conv_name,
  270. lr_mult=self.lr_mult_list[block + 1]))
  271. self.block_list.append(basic_block)
  272. shortcut = True
  273. self.pool2d_avg = AdaptiveAvgPool2D(1)
  274. self.pool2d_avg_channels = num_channels[-1] * 2
  275. stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
  276. self.out = Linear(
  277. self.pool2d_avg_channels,
  278. class_dim,
  279. weight_attr=ParamAttr(
  280. initializer=Uniform(-stdv, stdv), name="fc_0.w_0"),
  281. bias_attr=ParamAttr(name="fc_0.b_0"))
  282. def forward(self, inputs):
  283. y = self.conv1_1(inputs)
  284. y = self.conv1_2(y)
  285. y = self.conv1_3(y)
  286. y = self.pool2d_max(y)
  287. for block in self.block_list:
  288. y = block(y)
  289. y = self.pool2d_avg(y)
  290. y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
  291. y = self.out(y)
  292. return y
  293. def ResNet18_vd(**args):
  294. model = ResNet_vd(layers=18, **args)
  295. return model
  296. def ResNet34_vd(**args):
  297. model = ResNet_vd(layers=34, **args)
  298. return model
  299. def ResNet50_vd(**args):
  300. model = ResNet_vd(layers=50, **args)
  301. return model
  302. def ResNet101_vd(**args):
  303. model = ResNet_vd(layers=101, **args)
  304. return model
  305. def ResNet152_vd(**args):
  306. model = ResNet_vd(layers=152, **args)
  307. return model
  308. def ResNet200_vd(**args):
  309. model = ResNet_vd(layers=200, **args)
  310. return model
  311. def ResNet50_vd_ssld(**args):
  312. model = ResNet_vd(layers=50, lr_mult_list=[.1, .1, .2, .2, .3], **args)
  313. return model
  314. def ResNet101_vd_ssld(**args):
  315. model = ResNet_vd(layers=101, lr_mult_list=[.1, .1, .2, .2, .3], **args)
  316. return model