xception.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. from paddle import ParamAttr
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
  19. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  20. from paddle.nn.initializer import Uniform
  21. import math
  22. import sys
  23. __all__ = ['Xception41', 'Xception65', 'Xception71']
  24. class ConvBNLayer(nn.Layer):
  25. def __init__(self,
  26. num_channels,
  27. num_filters,
  28. filter_size,
  29. stride=1,
  30. groups=1,
  31. act=None,
  32. name=None):
  33. super(ConvBNLayer, self).__init__()
  34. self._conv = Conv2D(
  35. in_channels=num_channels,
  36. out_channels=num_filters,
  37. kernel_size=filter_size,
  38. stride=stride,
  39. padding=(filter_size - 1) // 2,
  40. groups=groups,
  41. weight_attr=ParamAttr(name=name + "_weights"),
  42. bias_attr=False)
  43. bn_name = "bn_" + name
  44. self._batch_norm = BatchNorm(
  45. num_filters,
  46. act=act,
  47. param_attr=ParamAttr(name=bn_name + "_scale"),
  48. bias_attr=ParamAttr(name=bn_name + "_offset"),
  49. moving_mean_name=bn_name + '_mean',
  50. moving_variance_name=bn_name + '_variance')
  51. def forward(self, inputs):
  52. y = self._conv(inputs)
  53. y = self._batch_norm(y)
  54. return y
  55. class SeparableConv(nn.Layer):
  56. def __init__(self, input_channels, output_channels, stride=1, name=None):
  57. super(SeparableConv, self).__init__()
  58. self._pointwise_conv = ConvBNLayer(
  59. input_channels, output_channels, 1, name=name + "_sep")
  60. self._depthwise_conv = ConvBNLayer(
  61. output_channels,
  62. output_channels,
  63. 3,
  64. stride=stride,
  65. groups=output_channels,
  66. name=name + "_dw")
  67. def forward(self, inputs):
  68. x = self._pointwise_conv(inputs)
  69. x = self._depthwise_conv(x)
  70. return x
  71. class EntryFlowBottleneckBlock(nn.Layer):
  72. def __init__(self,
  73. input_channels,
  74. output_channels,
  75. stride=2,
  76. name=None,
  77. relu_first=False):
  78. super(EntryFlowBottleneckBlock, self).__init__()
  79. self.relu_first = relu_first
  80. self._short = Conv2D(
  81. in_channels=input_channels,
  82. out_channels=output_channels,
  83. kernel_size=1,
  84. stride=stride,
  85. padding=0,
  86. weight_attr=ParamAttr(name + "_branch1_weights"),
  87. bias_attr=False)
  88. self._conv1 = SeparableConv(
  89. input_channels,
  90. output_channels,
  91. stride=1,
  92. name=name + "_branch2a_weights")
  93. self._conv2 = SeparableConv(
  94. output_channels,
  95. output_channels,
  96. stride=1,
  97. name=name + "_branch2b_weights")
  98. self._pool = MaxPool2D(kernel_size=3, stride=stride, padding=1)
  99. def forward(self, inputs):
  100. conv0 = inputs
  101. short = self._short(inputs)
  102. if self.relu_first:
  103. conv0 = F.relu(conv0)
  104. conv1 = self._conv1(conv0)
  105. conv2 = F.relu(conv1)
  106. conv2 = self._conv2(conv2)
  107. pool = self._pool(conv2)
  108. return paddle.add(x=short, y=pool)
  109. class EntryFlow(nn.Layer):
  110. def __init__(self, block_num=3):
  111. super(EntryFlow, self).__init__()
  112. name = "entry_flow"
  113. self.block_num = block_num
  114. self._conv1 = ConvBNLayer(
  115. 3, 32, 3, stride=2, act="relu", name=name + "_conv1")
  116. self._conv2 = ConvBNLayer(32, 64, 3, act="relu", name=name + "_conv2")
  117. if block_num == 3:
  118. self._conv_0 = EntryFlowBottleneckBlock(
  119. 64, 128, stride=2, name=name + "_0", relu_first=False)
  120. self._conv_1 = EntryFlowBottleneckBlock(
  121. 128, 256, stride=2, name=name + "_1", relu_first=True)
  122. self._conv_2 = EntryFlowBottleneckBlock(
  123. 256, 728, stride=2, name=name + "_2", relu_first=True)
  124. elif block_num == 5:
  125. self._conv_0 = EntryFlowBottleneckBlock(
  126. 64, 128, stride=2, name=name + "_0", relu_first=False)
  127. self._conv_1 = EntryFlowBottleneckBlock(
  128. 128, 256, stride=1, name=name + "_1", relu_first=True)
  129. self._conv_2 = EntryFlowBottleneckBlock(
  130. 256, 256, stride=2, name=name + "_2", relu_first=True)
  131. self._conv_3 = EntryFlowBottleneckBlock(
  132. 256, 728, stride=1, name=name + "_3", relu_first=True)
  133. self._conv_4 = EntryFlowBottleneckBlock(
  134. 728, 728, stride=2, name=name + "_4", relu_first=True)
  135. else:
  136. sys.exit(-1)
  137. def forward(self, inputs):
  138. x = self._conv1(inputs)
  139. x = self._conv2(x)
  140. if self.block_num == 3:
  141. x = self._conv_0(x)
  142. x = self._conv_1(x)
  143. x = self._conv_2(x)
  144. elif self.block_num == 5:
  145. x = self._conv_0(x)
  146. x = self._conv_1(x)
  147. x = self._conv_2(x)
  148. x = self._conv_3(x)
  149. x = self._conv_4(x)
  150. return x
  151. class MiddleFlowBottleneckBlock(nn.Layer):
  152. def __init__(self, input_channels, output_channels, name):
  153. super(MiddleFlowBottleneckBlock, self).__init__()
  154. self._conv_0 = SeparableConv(
  155. input_channels,
  156. output_channels,
  157. stride=1,
  158. name=name + "_branch2a_weights")
  159. self._conv_1 = SeparableConv(
  160. output_channels,
  161. output_channels,
  162. stride=1,
  163. name=name + "_branch2b_weights")
  164. self._conv_2 = SeparableConv(
  165. output_channels,
  166. output_channels,
  167. stride=1,
  168. name=name + "_branch2c_weights")
  169. def forward(self, inputs):
  170. conv0 = F.relu(inputs)
  171. conv0 = self._conv_0(conv0)
  172. conv1 = F.relu(conv0)
  173. conv1 = self._conv_1(conv1)
  174. conv2 = F.relu(conv1)
  175. conv2 = self._conv_2(conv2)
  176. return paddle.add(x=inputs, y=conv2)
  177. class MiddleFlow(nn.Layer):
  178. def __init__(self, block_num=8):
  179. super(MiddleFlow, self).__init__()
  180. self.block_num = block_num
  181. self._conv_0 = MiddleFlowBottleneckBlock(
  182. 728, 728, name="middle_flow_0")
  183. self._conv_1 = MiddleFlowBottleneckBlock(
  184. 728, 728, name="middle_flow_1")
  185. self._conv_2 = MiddleFlowBottleneckBlock(
  186. 728, 728, name="middle_flow_2")
  187. self._conv_3 = MiddleFlowBottleneckBlock(
  188. 728, 728, name="middle_flow_3")
  189. self._conv_4 = MiddleFlowBottleneckBlock(
  190. 728, 728, name="middle_flow_4")
  191. self._conv_5 = MiddleFlowBottleneckBlock(
  192. 728, 728, name="middle_flow_5")
  193. self._conv_6 = MiddleFlowBottleneckBlock(
  194. 728, 728, name="middle_flow_6")
  195. self._conv_7 = MiddleFlowBottleneckBlock(
  196. 728, 728, name="middle_flow_7")
  197. if block_num == 16:
  198. self._conv_8 = MiddleFlowBottleneckBlock(
  199. 728, 728, name="middle_flow_8")
  200. self._conv_9 = MiddleFlowBottleneckBlock(
  201. 728, 728, name="middle_flow_9")
  202. self._conv_10 = MiddleFlowBottleneckBlock(
  203. 728, 728, name="middle_flow_10")
  204. self._conv_11 = MiddleFlowBottleneckBlock(
  205. 728, 728, name="middle_flow_11")
  206. self._conv_12 = MiddleFlowBottleneckBlock(
  207. 728, 728, name="middle_flow_12")
  208. self._conv_13 = MiddleFlowBottleneckBlock(
  209. 728, 728, name="middle_flow_13")
  210. self._conv_14 = MiddleFlowBottleneckBlock(
  211. 728, 728, name="middle_flow_14")
  212. self._conv_15 = MiddleFlowBottleneckBlock(
  213. 728, 728, name="middle_flow_15")
  214. def forward(self, inputs):
  215. x = self._conv_0(inputs)
  216. x = self._conv_1(x)
  217. x = self._conv_2(x)
  218. x = self._conv_3(x)
  219. x = self._conv_4(x)
  220. x = self._conv_5(x)
  221. x = self._conv_6(x)
  222. x = self._conv_7(x)
  223. if self.block_num == 16:
  224. x = self._conv_8(x)
  225. x = self._conv_9(x)
  226. x = self._conv_10(x)
  227. x = self._conv_11(x)
  228. x = self._conv_12(x)
  229. x = self._conv_13(x)
  230. x = self._conv_14(x)
  231. x = self._conv_15(x)
  232. return x
  233. class ExitFlowBottleneckBlock(nn.Layer):
  234. def __init__(self, input_channels, output_channels1, output_channels2,
  235. name):
  236. super(ExitFlowBottleneckBlock, self).__init__()
  237. self._short = Conv2D(
  238. in_channels=input_channels,
  239. out_channels=output_channels2,
  240. kernel_size=1,
  241. stride=2,
  242. padding=0,
  243. weight_attr=ParamAttr(name + "_branch1_weights"),
  244. bias_attr=False)
  245. self._conv_1 = SeparableConv(
  246. input_channels,
  247. output_channels1,
  248. stride=1,
  249. name=name + "_branch2a_weights")
  250. self._conv_2 = SeparableConv(
  251. output_channels1,
  252. output_channels2,
  253. stride=1,
  254. name=name + "_branch2b_weights")
  255. self._pool = MaxPool2D(kernel_size=3, stride=2, padding=1)
  256. def forward(self, inputs):
  257. short = self._short(inputs)
  258. conv0 = F.relu(inputs)
  259. conv1 = self._conv_1(conv0)
  260. conv2 = F.relu(conv1)
  261. conv2 = self._conv_2(conv2)
  262. pool = self._pool(conv2)
  263. return paddle.add(x=short, y=pool)
  264. class ExitFlow(nn.Layer):
  265. def __init__(self, class_dim):
  266. super(ExitFlow, self).__init__()
  267. name = "exit_flow"
  268. self._conv_0 = ExitFlowBottleneckBlock(
  269. 728, 728, 1024, name=name + "_1")
  270. self._conv_1 = SeparableConv(1024, 1536, stride=1, name=name + "_2")
  271. self._conv_2 = SeparableConv(1536, 2048, stride=1, name=name + "_3")
  272. self._pool = AdaptiveAvgPool2D(1)
  273. stdv = 1.0 / math.sqrt(2048 * 1.0)
  274. self._out = Linear(
  275. 2048,
  276. class_dim,
  277. weight_attr=ParamAttr(
  278. name="fc_weights", initializer=Uniform(-stdv, stdv)),
  279. bias_attr=ParamAttr(name="fc_offset"))
  280. def forward(self, inputs):
  281. conv0 = self._conv_0(inputs)
  282. conv1 = self._conv_1(conv0)
  283. conv1 = F.relu(conv1)
  284. conv2 = self._conv_2(conv1)
  285. conv2 = F.relu(conv2)
  286. pool = self._pool(conv2)
  287. pool = paddle.flatten(pool, start_axis=1, stop_axis=-1)
  288. out = self._out(pool)
  289. return out
  290. class Xception(nn.Layer):
  291. def __init__(self,
  292. entry_flow_block_num=3,
  293. middle_flow_block_num=8,
  294. class_dim=1000):
  295. super(Xception, self).__init__()
  296. self.entry_flow_block_num = entry_flow_block_num
  297. self.middle_flow_block_num = middle_flow_block_num
  298. self._entry_flow = EntryFlow(entry_flow_block_num)
  299. self._middle_flow = MiddleFlow(middle_flow_block_num)
  300. self._exit_flow = ExitFlow(class_dim)
  301. def forward(self, inputs):
  302. x = self._entry_flow(inputs)
  303. x = self._middle_flow(x)
  304. x = self._exit_flow(x)
  305. return x
  306. def Xception41(**args):
  307. model = Xception(entry_flow_block_num=3, middle_flow_block_num=8, **args)
  308. return model
  309. def Xception65(**args):
  310. model = Xception(entry_flow_block_num=3, middle_flow_block_num=16, **args)
  311. return model
  312. def Xception71(**args):
  313. model = Xception(entry_flow_block_num=5, middle_flow_block_num=16, **args)
  314. return model