rec_lcnetv3.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import, division, print_function
  15. import torch
  16. import torch.nn.functional as F
  17. from torch import nn
  18. from ..common import Activation
  19. NET_CONFIG_det = {
  20. "blocks2":
  21. # k, in_c, out_c, s, use_se
  22. [[3, 16, 32, 1, False]],
  23. "blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
  24. "blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
  25. "blocks5": [
  26. [3, 128, 256, 2, False],
  27. [5, 256, 256, 1, False],
  28. [5, 256, 256, 1, False],
  29. [5, 256, 256, 1, False],
  30. [5, 256, 256, 1, False],
  31. ],
  32. "blocks6": [
  33. [5, 256, 512, 2, True],
  34. [5, 512, 512, 1, True],
  35. [5, 512, 512, 1, False],
  36. [5, 512, 512, 1, False],
  37. ],
  38. }
  39. NET_CONFIG_rec = {
  40. "blocks2":
  41. # k, in_c, out_c, s, use_se
  42. [[3, 16, 32, 1, False]],
  43. "blocks3": [[3, 32, 64, 1, False], [3, 64, 64, 1, False]],
  44. "blocks4": [[3, 64, 128, (2, 1), False], [3, 128, 128, 1, False]],
  45. "blocks5": [
  46. [3, 128, 256, (1, 2), False],
  47. [5, 256, 256, 1, False],
  48. [5, 256, 256, 1, False],
  49. [5, 256, 256, 1, False],
  50. [5, 256, 256, 1, False],
  51. ],
  52. "blocks6": [
  53. [5, 256, 512, (2, 1), True],
  54. [5, 512, 512, 1, True],
  55. [5, 512, 512, (2, 1), False],
  56. [5, 512, 512, 1, False],
  57. ],
  58. }
  59. def make_divisible(v, divisor=16, min_value=None):
  60. if min_value is None:
  61. min_value = divisor
  62. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  63. if new_v < 0.9 * v:
  64. new_v += divisor
  65. return new_v
  66. class LearnableAffineBlock(nn.Module):
  67. def __init__(self, scale_value=1.0, bias_value=0.0, lr_mult=1.0, lab_lr=0.1):
  68. super().__init__()
  69. self.scale = nn.Parameter(torch.Tensor([scale_value]))
  70. self.bias = nn.Parameter(torch.Tensor([bias_value]))
  71. def forward(self, x):
  72. return self.scale * x + self.bias
  73. class ConvBNLayer(nn.Module):
  74. def __init__(
  75. self, in_channels, out_channels, kernel_size, stride, groups=1, lr_mult=1.0
  76. ):
  77. super().__init__()
  78. self.conv = nn.Conv2d(
  79. in_channels=in_channels,
  80. out_channels=out_channels,
  81. kernel_size=kernel_size,
  82. stride=stride,
  83. padding=(kernel_size - 1) // 2,
  84. groups=groups,
  85. bias=False,
  86. )
  87. self.bn = nn.BatchNorm2d(
  88. out_channels,
  89. )
  90. def forward(self, x):
  91. x = self.conv(x)
  92. x = self.bn(x)
  93. return x
  94. class Act(nn.Module):
  95. def __init__(self, act="hswish", lr_mult=1.0, lab_lr=0.1):
  96. super().__init__()
  97. if act == "hswish":
  98. self.act = nn.Hardswish(inplace=True)
  99. else:
  100. assert act == "relu"
  101. self.act = Activation(act)
  102. self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
  103. def forward(self, x):
  104. return self.lab(self.act(x))
  105. class LearnableRepLayer(nn.Module):
  106. def __init__(
  107. self,
  108. in_channels,
  109. out_channels,
  110. kernel_size,
  111. stride=1,
  112. groups=1,
  113. num_conv_branches=1,
  114. lr_mult=1.0,
  115. lab_lr=0.1,
  116. ):
  117. super().__init__()
  118. self.is_repped = False
  119. self.groups = groups
  120. self.stride = stride
  121. self.kernel_size = kernel_size
  122. self.in_channels = in_channels
  123. self.out_channels = out_channels
  124. self.num_conv_branches = num_conv_branches
  125. self.padding = (kernel_size - 1) // 2
  126. self.identity = (
  127. nn.BatchNorm2d(
  128. num_features=in_channels,
  129. )
  130. if out_channels == in_channels and stride == 1
  131. else None
  132. )
  133. self.conv_kxk = nn.ModuleList(
  134. [
  135. ConvBNLayer(
  136. in_channels,
  137. out_channels,
  138. kernel_size,
  139. stride,
  140. groups=groups,
  141. lr_mult=lr_mult,
  142. )
  143. for _ in range(self.num_conv_branches)
  144. ]
  145. )
  146. self.conv_1x1 = (
  147. ConvBNLayer(
  148. in_channels, out_channels, 1, stride, groups=groups, lr_mult=lr_mult
  149. )
  150. if kernel_size > 1
  151. else None
  152. )
  153. self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
  154. self.act = Act(lr_mult=lr_mult, lab_lr=lab_lr)
  155. def forward(self, x):
  156. # for export
  157. if self.is_repped:
  158. out = self.lab(self.reparam_conv(x))
  159. if self.stride != 2:
  160. out = self.act(out)
  161. return out
  162. out = 0
  163. if self.identity is not None:
  164. out += self.identity(x)
  165. if self.conv_1x1 is not None:
  166. out += self.conv_1x1(x)
  167. for conv in self.conv_kxk:
  168. out += conv(x)
  169. out = self.lab(out)
  170. if self.stride != 2:
  171. out = self.act(out)
  172. return out
  173. def rep(self):
  174. if self.is_repped:
  175. return
  176. kernel, bias = self._get_kernel_bias()
  177. self.reparam_conv = nn.Conv2d(
  178. in_channels=self.in_channels,
  179. out_channels=self.out_channels,
  180. kernel_size=self.kernel_size,
  181. stride=self.stride,
  182. padding=self.padding,
  183. groups=self.groups,
  184. )
  185. self.reparam_conv.weight.data = kernel
  186. self.reparam_conv.bias.data = bias
  187. self.is_repped = True
  188. def _pad_kernel_1x1_to_kxk(self, kernel1x1, pad):
  189. if not isinstance(kernel1x1, torch.Tensor):
  190. return 0
  191. else:
  192. return nn.functional.pad(kernel1x1, [pad, pad, pad, pad])
  193. def _get_kernel_bias(self):
  194. kernel_conv_1x1, bias_conv_1x1 = self._fuse_bn_tensor(self.conv_1x1)
  195. kernel_conv_1x1 = self._pad_kernel_1x1_to_kxk(
  196. kernel_conv_1x1, self.kernel_size // 2
  197. )
  198. kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
  199. kernel_conv_kxk = 0
  200. bias_conv_kxk = 0
  201. for conv in self.conv_kxk:
  202. kernel, bias = self._fuse_bn_tensor(conv)
  203. kernel_conv_kxk += kernel
  204. bias_conv_kxk += bias
  205. kernel_reparam = kernel_conv_kxk + kernel_conv_1x1 + kernel_identity
  206. bias_reparam = bias_conv_kxk + bias_conv_1x1 + bias_identity
  207. return kernel_reparam, bias_reparam
  208. def _fuse_bn_tensor(self, branch):
  209. if not branch:
  210. return 0, 0
  211. elif isinstance(branch, ConvBNLayer):
  212. kernel = branch.conv.weight
  213. running_mean = branch.bn.running_mean
  214. running_var = branch.bn.running_var
  215. gamma = branch.bn.weight
  216. beta = branch.bn.bias
  217. eps = branch.bn.eps
  218. else:
  219. assert isinstance(branch, nn.BatchNorm2d)
  220. if not hasattr(self, "id_tensor"):
  221. input_dim = self.in_channels // self.groups
  222. kernel_value = torch.zeros(
  223. (self.in_channels, input_dim, self.kernel_size, self.kernel_size),
  224. dtype=branch.weight.dtype, device= branch.weight.device,
  225. )
  226. for i in range(self.in_channels):
  227. kernel_value[
  228. i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2
  229. ] = 1
  230. self.id_tensor = kernel_value
  231. kernel = self.id_tensor
  232. running_mean = branch.running_mean
  233. running_var = branch.running_var
  234. gamma = branch.weight
  235. beta = branch.bias
  236. eps = branch.eps
  237. std = (running_var + eps).sqrt()
  238. t = (gamma / std).reshape((-1, 1, 1, 1))
  239. return kernel * t, beta - running_mean * gamma / std
  240. class SELayer(nn.Module):
  241. def __init__(self, channel, reduction=4, lr_mult=1.0):
  242. super().__init__()
  243. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  244. self.conv1 = nn.Conv2d(
  245. in_channels=channel,
  246. out_channels=channel // reduction,
  247. kernel_size=1,
  248. stride=1,
  249. padding=0,
  250. )
  251. self.relu = nn.ReLU()
  252. self.conv2 = nn.Conv2d(
  253. in_channels=channel // reduction,
  254. out_channels=channel,
  255. kernel_size=1,
  256. stride=1,
  257. padding=0,
  258. )
  259. self.hardsigmoid = nn.Hardsigmoid(inplace=True)
  260. def forward(self, x):
  261. identity = x
  262. x = self.avg_pool(x)
  263. x = self.conv1(x)
  264. x = self.relu(x)
  265. x = self.conv2(x)
  266. x = self.hardsigmoid(x)
  267. x = identity * x
  268. return x
  269. class LCNetV3Block(nn.Module):
  270. def __init__(
  271. self,
  272. in_channels,
  273. out_channels,
  274. stride,
  275. dw_size,
  276. use_se=False,
  277. conv_kxk_num=4,
  278. lr_mult=1.0,
  279. lab_lr=0.1,
  280. ):
  281. super().__init__()
  282. self.use_se = use_se
  283. self.dw_conv = LearnableRepLayer(
  284. in_channels=in_channels,
  285. out_channels=in_channels,
  286. kernel_size=dw_size,
  287. stride=stride,
  288. groups=in_channels,
  289. num_conv_branches=conv_kxk_num,
  290. lr_mult=lr_mult,
  291. lab_lr=lab_lr,
  292. )
  293. if use_se:
  294. self.se = SELayer(in_channels, lr_mult=lr_mult)
  295. self.pw_conv = LearnableRepLayer(
  296. in_channels=in_channels,
  297. out_channels=out_channels,
  298. kernel_size=1,
  299. stride=1,
  300. num_conv_branches=conv_kxk_num,
  301. lr_mult=lr_mult,
  302. lab_lr=lab_lr,
  303. )
  304. def forward(self, x):
  305. x = self.dw_conv(x)
  306. if self.use_se:
  307. x = self.se(x)
  308. x = self.pw_conv(x)
  309. return x
  310. class PPLCNetV3(nn.Module):
  311. def __init__(
  312. self,
  313. scale=1.0,
  314. conv_kxk_num=4,
  315. lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
  316. lab_lr=0.1,
  317. det=False,
  318. **kwargs
  319. ):
  320. super().__init__()
  321. self.scale = scale
  322. self.lr_mult_list = lr_mult_list
  323. self.det = det
  324. self.net_config = NET_CONFIG_det if self.det else NET_CONFIG_rec
  325. assert isinstance(
  326. self.lr_mult_list, (list, tuple)
  327. ), "lr_mult_list should be in (list, tuple) but got {}".format(
  328. type(self.lr_mult_list)
  329. )
  330. assert (
  331. len(self.lr_mult_list) == 6
  332. ), "lr_mult_list length should be 6 but got {}".format(len(self.lr_mult_list))
  333. self.conv1 = ConvBNLayer(
  334. in_channels=3,
  335. out_channels=make_divisible(16 * scale),
  336. kernel_size=3,
  337. stride=2,
  338. lr_mult=self.lr_mult_list[0],
  339. )
  340. self.blocks2 = nn.Sequential(
  341. *[
  342. LCNetV3Block(
  343. in_channels=make_divisible(in_c * scale),
  344. out_channels=make_divisible(out_c * scale),
  345. dw_size=k,
  346. stride=s,
  347. use_se=se,
  348. conv_kxk_num=conv_kxk_num,
  349. lr_mult=self.lr_mult_list[1],
  350. lab_lr=lab_lr,
  351. )
  352. for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks2"])
  353. ]
  354. )
  355. self.blocks3 = nn.Sequential(
  356. *[
  357. LCNetV3Block(
  358. in_channels=make_divisible(in_c * scale),
  359. out_channels=make_divisible(out_c * scale),
  360. dw_size=k,
  361. stride=s,
  362. use_se=se,
  363. conv_kxk_num=conv_kxk_num,
  364. lr_mult=self.lr_mult_list[2],
  365. lab_lr=lab_lr,
  366. )
  367. for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks3"])
  368. ]
  369. )
  370. self.blocks4 = nn.Sequential(
  371. *[
  372. LCNetV3Block(
  373. in_channels=make_divisible(in_c * scale),
  374. out_channels=make_divisible(out_c * scale),
  375. dw_size=k,
  376. stride=s,
  377. use_se=se,
  378. conv_kxk_num=conv_kxk_num,
  379. lr_mult=self.lr_mult_list[3],
  380. lab_lr=lab_lr,
  381. )
  382. for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks4"])
  383. ]
  384. )
  385. self.blocks5 = nn.Sequential(
  386. *[
  387. LCNetV3Block(
  388. in_channels=make_divisible(in_c * scale),
  389. out_channels=make_divisible(out_c * scale),
  390. dw_size=k,
  391. stride=s,
  392. use_se=se,
  393. conv_kxk_num=conv_kxk_num,
  394. lr_mult=self.lr_mult_list[4],
  395. lab_lr=lab_lr,
  396. )
  397. for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks5"])
  398. ]
  399. )
  400. self.blocks6 = nn.Sequential(
  401. *[
  402. LCNetV3Block(
  403. in_channels=make_divisible(in_c * scale),
  404. out_channels=make_divisible(out_c * scale),
  405. dw_size=k,
  406. stride=s,
  407. use_se=se,
  408. conv_kxk_num=conv_kxk_num,
  409. lr_mult=self.lr_mult_list[5],
  410. lab_lr=lab_lr,
  411. )
  412. for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks6"])
  413. ]
  414. )
  415. self.out_channels = make_divisible(512 * scale)
  416. if self.det:
  417. mv_c = [16, 24, 56, 480]
  418. self.out_channels = [
  419. make_divisible(self.net_config["blocks3"][-1][2] * scale),
  420. make_divisible(self.net_config["blocks4"][-1][2] * scale),
  421. make_divisible(self.net_config["blocks5"][-1][2] * scale),
  422. make_divisible(self.net_config["blocks6"][-1][2] * scale),
  423. ]
  424. self.layer_list = nn.ModuleList(
  425. [
  426. nn.Conv2d(self.out_channels[0], int(mv_c[0] * scale), 1, 1, 0),
  427. nn.Conv2d(self.out_channels[1], int(mv_c[1] * scale), 1, 1, 0),
  428. nn.Conv2d(self.out_channels[2], int(mv_c[2] * scale), 1, 1, 0),
  429. nn.Conv2d(self.out_channels[3], int(mv_c[3] * scale), 1, 1, 0),
  430. ]
  431. )
  432. self.out_channels = [
  433. int(mv_c[0] * scale),
  434. int(mv_c[1] * scale),
  435. int(mv_c[2] * scale),
  436. int(mv_c[3] * scale),
  437. ]
  438. def forward(self, x):
  439. out_list = []
  440. x = self.conv1(x)
  441. x = self.blocks2(x)
  442. x = self.blocks3(x)
  443. out_list.append(x)
  444. x = self.blocks4(x)
  445. out_list.append(x)
  446. x = self.blocks5(x)
  447. out_list.append(x)
  448. x = self.blocks6(x)
  449. out_list.append(x)
  450. if self.det:
  451. out_list[0] = self.layer_list[0](out_list[0])
  452. out_list[1] = self.layer_list[1](out_list[1])
  453. out_list[2] = self.layer_list[2](out_list[2])
  454. out_list[3] = self.layer_list[3](out_list[3])
  455. return out_list
  456. if self.training:
  457. x = F.adaptive_avg_pool2d(x, [1, 40])
  458. else:
  459. x = F.avg_pool2d(x, [3, 2])
  460. return x