xception.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. # coding: utf8
  2. # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import contextlib
  19. import paddle
  20. import math
  21. import paddle.fluid as fluid
  22. from .segmentation.model_utils.libs import scope, name_scope
  23. from .segmentation.model_utils.libs import bn, bn_relu, relu
  24. from .segmentation.model_utils.libs import conv
  25. from .segmentation.model_utils.libs import separate_conv
  26. __all__ = ['xception_65', 'xception_41', 'xception_71']
  27. def check_data(data, number):
  28. if type(data) == int:
  29. return [data] * number
  30. assert len(data) == number
  31. return data
  32. def check_stride(s, os):
  33. if s <= os:
  34. return True
  35. else:
  36. return False
  37. def check_points(count, points):
  38. if points is None:
  39. return False
  40. else:
  41. if isinstance(points, list):
  42. return (True if count in points else False)
  43. else:
  44. return (True if count == points else False)
  45. class Xception():
  46. def __init__(self,
  47. num_classes=None,
  48. layers=65,
  49. output_stride=32,
  50. end_points=None,
  51. decode_points=None):
  52. self.backbone = 'xception_' + str(layers)
  53. self.num_classes = num_classes
  54. self.output_stride = output_stride
  55. self.output_stride = output_stride
  56. self.end_points = end_points
  57. self.decode_points = decode_points
  58. self.bottleneck_params = self.gen_bottleneck_params(self.backbone)
  59. def __call__(
  60. self,
  61. input,
  62. ):
  63. self.stride = 2
  64. self.block_point = 0
  65. self.short_cuts = dict()
  66. with scope(self.backbone):
  67. # Entry flow
  68. data = self.entry_flow(input)
  69. if check_points(self.block_point, self.end_points):
  70. return data, self.short_cuts
  71. # Middle flow
  72. data = self.middle_flow(data)
  73. if check_points(self.block_point, self.end_points):
  74. return data, self.short_cuts
  75. # Exit flow
  76. data = self.exit_flow(data)
  77. if check_points(self.block_point, self.end_points):
  78. return data, self.short_cuts
  79. if self.num_classes is not None:
  80. data = fluid.layers.reduce_mean(data, [2, 3], keep_dim=True)
  81. data = fluid.layers.dropout(data, 0.5)
  82. stdv = 1.0 / math.sqrt(data.shape[1] * 1.0)
  83. with scope("logit"):
  84. out = fluid.layers.fc(
  85. input=data,
  86. size=self.num_classes,
  87. act='softmax',
  88. param_attr=fluid.param_attr.ParamAttr(
  89. name='weights',
  90. initializer=fluid.initializer.Uniform(-stdv, stdv)),
  91. bias_attr=fluid.param_attr.ParamAttr(name='bias'))
  92. return out
  93. else:
  94. return data
  95. def gen_bottleneck_params(self, backbone='xception_65'):
  96. if backbone == 'xception_65':
  97. bottleneck_params = {
  98. "entry_flow": (3, [2, 2, 2], [128, 256, 728]),
  99. "middle_flow": (16, 1, 728),
  100. "exit_flow": (2, [2, 1], [[728, 1024, 1024],
  101. [1536, 1536, 2048]])
  102. }
  103. elif backbone == 'xception_41':
  104. bottleneck_params = {
  105. "entry_flow": (3, [2, 2, 2], [128, 256, 728]),
  106. "middle_flow": (8, 1, 728),
  107. "exit_flow": (2, [2, 1], [[728, 1024, 1024],
  108. [1536, 1536, 2048]])
  109. }
  110. elif backbone == 'xception_71':
  111. bottleneck_params = {
  112. "entry_flow": (5, [2, 1, 2, 1, 2], [128, 256, 256, 728, 728]),
  113. "middle_flow": (16, 1, 728),
  114. "exit_flow": (2, [2, 1], [[728, 1024, 1024],
  115. [1536, 1536, 2048]])
  116. }
  117. else:
  118. raise Exception(
  119. "xception backbont only support xception_41/xception_65/xception_71"
  120. )
  121. return bottleneck_params
  122. def entry_flow(self, data):
  123. param_attr = fluid.ParamAttr(
  124. name=name_scope + 'weights',
  125. regularizer=None,
  126. initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.09))
  127. with scope("entry_flow"):
  128. with scope("conv1"):
  129. data = bn_relu(
  130. conv(
  131. data,
  132. 32,
  133. 3,
  134. stride=2,
  135. padding=1,
  136. param_attr=param_attr),
  137. eps=1e-3)
  138. with scope("conv2"):
  139. data = bn_relu(
  140. conv(
  141. data,
  142. 64,
  143. 3,
  144. stride=1,
  145. padding=1,
  146. param_attr=param_attr),
  147. eps=1e-3)
  148. # get entry flow params
  149. block_num = self.bottleneck_params["entry_flow"][0]
  150. strides = self.bottleneck_params["entry_flow"][1]
  151. chns = self.bottleneck_params["entry_flow"][2]
  152. strides = check_data(strides, block_num)
  153. chns = check_data(chns, block_num)
  154. # params to control your flow
  155. s = self.stride
  156. block_point = self.block_point
  157. output_stride = self.output_stride
  158. with scope("entry_flow"):
  159. for i in range(block_num):
  160. block_point = block_point + 1
  161. with scope("block" + str(i + 1)):
  162. stride = strides[i] if check_stride(
  163. s * strides[i], output_stride) else 1
  164. data, short_cuts = self.xception_block(
  165. data, chns[i], [1, 1, stride])
  166. s = s * stride
  167. if check_points(block_point, self.decode_points):
  168. self.short_cuts[block_point] = short_cuts[1]
  169. self.stride = s
  170. self.block_point = block_point
  171. return data
  172. def middle_flow(self, data):
  173. block_num = self.bottleneck_params["middle_flow"][0]
  174. strides = self.bottleneck_params["middle_flow"][1]
  175. chns = self.bottleneck_params["middle_flow"][2]
  176. strides = check_data(strides, block_num)
  177. chns = check_data(chns, block_num)
  178. # params to control your flow
  179. s = self.stride
  180. block_point = self.block_point
  181. output_stride = self.output_stride
  182. with scope("middle_flow"):
  183. for i in range(block_num):
  184. block_point = block_point + 1
  185. with scope("block" + str(i + 1)):
  186. stride = strides[i] if check_stride(
  187. s * strides[i], output_stride) else 1
  188. data, short_cuts = self.xception_block(
  189. data, chns[i], [1, 1, strides[i]], skip_conv=False)
  190. s = s * stride
  191. if check_points(block_point, self.decode_points):
  192. self.short_cuts[block_point] = short_cuts[1]
  193. self.stride = s
  194. self.block_point = block_point
  195. return data
  196. def exit_flow(self, data):
  197. block_num = self.bottleneck_params["exit_flow"][0]
  198. strides = self.bottleneck_params["exit_flow"][1]
  199. chns = self.bottleneck_params["exit_flow"][2]
  200. strides = check_data(strides, block_num)
  201. chns = check_data(chns, block_num)
  202. assert (block_num == 2)
  203. # params to control your flow
  204. s = self.stride
  205. block_point = self.block_point
  206. output_stride = self.output_stride
  207. with scope("exit_flow"):
  208. with scope('block1'):
  209. block_point += 1
  210. stride = strides[0] if check_stride(s * strides[0],
  211. output_stride) else 1
  212. data, short_cuts = self.xception_block(data, chns[0],
  213. [1, 1, stride])
  214. s = s * stride
  215. if check_points(block_point, self.decode_points):
  216. self.short_cuts[block_point] = short_cuts[1]
  217. with scope('block2'):
  218. block_point += 1
  219. stride = strides[1] if check_stride(s * strides[1],
  220. output_stride) else 1
  221. data, short_cuts = self.xception_block(
  222. data,
  223. chns[1], [1, 1, stride],
  224. dilation=2,
  225. has_skip=False,
  226. activation_fn_in_separable_conv=True)
  227. s = s * stride
  228. if check_points(block_point, self.decode_points):
  229. self.short_cuts[block_point] = short_cuts[1]
  230. self.stride = s
  231. self.block_point = block_point
  232. return data
  233. def xception_block(self,
  234. input,
  235. channels,
  236. strides=1,
  237. filters=3,
  238. dilation=1,
  239. skip_conv=True,
  240. has_skip=True,
  241. activation_fn_in_separable_conv=False):
  242. repeat_number = 3
  243. channels = check_data(channels, repeat_number)
  244. filters = check_data(filters, repeat_number)
  245. strides = check_data(strides, repeat_number)
  246. data = input
  247. results = []
  248. for i in range(repeat_number):
  249. with scope('separable_conv' + str(i + 1)):
  250. if not activation_fn_in_separable_conv:
  251. data = relu(data)
  252. data = separate_conv(
  253. data,
  254. channels[i],
  255. strides[i],
  256. filters[i],
  257. dilation=dilation,
  258. eps=1e-3)
  259. else:
  260. data = separate_conv(
  261. data,
  262. channels[i],
  263. strides[i],
  264. filters[i],
  265. dilation=dilation,
  266. act=relu,
  267. eps=1e-3)
  268. results.append(data)
  269. if not has_skip:
  270. return data, results
  271. if skip_conv:
  272. param_attr = fluid.ParamAttr(
  273. name=name_scope + 'weights',
  274. regularizer=None,
  275. initializer=fluid.initializer.TruncatedNormal(
  276. loc=0.0, scale=0.09))
  277. with scope('shortcut'):
  278. skip = bn(
  279. conv(
  280. input,
  281. channels[-1],
  282. 1,
  283. strides[-1],
  284. groups=1,
  285. padding=0,
  286. param_attr=param_attr),
  287. eps=1e-3)
  288. else:
  289. skip = input
  290. return data + skip, results
  291. def xception_65(num_classes=None):
  292. model = Xception(num_classes, 65)
  293. return model
  294. def xception_41(num_classes=None):
  295. model = Xception(num_classes, 41)
  296. return model
  297. def xception_71(num_classes=None):
  298. model = Xception(num_classes, 71)
  299. return model