unet.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. # coding: utf8
  2. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. from collections import OrderedDict
  19. import paddle.fluid as fluid
  20. from .model_utils.libs import scope, name_scope
  21. from .model_utils.libs import bn, bn_relu, relu
  22. from .model_utils.libs import conv, max_pool, deconv
  23. from .model_utils.libs import sigmoid_to_softmax
  24. from .model_utils.loss import softmax_with_loss
  25. from .model_utils.loss import dice_loss
  26. from .model_utils.loss import bce_loss
  27. class UNet(object):
  28. """实现Unet模型
  29. `"U-Net: Convolutional Networks for Biomedical Image Segmentation"
  30. <https://arxiv.org/abs/1505.04597>`
  31. Args:
  32. num_classes (int): 类别数
  33. mode (str): 网络运行模式,根据mode构建网络的输入和返回。
  34. 当mode为'train'时,输入为image(-1, 3, -1, -1)和label (-1, 1, -1, -1) 返回loss。
  35. 当mode为'train'时,输入为image (-1, 3, -1, -1)和label (-1, 1, -1, -1),返回loss,
  36. pred (与网络输入label 相同大小的预测结果,值代表相应的类别),label,mask(非忽略值的mask,
  37. 与label相同大小,bool类型)。
  38. 当mode为'test'时,输入为image(-1, 3, -1, -1)返回pred (-1, 1, -1, -1)和
  39. logit (-1, num_classes, -1, -1) 通道维上代表每一类的概率值。
  40. upsample_mode (str): UNet decode时采用的上采样方式,取值为'bilinear'时利用双线行差值进行上菜样,
  41. 当输入其他选项时则利用反卷积进行上菜样,默认为'bilinear'。
  42. use_bce_loss (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。
  43. use_dice_loss (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。
  44. 当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。
  45. class_weight (list/str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为
  46. num_classes。当class_weight为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重
  47. 自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
  48. 即平时使用的交叉熵损失函数。
  49. ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。
  50. fixed_input_shape (list): 长度为2,维度为1的list,如:[640,720],用来固定模型输入:'image'的shape,默认为None。
  51. Raises:
  52. ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
  53. ValueError: class_weight为list, 但长度不等于num_class。
  54. class_weight为str, 但class_weight.low()不等于dynamic。
  55. TypeError: class_weight不为None时,其类型不是list或str。
  56. """
  57. def __init__(self,
  58. num_classes,
  59. mode='train',
  60. upsample_mode='bilinear',
  61. use_bce_loss=False,
  62. use_dice_loss=False,
  63. class_weight=None,
  64. ignore_index=255,
  65. fixed_input_shape=None):
  66. # dice_loss或bce_loss只适用两类分割中
  67. if num_classes > 2 and (use_bce_loss or use_dice_loss):
  68. raise Exception(
  69. "dice loss and bce loss is only applicable to binary classfication"
  70. )
  71. if class_weight is not None:
  72. if isinstance(class_weight, list):
  73. if len(class_weight) != num_classes:
  74. raise ValueError(
  75. "Length of class_weight should be equal to number of classes"
  76. )
  77. elif isinstance(class_weight, str):
  78. if class_weight.lower() != 'dynamic':
  79. raise ValueError(
  80. "if class_weight is string, must be dynamic!")
  81. else:
  82. raise TypeError(
  83. 'Expect class_weight is a list or string but receive {}'.
  84. format(type(class_weight)))
  85. self.num_classes = num_classes
  86. self.mode = mode
  87. self.upsample_mode = upsample_mode
  88. self.use_bce_loss = use_bce_loss
  89. self.use_dice_loss = use_dice_loss
  90. self.class_weight = class_weight
  91. self.ignore_index = ignore_index
  92. self.fixed_input_shape = fixed_input_shape
  93. def _double_conv(self, data, out_ch):
  94. param_attr = fluid.ParamAttr(
  95. name='weights',
  96. regularizer=fluid.regularizer.L2DecayRegularizer(
  97. regularization_coeff=0.0),
  98. initializer=fluid.initializer.TruncatedNormal(
  99. loc=0.0, scale=0.33))
  100. with scope("conv0"):
  101. data = bn_relu(
  102. conv(
  103. data,
  104. out_ch,
  105. 3,
  106. stride=1,
  107. padding=1,
  108. param_attr=param_attr))
  109. with scope("conv1"):
  110. data = bn_relu(
  111. conv(
  112. data,
  113. out_ch,
  114. 3,
  115. stride=1,
  116. padding=1,
  117. param_attr=param_attr))
  118. return data
  119. def _down(self, data, out_ch):
  120. # 下采样:max_pool + 2个卷积
  121. with scope("down"):
  122. data = max_pool(data, 2, 2, 0)
  123. data = self._double_conv(data, out_ch)
  124. return data
  125. def _up(self, data, short_cut, out_ch):
  126. # 上采样:data上采样(resize或deconv), 并与short_cut concat
  127. param_attr = fluid.ParamAttr(
  128. name='weights',
  129. regularizer=fluid.regularizer.L2DecayRegularizer(
  130. regularization_coeff=0.0),
  131. initializer=fluid.initializer.XavierInitializer(), )
  132. with scope("up"):
  133. if self.upsample_mode == 'bilinear':
  134. short_cut_shape = fluid.layers.shape(short_cut)
  135. data = fluid.layers.resize_bilinear(data, short_cut_shape[2:])
  136. else:
  137. data = deconv(
  138. data,
  139. out_ch // 2,
  140. filter_size=2,
  141. stride=2,
  142. padding=0,
  143. param_attr=param_attr)
  144. data = fluid.layers.concat([data, short_cut], axis=1)
  145. data = self._double_conv(data, out_ch)
  146. return data
  147. def _encode(self, data):
  148. # 编码器设置
  149. short_cuts = []
  150. with scope("encode"):
  151. with scope("block1"):
  152. data = self._double_conv(data, 64)
  153. short_cuts.append(data)
  154. with scope("block2"):
  155. data = self._down(data, 128)
  156. short_cuts.append(data)
  157. with scope("block3"):
  158. data = self._down(data, 256)
  159. short_cuts.append(data)
  160. with scope("block4"):
  161. data = self._down(data, 512)
  162. short_cuts.append(data)
  163. with scope("block5"):
  164. data = self._down(data, 512)
  165. return data, short_cuts
  166. def _decode(self, data, short_cuts):
  167. # 解码器设置,与编码器对称
  168. with scope("decode"):
  169. with scope("decode1"):
  170. data = self._up(data, short_cuts[3], 256)
  171. with scope("decode2"):
  172. data = self._up(data, short_cuts[2], 128)
  173. with scope("decode3"):
  174. data = self._up(data, short_cuts[1], 64)
  175. with scope("decode4"):
  176. data = self._up(data, short_cuts[0], 64)
  177. return data
  178. def _get_logit(self, data, num_classes):
  179. # 根据类别数设置最后一个卷积层输出
  180. param_attr = fluid.ParamAttr(
  181. name='weights',
  182. regularizer=fluid.regularizer.L2DecayRegularizer(
  183. regularization_coeff=0.0),
  184. initializer=fluid.initializer.TruncatedNormal(
  185. loc=0.0, scale=0.01))
  186. with scope("logit"):
  187. data = conv(
  188. data,
  189. num_classes,
  190. 3,
  191. stride=1,
  192. padding=1,
  193. param_attr=param_attr)
  194. return data
  195. def _get_loss(self, logit, label, mask):
  196. avg_loss = 0
  197. if not (self.use_dice_loss or self.use_bce_loss):
  198. avg_loss += softmax_with_loss(
  199. logit,
  200. label,
  201. mask,
  202. num_classes=self.num_classes,
  203. weight=self.class_weight,
  204. ignore_index=self.ignore_index)
  205. else:
  206. if self.use_dice_loss:
  207. avg_loss += dice_loss(logit, label, mask)
  208. if self.use_bce_loss:
  209. avg_loss += bce_loss(
  210. logit, label, mask, ignore_index=self.ignore_index)
  211. return avg_loss
  212. def generate_inputs(self):
  213. inputs = OrderedDict()
  214. if self.fixed_input_shape is not None:
  215. input_shape = [
  216. None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
  217. ]
  218. inputs['image'] = fluid.data(
  219. dtype='float32', shape=input_shape, name='image')
  220. else:
  221. inputs['image'] = fluid.data(
  222. dtype='float32', shape=[None, 3, None, None], name='image')
  223. if self.mode == 'train':
  224. inputs['label'] = fluid.data(
  225. dtype='int32', shape=[None, 1, None, None], name='label')
  226. elif self.mode == 'eval':
  227. inputs['label'] = fluid.data(
  228. dtype='int32', shape=[None, 1, None, None], name='label')
  229. return inputs
  230. def build_net(self, inputs):
  231. # 在两类分割情况下,当loss函数选择dice_loss或bce_loss的时候,最后logit输出通道数设置为1
  232. if self.use_dice_loss or self.use_bce_loss:
  233. self.num_classes = 1
  234. image = inputs['image']
  235. encode_data, short_cuts = self._encode(image)
  236. decode_data = self._decode(encode_data, short_cuts)
  237. logit = self._get_logit(decode_data, self.num_classes)
  238. if self.num_classes == 1:
  239. out = sigmoid_to_softmax(logit)
  240. out = fluid.layers.transpose(out, [0, 2, 3, 1])
  241. else:
  242. out = fluid.layers.transpose(logit, [0, 2, 3, 1])
  243. pred = fluid.layers.argmax(out, axis=3)
  244. pred = fluid.layers.unsqueeze(pred, axes=[3])
  245. if self.mode == 'train':
  246. label = inputs['label']
  247. mask = label != self.ignore_index
  248. return self._get_loss(logit, label, mask)
  249. elif self.mode == 'eval':
  250. label = inputs['label']
  251. mask = label != self.ignore_index
  252. loss = self._get_loss(logit, label, mask)
  253. return loss, pred, label, mask
  254. else:
  255. if self.num_classes == 1:
  256. logit = sigmoid_to_softmax(logit)
  257. else:
  258. logit = fluid.layers.softmax(logit, axis=1)
  259. return pred, logit