unet.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # coding: utf8
  2. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. from collections import OrderedDict
  19. import paddle.fluid as fluid
  20. from .model_utils.libs import scope, name_scope
  21. from .model_utils.libs import bn, bn_relu, relu
  22. from .model_utils.libs import conv, max_pool, deconv
  23. from .model_utils.libs import sigmoid_to_softmax
  24. from .model_utils.loss import softmax_with_loss
  25. from .model_utils.loss import dice_loss
  26. from .model_utils.loss import bce_loss
  27. class UNet(object):
  28. """实现Unet模型
  29. `"U-Net: Convolutional Networks for Biomedical Image Segmentation"
  30. <https://arxiv.org/abs/1505.04597>`
  31. Args:
  32. num_classes (int): 类别数
  33. mode (str): 网络运行模式,根据mode构建网络的输入和返回。
  34. 当mode为'train'时,输入为image(-1, 3, -1, -1)和label (-1, 1, -1, -1) 返回loss。
  35. 当mode为'train'时,输入为image (-1, 3, -1, -1)和label (-1, 1, -1, -1),返回loss,
  36. pred (与网络输入label 相同大小的预测结果,值代表相应的类别),label,mask(非忽略值的mask,
  37. 与label相同大小,bool类型)。
  38. 当mode为'test'时,输入为image(-1, 3, -1, -1)返回pred (-1, 1, -1, -1)和
  39. logit (-1, num_classes, -1, -1) 通道维上代表每一类的概率值。
  40. upsample_mode (str): UNet decode时采用的上采样方式,取值为'bilinear'时利用双线行差值进行上菜样,
  41. 当输入其他选项时则利用反卷积进行上菜样,默认为'bilinear'。
  42. use_bce_loss (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。
  43. use_dice_loss (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。
  44. 当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。
  45. class_weight (list/str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为
  46. num_classes。当class_weight为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重
  47. 自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
  48. 即平时使用的交叉熵损失函数。
  49. ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。
  50. fixed_input_shape (list): 长度为2,维度为1的list,如:[640,720],用来固定模型输入:'image'的shape,默认为None。
  51. Raises:
  52. ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
  53. ValueError: class_weight为list, 但长度不等于num_class。
  54. class_weight为str, 但class_weight.low()不等于dynamic。
  55. TypeError: class_weight不为None时,其类型不是list或str。
  56. """
  57. def __init__(self,
  58. num_classes,
  59. input_channel=3,
  60. mode='train',
  61. upsample_mode='bilinear',
  62. use_bce_loss=False,
  63. use_dice_loss=False,
  64. class_weight=None,
  65. ignore_index=255,
  66. fixed_input_shape=None):
  67. # dice_loss或bce_loss只适用两类分割中
  68. if num_classes > 2 and (use_bce_loss or use_dice_loss):
  69. raise Exception(
  70. "dice loss and bce loss is only applicable to binary classfication"
  71. )
  72. if class_weight is not None:
  73. if isinstance(class_weight, list):
  74. if len(class_weight) != num_classes:
  75. raise ValueError(
  76. "Length of class_weight should be equal to number of classes"
  77. )
  78. elif isinstance(class_weight, str):
  79. if class_weight.lower() != 'dynamic':
  80. raise ValueError(
  81. "if class_weight is string, must be dynamic!")
  82. else:
  83. raise TypeError(
  84. 'Expect class_weight is a list or string but receive {}'.
  85. format(type(class_weight)))
  86. self.num_classes = num_classes
  87. self.input_channel = input_channel
  88. self.mode = mode
  89. self.upsample_mode = upsample_mode
  90. self.use_bce_loss = use_bce_loss
  91. self.use_dice_loss = use_dice_loss
  92. self.class_weight = class_weight
  93. self.ignore_index = ignore_index
  94. self.fixed_input_shape = fixed_input_shape
  95. def _double_conv(self, data, out_ch):
  96. param_attr = fluid.ParamAttr(
  97. name='weights',
  98. regularizer=fluid.regularizer.L2DecayRegularizer(
  99. regularization_coeff=0.0),
  100. initializer=fluid.initializer.TruncatedNormal(
  101. loc=0.0, scale=0.33))
  102. with scope("conv0"):
  103. data = bn_relu(
  104. conv(
  105. data,
  106. out_ch,
  107. 3,
  108. stride=1,
  109. padding=1,
  110. param_attr=param_attr))
  111. with scope("conv1"):
  112. data = bn_relu(
  113. conv(
  114. data,
  115. out_ch,
  116. 3,
  117. stride=1,
  118. padding=1,
  119. param_attr=param_attr))
  120. return data
  121. def _down(self, data, out_ch):
  122. # 下采样:max_pool + 2个卷积
  123. with scope("down"):
  124. data = max_pool(data, 2, 2, 0)
  125. data = self._double_conv(data, out_ch)
  126. return data
  127. def _up(self, data, short_cut, out_ch):
  128. # 上采样:data上采样(resize或deconv), 并与short_cut concat
  129. param_attr = fluid.ParamAttr(
  130. name='weights',
  131. regularizer=fluid.regularizer.L2DecayRegularizer(
  132. regularization_coeff=0.0),
  133. initializer=fluid.initializer.XavierInitializer(), )
  134. with scope("up"):
  135. if self.upsample_mode == 'bilinear':
  136. short_cut_shape = fluid.layers.shape(short_cut)
  137. data = fluid.layers.resize_bilinear(data, short_cut_shape[2:])
  138. else:
  139. data = deconv(
  140. data,
  141. out_ch // 2,
  142. filter_size=2,
  143. stride=2,
  144. padding=0,
  145. param_attr=param_attr)
  146. data = fluid.layers.concat([data, short_cut], axis=1)
  147. data = self._double_conv(data, out_ch)
  148. return data
  149. def _encode(self, data):
  150. # 编码器设置
  151. short_cuts = []
  152. with scope("encode"):
  153. with scope("block1"):
  154. data = self._double_conv(data, 64)
  155. short_cuts.append(data)
  156. with scope("block2"):
  157. data = self._down(data, 128)
  158. short_cuts.append(data)
  159. with scope("block3"):
  160. data = self._down(data, 256)
  161. short_cuts.append(data)
  162. with scope("block4"):
  163. data = self._down(data, 512)
  164. short_cuts.append(data)
  165. with scope("block5"):
  166. data = self._down(data, 512)
  167. return data, short_cuts
  168. def _decode(self, data, short_cuts):
  169. # 解码器设置,与编码器对称
  170. with scope("decode"):
  171. with scope("decode1"):
  172. data = self._up(data, short_cuts[3], 256)
  173. with scope("decode2"):
  174. data = self._up(data, short_cuts[2], 128)
  175. with scope("decode3"):
  176. data = self._up(data, short_cuts[1], 64)
  177. with scope("decode4"):
  178. data = self._up(data, short_cuts[0], 64)
  179. return data
  180. def _get_logit(self, data, num_classes):
  181. # 根据类别数设置最后一个卷积层输出
  182. param_attr = fluid.ParamAttr(
  183. name='weights',
  184. regularizer=fluid.regularizer.L2DecayRegularizer(
  185. regularization_coeff=0.0),
  186. initializer=fluid.initializer.TruncatedNormal(
  187. loc=0.0, scale=0.01))
  188. with scope("logit"):
  189. data = conv(
  190. data,
  191. num_classes,
  192. 3,
  193. stride=1,
  194. padding=1,
  195. param_attr=param_attr)
  196. return data
  197. def _get_loss(self, logit, label, mask):
  198. avg_loss = 0
  199. if not (self.use_dice_loss or self.use_bce_loss):
  200. avg_loss += softmax_with_loss(
  201. logit,
  202. label,
  203. mask,
  204. num_classes=self.num_classes,
  205. weight=self.class_weight,
  206. ignore_index=self.ignore_index)
  207. else:
  208. if self.use_dice_loss:
  209. avg_loss += dice_loss(logit, label, mask)
  210. if self.use_bce_loss:
  211. avg_loss += bce_loss(
  212. logit, label, mask, ignore_index=self.ignore_index)
  213. return avg_loss
  214. def generate_inputs(self):
  215. inputs = OrderedDict()
  216. if self.fixed_input_shape is not None:
  217. input_shape = [
  218. None, self.input_channel, self.fixed_input_shape[1],
  219. self.fixed_input_shape[0]
  220. ]
  221. inputs['image'] = fluid.data(
  222. dtype='float32', shape=input_shape, name='image')
  223. else:
  224. inputs['image'] = fluid.data(
  225. dtype='float32',
  226. shape=[None, self.input_channel, None, None],
  227. name='image')
  228. if self.mode == 'train':
  229. inputs['label'] = fluid.data(
  230. dtype='int32', shape=[None, 1, None, None], name='label')
  231. elif self.mode == 'eval':
  232. inputs['label'] = fluid.data(
  233. dtype='int32', shape=[None, 1, None, None], name='label')
  234. return inputs
  235. def build_net(self, inputs):
  236. # 在两类分割情况下,当loss函数选择dice_loss或bce_loss的时候,最后logit输出通道数设置为1
  237. if self.use_dice_loss or self.use_bce_loss:
  238. self.num_classes = 1
  239. image = inputs['image']
  240. encode_data, short_cuts = self._encode(image)
  241. decode_data = self._decode(encode_data, short_cuts)
  242. logit = self._get_logit(decode_data, self.num_classes)
  243. if self.num_classes == 1:
  244. out = sigmoid_to_softmax(logit)
  245. out = fluid.layers.transpose(out, [0, 2, 3, 1])
  246. else:
  247. out = fluid.layers.transpose(logit, [0, 2, 3, 1])
  248. pred = fluid.layers.argmax(out, axis=3)
  249. pred = fluid.layers.unsqueeze(pred, axes=[3])
  250. if self.mode == 'train':
  251. label = inputs['label']
  252. mask = label != self.ignore_index
  253. return self._get_loss(logit, label, mask)
  254. elif self.mode == 'eval':
  255. label = inputs['label']
  256. mask = label != self.ignore_index
  257. loss = self._get_loss(logit, label, mask)
  258. return loss, pred, label, mask
  259. else:
  260. if self.num_classes == 1:
  261. logit = sigmoid_to_softmax(logit)
  262. else:
  263. logit = fluid.layers.softmax(logit, axis=1)
  264. return pred, logit