fast_scnn.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # coding: utf8
  2. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. from collections import OrderedDict
  19. import paddle.fluid as fluid
  20. from .model_utils.libs import scope
  21. from .model_utils.libs import bn, bn_relu, relu, conv_bn_layer
  22. from .model_utils.libs import conv, avg_pool
  23. from .model_utils.libs import separate_conv
  24. from .model_utils.libs import sigmoid_to_softmax
  25. from .model_utils.loss import softmax_with_loss
  26. from .model_utils.loss import dice_loss
  27. from .model_utils.loss import bce_loss
  28. class FastSCNN(object):
  29. def __init__(self,
  30. num_classes,
  31. mode='train',
  32. use_bce_loss=False,
  33. use_dice_loss=False,
  34. class_weight=None,
  35. multi_loss_weight=[1.0],
  36. ignore_index=255,
  37. fixed_input_shape=None):
  38. # dice_loss或bce_loss只适用两类分割中
  39. if num_classes > 2 and (use_bce_loss or use_dice_loss):
  40. raise ValueError(
  41. "dice loss and bce loss is only applicable to binary classfication"
  42. )
  43. if class_weight is not None:
  44. if isinstance(class_weight, list):
  45. if len(class_weight) != num_classes:
  46. raise ValueError(
  47. "Length of class_weight should be equal to number of classes"
  48. )
  49. elif isinstance(class_weight, str):
  50. if class_weight.lower() != 'dynamic':
  51. raise ValueError(
  52. "if class_weight is string, must be dynamic!")
  53. else:
  54. raise TypeError(
  55. 'Expect class_weight is a list or string but receive {}'.
  56. format(type(class_weight)))
  57. self.num_classes = num_classes
  58. self.mode = mode
  59. self.use_bce_loss = use_bce_loss
  60. self.use_dice_loss = use_dice_loss
  61. self.class_weight = class_weight
  62. self.ignore_index = ignore_index
  63. self.multi_loss_weight = multi_loss_weight
  64. self.fixed_input_shape = fixed_input_shape
  65. def build_net(self, inputs):
  66. if self.use_dice_loss or self.use_bce_loss:
  67. self.num_classes = 1
  68. image = inputs['image']
  69. size = fluid.layers.shape(image)[2:]
  70. with scope('learning_to_downsample'):
  71. higher_res_features = self._learning_to_downsample(image, 32, 48,
  72. 64)
  73. with scope('global_feature_extractor'):
  74. lower_res_feature = self._global_feature_extractor(
  75. higher_res_features, 64, [64, 96, 128], 128, 6, [3, 3, 3])
  76. with scope('feature_fusion'):
  77. x = self._feature_fusion(higher_res_features, lower_res_feature,
  78. 64, 128, 128)
  79. with scope('classifier'):
  80. logit = self._classifier(x, 128)
  81. logit = fluid.layers.resize_bilinear(
  82. logit, size, align_corners=False, align_mode=1)
  83. if len(self.multi_loss_weight) == 3:
  84. with scope('aux_layer_higher'):
  85. higher_logit = self._aux_layer(higher_res_features,
  86. self.num_classes)
  87. higher_logit = fluid.layers.resize_bilinear(
  88. higher_logit, size, align_corners=False, align_mode=1)
  89. with scope('aux_layer_lower'):
  90. lower_logit = self._aux_layer(lower_res_feature,
  91. self.num_classes)
  92. lower_logit = fluid.layers.resize_bilinear(
  93. lower_logit, size, align_corners=False, align_mode=1)
  94. logit = (logit, higher_logit, lower_logit)
  95. elif len(self.multi_loss_weight) == 2:
  96. with scope('aux_layer_higher'):
  97. higher_logit = self._aux_layer(higher_res_features,
  98. self.num_classes)
  99. higher_logit = fluid.layers.resize_bilinear(
  100. higher_logit, size, align_corners=False, align_mode=1)
  101. logit = (logit, higher_logit)
  102. else:
  103. logit = (logit, )
  104. if self.num_classes == 1:
  105. out = sigmoid_to_softmax(logit[0])
  106. out = fluid.layers.transpose(out, [0, 2, 3, 1])
  107. else:
  108. out = fluid.layers.transpose(logit[0], [0, 2, 3, 1])
  109. pred = fluid.layers.argmax(out, axis=3)
  110. pred = fluid.layers.unsqueeze(pred, axes=[3])
  111. if self.mode == 'train':
  112. label = inputs['label']
  113. return self._get_loss(logit, label)
  114. elif self.mode == 'eval':
  115. label = inputs['label']
  116. loss = self._get_loss(logit, label)
  117. return loss, pred, label, mask
  118. else:
  119. if self.num_classes == 1:
  120. logit = sigmoid_to_softmax(logit[0])
  121. else:
  122. logit = fluid.layers.softmax(logit[0], axis=1)
  123. return pred, logit
  124. def generate_inputs(self):
  125. inputs = OrderedDict()
  126. if self.fixed_input_shape is not None:
  127. input_shape = [
  128. None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
  129. ]
  130. inputs['image'] = fluid.data(
  131. dtype='float32', shape=input_shape, name='image')
  132. else:
  133. inputs['image'] = fluid.data(
  134. dtype='float32', shape=[None, 3, None, None], name='image')
  135. if self.mode == 'train':
  136. inputs['label'] = fluid.data(
  137. dtype='int32', shape=[None, 1, None, None], name='label')
  138. elif self.mode == 'eval':
  139. inputs['label'] = fluid.data(
  140. dtype='int32', shape=[None, 1, None, None], name='label')
  141. return inputs
  142. def _get_loss(self, logits, label):
  143. avg_loss = 0
  144. if not (self.use_dice_loss or self.use_bce_loss):
  145. for i, logit in enumerate(logits):
  146. logit_mask = (
  147. label.astype('int32') != self.ignore_index).astype('int32')
  148. loss = softmax_with_loss(
  149. logit,
  150. label,
  151. logit_mask,
  152. num_classes=self.num_classes,
  153. weight=self.class_weight,
  154. ignore_index=self.ignore_index)
  155. avg_loss += self.multi_loss_weight[i] * loss
  156. else:
  157. if self.use_dice_loss:
  158. for i, logit in enumerate(logits):
  159. logit_mask = (label.astype('int32') != self.ignore_index
  160. ).astype('int32')
  161. loss = dice_loss(logit, label, logit_mask)
  162. avg_loss += self.multi_loss_weight[i] * loss
  163. if self.use_bce_loss:
  164. for i, logit in enumerate(logits):
  165. #logit_label = fluid.layers.resize_nearest(label, logit_shape[2:])
  166. logit_mask = (label.astype('int32') != self.ignore_index
  167. ).astype('int32')
  168. loss = bce_loss(
  169. logit,
  170. label,
  171. logit_mask,
  172. ignore_index=self.ignore_index)
  173. avg_loss += self.multi_loss_weight[i] * loss
  174. return avg_loss
  175. def _learning_to_downsample(self,
  176. x,
  177. dw_channels1=32,
  178. dw_channels2=48,
  179. out_channels=64):
  180. x = relu(bn(conv(x, dw_channels1, 3, 2)))
  181. with scope('dsconv1'):
  182. x = separate_conv(
  183. x, dw_channels2, stride=2, filter=3, act=fluid.layers.relu)
  184. with scope('dsconv2'):
  185. x = separate_conv(
  186. x, out_channels, stride=2, filter=3, act=fluid.layers.relu)
  187. return x
  188. def _shortcut(self, input, data_residual):
  189. return fluid.layers.elementwise_add(input, data_residual)
  190. def _dropout2d(self, input, prob, is_train=False):
  191. if not is_train:
  192. return input
  193. keep_prob = 1.0 - prob
  194. shape = fluid.layers.shape(input)
  195. channels = shape[1]
  196. random_tensor = keep_prob + fluid.layers.uniform_random(
  197. [shape[0], channels, 1, 1], min=0., max=1.)
  198. binary_tensor = fluid.layers.floor(random_tensor)
  199. output = input / keep_prob * binary_tensor
  200. return output
  201. def _inverted_residual_unit(self,
  202. input,
  203. num_in_filter,
  204. num_filters,
  205. ifshortcut,
  206. stride,
  207. filter_size,
  208. padding,
  209. expansion_factor,
  210. name=None):
  211. num_expfilter = int(round(num_in_filter * expansion_factor))
  212. channel_expand = conv_bn_layer(
  213. input=input,
  214. num_filters=num_expfilter,
  215. filter_size=1,
  216. stride=1,
  217. padding=0,
  218. num_groups=1,
  219. if_act=True,
  220. name=name + '_expand')
  221. bottleneck_conv = conv_bn_layer(
  222. input=channel_expand,
  223. num_filters=num_expfilter,
  224. filter_size=filter_size,
  225. stride=stride,
  226. padding=padding,
  227. num_groups=num_expfilter,
  228. if_act=True,
  229. name=name + '_dwise',
  230. use_cudnn=False)
  231. depthwise_output = bottleneck_conv
  232. linear_out = conv_bn_layer(
  233. input=bottleneck_conv,
  234. num_filters=num_filters,
  235. filter_size=1,
  236. stride=1,
  237. padding=0,
  238. num_groups=1,
  239. if_act=False,
  240. name=name + '_linear')
  241. if ifshortcut:
  242. out = self._shortcut(input=input, data_residual=linear_out)
  243. return out, depthwise_output
  244. else:
  245. return linear_out, depthwise_output
  246. def _inverted_blocks(self, input, in_c, t, c, n, s, name=None):
  247. first_block, depthwise_output = self._inverted_residual_unit(
  248. input=input,
  249. num_in_filter=in_c,
  250. num_filters=c,
  251. ifshortcut=False,
  252. stride=s,
  253. filter_size=3,
  254. padding=1,
  255. expansion_factor=t,
  256. name=name + '_1')
  257. last_residual_block = first_block
  258. last_c = c
  259. for i in range(1, n):
  260. last_residual_block, depthwise_output = self._inverted_residual_unit(
  261. input=last_residual_block,
  262. num_in_filter=last_c,
  263. num_filters=c,
  264. ifshortcut=True,
  265. stride=1,
  266. filter_size=3,
  267. padding=1,
  268. expansion_factor=t,
  269. name=name + '_' + str(i + 1))
  270. return last_residual_block, depthwise_output
  271. def _psp_module(self, input, out_features):
  272. cat_layers = []
  273. sizes = (1, 2, 3, 6)
  274. for size in sizes:
  275. psp_name = "psp" + str(size)
  276. with scope(psp_name):
  277. pool = fluid.layers.adaptive_pool2d(
  278. input,
  279. pool_size=[size, size],
  280. pool_type='avg',
  281. name=psp_name + '_adapool')
  282. data = conv(
  283. pool,
  284. out_features,
  285. filter_size=1,
  286. bias_attr=False,
  287. name=psp_name + '_conv')
  288. data_bn = bn(data, act='relu')
  289. interp = fluid.layers.resize_bilinear(
  290. data_bn,
  291. out_shape=fluid.layers.shape(input)[2:],
  292. name=psp_name + '_interp',
  293. align_corners=False,
  294. align_mode=1)
  295. cat_layers.append(interp)
  296. cat_layers = [input] + cat_layers
  297. out = fluid.layers.concat(cat_layers, axis=1, name='psp_cat')
  298. return out
  299. def _aux_layer(self, x, num_classes):
  300. x = relu(bn(conv(x, 32, 3, padding=1)))
  301. x = self._dropout2d(x, 0.1, is_train=(self.mode == 'train'))
  302. with scope('logit'):
  303. x = conv(x, num_classes, 1, bias_attr=True)
  304. return x
  305. def _feature_fusion(self,
  306. higher_res_feature,
  307. lower_res_feature,
  308. higher_in_channels,
  309. lower_in_channels,
  310. out_channels,
  311. scale_factor=4):
  312. shape = fluid.layers.shape(higher_res_feature)
  313. w = shape[-1]
  314. h = shape[-2]
  315. lower_res_feature = fluid.layers.resize_bilinear(
  316. lower_res_feature, [h, w], align_corners=False, align_mode=1)
  317. with scope('dwconv'):
  318. lower_res_feature = relu(
  319. bn(conv(lower_res_feature, out_channels,
  320. 1))) #(lower_res_feature)
  321. with scope('conv_lower_res'):
  322. lower_res_feature = bn(
  323. conv(
  324. lower_res_feature, out_channels, 1, bias_attr=True))
  325. with scope('conv_higher_res'):
  326. higher_res_feature = bn(
  327. conv(
  328. higher_res_feature, out_channels, 1, bias_attr=True))
  329. out = higher_res_feature + lower_res_feature
  330. return relu(out)
  331. def _global_feature_extractor(self,
  332. x,
  333. in_channels=64,
  334. block_channels=(64, 96, 128),
  335. out_channels=128,
  336. t=6,
  337. num_blocks=(3, 3, 3)):
  338. x, _ = self._inverted_blocks(x, in_channels, t, block_channels[0],
  339. num_blocks[0], 2, 'inverted_block_1')
  340. x, _ = self._inverted_blocks(x, block_channels[0], t,
  341. block_channels[1], num_blocks[1], 2,
  342. 'inverted_block_2')
  343. x, _ = self._inverted_blocks(x, block_channels[1], t,
  344. block_channels[2], num_blocks[2], 1,
  345. 'inverted_block_3')
  346. x = self._psp_module(x, block_channels[2] // 4)
  347. with scope('out'):
  348. x = relu(bn(conv(x, out_channels, 1)))
  349. return x
  350. def _classifier(self, x, dw_channels, stride=1):
  351. with scope('dsconv1'):
  352. x = separate_conv(
  353. x, dw_channels, stride=stride, filter=3, act=fluid.layers.relu)
  354. with scope('dsconv2'):
  355. x = separate_conv(
  356. x, dw_channels, stride=stride, filter=3, act=fluid.layers.relu)
  357. x = self._dropout2d(x, 0.1, is_train=self.mode == 'train')
  358. x = conv(x, self.num_classes, 1, bias_attr=True)
  359. return x