db_fpn.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. import torch
  2. import torch.nn.functional as F
  3. from torch import nn
  4. from ..backbones.det_mobilenet_v3 import SEModule
  5. from ..necks.intracl import IntraCLBlock
  6. def hard_swish(x, inplace=True):
  7. return x * F.relu6(x + 3.0, inplace=inplace) / 6.0
  8. class DSConv(nn.Module):
  9. def __init__(
  10. self,
  11. in_channels,
  12. out_channels,
  13. kernel_size,
  14. padding,
  15. stride=1,
  16. groups=None,
  17. if_act=True,
  18. act="relu",
  19. **kwargs
  20. ):
  21. super(DSConv, self).__init__()
  22. if groups == None:
  23. groups = in_channels
  24. self.if_act = if_act
  25. self.act = act
  26. self.conv1 = nn.Conv2d(
  27. in_channels=in_channels,
  28. out_channels=in_channels,
  29. kernel_size=kernel_size,
  30. stride=stride,
  31. padding=padding,
  32. groups=groups,
  33. bias=False,
  34. )
  35. self.bn1 = nn.BatchNorm2d(in_channels)
  36. self.conv2 = nn.Conv2d(
  37. in_channels=in_channels,
  38. out_channels=int(in_channels * 4),
  39. kernel_size=1,
  40. stride=1,
  41. bias=False,
  42. )
  43. self.bn2 = nn.BatchNorm2d(int(in_channels * 4))
  44. self.conv3 = nn.Conv2d(
  45. in_channels=int(in_channels * 4),
  46. out_channels=out_channels,
  47. kernel_size=1,
  48. stride=1,
  49. bias=False,
  50. )
  51. self._c = [in_channels, out_channels]
  52. if in_channels != out_channels:
  53. self.conv_end = nn.Conv2d(
  54. in_channels=in_channels,
  55. out_channels=out_channels,
  56. kernel_size=1,
  57. stride=1,
  58. bias=False,
  59. )
  60. def forward(self, inputs):
  61. x = self.conv1(inputs)
  62. x = self.bn1(x)
  63. x = self.conv2(x)
  64. x = self.bn2(x)
  65. if self.if_act:
  66. if self.act == "relu":
  67. x = F.relu(x)
  68. elif self.act == "hardswish":
  69. x = hard_swish(x)
  70. else:
  71. print(
  72. "The activation function({}) is selected incorrectly.".format(
  73. self.act
  74. )
  75. )
  76. exit()
  77. x = self.conv3(x)
  78. if self._c[0] != self._c[1]:
  79. x = x + self.conv_end(inputs)
  80. return x
  81. class DBFPN(nn.Module):
  82. def __init__(self, in_channels, out_channels, use_asf=False, **kwargs):
  83. super(DBFPN, self).__init__()
  84. self.out_channels = out_channels
  85. self.use_asf = use_asf
  86. self.in2_conv = nn.Conv2d(
  87. in_channels=in_channels[0],
  88. out_channels=self.out_channels,
  89. kernel_size=1,
  90. bias=False,
  91. )
  92. self.in3_conv = nn.Conv2d(
  93. in_channels=in_channels[1],
  94. out_channels=self.out_channels,
  95. kernel_size=1,
  96. bias=False,
  97. )
  98. self.in4_conv = nn.Conv2d(
  99. in_channels=in_channels[2],
  100. out_channels=self.out_channels,
  101. kernel_size=1,
  102. bias=False,
  103. )
  104. self.in5_conv = nn.Conv2d(
  105. in_channels=in_channels[3],
  106. out_channels=self.out_channels,
  107. kernel_size=1,
  108. bias=False,
  109. )
  110. self.p5_conv = nn.Conv2d(
  111. in_channels=self.out_channels,
  112. out_channels=self.out_channels // 4,
  113. kernel_size=3,
  114. padding=1,
  115. bias=False,
  116. )
  117. self.p4_conv = nn.Conv2d(
  118. in_channels=self.out_channels,
  119. out_channels=self.out_channels // 4,
  120. kernel_size=3,
  121. padding=1,
  122. bias=False,
  123. )
  124. self.p3_conv = nn.Conv2d(
  125. in_channels=self.out_channels,
  126. out_channels=self.out_channels // 4,
  127. kernel_size=3,
  128. padding=1,
  129. bias=False,
  130. )
  131. self.p2_conv = nn.Conv2d(
  132. in_channels=self.out_channels,
  133. out_channels=self.out_channels // 4,
  134. kernel_size=3,
  135. padding=1,
  136. bias=False,
  137. )
  138. if self.use_asf is True:
  139. self.asf = ASFBlock(self.out_channels, self.out_channels // 4)
  140. def forward(self, x):
  141. c2, c3, c4, c5 = x
  142. in5 = self.in5_conv(c5)
  143. in4 = self.in4_conv(c4)
  144. in3 = self.in3_conv(c3)
  145. in2 = self.in2_conv(c2)
  146. out4 = in4 + F.interpolate(
  147. in5,
  148. scale_factor=2,
  149. mode="nearest",
  150. ) # align_mode=1) # 1/16
  151. out3 = in3 + F.interpolate(
  152. out4,
  153. scale_factor=2,
  154. mode="nearest",
  155. ) # align_mode=1) # 1/8
  156. out2 = in2 + F.interpolate(
  157. out3,
  158. scale_factor=2,
  159. mode="nearest",
  160. ) # align_mode=1) # 1/4
  161. p5 = self.p5_conv(in5)
  162. p4 = self.p4_conv(out4)
  163. p3 = self.p3_conv(out3)
  164. p2 = self.p2_conv(out2)
  165. p5 = F.interpolate(
  166. p5,
  167. scale_factor=8,
  168. mode="nearest",
  169. ) # align_mode=1)
  170. p4 = F.interpolate(
  171. p4,
  172. scale_factor=4,
  173. mode="nearest",
  174. ) # align_mode=1)
  175. p3 = F.interpolate(
  176. p3,
  177. scale_factor=2,
  178. mode="nearest",
  179. ) # align_mode=1)
  180. fuse = torch.cat([p5, p4, p3, p2], dim=1)
  181. if self.use_asf is True:
  182. fuse = self.asf(fuse, [p5, p4, p3, p2])
  183. return fuse
  184. class RSELayer(nn.Module):
  185. def __init__(self, in_channels, out_channels, kernel_size, shortcut=True):
  186. super(RSELayer, self).__init__()
  187. self.out_channels = out_channels
  188. self.in_conv = nn.Conv2d(
  189. in_channels=in_channels,
  190. out_channels=self.out_channels,
  191. kernel_size=kernel_size,
  192. padding=int(kernel_size // 2),
  193. bias=False,
  194. )
  195. self.se_block = SEModule(self.out_channels)
  196. self.shortcut = shortcut
  197. def forward(self, ins):
  198. x = self.in_conv(ins)
  199. if self.shortcut:
  200. out = x + self.se_block(x)
  201. else:
  202. out = self.se_block(x)
  203. return out
  204. class RSEFPN(nn.Module):
  205. def __init__(self, in_channels, out_channels, shortcut=True, **kwargs):
  206. super(RSEFPN, self).__init__()
  207. self.out_channels = out_channels
  208. self.ins_conv = nn.ModuleList()
  209. self.inp_conv = nn.ModuleList()
  210. self.intracl = False
  211. if "intracl" in kwargs.keys() and kwargs["intracl"] is True:
  212. self.intracl = kwargs["intracl"]
  213. self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
  214. self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
  215. self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
  216. self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
  217. for i in range(len(in_channels)):
  218. self.ins_conv.append(
  219. RSELayer(in_channels[i], out_channels, kernel_size=1, shortcut=shortcut)
  220. )
  221. self.inp_conv.append(
  222. RSELayer(
  223. out_channels, out_channels // 4, kernel_size=3, shortcut=shortcut
  224. )
  225. )
  226. def forward(self, x):
  227. c2, c3, c4, c5 = x
  228. in5 = self.ins_conv[3](c5)
  229. in4 = self.ins_conv[2](c4)
  230. in3 = self.ins_conv[1](c3)
  231. in2 = self.ins_conv[0](c2)
  232. out4 = in4 + F.interpolate(in5, scale_factor=2, mode="nearest") # 1/16
  233. out3 = in3 + F.interpolate(out4, scale_factor=2, mode="nearest") # 1/8
  234. out2 = in2 + F.interpolate(out3, scale_factor=2, mode="nearest") # 1/4
  235. p5 = self.inp_conv[3](in5)
  236. p4 = self.inp_conv[2](out4)
  237. p3 = self.inp_conv[1](out3)
  238. p2 = self.inp_conv[0](out2)
  239. if self.intracl is True:
  240. p5 = self.incl4(p5)
  241. p4 = self.incl3(p4)
  242. p3 = self.incl2(p3)
  243. p2 = self.incl1(p2)
  244. p5 = F.interpolate(p5, scale_factor=8, mode="nearest")
  245. p4 = F.interpolate(p4, scale_factor=4, mode="nearest")
  246. p3 = F.interpolate(p3, scale_factor=2, mode="nearest")
  247. fuse = torch.cat([p5, p4, p3, p2], dim=1)
  248. return fuse
  249. class LKPAN(nn.Module):
  250. def __init__(self, in_channels, out_channels, mode="large", **kwargs):
  251. super(LKPAN, self).__init__()
  252. self.out_channels = out_channels
  253. self.ins_conv = nn.ModuleList()
  254. self.inp_conv = nn.ModuleList()
  255. # pan head
  256. self.pan_head_conv = nn.ModuleList()
  257. self.pan_lat_conv = nn.ModuleList()
  258. if mode.lower() == "lite":
  259. p_layer = DSConv
  260. elif mode.lower() == "large":
  261. p_layer = nn.Conv2d
  262. else:
  263. raise ValueError(
  264. "mode can only be one of ['lite', 'large'], but received {}".format(
  265. mode
  266. )
  267. )
  268. for i in range(len(in_channels)):
  269. self.ins_conv.append(
  270. nn.Conv2d(
  271. in_channels=in_channels[i],
  272. out_channels=self.out_channels,
  273. kernel_size=1,
  274. bias=False,
  275. )
  276. )
  277. self.inp_conv.append(
  278. p_layer(
  279. in_channels=self.out_channels,
  280. out_channels=self.out_channels // 4,
  281. kernel_size=9,
  282. padding=4,
  283. bias=False,
  284. )
  285. )
  286. if i > 0:
  287. self.pan_head_conv.append(
  288. nn.Conv2d(
  289. in_channels=self.out_channels // 4,
  290. out_channels=self.out_channels // 4,
  291. kernel_size=3,
  292. padding=1,
  293. stride=2,
  294. bias=False,
  295. )
  296. )
  297. self.pan_lat_conv.append(
  298. p_layer(
  299. in_channels=self.out_channels // 4,
  300. out_channels=self.out_channels // 4,
  301. kernel_size=9,
  302. padding=4,
  303. bias=False,
  304. )
  305. )
  306. self.intracl = False
  307. if "intracl" in kwargs.keys() and kwargs["intracl"] is True:
  308. self.intracl = kwargs["intracl"]
  309. self.incl1 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
  310. self.incl2 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
  311. self.incl3 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
  312. self.incl4 = IntraCLBlock(self.out_channels // 4, reduce_factor=2)
  313. def forward(self, x):
  314. c2, c3, c4, c5 = x
  315. in5 = self.ins_conv[3](c5)
  316. in4 = self.ins_conv[2](c4)
  317. in3 = self.ins_conv[1](c3)
  318. in2 = self.ins_conv[0](c2)
  319. out4 = in4 + F.interpolate(in5, scale_factor=2, mode="nearest") # 1/16
  320. out3 = in3 + F.interpolate(out4, scale_factor=2, mode="nearest") # 1/8
  321. out2 = in2 + F.interpolate(out3, scale_factor=2, mode="nearest") # 1/4
  322. f5 = self.inp_conv[3](in5)
  323. f4 = self.inp_conv[2](out4)
  324. f3 = self.inp_conv[1](out3)
  325. f2 = self.inp_conv[0](out2)
  326. pan3 = f3 + self.pan_head_conv[0](f2)
  327. pan4 = f4 + self.pan_head_conv[1](pan3)
  328. pan5 = f5 + self.pan_head_conv[2](pan4)
  329. p2 = self.pan_lat_conv[0](f2)
  330. p3 = self.pan_lat_conv[1](pan3)
  331. p4 = self.pan_lat_conv[2](pan4)
  332. p5 = self.pan_lat_conv[3](pan5)
  333. if self.intracl is True:
  334. p5 = self.incl4(p5)
  335. p4 = self.incl3(p4)
  336. p3 = self.incl2(p3)
  337. p2 = self.incl1(p2)
  338. p5 = F.interpolate(p5, scale_factor=8, mode="nearest")
  339. p4 = F.interpolate(p4, scale_factor=4, mode="nearest")
  340. p3 = F.interpolate(p3, scale_factor=2, mode="nearest")
  341. fuse = torch.cat([p5, p4, p3, p2], dim=1)
  342. return fuse
  343. class ASFBlock(nn.Module):
  344. """
  345. This code is refered from:
  346. https://github.com/MhLiao/DB/blob/master/decoders/feature_attention.py
  347. """
  348. def __init__(self, in_channels, inter_channels, out_features_num=4):
  349. """
  350. Adaptive Scale Fusion (ASF) block of DBNet++
  351. Args:
  352. in_channels: the number of channels in the input data
  353. inter_channels: the number of middle channels
  354. out_features_num: the number of fused stages
  355. """
  356. super(ASFBlock, self).__init__()
  357. self.in_channels = in_channels
  358. self.inter_channels = inter_channels
  359. self.out_features_num = out_features_num
  360. self.conv = nn.Conv2d(in_channels, inter_channels, 3, padding=1)
  361. self.spatial_scale = nn.Sequential(
  362. # Nx1xHxW
  363. nn.Conv2d(
  364. in_channels=1,
  365. out_channels=1,
  366. kernel_size=3,
  367. bias=False,
  368. padding=1,
  369. ),
  370. nn.ReLU(),
  371. nn.Conv2d(
  372. in_channels=1,
  373. out_channels=1,
  374. kernel_size=1,
  375. bias=False,
  376. ),
  377. nn.Sigmoid(),
  378. )
  379. self.channel_scale = nn.Sequential(
  380. nn.Conv2d(
  381. in_channels=inter_channels,
  382. out_channels=out_features_num,
  383. kernel_size=1,
  384. bias=False,
  385. ),
  386. nn.Sigmoid(),
  387. )
  388. def forward(self, fuse_features, features_list):
  389. fuse_features = self.conv(fuse_features)
  390. spatial_x = torch.mean(fuse_features, dim=1, keepdim=True)
  391. attention_scores = self.spatial_scale(spatial_x) + fuse_features
  392. attention_scores = self.channel_scale(attention_scores)
  393. assert len(features_list) == self.out_features_num
  394. out_list = []
  395. for i in range(self.out_features_num):
  396. out_list.append(attention_scores[:, i : i + 1] * features_list[i])
  397. return torch.cat(out_list, dim=1)