unet.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # coding: utf8
  2. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. from collections import OrderedDict
  19. import paddle.fluid as fluid
  20. from .model_utils.libs import scope, name_scope
  21. from .model_utils.libs import bn, bn_relu, relu
  22. from .model_utils.libs import conv, max_pool, deconv
  23. from .model_utils.libs import sigmoid_to_softmax
  24. from .model_utils.loss import softmax_with_loss
  25. from .model_utils.loss import dice_loss
  26. from .model_utils.loss import bce_loss
  27. class UNet(object):
  28. """实现Unet模型
  29. `"U-Net: Convolutional Networks for Biomedical Image Segmentation"
  30. <https://arxiv.org/abs/1505.04597>`
  31. Args:
  32. num_classes (int): 类别数
  33. mode (str): 网络运行模式,根据mode构建网络的输入和返回。
  34. 当mode为'train'时,输入为image(-1, 3, -1, -1)和label (-1, 1, -1, -1) 返回loss。
  35. 当mode为'train'时,输入为image (-1, 3, -1, -1)和label (-1, 1, -1, -1),返回loss,
  36. pred (与网络输入label 相同大小的预测结果,值代表相应的类别),label,mask(非忽略值的mask,
  37. 与label相同大小,bool类型)。
  38. 当mode为'test'时,输入为image(-1, 3, -1, -1)返回pred (-1, 1, -1, -1)和
  39. logit (-1, num_classes, -1, -1) 通道维上代表每一类的概率值。
  40. upsample_mode (str): UNet decode时采用的上采样方式,取值为'bilinear'时利用双线行差值进行上菜样,
  41. 当输入其他选项时则利用反卷积进行上菜样,默认为'bilinear'。
  42. use_bce_loss (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。
  43. use_dice_loss (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。
  44. 当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。
  45. class_weight (list/str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为
  46. num_classes。当class_weight为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重
  47. 自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
  48. 即平时使用的交叉熵损失函数。
  49. ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。
  50. fixed_input_shape (list): 长度为2,维度为1的list,如:[640,720],用来固定模型输入:'image'的shape,默认为None。
  51. Raises:
  52. ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
  53. ValueError: class_weight为list, 但长度不等于num_class。
  54. class_weight为str, 但class_weight.low()不等于dynamic。
  55. TypeError: class_weight不为None时,其类型不是list或str。
  56. """
  57. def __init__(self,
  58. num_classes,
  59. mode='train',
  60. upsample_mode='bilinear',
  61. use_bce_loss=False,
  62. use_dice_loss=False,
  63. class_weight=None,
  64. ignore_index=255,
  65. fixed_input_shape=None):
  66. # dice_loss或bce_loss只适用两类分割中
  67. if num_classes > 2 and (use_bce_loss or use_dice_loss):
  68. raise Exception(
  69. "dice loss and bce loss is only applicable to binary classfication"
  70. )
  71. if class_weight is not None:
  72. if isinstance(class_weight, list):
  73. if len(class_weight) != num_classes:
  74. raise ValueError(
  75. "Length of class_weight should be equal to number of classes"
  76. )
  77. elif isinstance(class_weight, str):
  78. if class_weight.lower() != 'dynamic':
  79. raise ValueError(
  80. "if class_weight is string, must be dynamic!")
  81. else:
  82. raise TypeError(
  83. 'Expect class_weight is a list or string but receive {}'.
  84. format(type(class_weight)))
  85. self.num_classes = num_classes
  86. self.mode = mode
  87. self.upsample_mode = upsample_mode
  88. self.use_bce_loss = use_bce_loss
  89. self.use_dice_loss = use_dice_loss
  90. self.class_weight = class_weight
  91. self.ignore_index = ignore_index
  92. self.fixed_input_shape = fixed_input_shape
  93. def _double_conv(self, data, out_ch):
  94. param_attr = fluid.ParamAttr(
  95. name='weights',
  96. regularizer=fluid.regularizer.L2DecayRegularizer(
  97. regularization_coeff=0.0),
  98. initializer=fluid.initializer.TruncatedNormal(
  99. loc=0.0, scale=0.33))
  100. with scope("conv0"):
  101. data = bn_relu(
  102. conv(
  103. data,
  104. out_ch,
  105. 3,
  106. stride=1,
  107. padding=1,
  108. param_attr=param_attr))
  109. with scope("conv1"):
  110. data = bn_relu(
  111. conv(
  112. data,
  113. out_ch,
  114. 3,
  115. stride=1,
  116. padding=1,
  117. param_attr=param_attr))
  118. return data
  119. def _down(self, data, out_ch):
  120. # 下采样:max_pool + 2个卷积
  121. with scope("down"):
  122. data = max_pool(data, 2, 2, 0)
  123. data = self._double_conv(data, out_ch)
  124. return data
  125. def _up(self, data, short_cut, out_ch):
  126. # 上采样:data上采样(resize或deconv), 并与short_cut concat
  127. param_attr = fluid.ParamAttr(
  128. name='weights',
  129. regularizer=fluid.regularizer.L2DecayRegularizer(
  130. regularization_coeff=0.0),
  131. initializer=fluid.initializer.XavierInitializer(), )
  132. with scope("up"):
  133. if self.upsample_mode == 'bilinear':
  134. short_cut_shape = fluid.layers.shape(short_cut)
  135. data = fluid.layers.resize_bilinear(
  136. data, short_cut_shape[2:], align_corners=False)
  137. else:
  138. data = deconv(
  139. data,
  140. out_ch // 2,
  141. filter_size=2,
  142. stride=2,
  143. padding=0,
  144. param_attr=param_attr)
  145. data = fluid.layers.concat([data, short_cut], axis=1)
  146. data = self._double_conv(data, out_ch)
  147. return data
  148. def _encode(self, data):
  149. # 编码器设置
  150. short_cuts = []
  151. with scope("encode"):
  152. with scope("block1"):
  153. data = self._double_conv(data, 64)
  154. short_cuts.append(data)
  155. with scope("block2"):
  156. data = self._down(data, 128)
  157. short_cuts.append(data)
  158. with scope("block3"):
  159. data = self._down(data, 256)
  160. short_cuts.append(data)
  161. with scope("block4"):
  162. data = self._down(data, 512)
  163. short_cuts.append(data)
  164. with scope("block5"):
  165. data = self._down(data, 512)
  166. return data, short_cuts
  167. def _decode(self, data, short_cuts):
  168. # 解码器设置,与编码器对称
  169. with scope("decode"):
  170. with scope("decode1"):
  171. data = self._up(data, short_cuts[3], 256)
  172. with scope("decode2"):
  173. data = self._up(data, short_cuts[2], 128)
  174. with scope("decode3"):
  175. data = self._up(data, short_cuts[1], 64)
  176. with scope("decode4"):
  177. data = self._up(data, short_cuts[0], 64)
  178. return data
  179. def _get_logit(self, data, num_classes):
  180. # 根据类别数设置最后一个卷积层输出
  181. param_attr = fluid.ParamAttr(
  182. name='weights',
  183. regularizer=fluid.regularizer.L2DecayRegularizer(
  184. regularization_coeff=0.0),
  185. initializer=fluid.initializer.TruncatedNormal(
  186. loc=0.0, scale=0.01))
  187. with scope("logit"):
  188. data = conv(
  189. data,
  190. num_classes,
  191. 3,
  192. stride=1,
  193. padding=1,
  194. param_attr=param_attr)
  195. return data
  196. def _get_loss(self, logit, label, mask):
  197. avg_loss = 0
  198. if not (self.use_dice_loss or self.use_bce_loss):
  199. avg_loss += softmax_with_loss(
  200. logit,
  201. label,
  202. mask,
  203. num_classes=self.num_classes,
  204. weight=self.class_weight,
  205. ignore_index=self.ignore_index)
  206. else:
  207. if self.use_dice_loss:
  208. avg_loss += dice_loss(logit, label, mask)
  209. if self.use_bce_loss:
  210. avg_loss += bce_loss(
  211. logit, label, mask, ignore_index=self.ignore_index)
  212. return avg_loss
  213. def generate_inputs(self):
  214. inputs = OrderedDict()
  215. if self.fixed_input_shape is not None:
  216. input_shape = [
  217. None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
  218. ]
  219. inputs['image'] = fluid.data(
  220. dtype='float32', shape=input_shape, name='image')
  221. else:
  222. inputs['image'] = fluid.data(
  223. dtype='float32', shape=[None, 3, None, None], name='image')
  224. if self.mode == 'train':
  225. inputs['label'] = fluid.data(
  226. dtype='int32', shape=[None, 1, None, None], name='label')
  227. elif self.mode == 'eval':
  228. inputs['label'] = fluid.data(
  229. dtype='int32', shape=[None, 1, None, None], name='label')
  230. return inputs
  231. def build_net(self, inputs):
  232. # 在两类分割情况下,当loss函数选择dice_loss或bce_loss的时候,最后logit输出通道数设置为1
  233. if self.use_dice_loss or self.use_bce_loss:
  234. self.num_classes = 1
  235. image = inputs['image']
  236. encode_data, short_cuts = self._encode(image)
  237. decode_data = self._decode(encode_data, short_cuts)
  238. logit = self._get_logit(decode_data, self.num_classes)
  239. if self.num_classes == 1:
  240. out = sigmoid_to_softmax(logit)
  241. out = fluid.layers.transpose(out, [0, 2, 3, 1])
  242. else:
  243. out = fluid.layers.transpose(logit, [0, 2, 3, 1])
  244. pred = fluid.layers.argmax(out, axis=3)
  245. pred = fluid.layers.unsqueeze(pred, axes=[3])
  246. if self.mode == 'train':
  247. label = inputs['label']
  248. mask = label != self.ignore_index
  249. return self._get_loss(logit, label, mask)
  250. elif self.mode == 'eval':
  251. label = inputs['label']
  252. mask = label != self.ignore_index
  253. loss = self._get_loss(logit, label, mask)
  254. return loss, pred, label, mask
  255. else:
  256. if self.num_classes == 1:
  257. logit = sigmoid_to_softmax(logit)
  258. else:
  259. logit = fluid.layers.softmax(logit, axis=1)
  260. return pred, logit