xception.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # coding: utf8
  2. # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import contextlib
  19. import paddle
  20. import math
  21. from collections import OrderedDict
  22. import paddle.fluid as fluid
  23. from .segmentation.model_utils.libs import scope, name_scope
  24. from .segmentation.model_utils.libs import bn, bn_relu, relu
  25. from .segmentation.model_utils.libs import conv
  26. from .segmentation.model_utils.libs import separate_conv
  27. __all__ = ['xception_65', 'xception_41', 'xception_71']
  28. def check_data(data, number):
  29. if type(data) == int:
  30. return [data] * number
  31. assert len(data) == number
  32. return data
  33. def check_stride(s, os):
  34. if s <= os:
  35. return True
  36. else:
  37. return False
  38. def check_points(count, points):
  39. if points is None:
  40. return False
  41. else:
  42. if isinstance(points, list):
  43. return (True if count in points else False)
  44. else:
  45. return (True if count == points else False)
  46. class Xception():
  47. def __init__(self,
  48. num_classes=None,
  49. layers=65,
  50. output_stride=32,
  51. end_points=None,
  52. decode_points=None):
  53. self.backbone = 'xception_' + str(layers)
  54. self.num_classes = num_classes
  55. self.output_stride = output_stride
  56. self.output_stride = output_stride
  57. self.end_points = end_points
  58. self.decode_points = decode_points
  59. self.bottleneck_params = self.gen_bottleneck_params(self.backbone)
  60. def __call__(
  61. self,
  62. input,
  63. ):
  64. self.stride = 2
  65. self.block_point = 0
  66. self.short_cuts = dict()
  67. with scope(self.backbone):
  68. # Entry flow
  69. data = self.entry_flow(input)
  70. if check_points(self.block_point, self.end_points):
  71. return data, self.short_cuts
  72. # Middle flow
  73. data = self.middle_flow(data)
  74. if check_points(self.block_point, self.end_points):
  75. return data, self.short_cuts
  76. # Exit flow
  77. data = self.exit_flow(data)
  78. if check_points(self.block_point, self.end_points):
  79. return data, self.short_cuts
  80. if self.num_classes is not None:
  81. data = fluid.layers.reduce_mean(data, [2, 3], keep_dim=True)
  82. data = fluid.layers.dropout(data, 0.5)
  83. stdv = 1.0 / math.sqrt(data.shape[1] * 1.0)
  84. with scope("logit"):
  85. out = fluid.layers.fc(
  86. input=data,
  87. size=self.num_classes,
  88. act='softmax',
  89. param_attr=fluid.param_attr.ParamAttr(
  90. name='weights',
  91. initializer=fluid.initializer.Uniform(-stdv, stdv)),
  92. bias_attr=fluid.param_attr.ParamAttr(name='bias'))
  93. return OrderedDict([('logits', out)])
  94. else:
  95. return data
  96. def gen_bottleneck_params(self, backbone='xception_65'):
  97. if backbone == 'xception_65':
  98. bottleneck_params = {
  99. "entry_flow": (3, [2, 2, 2], [128, 256, 728]),
  100. "middle_flow": (16, 1, 728),
  101. "exit_flow": (2, [2, 1], [[728, 1024, 1024],
  102. [1536, 1536, 2048]])
  103. }
  104. elif backbone == 'xception_41':
  105. bottleneck_params = {
  106. "entry_flow": (3, [2, 2, 2], [128, 256, 728]),
  107. "middle_flow": (8, 1, 728),
  108. "exit_flow": (2, [2, 1], [[728, 1024, 1024],
  109. [1536, 1536, 2048]])
  110. }
  111. elif backbone == 'xception_71':
  112. bottleneck_params = {
  113. "entry_flow": (5, [2, 1, 2, 1, 2], [128, 256, 256, 728, 728]),
  114. "middle_flow": (16, 1, 728),
  115. "exit_flow": (2, [2, 1], [[728, 1024, 1024],
  116. [1536, 1536, 2048]])
  117. }
  118. else:
  119. raise Exception(
  120. "xception backbont only support xception_41/xception_65/xception_71"
  121. )
  122. return bottleneck_params
  123. def entry_flow(self, data):
  124. param_attr = fluid.ParamAttr(
  125. name=name_scope + 'weights',
  126. regularizer=None,
  127. initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.09))
  128. with scope("entry_flow"):
  129. with scope("conv1"):
  130. data = bn_relu(
  131. conv(
  132. data,
  133. 32,
  134. 3,
  135. stride=2,
  136. padding=1,
  137. param_attr=param_attr),
  138. eps=1e-3)
  139. with scope("conv2"):
  140. data = bn_relu(
  141. conv(
  142. data,
  143. 64,
  144. 3,
  145. stride=1,
  146. padding=1,
  147. param_attr=param_attr),
  148. eps=1e-3)
  149. # get entry flow params
  150. block_num = self.bottleneck_params["entry_flow"][0]
  151. strides = self.bottleneck_params["entry_flow"][1]
  152. chns = self.bottleneck_params["entry_flow"][2]
  153. strides = check_data(strides, block_num)
  154. chns = check_data(chns, block_num)
  155. # params to control your flow
  156. s = self.stride
  157. block_point = self.block_point
  158. output_stride = self.output_stride
  159. with scope("entry_flow"):
  160. for i in range(block_num):
  161. block_point = block_point + 1
  162. with scope("block" + str(i + 1)):
  163. stride = strides[i] if check_stride(
  164. s * strides[i], output_stride) else 1
  165. data, short_cuts = self.xception_block(
  166. data, chns[i], [1, 1, stride])
  167. s = s * stride
  168. if check_points(block_point, self.decode_points):
  169. self.short_cuts[block_point] = short_cuts[1]
  170. self.stride = s
  171. self.block_point = block_point
  172. return data
  173. def middle_flow(self, data):
  174. block_num = self.bottleneck_params["middle_flow"][0]
  175. strides = self.bottleneck_params["middle_flow"][1]
  176. chns = self.bottleneck_params["middle_flow"][2]
  177. strides = check_data(strides, block_num)
  178. chns = check_data(chns, block_num)
  179. # params to control your flow
  180. s = self.stride
  181. block_point = self.block_point
  182. output_stride = self.output_stride
  183. with scope("middle_flow"):
  184. for i in range(block_num):
  185. block_point = block_point + 1
  186. with scope("block" + str(i + 1)):
  187. stride = strides[i] if check_stride(
  188. s * strides[i], output_stride) else 1
  189. data, short_cuts = self.xception_block(
  190. data, chns[i], [1, 1, strides[i]], skip_conv=False)
  191. s = s * stride
  192. if check_points(block_point, self.decode_points):
  193. self.short_cuts[block_point] = short_cuts[1]
  194. self.stride = s
  195. self.block_point = block_point
  196. return data
  197. def exit_flow(self, data):
  198. block_num = self.bottleneck_params["exit_flow"][0]
  199. strides = self.bottleneck_params["exit_flow"][1]
  200. chns = self.bottleneck_params["exit_flow"][2]
  201. strides = check_data(strides, block_num)
  202. chns = check_data(chns, block_num)
  203. assert (block_num == 2)
  204. # params to control your flow
  205. s = self.stride
  206. block_point = self.block_point
  207. output_stride = self.output_stride
  208. with scope("exit_flow"):
  209. with scope('block1'):
  210. block_point += 1
  211. stride = strides[0] if check_stride(s * strides[0],
  212. output_stride) else 1
  213. data, short_cuts = self.xception_block(data, chns[0],
  214. [1, 1, stride])
  215. s = s * stride
  216. if check_points(block_point, self.decode_points):
  217. self.short_cuts[block_point] = short_cuts[1]
  218. with scope('block2'):
  219. block_point += 1
  220. stride = strides[1] if check_stride(s * strides[1],
  221. output_stride) else 1
  222. data, short_cuts = self.xception_block(
  223. data,
  224. chns[1], [1, 1, stride],
  225. dilation=2,
  226. has_skip=False,
  227. activation_fn_in_separable_conv=True)
  228. s = s * stride
  229. if check_points(block_point, self.decode_points):
  230. self.short_cuts[block_point] = short_cuts[1]
  231. self.stride = s
  232. self.block_point = block_point
  233. return data
  234. def xception_block(self,
  235. input,
  236. channels,
  237. strides=1,
  238. filters=3,
  239. dilation=1,
  240. skip_conv=True,
  241. has_skip=True,
  242. activation_fn_in_separable_conv=False):
  243. repeat_number = 3
  244. channels = check_data(channels, repeat_number)
  245. filters = check_data(filters, repeat_number)
  246. strides = check_data(strides, repeat_number)
  247. data = input
  248. results = []
  249. for i in range(repeat_number):
  250. with scope('separable_conv' + str(i + 1)):
  251. if not activation_fn_in_separable_conv:
  252. data = relu(data)
  253. data = separate_conv(
  254. data,
  255. channels[i],
  256. strides[i],
  257. filters[i],
  258. dilation=dilation,
  259. eps=1e-3)
  260. else:
  261. data = separate_conv(
  262. data,
  263. channels[i],
  264. strides[i],
  265. filters[i],
  266. dilation=dilation,
  267. act=relu,
  268. eps=1e-3)
  269. results.append(data)
  270. if not has_skip:
  271. return data, results
  272. if skip_conv:
  273. param_attr = fluid.ParamAttr(
  274. name=name_scope + 'weights',
  275. regularizer=None,
  276. initializer=fluid.initializer.TruncatedNormal(
  277. loc=0.0, scale=0.09))
  278. with scope('shortcut'):
  279. skip = bn(
  280. conv(
  281. input,
  282. channels[-1],
  283. 1,
  284. strides[-1],
  285. groups=1,
  286. padding=0,
  287. param_attr=param_attr),
  288. eps=1e-3)
  289. else:
  290. skip = input
  291. return data + skip, results
  292. def xception_65(num_classes=None):
  293. model = Xception(num_classes, 65)
  294. return model
  295. def xception_41(num_classes=None):
  296. model = Xception(num_classes, 41)
  297. return model
  298. def xception_71(num_classes=None):
  299. model = Xception(num_classes, 71)
  300. return model