unet.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # coding: utf8
  2. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. from collections import OrderedDict
  19. import paddle.fluid as fluid
  20. from .model_utils.libs import scope, name_scope
  21. from .model_utils.libs import bn, bn_relu, relu
  22. from .model_utils.libs import conv, max_pool, deconv
  23. from .model_utils.libs import sigmoid_to_softmax
  24. from .model_utils.loss import softmax_with_loss
  25. from .model_utils.loss import dice_loss
  26. from .model_utils.loss import bce_loss
  27. import paddlex.utils.logging as logging
  28. class UNet(object):
  29. """实现Unet模型
  30. `"U-Net: Convolutional Networks for Biomedical Image Segmentation"
  31. <https://arxiv.org/abs/1505.04597>`
  32. Args:
  33. num_classes (int): 类别数
  34. mode (str): 网络运行模式,根据mode构建网络的输入和返回。
  35. 当mode为'train'时,输入为image(-1, 3, -1, -1)和label (-1, 1, -1, -1) 返回loss。
  36. 当mode为'train'时,输入为image (-1, 3, -1, -1)和label (-1, 1, -1, -1),返回loss,
  37. pred (与网络输入label 相同大小的预测结果,值代表相应的类别),label,mask(非忽略值的mask,
  38. 与label相同大小,bool类型)。
  39. 当mode为'test'时,输入为image(-1, 3, -1, -1)返回pred (-1, 1, -1, -1)和
  40. logit (-1, num_classes, -1, -1) 通道维上代表每一类的概率值。
  41. upsample_mode (str): UNet decode时采用的上采样方式,取值为'bilinear'时利用双线行差值进行上菜样,
  42. 当输入其他选项时则利用反卷积进行上菜样,默认为'bilinear'。
  43. use_bce_loss (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。
  44. use_dice_loss (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。
  45. 当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。
  46. class_weight (list/str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为
  47. num_classes。当class_weight为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重
  48. 自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
  49. 即平时使用的交叉熵损失函数。
  50. ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。
  51. fixed_input_shape (list): 长度为2,维度为1的list,如:[640,720],用来固定模型输入:'image'的shape,默认为None。
  52. Raises:
  53. ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
  54. ValueError: class_weight为list, 但长度不等于num_class。
  55. class_weight为str, 但class_weight.low()不等于dynamic。
  56. TypeError: class_weight不为None时,其类型不是list或str。
  57. """
  58. def __init__(self,
  59. num_classes,
  60. mode='train',
  61. upsample_mode='bilinear',
  62. use_bce_loss=False,
  63. use_dice_loss=False,
  64. class_weight=None,
  65. ignore_index=255,
  66. fixed_input_shape=None):
  67. # dice_loss或bce_loss只适用两类分割中
  68. if num_classes > 2 and (use_bce_loss or use_dice_loss):
  69. raise Exception(
  70. "dice loss and bce loss is only applicable to binary classfication"
  71. )
  72. if class_weight is not None:
  73. if isinstance(class_weight, list):
  74. if len(class_weight) != num_classes:
  75. raise ValueError(
  76. "Length of class_weight should be equal to number of classes"
  77. )
  78. elif isinstance(class_weight, str):
  79. if class_weight.lower() != 'dynamic':
  80. raise ValueError(
  81. "if class_weight is string, must be dynamic!")
  82. else:
  83. raise TypeError(
  84. 'Expect class_weight is a list or string but receive {}'.
  85. format(type(class_weight)))
  86. self.num_classes = num_classes
  87. self.mode = mode
  88. self.upsample_mode = upsample_mode
  89. self.use_bce_loss = use_bce_loss
  90. self.use_dice_loss = use_dice_loss
  91. self.class_weight = class_weight
  92. self.ignore_index = ignore_index
  93. self.fixed_input_shape = fixed_input_shape
  94. def _double_conv(self, data, out_ch):
  95. param_attr = fluid.ParamAttr(
  96. name='weights',
  97. regularizer=fluid.regularizer.L2DecayRegularizer(
  98. regularization_coeff=0.0),
  99. initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.33))
  100. with scope("conv0"):
  101. data = bn_relu(
  102. conv(
  103. data,
  104. out_ch,
  105. 3,
  106. stride=1,
  107. padding=1,
  108. param_attr=param_attr))
  109. with scope("conv1"):
  110. data = bn_relu(
  111. conv(
  112. data,
  113. out_ch,
  114. 3,
  115. stride=1,
  116. padding=1,
  117. param_attr=param_attr))
  118. return data
  119. def _down(self, data, out_ch):
  120. # 下采样:max_pool + 2个卷积
  121. with scope("down"):
  122. data = max_pool(data, 2, 2, 0)
  123. data = self._double_conv(data, out_ch)
  124. return data
  125. def _up(self, data, short_cut, out_ch):
  126. # 上采样:data上采样(resize或deconv), 并与short_cut concat
  127. param_attr = fluid.ParamAttr(
  128. name='weights',
  129. regularizer=fluid.regularizer.L2DecayRegularizer(
  130. regularization_coeff=0.0),
  131. initializer=fluid.initializer.XavierInitializer(),
  132. )
  133. with scope("up"):
  134. if self.upsample_mode == 'bilinear':
  135. short_cut_shape = fluid.layers.shape(short_cut)
  136. data = fluid.layers.resize_bilinear(data, short_cut_shape[2:])
  137. else:
  138. data = deconv(
  139. data,
  140. out_ch // 2,
  141. filter_size=2,
  142. stride=2,
  143. padding=0,
  144. param_attr=param_attr)
  145. data = fluid.layers.concat([data, short_cut], axis=1)
  146. data = self._double_conv(data, out_ch)
  147. return data
  148. def _encode(self, data):
  149. # 编码器设置
  150. short_cuts = []
  151. with scope("encode"):
  152. with scope("block1"):
  153. data = self._double_conv(data, 64)
  154. short_cuts.append(data)
  155. with scope("block2"):
  156. data = self._down(data, 128)
  157. short_cuts.append(data)
  158. with scope("block3"):
  159. data = self._down(data, 256)
  160. short_cuts.append(data)
  161. with scope("block4"):
  162. data = self._down(data, 512)
  163. short_cuts.append(data)
  164. with scope("block5"):
  165. data = self._down(data, 512)
  166. return data, short_cuts
  167. def _decode(self, data, short_cuts):
  168. # 解码器设置,与编码器对称
  169. with scope("decode"):
  170. with scope("decode1"):
  171. data = self._up(data, short_cuts[3], 256)
  172. with scope("decode2"):
  173. data = self._up(data, short_cuts[2], 128)
  174. with scope("decode3"):
  175. data = self._up(data, short_cuts[1], 64)
  176. with scope("decode4"):
  177. data = self._up(data, short_cuts[0], 64)
  178. return data
  179. def _get_logit(self, data, num_classes):
  180. # 根据类别数设置最后一个卷积层输出
  181. param_attr = fluid.ParamAttr(
  182. name='weights',
  183. regularizer=fluid.regularizer.L2DecayRegularizer(
  184. regularization_coeff=0.0),
  185. initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.01))
  186. with scope("logit"):
  187. data = conv(
  188. data,
  189. num_classes,
  190. 3,
  191. stride=1,
  192. padding=1,
  193. param_attr=param_attr)
  194. return data
  195. def _get_loss(self, logit, label, mask):
  196. avg_loss = 0
  197. if not (self.use_dice_loss or self.use_bce_loss):
  198. avg_loss += softmax_with_loss(
  199. logit,
  200. label,
  201. mask,
  202. num_classes=self.num_classes,
  203. weight=self.class_weight,
  204. ignore_index=self.ignore_index)
  205. else:
  206. if self.use_dice_loss:
  207. avg_loss += dice_loss(logit, label, mask)
  208. if self.use_bce_loss:
  209. avg_loss += bce_loss(
  210. logit, label, mask, ignore_index=self.ignore_index)
  211. return avg_loss
  212. def generate_inputs(self):
  213. inputs = OrderedDict()
  214. if self.fixed_input_shape is not None:
  215. input_shape =[None, 3, self.fixed_input_shape[0], self.fixed_input_shape[1]]
  216. inputs['image'] = fluid.data(
  217. dtype='float32', shape=input_shape, name='image')
  218. else:
  219. inputs['image'] = fluid.data(
  220. dtype='float32', shape=[None, 3, None, None], name='image')
  221. if self.mode == 'train':
  222. inputs['label'] = fluid.data(
  223. dtype='int32', shape=[None, 1, None, None], name='label')
  224. elif self.mode == 'eval':
  225. inputs['label'] = fluid.data(
  226. dtype='int32', shape=[None, 1, None, None], name='label')
  227. return inputs
  228. def build_net(self, inputs):
  229. # 在两类分割情况下,当loss函数选择dice_loss或bce_loss的时候,最后logit输出通道数设置为1
  230. if self.use_dice_loss or self.use_bce_loss:
  231. self.num_classes = 1
  232. image = inputs['image']
  233. encode_data, short_cuts = self._encode(image)
  234. decode_data = self._decode(encode_data, short_cuts)
  235. logit = self._get_logit(decode_data, self.num_classes)
  236. if self.num_classes == 1:
  237. out = sigmoid_to_softmax(logit)
  238. out = fluid.layers.transpose(out, [0, 2, 3, 1])
  239. else:
  240. out = fluid.layers.transpose(logit, [0, 2, 3, 1])
  241. pred = fluid.layers.argmax(out, axis=3)
  242. pred = fluid.layers.unsqueeze(pred, axes=[3])
  243. if self.mode == 'train':
  244. label = inputs['label']
  245. mask = label != self.ignore_index
  246. return self._get_loss(logit, label, mask)
  247. elif self.mode == 'eval':
  248. label = inputs['label']
  249. mask = label != self.ignore_index
  250. loss = self._get_loss(logit, label, mask)
  251. return loss, pred, label, mask
  252. else:
  253. if self.num_classes == 1:
  254. logit = sigmoid_to_softmax(logit)
  255. else:
  256. logit = fluid.layers.softmax(logit, axis=1)
  257. return pred, logit