fast_scnn.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. # coding: utf8
  2. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. from collections import OrderedDict
  19. import paddle.fluid as fluid
  20. from .model_utils.libs import scope
  21. from .model_utils.libs import bn, bn_relu, relu, conv_bn_layer
  22. from .model_utils.libs import conv, avg_pool
  23. from .model_utils.libs import separate_conv
  24. from .model_utils.libs import sigmoid_to_softmax
  25. from .model_utils.loss import softmax_with_loss
  26. from .model_utils.loss import dice_loss
  27. from .model_utils.loss import bce_loss
  28. class FastSCNN(object):
  29. def __init__(self,
  30. num_classes,
  31. mode='train',
  32. use_bce_loss=False,
  33. use_dice_loss=False,
  34. class_weight=None,
  35. multi_loss_weight=[1.0],
  36. ignore_index=255,
  37. fixed_input_shape=None):
  38. # dice_loss或bce_loss只适用两类分割中
  39. if num_classes > 2 and (use_bce_loss or use_dice_loss):
  40. raise ValueError(
  41. "dice loss and bce loss is only applicable to binary classfication"
  42. )
  43. if class_weight is not None:
  44. if isinstance(class_weight, list):
  45. if len(class_weight) != num_classes:
  46. raise ValueError(
  47. "Length of class_weight should be equal to number of classes"
  48. )
  49. elif isinstance(class_weight, str):
  50. if class_weight.lower() != 'dynamic':
  51. raise ValueError(
  52. "if class_weight is string, must be dynamic!")
  53. else:
  54. raise TypeError(
  55. 'Expect class_weight is a list or string but receive {}'.
  56. format(type(class_weight)))
  57. self.num_classes = num_classes
  58. self.mode = mode
  59. self.use_bce_loss = use_bce_loss
  60. self.use_dice_loss = use_dice_loss
  61. self.class_weight = class_weight
  62. self.ignore_index = ignore_index
  63. self.multi_loss_weight = multi_loss_weight
  64. self.fixed_input_shape = fixed_input_shape
  65. def build_net(self, inputs):
  66. if self.use_dice_loss or self.use_bce_loss:
  67. self.num_classes = 1
  68. image = inputs['image']
  69. size = fluid.layers.shape(image)[2:]
  70. with scope('learning_to_downsample'):
  71. higher_res_features = self._learning_to_downsample(image, 32, 48,
  72. 64)
  73. with scope('global_feature_extractor'):
  74. lower_res_feature = self._global_feature_extractor(
  75. higher_res_features, 64, [64, 96, 128], 128, 6, [3, 3, 3])
  76. with scope('feature_fusion'):
  77. x = self._feature_fusion(higher_res_features, lower_res_feature,
  78. 64, 128, 128)
  79. with scope('classifier'):
  80. logit = self._classifier(x, 128)
  81. logit = fluid.layers.resize_bilinear(logit, size, align_mode=0)
  82. if len(self.multi_loss_weight) == 3:
  83. with scope('aux_layer_higher'):
  84. higher_logit = self._aux_layer(higher_res_features,
  85. self.num_classes)
  86. higher_logit = fluid.layers.resize_bilinear(
  87. higher_logit, size, align_mode=0)
  88. with scope('aux_layer_lower'):
  89. lower_logit = self._aux_layer(lower_res_feature,
  90. self.num_classes)
  91. lower_logit = fluid.layers.resize_bilinear(
  92. lower_logit, size, align_mode=0)
  93. logit = (logit, higher_logit, lower_logit)
  94. elif len(self.multi_loss_weight) == 2:
  95. with scope('aux_layer_higher'):
  96. higher_logit = self._aux_layer(higher_res_features,
  97. self.num_classes)
  98. higher_logit = fluid.layers.resize_bilinear(
  99. higher_logit, size, align_mode=0)
  100. logit = (logit, higher_logit)
  101. else:
  102. logit = (logit, )
  103. if self.num_classes == 1:
  104. out = sigmoid_to_softmax(logit[0])
  105. out = fluid.layers.transpose(out, [0, 2, 3, 1])
  106. else:
  107. out = fluid.layers.transpose(logit[0], [0, 2, 3, 1])
  108. pred = fluid.layers.argmax(out, axis=3)
  109. pred = fluid.layers.unsqueeze(pred, axes=[3])
  110. if self.mode == 'train':
  111. label = inputs['label']
  112. return self._get_loss(logit, label)
  113. elif self.mode == 'eval':
  114. label = inputs['label']
  115. loss = self._get_loss(logit, label)
  116. return loss, pred, label, mask
  117. else:
  118. if self.num_classes == 1:
  119. logit = sigmoid_to_softmax(logit[0])
  120. else:
  121. logit = fluid.layers.softmax(logit[0], axis=1)
  122. return pred, logit
  123. def generate_inputs(self):
  124. inputs = OrderedDict()
  125. if self.fixed_input_shape is not None:
  126. input_shape = [
  127. None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
  128. ]
  129. inputs['image'] = fluid.data(
  130. dtype='float32', shape=input_shape, name='image')
  131. else:
  132. inputs['image'] = fluid.data(
  133. dtype='float32', shape=[None, 3, None, None], name='image')
  134. if self.mode == 'train':
  135. inputs['label'] = fluid.data(
  136. dtype='int32', shape=[None, 1, None, None], name='label')
  137. elif self.mode == 'eval':
  138. inputs['label'] = fluid.data(
  139. dtype='int32', shape=[None, 1, None, None], name='label')
  140. return inputs
  141. def _get_loss(self, logits, label):
  142. avg_loss = 0
  143. if not (self.use_dice_loss or self.use_bce_loss):
  144. for i, logit in enumerate(logits):
  145. logit_mask = (
  146. label.astype('int32') != self.ignore_index).astype('int32')
  147. loss = softmax_with_loss(
  148. logit,
  149. label,
  150. logit_mask,
  151. num_classes=self.num_classes,
  152. weight=self.class_weight,
  153. ignore_index=self.ignore_index)
  154. avg_loss += self.multi_loss_weight[i] * loss
  155. else:
  156. if self.use_dice_loss:
  157. for i, logit in enumerate(logits):
  158. logit_mask = (label.astype('int32') != self.ignore_index
  159. ).astype('int32')
  160. loss = dice_loss(logit, label, logit_mask)
  161. avg_loss += self.multi_loss_weight[i] * loss
  162. if self.use_bce_loss:
  163. for i, logit in enumerate(logits):
  164. #logit_label = fluid.layers.resize_nearest(label, logit_shape[2:])
  165. logit_mask = (label.astype('int32') != self.ignore_index
  166. ).astype('int32')
  167. loss = bce_loss(
  168. logit,
  169. label,
  170. logit_mask,
  171. ignore_index=self.ignore_index)
  172. avg_loss += self.multi_loss_weight[i] * loss
  173. return avg_loss
  174. def _learning_to_downsample(self,
  175. x,
  176. dw_channels1=32,
  177. dw_channels2=48,
  178. out_channels=64):
  179. x = relu(bn(conv(x, dw_channels1, 3, 2)))
  180. with scope('dsconv1'):
  181. x = separate_conv(
  182. x, dw_channels2, stride=2, filter=3, act=fluid.layers.relu)
  183. with scope('dsconv2'):
  184. x = separate_conv(
  185. x, out_channels, stride=2, filter=3, act=fluid.layers.relu)
  186. return x
  187. def _shortcut(self, input, data_residual):
  188. return fluid.layers.elementwise_add(input, data_residual)
  189. def _dropout2d(self, input, prob, is_train=False):
  190. if not is_train:
  191. return input
  192. keep_prob = 1.0 - prob
  193. shape = fluid.layers.shape(input)
  194. channels = shape[1]
  195. random_tensor = keep_prob + fluid.layers.uniform_random(
  196. [shape[0], channels, 1, 1], min=0., max=1.)
  197. binary_tensor = fluid.layers.floor(random_tensor)
  198. output = input / keep_prob * binary_tensor
  199. return output
  200. def _inverted_residual_unit(self,
  201. input,
  202. num_in_filter,
  203. num_filters,
  204. ifshortcut,
  205. stride,
  206. filter_size,
  207. padding,
  208. expansion_factor,
  209. name=None):
  210. num_expfilter = int(round(num_in_filter * expansion_factor))
  211. channel_expand = conv_bn_layer(
  212. input=input,
  213. num_filters=num_expfilter,
  214. filter_size=1,
  215. stride=1,
  216. padding=0,
  217. num_groups=1,
  218. if_act=True,
  219. name=name + '_expand')
  220. bottleneck_conv = conv_bn_layer(
  221. input=channel_expand,
  222. num_filters=num_expfilter,
  223. filter_size=filter_size,
  224. stride=stride,
  225. padding=padding,
  226. num_groups=num_expfilter,
  227. if_act=True,
  228. name=name + '_dwise',
  229. use_cudnn=False)
  230. depthwise_output = bottleneck_conv
  231. linear_out = conv_bn_layer(
  232. input=bottleneck_conv,
  233. num_filters=num_filters,
  234. filter_size=1,
  235. stride=1,
  236. padding=0,
  237. num_groups=1,
  238. if_act=False,
  239. name=name + '_linear')
  240. if ifshortcut:
  241. out = self._shortcut(input=input, data_residual=linear_out)
  242. return out, depthwise_output
  243. else:
  244. return linear_out, depthwise_output
  245. def _inverted_blocks(self, input, in_c, t, c, n, s, name=None):
  246. first_block, depthwise_output = self._inverted_residual_unit(
  247. input=input,
  248. num_in_filter=in_c,
  249. num_filters=c,
  250. ifshortcut=False,
  251. stride=s,
  252. filter_size=3,
  253. padding=1,
  254. expansion_factor=t,
  255. name=name + '_1')
  256. last_residual_block = first_block
  257. last_c = c
  258. for i in range(1, n):
  259. last_residual_block, depthwise_output = self._inverted_residual_unit(
  260. input=last_residual_block,
  261. num_in_filter=last_c,
  262. num_filters=c,
  263. ifshortcut=True,
  264. stride=1,
  265. filter_size=3,
  266. padding=1,
  267. expansion_factor=t,
  268. name=name + '_' + str(i + 1))
  269. return last_residual_block, depthwise_output
  270. def _psp_module(self, input, out_features):
  271. cat_layers = []
  272. sizes = (1, 2, 3, 6)
  273. for size in sizes:
  274. psp_name = "psp" + str(size)
  275. with scope(psp_name):
  276. pool = fluid.layers.adaptive_pool2d(
  277. input,
  278. pool_size=[size, size],
  279. pool_type='avg',
  280. name=psp_name + '_adapool')
  281. data = conv(
  282. pool,
  283. out_features,
  284. filter_size=1,
  285. bias_attr=False,
  286. name=psp_name + '_conv')
  287. data_bn = bn(data, act='relu')
  288. interp = fluid.layers.resize_bilinear(
  289. data_bn,
  290. out_shape=fluid.layers.shape(input)[2:],
  291. name=psp_name + '_interp',
  292. align_mode=0)
  293. cat_layers.append(interp)
  294. cat_layers = [input] + cat_layers
  295. out = fluid.layers.concat(cat_layers, axis=1, name='psp_cat')
  296. return out
  297. def _aux_layer(self, x, num_classes):
  298. x = relu(bn(conv(x, 32, 3, padding=1)))
  299. x = self._dropout2d(x, 0.1, is_train=(self.mode == 'train'))
  300. with scope('logit'):
  301. x = conv(x, num_classes, 1, bias_attr=True)
  302. return x
  303. def _feature_fusion(self,
  304. higher_res_feature,
  305. lower_res_feature,
  306. higher_in_channels,
  307. lower_in_channels,
  308. out_channels,
  309. scale_factor=4):
  310. shape = fluid.layers.shape(higher_res_feature)
  311. w = shape[-1]
  312. h = shape[-2]
  313. lower_res_feature = fluid.layers.resize_bilinear(
  314. lower_res_feature, [h, w], align_mode=0)
  315. with scope('dwconv'):
  316. lower_res_feature = relu(
  317. bn(conv(lower_res_feature, out_channels,
  318. 1))) #(lower_res_feature)
  319. with scope('conv_lower_res'):
  320. lower_res_feature = bn(
  321. conv(
  322. lower_res_feature, out_channels, 1, bias_attr=True))
  323. with scope('conv_higher_res'):
  324. higher_res_feature = bn(
  325. conv(
  326. higher_res_feature, out_channels, 1, bias_attr=True))
  327. out = higher_res_feature + lower_res_feature
  328. return relu(out)
  329. def _global_feature_extractor(self,
  330. x,
  331. in_channels=64,
  332. block_channels=(64, 96, 128),
  333. out_channels=128,
  334. t=6,
  335. num_blocks=(3, 3, 3)):
  336. x, _ = self._inverted_blocks(x, in_channels, t, block_channels[0],
  337. num_blocks[0], 2, 'inverted_block_1')
  338. x, _ = self._inverted_blocks(x, block_channels[0], t,
  339. block_channels[1], num_blocks[1], 2,
  340. 'inverted_block_2')
  341. x, _ = self._inverted_blocks(x, block_channels[1], t,
  342. block_channels[2], num_blocks[2], 1,
  343. 'inverted_block_3')
  344. x = self._psp_module(x, block_channels[2] // 4)
  345. with scope('out'):
  346. x = relu(bn(conv(x, out_channels, 1)))
  347. return x
  348. def _classifier(self, x, dw_channels, stride=1):
  349. with scope('dsconv1'):
  350. x = separate_conv(
  351. x, dw_channels, stride=stride, filter=3, act=fluid.layers.relu)
  352. with scope('dsconv2'):
  353. x = separate_conv(
  354. x, dw_channels, stride=stride, filter=3, act=fluid.layers.relu)
  355. x = self._dropout2d(x, 0.1, is_train=self.mode == 'train')
  356. x = conv(x, self.num_classes, 1, bias_attr=True)
  357. return x