xception_deeplab.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. import paddle
  2. from paddle import ParamAttr
  3. import paddle.nn as nn
  4. import paddle.nn.functional as F
  5. from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
  6. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  7. __all__ = ["Xception41_deeplab", "Xception65_deeplab", "Xception71_deeplab"]
  8. def check_data(data, number):
  9. if type(data) == int:
  10. return [data] * number
  11. assert len(data) == number
  12. return data
  13. def check_stride(s, os):
  14. if s <= os:
  15. return True
  16. else:
  17. return False
  18. def check_points(count, points):
  19. if points is None:
  20. return False
  21. else:
  22. if isinstance(points, list):
  23. return (True if count in points else False)
  24. else:
  25. return (True if count == points else False)
  26. def gen_bottleneck_params(backbone='xception_65'):
  27. if backbone == 'xception_65':
  28. bottleneck_params = {
  29. "entry_flow": (3, [2, 2, 2], [128, 256, 728]),
  30. "middle_flow": (16, 1, 728),
  31. "exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
  32. }
  33. elif backbone == 'xception_41':
  34. bottleneck_params = {
  35. "entry_flow": (3, [2, 2, 2], [128, 256, 728]),
  36. "middle_flow": (8, 1, 728),
  37. "exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
  38. }
  39. elif backbone == 'xception_71':
  40. bottleneck_params = {
  41. "entry_flow": (5, [2, 1, 2, 1, 2], [128, 256, 256, 728, 728]),
  42. "middle_flow": (16, 1, 728),
  43. "exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
  44. }
  45. else:
  46. raise Exception(
  47. "xception backbont only support xception_41/xception_65/xception_71"
  48. )
  49. return bottleneck_params
  50. class ConvBNLayer(nn.Layer):
  51. def __init__(self,
  52. input_channels,
  53. output_channels,
  54. filter_size,
  55. stride=1,
  56. padding=0,
  57. act=None,
  58. name=None):
  59. super(ConvBNLayer, self).__init__()
  60. self._conv = Conv2D(
  61. in_channels=input_channels,
  62. out_channels=output_channels,
  63. kernel_size=filter_size,
  64. stride=stride,
  65. padding=padding,
  66. weight_attr=ParamAttr(name=name + "/weights"),
  67. bias_attr=False)
  68. self._bn = BatchNorm(
  69. num_channels=output_channels,
  70. act=act,
  71. epsilon=1e-3,
  72. momentum=0.99,
  73. param_attr=ParamAttr(name=name + "/BatchNorm/gamma"),
  74. bias_attr=ParamAttr(name=name + "/BatchNorm/beta"),
  75. moving_mean_name=name + "/BatchNorm/moving_mean",
  76. moving_variance_name=name + "/BatchNorm/moving_variance")
  77. def forward(self, inputs):
  78. return self._bn(self._conv(inputs))
  79. class Seperate_Conv(nn.Layer):
  80. def __init__(self,
  81. input_channels,
  82. output_channels,
  83. stride,
  84. filter,
  85. dilation=1,
  86. act=None,
  87. name=None):
  88. super(Seperate_Conv, self).__init__()
  89. self._conv1 = Conv2D(
  90. in_channels=input_channels,
  91. out_channels=input_channels,
  92. kernel_size=filter,
  93. stride=stride,
  94. groups=input_channels,
  95. padding=(filter) // 2 * dilation,
  96. dilation=dilation,
  97. weight_attr=ParamAttr(name=name + "/depthwise/weights"),
  98. bias_attr=False)
  99. self._bn1 = BatchNorm(
  100. input_channels,
  101. act=act,
  102. epsilon=1e-3,
  103. momentum=0.99,
  104. param_attr=ParamAttr(name=name + "/depthwise/BatchNorm/gamma"),
  105. bias_attr=ParamAttr(name=name + "/depthwise/BatchNorm/beta"),
  106. moving_mean_name=name + "/depthwise/BatchNorm/moving_mean",
  107. moving_variance_name=name + "/depthwise/BatchNorm/moving_variance")
  108. self._conv2 = Conv2D(
  109. input_channels,
  110. output_channels,
  111. 1,
  112. stride=1,
  113. groups=1,
  114. padding=0,
  115. weight_attr=ParamAttr(name=name + "/pointwise/weights"),
  116. bias_attr=False)
  117. self._bn2 = BatchNorm(
  118. output_channels,
  119. act=act,
  120. epsilon=1e-3,
  121. momentum=0.99,
  122. param_attr=ParamAttr(name=name + "/pointwise/BatchNorm/gamma"),
  123. bias_attr=ParamAttr(name=name + "/pointwise/BatchNorm/beta"),
  124. moving_mean_name=name + "/pointwise/BatchNorm/moving_mean",
  125. moving_variance_name=name + "/pointwise/BatchNorm/moving_variance")
  126. def forward(self, inputs):
  127. x = self._conv1(inputs)
  128. x = self._bn1(x)
  129. x = self._conv2(x)
  130. x = self._bn2(x)
  131. return x
  132. class Xception_Block(nn.Layer):
  133. def __init__(self,
  134. input_channels,
  135. output_channels,
  136. strides=1,
  137. filter_size=3,
  138. dilation=1,
  139. skip_conv=True,
  140. has_skip=True,
  141. activation_fn_in_separable_conv=False,
  142. name=None):
  143. super(Xception_Block, self).__init__()
  144. repeat_number = 3
  145. output_channels = check_data(output_channels, repeat_number)
  146. filter_size = check_data(filter_size, repeat_number)
  147. strides = check_data(strides, repeat_number)
  148. self.has_skip = has_skip
  149. self.skip_conv = skip_conv
  150. self.activation_fn_in_separable_conv = activation_fn_in_separable_conv
  151. if not activation_fn_in_separable_conv:
  152. self._conv1 = Seperate_Conv(
  153. input_channels,
  154. output_channels[0],
  155. stride=strides[0],
  156. filter=filter_size[0],
  157. dilation=dilation,
  158. name=name + "/separable_conv1")
  159. self._conv2 = Seperate_Conv(
  160. output_channels[0],
  161. output_channels[1],
  162. stride=strides[1],
  163. filter=filter_size[1],
  164. dilation=dilation,
  165. name=name + "/separable_conv2")
  166. self._conv3 = Seperate_Conv(
  167. output_channels[1],
  168. output_channels[2],
  169. stride=strides[2],
  170. filter=filter_size[2],
  171. dilation=dilation,
  172. name=name + "/separable_conv3")
  173. else:
  174. self._conv1 = Seperate_Conv(
  175. input_channels,
  176. output_channels[0],
  177. stride=strides[0],
  178. filter=filter_size[0],
  179. act="relu",
  180. dilation=dilation,
  181. name=name + "/separable_conv1")
  182. self._conv2 = Seperate_Conv(
  183. output_channels[0],
  184. output_channels[1],
  185. stride=strides[1],
  186. filter=filter_size[1],
  187. act="relu",
  188. dilation=dilation,
  189. name=name + "/separable_conv2")
  190. self._conv3 = Seperate_Conv(
  191. output_channels[1],
  192. output_channels[2],
  193. stride=strides[2],
  194. filter=filter_size[2],
  195. act="relu",
  196. dilation=dilation,
  197. name=name + "/separable_conv3")
  198. if has_skip and skip_conv:
  199. self._short = ConvBNLayer(
  200. input_channels,
  201. output_channels[-1],
  202. 1,
  203. stride=strides[-1],
  204. padding=0,
  205. name=name + "/shortcut")
  206. def forward(self, inputs):
  207. if not self.activation_fn_in_separable_conv:
  208. x = F.relu(inputs)
  209. x = self._conv1(x)
  210. x = F.relu(x)
  211. x = self._conv2(x)
  212. x = F.relu(x)
  213. x = self._conv3(x)
  214. else:
  215. x = self._conv1(inputs)
  216. x = self._conv2(x)
  217. x = self._conv3(x)
  218. if self.has_skip:
  219. if self.skip_conv:
  220. skip = self._short(inputs)
  221. else:
  222. skip = inputs
  223. return paddle.add(x, skip)
  224. else:
  225. return x
  226. class XceptionDeeplab(nn.Layer):
  227. def __init__(self, backbone, class_dim=1000):
  228. super(XceptionDeeplab, self).__init__()
  229. bottleneck_params = gen_bottleneck_params(backbone)
  230. self.backbone = backbone
  231. self._conv1 = ConvBNLayer(
  232. 3,
  233. 32,
  234. 3,
  235. stride=2,
  236. padding=1,
  237. act="relu",
  238. name=self.backbone + "/entry_flow/conv1")
  239. self._conv2 = ConvBNLayer(
  240. 32,
  241. 64,
  242. 3,
  243. stride=1,
  244. padding=1,
  245. act="relu",
  246. name=self.backbone + "/entry_flow/conv2")
  247. self.block_num = bottleneck_params["entry_flow"][0]
  248. self.strides = bottleneck_params["entry_flow"][1]
  249. self.chns = bottleneck_params["entry_flow"][2]
  250. self.strides = check_data(self.strides, self.block_num)
  251. self.chns = check_data(self.chns, self.block_num)
  252. self.entry_flow = []
  253. self.middle_flow = []
  254. self.stride = 2
  255. self.output_stride = 32
  256. s = self.stride
  257. for i in range(self.block_num):
  258. stride = self.strides[i] if check_stride(s * self.strides[i],
  259. self.output_stride) else 1
  260. xception_block = self.add_sublayer(
  261. self.backbone + "/entry_flow/block" + str(i + 1),
  262. Xception_Block(
  263. input_channels=64 if i == 0 else self.chns[i - 1],
  264. output_channels=self.chns[i],
  265. strides=[1, 1, self.stride],
  266. name=self.backbone + "/entry_flow/block" + str(i + 1)))
  267. self.entry_flow.append(xception_block)
  268. s = s * stride
  269. self.stride = s
  270. self.block_num = bottleneck_params["middle_flow"][0]
  271. self.strides = bottleneck_params["middle_flow"][1]
  272. self.chns = bottleneck_params["middle_flow"][2]
  273. self.strides = check_data(self.strides, self.block_num)
  274. self.chns = check_data(self.chns, self.block_num)
  275. s = self.stride
  276. for i in range(self.block_num):
  277. stride = self.strides[i] if check_stride(s * self.strides[i],
  278. self.output_stride) else 1
  279. xception_block = self.add_sublayer(
  280. self.backbone + "/middle_flow/block" + str(i + 1),
  281. Xception_Block(
  282. input_channels=728,
  283. output_channels=728,
  284. strides=[1, 1, self.strides[i]],
  285. skip_conv=False,
  286. name=self.backbone + "/middle_flow/block" + str(i + 1)))
  287. self.middle_flow.append(xception_block)
  288. s = s * stride
  289. self.stride = s
  290. self.block_num = bottleneck_params["exit_flow"][0]
  291. self.strides = bottleneck_params["exit_flow"][1]
  292. self.chns = bottleneck_params["exit_flow"][2]
  293. self.strides = check_data(self.strides, self.block_num)
  294. self.chns = check_data(self.chns, self.block_num)
  295. s = self.stride
  296. stride = self.strides[0] if check_stride(s * self.strides[0],
  297. self.output_stride) else 1
  298. self._exit_flow_1 = Xception_Block(
  299. 728,
  300. self.chns[0], [1, 1, stride],
  301. name=self.backbone + "/exit_flow/block1")
  302. s = s * stride
  303. stride = self.strides[1] if check_stride(s * self.strides[1],
  304. self.output_stride) else 1
  305. self._exit_flow_2 = Xception_Block(
  306. self.chns[0][-1],
  307. self.chns[1], [1, 1, stride],
  308. dilation=2,
  309. has_skip=False,
  310. activation_fn_in_separable_conv=True,
  311. name=self.backbone + "/exit_flow/block2")
  312. s = s * stride
  313. self.stride = s
  314. self._drop = Dropout(p=0.5, mode="downscale_in_infer")
  315. self._pool = AdaptiveAvgPool2D(1)
  316. self._fc = Linear(
  317. self.chns[1][-1],
  318. class_dim,
  319. weight_attr=ParamAttr(name="fc_weights"),
  320. bias_attr=ParamAttr(name="fc_bias"))
  321. def forward(self, inputs):
  322. x = self._conv1(inputs)
  323. x = self._conv2(x)
  324. for ef in self.entry_flow:
  325. x = ef(x)
  326. for mf in self.middle_flow:
  327. x = mf(x)
  328. x = self._exit_flow_1(x)
  329. x = self._exit_flow_2(x)
  330. x = self._drop(x)
  331. x = self._pool(x)
  332. x = paddle.squeeze(x, axis=[2, 3])
  333. x = self._fc(x)
  334. return x
  335. def Xception41_deeplab(**args):
  336. model = XceptionDeeplab('xception_41', **args)
  337. return model
  338. def Xception65_deeplab(**args):
  339. model = XceptionDeeplab("xception_65", **args)
  340. return model
  341. def Xception71_deeplab(**args):
  342. model = XceptionDeeplab("xception_71", **args)
  343. return model