fast_scnn.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. # coding: utf8
  2. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. from collections import OrderedDict
  19. import paddle.fluid as fluid
  20. from .model_utils.libs import scope
  21. from .model_utils.libs import bn, bn_relu, relu, conv_bn_layer
  22. from .model_utils.libs import conv, avg_pool
  23. from .model_utils.libs import separate_conv
  24. from .model_utils.libs import sigmoid_to_softmax
  25. from .model_utils.loss import softmax_with_loss
  26. from .model_utils.loss import dice_loss
  27. from .model_utils.loss import bce_loss
  28. class FastSCNN(object):
  29. def __init__(self,
  30. num_classes,
  31. input_channel=3,
  32. mode='train',
  33. use_bce_loss=False,
  34. use_dice_loss=False,
  35. class_weight=None,
  36. multi_loss_weight=[1.0],
  37. ignore_index=255,
  38. fixed_input_shape=None):
  39. # dice_loss或bce_loss只适用两类分割中
  40. if num_classes > 2 and (use_bce_loss or use_dice_loss):
  41. raise ValueError(
  42. "dice loss and bce loss is only applicable to binary classfication"
  43. )
  44. if class_weight is not None:
  45. if isinstance(class_weight, list):
  46. if len(class_weight) != num_classes:
  47. raise ValueError(
  48. "Length of class_weight should be equal to number of classes"
  49. )
  50. elif isinstance(class_weight, str):
  51. if class_weight.lower() != 'dynamic':
  52. raise ValueError(
  53. "if class_weight is string, must be dynamic!")
  54. else:
  55. raise TypeError(
  56. 'Expect class_weight is a list or string but receive {}'.
  57. format(type(class_weight)))
  58. self.num_classes = num_classes
  59. self.input_channel = input_channel
  60. self.mode = mode
  61. self.use_bce_loss = use_bce_loss
  62. self.use_dice_loss = use_dice_loss
  63. self.class_weight = class_weight
  64. self.ignore_index = ignore_index
  65. self.multi_loss_weight = multi_loss_weight
  66. self.fixed_input_shape = fixed_input_shape
  67. def build_net(self, inputs):
  68. if self.use_dice_loss or self.use_bce_loss:
  69. self.num_classes = 1
  70. image = inputs['image']
  71. size = fluid.layers.shape(image)[2:]
  72. with scope('learning_to_downsample'):
  73. higher_res_features = self._learning_to_downsample(image, 32, 48,
  74. 64)
  75. with scope('global_feature_extractor'):
  76. lower_res_feature = self._global_feature_extractor(
  77. higher_res_features, 64, [64, 96, 128], 128, 6, [3, 3, 3])
  78. with scope('feature_fusion'):
  79. x = self._feature_fusion(higher_res_features, lower_res_feature,
  80. 64, 128, 128)
  81. with scope('classifier'):
  82. logit = self._classifier(x, 128)
  83. logit = fluid.layers.resize_bilinear(logit, size, align_mode=0)
  84. if len(self.multi_loss_weight) == 3:
  85. with scope('aux_layer_higher'):
  86. higher_logit = self._aux_layer(higher_res_features,
  87. self.num_classes)
  88. higher_logit = fluid.layers.resize_bilinear(
  89. higher_logit, size, align_mode=0)
  90. with scope('aux_layer_lower'):
  91. lower_logit = self._aux_layer(lower_res_feature,
  92. self.num_classes)
  93. lower_logit = fluid.layers.resize_bilinear(
  94. lower_logit, size, align_mode=0)
  95. logit = (logit, higher_logit, lower_logit)
  96. elif len(self.multi_loss_weight) == 2:
  97. with scope('aux_layer_higher'):
  98. higher_logit = self._aux_layer(higher_res_features,
  99. self.num_classes)
  100. higher_logit = fluid.layers.resize_bilinear(
  101. higher_logit, size, align_mode=0)
  102. logit = (logit, higher_logit)
  103. else:
  104. logit = (logit, )
  105. if self.num_classes == 1:
  106. out = sigmoid_to_softmax(logit[0])
  107. out = fluid.layers.transpose(out, [0, 2, 3, 1])
  108. else:
  109. out = fluid.layers.transpose(logit[0], [0, 2, 3, 1])
  110. pred = fluid.layers.argmax(out, axis=3)
  111. pred = fluid.layers.unsqueeze(pred, axes=[3])
  112. if self.mode == 'train':
  113. label = inputs['label']
  114. return self._get_loss(logit, label)
  115. elif self.mode == 'eval':
  116. label = inputs['label']
  117. loss = self._get_loss(logit, label)
  118. return loss, pred, label, mask
  119. else:
  120. if self.num_classes == 1:
  121. logit = sigmoid_to_softmax(logit[0])
  122. else:
  123. logit = fluid.layers.softmax(logit[0], axis=1)
  124. return pred, logit
  125. def generate_inputs(self):
  126. inputs = OrderedDict()
  127. if self.fixed_input_shape is not None:
  128. input_shape = [
  129. None, self.input_channel, self.fixed_input_shape[1],
  130. self.fixed_input_shape[0]
  131. ]
  132. inputs['image'] = fluid.data(
  133. dtype='float32', shape=input_shape, name='image')
  134. else:
  135. inputs['image'] = fluid.data(
  136. dtype='float32',
  137. shape=[None, self.input_channel, None, None],
  138. name='image')
  139. if self.mode == 'train':
  140. inputs['label'] = fluid.data(
  141. dtype='int32', shape=[None, 1, None, None], name='label')
  142. elif self.mode == 'eval':
  143. inputs['label'] = fluid.data(
  144. dtype='int32', shape=[None, 1, None, None], name='label')
  145. return inputs
  146. def _get_loss(self, logits, label):
  147. avg_loss = 0
  148. if not (self.use_dice_loss or self.use_bce_loss):
  149. for i, logit in enumerate(logits):
  150. logit_mask = (
  151. label.astype('int32') != self.ignore_index).astype('int32')
  152. loss = softmax_with_loss(
  153. logit,
  154. label,
  155. logit_mask,
  156. num_classes=self.num_classes,
  157. weight=self.class_weight,
  158. ignore_index=self.ignore_index)
  159. avg_loss += self.multi_loss_weight[i] * loss
  160. else:
  161. if self.use_dice_loss:
  162. for i, logit in enumerate(logits):
  163. logit_mask = (label.astype('int32') != self.ignore_index
  164. ).astype('int32')
  165. loss = dice_loss(logit, label, logit_mask)
  166. avg_loss += self.multi_loss_weight[i] * loss
  167. if self.use_bce_loss:
  168. for i, logit in enumerate(logits):
  169. #logit_label = fluid.layers.resize_nearest(label, logit_shape[2:])
  170. logit_mask = (label.astype('int32') != self.ignore_index
  171. ).astype('int32')
  172. loss = bce_loss(
  173. logit,
  174. label,
  175. logit_mask,
  176. ignore_index=self.ignore_index)
  177. avg_loss += self.multi_loss_weight[i] * loss
  178. return avg_loss
  179. def _learning_to_downsample(self,
  180. x,
  181. dw_channels1=32,
  182. dw_channels2=48,
  183. out_channels=64):
  184. x = relu(bn(conv(x, dw_channels1, 3, 2)))
  185. with scope('dsconv1'):
  186. x = separate_conv(
  187. x, dw_channels2, stride=2, filter=3, act=fluid.layers.relu)
  188. with scope('dsconv2'):
  189. x = separate_conv(
  190. x, out_channels, stride=2, filter=3, act=fluid.layers.relu)
  191. return x
  192. def _shortcut(self, input, data_residual):
  193. return fluid.layers.elementwise_add(input, data_residual)
  194. def _dropout2d(self, input, prob, is_train=False):
  195. if not is_train:
  196. return input
  197. keep_prob = 1.0 - prob
  198. shape = fluid.layers.shape(input)
  199. channels = shape[1]
  200. random_tensor = keep_prob + fluid.layers.uniform_random(
  201. [shape[0], channels, 1, 1], min=0., max=1.)
  202. binary_tensor = fluid.layers.floor(random_tensor)
  203. output = input / keep_prob * binary_tensor
  204. return output
  205. def _inverted_residual_unit(self,
  206. input,
  207. num_in_filter,
  208. num_filters,
  209. ifshortcut,
  210. stride,
  211. filter_size,
  212. padding,
  213. expansion_factor,
  214. name=None):
  215. num_expfilter = int(round(num_in_filter * expansion_factor))
  216. channel_expand = conv_bn_layer(
  217. input=input,
  218. num_filters=num_expfilter,
  219. filter_size=1,
  220. stride=1,
  221. padding=0,
  222. num_groups=1,
  223. if_act=True,
  224. name=name + '_expand')
  225. bottleneck_conv = conv_bn_layer(
  226. input=channel_expand,
  227. num_filters=num_expfilter,
  228. filter_size=filter_size,
  229. stride=stride,
  230. padding=padding,
  231. num_groups=num_expfilter,
  232. if_act=True,
  233. name=name + '_dwise',
  234. use_cudnn=False)
  235. depthwise_output = bottleneck_conv
  236. linear_out = conv_bn_layer(
  237. input=bottleneck_conv,
  238. num_filters=num_filters,
  239. filter_size=1,
  240. stride=1,
  241. padding=0,
  242. num_groups=1,
  243. if_act=False,
  244. name=name + '_linear')
  245. if ifshortcut:
  246. out = self._shortcut(input=input, data_residual=linear_out)
  247. return out, depthwise_output
  248. else:
  249. return linear_out, depthwise_output
  250. def _inverted_blocks(self, input, in_c, t, c, n, s, name=None):
  251. first_block, depthwise_output = self._inverted_residual_unit(
  252. input=input,
  253. num_in_filter=in_c,
  254. num_filters=c,
  255. ifshortcut=False,
  256. stride=s,
  257. filter_size=3,
  258. padding=1,
  259. expansion_factor=t,
  260. name=name + '_1')
  261. last_residual_block = first_block
  262. last_c = c
  263. for i in range(1, n):
  264. last_residual_block, depthwise_output = self._inverted_residual_unit(
  265. input=last_residual_block,
  266. num_in_filter=last_c,
  267. num_filters=c,
  268. ifshortcut=True,
  269. stride=1,
  270. filter_size=3,
  271. padding=1,
  272. expansion_factor=t,
  273. name=name + '_' + str(i + 1))
  274. return last_residual_block, depthwise_output
  275. def _psp_module(self, input, out_features):
  276. cat_layers = []
  277. sizes = (1, 2, 3, 6)
  278. for size in sizes:
  279. psp_name = "psp" + str(size)
  280. with scope(psp_name):
  281. pool = fluid.layers.adaptive_pool2d(
  282. input,
  283. pool_size=[size, size],
  284. pool_type='avg',
  285. name=psp_name + '_adapool')
  286. data = conv(
  287. pool,
  288. out_features,
  289. filter_size=1,
  290. bias_attr=False,
  291. name=psp_name + '_conv')
  292. data_bn = bn(data, act='relu')
  293. interp = fluid.layers.resize_bilinear(
  294. data_bn,
  295. out_shape=fluid.layers.shape(input)[2:],
  296. name=psp_name + '_interp',
  297. align_mode=0)
  298. cat_layers.append(interp)
  299. cat_layers = [input] + cat_layers
  300. out = fluid.layers.concat(cat_layers, axis=1, name='psp_cat')
  301. return out
  302. def _aux_layer(self, x, num_classes):
  303. x = relu(bn(conv(x, 32, 3, padding=1)))
  304. x = self._dropout2d(x, 0.1, is_train=(self.mode == 'train'))
  305. with scope('logit'):
  306. x = conv(x, num_classes, 1, bias_attr=True)
  307. return x
  308. def _feature_fusion(self,
  309. higher_res_feature,
  310. lower_res_feature,
  311. higher_in_channels,
  312. lower_in_channels,
  313. out_channels,
  314. scale_factor=4):
  315. shape = fluid.layers.shape(higher_res_feature)
  316. w = shape[-1]
  317. h = shape[-2]
  318. lower_res_feature = fluid.layers.resize_bilinear(
  319. lower_res_feature, [h, w], align_mode=0)
  320. with scope('dwconv'):
  321. lower_res_feature = relu(
  322. bn(conv(lower_res_feature, out_channels,
  323. 1))) #(lower_res_feature)
  324. with scope('conv_lower_res'):
  325. lower_res_feature = bn(
  326. conv(
  327. lower_res_feature, out_channels, 1, bias_attr=True))
  328. with scope('conv_higher_res'):
  329. higher_res_feature = bn(
  330. conv(
  331. higher_res_feature, out_channels, 1, bias_attr=True))
  332. out = higher_res_feature + lower_res_feature
  333. return relu(out)
  334. def _global_feature_extractor(self,
  335. x,
  336. in_channels=64,
  337. block_channels=(64, 96, 128),
  338. out_channels=128,
  339. t=6,
  340. num_blocks=(3, 3, 3)):
  341. x, _ = self._inverted_blocks(x, in_channels, t, block_channels[0],
  342. num_blocks[0], 2, 'inverted_block_1')
  343. x, _ = self._inverted_blocks(x, block_channels[0], t,
  344. block_channels[1], num_blocks[1], 2,
  345. 'inverted_block_2')
  346. x, _ = self._inverted_blocks(x, block_channels[1], t,
  347. block_channels[2], num_blocks[2], 1,
  348. 'inverted_block_3')
  349. x = self._psp_module(x, block_channels[2] // 4)
  350. with scope('out'):
  351. x = relu(bn(conv(x, out_channels, 1)))
  352. return x
  353. def _classifier(self, x, dw_channels, stride=1):
  354. with scope('dsconv1'):
  355. x = separate_conv(
  356. x, dw_channels, stride=stride, filter=3, act=fluid.layers.relu)
  357. with scope('dsconv2'):
  358. x = separate_conv(
  359. x, dw_channels, stride=stride, filter=3, act=fluid.layers.relu)
  360. x = self._dropout2d(x, 0.1, is_train=self.mode == 'train')
  361. x = conv(x, self.num_classes, 1, bias_attr=True)
  362. return x