rec_mv1_enhance.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. import os, sys
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from ..common import Activation
  6. class ConvBNLayer(nn.Module):
  7. def __init__(self,
  8. num_channels,
  9. filter_size,
  10. num_filters,
  11. stride,
  12. padding,
  13. channels=None,
  14. num_groups=1,
  15. act='hard_swish'):
  16. super(ConvBNLayer, self).__init__()
  17. self.act = act
  18. self._conv = nn.Conv2d(
  19. in_channels=num_channels,
  20. out_channels=num_filters,
  21. kernel_size=filter_size,
  22. stride=stride,
  23. padding=padding,
  24. groups=num_groups,
  25. bias=False)
  26. self._batch_norm = nn.BatchNorm2d(
  27. num_filters,
  28. )
  29. if self.act is not None:
  30. self._act = Activation(act_type=act, inplace=True)
  31. def forward(self, inputs):
  32. y = self._conv(inputs)
  33. y = self._batch_norm(y)
  34. if self.act is not None:
  35. y = self._act(y)
  36. return y
  37. class DepthwiseSeparable(nn.Module):
  38. def __init__(self,
  39. num_channels,
  40. num_filters1,
  41. num_filters2,
  42. num_groups,
  43. stride,
  44. scale,
  45. dw_size=3,
  46. padding=1,
  47. use_se=False):
  48. super(DepthwiseSeparable, self).__init__()
  49. self.use_se = use_se
  50. self._depthwise_conv = ConvBNLayer(
  51. num_channels=num_channels,
  52. num_filters=int(num_filters1 * scale),
  53. filter_size=dw_size,
  54. stride=stride,
  55. padding=padding,
  56. num_groups=int(num_groups * scale))
  57. if use_se:
  58. self._se = SEModule(int(num_filters1 * scale))
  59. self._pointwise_conv = ConvBNLayer(
  60. num_channels=int(num_filters1 * scale),
  61. filter_size=1,
  62. num_filters=int(num_filters2 * scale),
  63. stride=1,
  64. padding=0)
  65. def forward(self, inputs):
  66. y = self._depthwise_conv(inputs)
  67. if self.use_se:
  68. y = self._se(y)
  69. y = self._pointwise_conv(y)
  70. return y
  71. class MobileNetV1Enhance(nn.Module):
  72. def __init__(self,
  73. in_channels=3,
  74. scale=0.5,
  75. last_conv_stride=1,
  76. last_pool_type='max',
  77. **kwargs):
  78. super().__init__()
  79. self.scale = scale
  80. self.block_list = []
  81. self.conv1 = ConvBNLayer(
  82. num_channels=in_channels,
  83. filter_size=3,
  84. channels=3,
  85. num_filters=int(32 * scale),
  86. stride=2,
  87. padding=1)
  88. conv2_1 = DepthwiseSeparable(
  89. num_channels=int(32 * scale),
  90. num_filters1=32,
  91. num_filters2=64,
  92. num_groups=32,
  93. stride=1,
  94. scale=scale)
  95. self.block_list.append(conv2_1)
  96. conv2_2 = DepthwiseSeparable(
  97. num_channels=int(64 * scale),
  98. num_filters1=64,
  99. num_filters2=128,
  100. num_groups=64,
  101. stride=1,
  102. scale=scale)
  103. self.block_list.append(conv2_2)
  104. conv3_1 = DepthwiseSeparable(
  105. num_channels=int(128 * scale),
  106. num_filters1=128,
  107. num_filters2=128,
  108. num_groups=128,
  109. stride=1,
  110. scale=scale)
  111. self.block_list.append(conv3_1)
  112. conv3_2 = DepthwiseSeparable(
  113. num_channels=int(128 * scale),
  114. num_filters1=128,
  115. num_filters2=256,
  116. num_groups=128,
  117. stride=(2, 1),
  118. scale=scale)
  119. self.block_list.append(conv3_2)
  120. conv4_1 = DepthwiseSeparable(
  121. num_channels=int(256 * scale),
  122. num_filters1=256,
  123. num_filters2=256,
  124. num_groups=256,
  125. stride=1,
  126. scale=scale)
  127. self.block_list.append(conv4_1)
  128. conv4_2 = DepthwiseSeparable(
  129. num_channels=int(256 * scale),
  130. num_filters1=256,
  131. num_filters2=512,
  132. num_groups=256,
  133. stride=(2, 1),
  134. scale=scale)
  135. self.block_list.append(conv4_2)
  136. for _ in range(5):
  137. conv5 = DepthwiseSeparable(
  138. num_channels=int(512 * scale),
  139. num_filters1=512,
  140. num_filters2=512,
  141. num_groups=512,
  142. stride=1,
  143. dw_size=5,
  144. padding=2,
  145. scale=scale,
  146. use_se=False)
  147. self.block_list.append(conv5)
  148. conv5_6 = DepthwiseSeparable(
  149. num_channels=int(512 * scale),
  150. num_filters1=512,
  151. num_filters2=1024,
  152. num_groups=512,
  153. stride=(2, 1),
  154. dw_size=5,
  155. padding=2,
  156. scale=scale,
  157. use_se=True)
  158. self.block_list.append(conv5_6)
  159. conv6 = DepthwiseSeparable(
  160. num_channels=int(1024 * scale),
  161. num_filters1=1024,
  162. num_filters2=1024,
  163. num_groups=1024,
  164. stride=last_conv_stride,
  165. dw_size=5,
  166. padding=2,
  167. use_se=True,
  168. scale=scale)
  169. self.block_list.append(conv6)
  170. self.block_list = nn.Sequential(*self.block_list)
  171. if last_pool_type == 'avg':
  172. self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
  173. else:
  174. self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
  175. self.out_channels = int(1024 * scale)
  176. def forward(self, inputs):
  177. y = self.conv1(inputs)
  178. y = self.block_list(y)
  179. y = self.pool(y)
  180. return y
  181. def hardsigmoid(x):
  182. return F.relu6(x + 3., inplace=True) / 6.
  183. class SEModule(nn.Module):
  184. def __init__(self, channel, reduction=4):
  185. super(SEModule, self).__init__()
  186. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  187. self.conv1 = nn.Conv2d(
  188. in_channels=channel,
  189. out_channels=channel // reduction,
  190. kernel_size=1,
  191. stride=1,
  192. padding=0,
  193. bias=True)
  194. self.conv2 = nn.Conv2d(
  195. in_channels=channel // reduction,
  196. out_channels=channel,
  197. kernel_size=1,
  198. stride=1,
  199. padding=0,
  200. bias=True)
  201. def forward(self, inputs):
  202. outputs = self.avg_pool(inputs)
  203. outputs = self.conv1(outputs)
  204. outputs = F.relu(outputs)
  205. outputs = self.conv2(outputs)
  206. outputs = hardsigmoid(outputs)
  207. x = torch.mul(inputs, outputs)
  208. return x