rec_hgnet.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import torch
  2. import torch.nn.functional as F
  3. from torch import nn
  4. class ConvBNAct(nn.Module):
  5. def __init__(
  6. self, in_channels, out_channels, kernel_size, stride, groups=1, use_act=True
  7. ):
  8. super().__init__()
  9. self.use_act = use_act
  10. self.conv = nn.Conv2d(
  11. in_channels,
  12. out_channels,
  13. kernel_size,
  14. stride,
  15. padding=(kernel_size - 1) // 2,
  16. groups=groups,
  17. bias=False,
  18. )
  19. self.bn = nn.BatchNorm2d(out_channels)
  20. if self.use_act:
  21. self.act = nn.ReLU()
  22. def forward(self, x):
  23. x = self.conv(x)
  24. x = self.bn(x)
  25. if self.use_act:
  26. x = self.act(x)
  27. return x
  28. class ESEModule(nn.Module):
  29. def __init__(self, channels):
  30. super().__init__()
  31. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  32. self.conv = nn.Conv2d(
  33. in_channels=channels,
  34. out_channels=channels,
  35. kernel_size=1,
  36. stride=1,
  37. padding=0,
  38. )
  39. self.sigmoid = nn.Sigmoid()
  40. def forward(self, x):
  41. identity = x
  42. x = self.avg_pool(x)
  43. x = self.conv(x)
  44. x = self.sigmoid(x)
  45. return x * identity
  46. class HG_Block(nn.Module):
  47. def __init__(
  48. self,
  49. in_channels,
  50. mid_channels,
  51. out_channels,
  52. layer_num,
  53. identity=False,
  54. ):
  55. super().__init__()
  56. self.identity = identity
  57. self.layers = nn.ModuleList()
  58. self.layers.append(
  59. ConvBNAct(
  60. in_channels=in_channels,
  61. out_channels=mid_channels,
  62. kernel_size=3,
  63. stride=1,
  64. )
  65. )
  66. for _ in range(layer_num - 1):
  67. self.layers.append(
  68. ConvBNAct(
  69. in_channels=mid_channels,
  70. out_channels=mid_channels,
  71. kernel_size=3,
  72. stride=1,
  73. )
  74. )
  75. # feature aggregation
  76. total_channels = in_channels + layer_num * mid_channels
  77. self.aggregation_conv = ConvBNAct(
  78. in_channels=total_channels,
  79. out_channels=out_channels,
  80. kernel_size=1,
  81. stride=1,
  82. )
  83. self.att = ESEModule(out_channels)
  84. def forward(self, x):
  85. identity = x
  86. output = []
  87. output.append(x)
  88. for layer in self.layers:
  89. x = layer(x)
  90. output.append(x)
  91. x = torch.cat(output, dim=1)
  92. x = self.aggregation_conv(x)
  93. x = self.att(x)
  94. if self.identity:
  95. x += identity
  96. return x
  97. class HG_Stage(nn.Module):
  98. def __init__(
  99. self,
  100. in_channels,
  101. mid_channels,
  102. out_channels,
  103. block_num,
  104. layer_num,
  105. downsample=True,
  106. stride=[2, 1],
  107. ):
  108. super().__init__()
  109. self.downsample = downsample
  110. if downsample:
  111. self.downsample = ConvBNAct(
  112. in_channels=in_channels,
  113. out_channels=in_channels,
  114. kernel_size=3,
  115. stride=stride,
  116. groups=in_channels,
  117. use_act=False,
  118. )
  119. blocks_list = []
  120. blocks_list.append(
  121. HG_Block(in_channels, mid_channels, out_channels, layer_num, identity=False)
  122. )
  123. for _ in range(block_num - 1):
  124. blocks_list.append(
  125. HG_Block(
  126. out_channels, mid_channels, out_channels, layer_num, identity=True
  127. )
  128. )
  129. self.blocks = nn.Sequential(*blocks_list)
  130. def forward(self, x):
  131. if self.downsample:
  132. x = self.downsample(x)
  133. x = self.blocks(x)
  134. return x
  135. class PPHGNet(nn.Module):
  136. """
  137. PPHGNet
  138. Args:
  139. stem_channels: list. Stem channel list of PPHGNet.
  140. stage_config: dict. The configuration of each stage of PPHGNet. such as the number of channels, stride, etc.
  141. layer_num: int. Number of layers of HG_Block.
  142. use_last_conv: boolean. Whether to use a 1x1 convolutional layer before the classification layer.
  143. class_expand: int=2048. Number of channels for the last 1x1 convolutional layer.
  144. dropout_prob: float. Parameters of dropout, 0.0 means dropout is not used.
  145. class_num: int=1000. The number of classes.
  146. Returns:
  147. model: nn.Layer. Specific PPHGNet model depends on args.
  148. """
  149. def __init__(
  150. self,
  151. stem_channels,
  152. stage_config,
  153. layer_num,
  154. in_channels=3,
  155. det=False,
  156. out_indices=None,
  157. ):
  158. super().__init__()
  159. self.det = det
  160. self.out_indices = out_indices if out_indices is not None else [0, 1, 2, 3]
  161. # stem
  162. stem_channels.insert(0, in_channels)
  163. self.stem = nn.Sequential(
  164. *[
  165. ConvBNAct(
  166. in_channels=stem_channels[i],
  167. out_channels=stem_channels[i + 1],
  168. kernel_size=3,
  169. stride=2 if i == 0 else 1,
  170. )
  171. for i in range(len(stem_channels) - 1)
  172. ]
  173. )
  174. if self.det:
  175. self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  176. # stages
  177. self.stages = nn.ModuleList()
  178. self.out_channels = []
  179. for block_id, k in enumerate(stage_config):
  180. (
  181. in_channels,
  182. mid_channels,
  183. out_channels,
  184. block_num,
  185. downsample,
  186. stride,
  187. ) = stage_config[k]
  188. self.stages.append(
  189. HG_Stage(
  190. in_channels,
  191. mid_channels,
  192. out_channels,
  193. block_num,
  194. layer_num,
  195. downsample,
  196. stride,
  197. )
  198. )
  199. if block_id in self.out_indices:
  200. self.out_channels.append(out_channels)
  201. if not self.det:
  202. self.out_channels = stage_config["stage4"][2]
  203. self._init_weights()
  204. def _init_weights(self):
  205. for m in self.modules():
  206. if isinstance(m, nn.Conv2d):
  207. nn.init.kaiming_normal_(m.weight)
  208. elif isinstance(m, nn.BatchNorm2d):
  209. nn.init.ones_(m.weight)
  210. nn.init.zeros_(m.bias)
  211. elif isinstance(m, nn.Linear):
  212. nn.init.zeros_(m.bias)
  213. def forward(self, x):
  214. x = self.stem(x)
  215. if self.det:
  216. x = self.pool(x)
  217. out = []
  218. for i, stage in enumerate(self.stages):
  219. x = stage(x)
  220. if self.det and i in self.out_indices:
  221. out.append(x)
  222. if self.det:
  223. return out
  224. if self.training:
  225. x = F.adaptive_avg_pool2d(x, [1, 40])
  226. else:
  227. x = F.avg_pool2d(x, [3, 2])
  228. return x
  229. def PPHGNet_small(pretrained=False, use_ssld=False, det=False, **kwargs):
  230. """
  231. PPHGNet_small
  232. Args:
  233. pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
  234. If str, means the path of the pretrained model.
  235. use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
  236. Returns:
  237. model: nn.Layer. Specific `PPHGNet_small` model depends on args.
  238. """
  239. stage_config_det = {
  240. # in_channels, mid_channels, out_channels, blocks, downsample
  241. "stage1": [128, 128, 256, 1, False, 2],
  242. "stage2": [256, 160, 512, 1, True, 2],
  243. "stage3": [512, 192, 768, 2, True, 2],
  244. "stage4": [768, 224, 1024, 1, True, 2],
  245. }
  246. stage_config_rec = {
  247. # in_channels, mid_channels, out_channels, blocks, downsample
  248. "stage1": [128, 128, 256, 1, True, [2, 1]],
  249. "stage2": [256, 160, 512, 1, True, [1, 2]],
  250. "stage3": [512, 192, 768, 2, True, [2, 1]],
  251. "stage4": [768, 224, 1024, 1, True, [2, 1]],
  252. }
  253. model = PPHGNet(
  254. stem_channels=[64, 64, 128],
  255. stage_config=stage_config_det if det else stage_config_rec,
  256. layer_num=6,
  257. det=det,
  258. **kwargs
  259. )
  260. return model