regnet.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. from paddle import ParamAttr
  20. import paddle.nn as nn
  21. import paddle.nn.functional as F
  22. from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
  23. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  24. from paddle.nn.initializer import Uniform
  25. import math
  26. __all__ = [
  27. "RegNetX_200MF", "RegNetX_4GF", "RegNetX_32GF", "RegNetY_200MF",
  28. "RegNetY_4GF", "RegNetY_32GF"
  29. ]
  30. def quantize_float(f, q):
  31. """Converts a float to closest non-zero int divisible by q."""
  32. return int(round(f / q) * q)
  33. def adjust_ws_gs_comp(ws, bms, gs):
  34. """Adjusts the compatibility of widths and groups."""
  35. ws_bot = [int(w * b) for w, b in zip(ws, bms)]
  36. gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)]
  37. ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)]
  38. ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)]
  39. return ws, gs
  40. def get_stages_from_blocks(ws, rs):
  41. """Gets ws/ds of network at each stage from per block values."""
  42. ts = [
  43. w != wp or r != rp
  44. for w, wp, r, rp in zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
  45. ]
  46. s_ws = [w for w, t in zip(ws, ts[:-1]) if t]
  47. s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist()
  48. return s_ws, s_ds
  49. def generate_regnet(w_a, w_0, w_m, d, q=8):
  50. """Generates per block ws from RegNet parameters."""
  51. assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
  52. ws_cont = np.arange(d) * w_a + w_0
  53. ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
  54. ws = w_0 * np.power(w_m, ks)
  55. ws = np.round(np.divide(ws, q)) * q
  56. num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
  57. ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
  58. return ws, num_stages, max_stage, ws_cont
  59. class ConvBNLayer(nn.Layer):
  60. def __init__(self,
  61. num_channels,
  62. num_filters,
  63. filter_size,
  64. stride=1,
  65. groups=1,
  66. padding=0,
  67. act=None,
  68. name=None):
  69. super(ConvBNLayer, self).__init__()
  70. self._conv = Conv2D(
  71. in_channels=num_channels,
  72. out_channels=num_filters,
  73. kernel_size=filter_size,
  74. stride=stride,
  75. padding=padding,
  76. groups=groups,
  77. weight_attr=ParamAttr(name=name + ".conv2d.output.1.w_0"),
  78. bias_attr=ParamAttr(name=name + ".conv2d.output.1.b_0"))
  79. bn_name = name + "_bn"
  80. self._batch_norm = BatchNorm(
  81. num_filters,
  82. act=act,
  83. param_attr=ParamAttr(name=bn_name + ".output.1.w_0"),
  84. bias_attr=ParamAttr(bn_name + ".output.1.b_0"),
  85. moving_mean_name=bn_name + "_mean",
  86. moving_variance_name=bn_name + "_variance")
  87. def forward(self, inputs):
  88. y = self._conv(inputs)
  89. y = self._batch_norm(y)
  90. return y
  91. class BottleneckBlock(nn.Layer):
  92. def __init__(self,
  93. num_channels,
  94. num_filters,
  95. stride,
  96. bm,
  97. gw,
  98. se_on,
  99. se_r,
  100. shortcut=True,
  101. name=None):
  102. super(BottleneckBlock, self).__init__()
  103. # Compute the bottleneck width
  104. w_b = int(round(num_filters * bm))
  105. # Compute the number of groups
  106. num_gs = w_b // gw
  107. self.se_on = se_on
  108. self.conv0 = ConvBNLayer(
  109. num_channels=num_channels,
  110. num_filters=w_b,
  111. filter_size=1,
  112. padding=0,
  113. act="relu",
  114. name=name + "_branch2a")
  115. self.conv1 = ConvBNLayer(
  116. num_channels=w_b,
  117. num_filters=w_b,
  118. filter_size=3,
  119. stride=stride,
  120. padding=1,
  121. groups=num_gs,
  122. act="relu",
  123. name=name + "_branch2b")
  124. if se_on:
  125. w_se = int(round(num_channels * se_r))
  126. self.se_block = SELayer(
  127. num_channels=w_b,
  128. num_filters=w_b,
  129. reduction_ratio=w_se,
  130. name=name + "_branch2se")
  131. self.conv2 = ConvBNLayer(
  132. num_channels=w_b,
  133. num_filters=num_filters,
  134. filter_size=1,
  135. act=None,
  136. name=name + "_branch2c")
  137. if not shortcut:
  138. self.short = ConvBNLayer(
  139. num_channels=num_channels,
  140. num_filters=num_filters,
  141. filter_size=1,
  142. stride=stride,
  143. name=name + "_branch1")
  144. self.shortcut = shortcut
  145. def forward(self, inputs):
  146. y = self.conv0(inputs)
  147. conv1 = self.conv1(y)
  148. if self.se_on:
  149. conv1 = self.se_block(conv1)
  150. conv2 = self.conv2(conv1)
  151. if self.shortcut:
  152. short = inputs
  153. else:
  154. short = self.short(inputs)
  155. y = paddle.add(x=short, y=conv2)
  156. y = F.relu(y)
  157. return y
  158. class SELayer(nn.Layer):
  159. def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
  160. super(SELayer, self).__init__()
  161. self.pool2d_gap = AdaptiveAvgPool2D(1)
  162. self._num_channels = num_channels
  163. med_ch = int(num_channels / reduction_ratio)
  164. stdv = 1.0 / math.sqrt(num_channels * 1.0)
  165. self.squeeze = Linear(
  166. num_channels,
  167. med_ch,
  168. weight_attr=ParamAttr(
  169. initializer=Uniform(-stdv, stdv), name=name + "_sqz_weights"),
  170. bias_attr=ParamAttr(name=name + "_sqz_offset"))
  171. stdv = 1.0 / math.sqrt(med_ch * 1.0)
  172. self.excitation = Linear(
  173. med_ch,
  174. num_filters,
  175. weight_attr=ParamAttr(
  176. initializer=Uniform(-stdv, stdv), name=name + "_exc_weights"),
  177. bias_attr=ParamAttr(name=name + "_exc_offset"))
  178. def forward(self, input):
  179. pool = self.pool2d_gap(input)
  180. pool = paddle.reshape(pool, shape=[-1, self._num_channels])
  181. squeeze = self.squeeze(pool)
  182. squeeze = F.relu(squeeze)
  183. excitation = self.excitation(squeeze)
  184. excitation = F.sigmoid(excitation)
  185. excitation = paddle.reshape(
  186. excitation, shape=[-1, self._num_channels, 1, 1])
  187. out = input * excitation
  188. return out
  189. class RegNet(nn.Layer):
  190. def __init__(self,
  191. w_a,
  192. w_0,
  193. w_m,
  194. d,
  195. group_w,
  196. bot_mul,
  197. q=8,
  198. se_on=False,
  199. class_dim=1000):
  200. super(RegNet, self).__init__()
  201. # Generate RegNet ws per block
  202. b_ws, num_s, max_s, ws_cont = generate_regnet(w_a, w_0, w_m, d, q)
  203. # Convert to per stage format
  204. ws, ds = get_stages_from_blocks(b_ws, b_ws)
  205. # Generate group widths and bot muls
  206. gws = [group_w for _ in range(num_s)]
  207. bms = [bot_mul for _ in range(num_s)]
  208. # Adjust the compatibility of ws and gws
  209. ws, gws = adjust_ws_gs_comp(ws, bms, gws)
  210. # Use the same stride for each stage
  211. ss = [2 for _ in range(num_s)]
  212. # Use SE for RegNetY
  213. se_r = 0.25
  214. # Construct the model
  215. # Group params by stage
  216. stage_params = list(zip(ds, ws, ss, bms, gws))
  217. # Construct the stem
  218. stem_type = "simple_stem_in"
  219. stem_w = 32
  220. block_type = "res_bottleneck_block"
  221. self.conv = ConvBNLayer(
  222. num_channels=3,
  223. num_filters=stem_w,
  224. filter_size=3,
  225. stride=2,
  226. padding=1,
  227. act="relu",
  228. name="stem_conv")
  229. self.block_list = []
  230. for block, (d, w_out, stride, bm, gw) in enumerate(stage_params):
  231. shortcut = False
  232. for i in range(d):
  233. num_channels = stem_w if block == i == 0 else in_channels
  234. # Stride apply to the first block of the stage
  235. b_stride = stride if i == 0 else 1
  236. conv_name = "s" + str(block + 1) + "_b" + str(i +
  237. 1) # chr(97 + i)
  238. bottleneck_block = self.add_sublayer(
  239. conv_name,
  240. BottleneckBlock(
  241. num_channels=num_channels,
  242. num_filters=w_out,
  243. stride=b_stride,
  244. bm=bm,
  245. gw=gw,
  246. se_on=se_on,
  247. se_r=se_r,
  248. shortcut=shortcut,
  249. name=conv_name))
  250. in_channels = w_out
  251. self.block_list.append(bottleneck_block)
  252. shortcut = True
  253. self.pool2d_avg = AdaptiveAvgPool2D(1)
  254. self.pool2d_avg_channels = w_out
  255. stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
  256. self.out = Linear(
  257. self.pool2d_avg_channels,
  258. class_dim,
  259. weight_attr=ParamAttr(
  260. initializer=Uniform(-stdv, stdv), name="fc_0.w_0"),
  261. bias_attr=ParamAttr(name="fc_0.b_0"))
  262. def forward(self, inputs):
  263. y = self.conv(inputs)
  264. for block in self.block_list:
  265. y = block(y)
  266. y = self.pool2d_avg(y)
  267. y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
  268. y = self.out(y)
  269. return y
  270. def RegNetX_200MF(**args):
  271. model = RegNet(
  272. w_a=36.44, w_0=24, w_m=2.49, d=13, group_w=8, bot_mul=1.0, q=8, **args)
  273. return model
  274. def RegNetX_4GF(**args):
  275. model = RegNet(
  276. w_a=38.65,
  277. w_0=96,
  278. w_m=2.43,
  279. d=23,
  280. group_w=40,
  281. bot_mul=1.0,
  282. q=8,
  283. **args)
  284. return model
  285. def RegNetX_32GF(**args):
  286. model = RegNet(
  287. w_a=69.86,
  288. w_0=320,
  289. w_m=2.0,
  290. d=23,
  291. group_w=168,
  292. bot_mul=1.0,
  293. q=8,
  294. **args)
  295. return model
  296. def RegNetY_200MF(**args):
  297. model = RegNet(
  298. w_a=36.44,
  299. w_0=24,
  300. w_m=2.49,
  301. d=13,
  302. group_w=8,
  303. bot_mul=1.0,
  304. q=8,
  305. se_on=True,
  306. **args)
  307. return model
  308. def RegNetY_4GF(**args):
  309. model = RegNet(
  310. w_a=31.41,
  311. w_0=96,
  312. w_m=2.24,
  313. d=22,
  314. group_w=64,
  315. bot_mul=1.0,
  316. q=8,
  317. se_on=True,
  318. **args)
  319. return model
  320. def RegNetY_32GF(**args):
  321. model = RegNet(
  322. w_a=115.89,
  323. w_0=232,
  324. w_m=2.53,
  325. d=20,
  326. group_w=232,
  327. bot_mul=1.0,
  328. q=8,
  329. se_on=True,
  330. **args)
  331. return model