unet_3plus.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddlex.paddleseg.cvlibs import manager
  18. from paddlex.paddleseg.models.layers.layer_libs import SyncBatchNorm
  19. from paddlex.paddleseg.cvlibs.param_init import kaiming_normal_init
  20. @manager.MODELS.add_component
  21. class UNet3Plus(nn.Layer):
  22. """
  23. The UNet3+ implementation based on PaddlePaddle.
  24. The original article refers to
  25. Huang H , Lin L , Tong R , et al. "UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation"
  26. (https://arxiv.org/abs/2004.08790).
  27. Args:
  28. in_channels (int, optional): The channel number of input image. Default: 3.
  29. num_classes (int, optional): The unique number of target classes. Default: 2.
  30. is_batchnorm (bool, optional): Use batchnorm after conv or not. Default: True.
  31. is_deepsup (bool, optional): Use deep supervision or not. Default: False.
  32. is_CGM (bool, optional): Use classification-guided module or not.
  33. If True, is_deepsup must be True. Default: False.
  34. """
  35. def __init__(self,
  36. in_channels=3,
  37. num_classes=2,
  38. is_batchnorm=True,
  39. is_deepsup=False,
  40. is_CGM=False):
  41. super(UNet3Plus, self).__init__()
  42. # parameters
  43. self.is_deepsup = True if is_CGM else is_deepsup
  44. self.is_CGM = is_CGM
  45. # internal definition
  46. self.filters = [64, 128, 256, 512, 1024]
  47. self.cat_channels = self.filters[0]
  48. self.cat_blocks = 5
  49. self.up_channels = self.cat_channels * self.cat_blocks
  50. # layers
  51. self.encoder = Encoder(in_channels, self.filters, is_batchnorm)
  52. self.decoder = Decoder(self.filters, self.cat_channels,
  53. self.up_channels)
  54. if self.is_deepsup:
  55. self.deepsup = DeepSup(self.up_channels, self.filters, num_classes)
  56. if self.is_CGM:
  57. self.cls = nn.Sequential(
  58. nn.Dropout(p=0.5), nn.Conv2D(self.filters[4], 2, 1),
  59. nn.AdaptiveMaxPool2D(1), nn.Sigmoid())
  60. else:
  61. self.outconv1 = nn.Conv2D(
  62. self.up_channels, num_classes, 3, padding=1)
  63. # initialise weights
  64. for sublayer in self.sublayers():
  65. if isinstance(sublayer, nn.Conv2D):
  66. kaiming_normal_init(sublayer.weight)
  67. elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)):
  68. kaiming_normal_init(sublayer.weight)
  69. def dotProduct(self, seg, cls):
  70. B, N, H, W = seg.shape
  71. seg = seg.reshape((B, N, H * W))
  72. clssp = paddle.ones([1, N])
  73. ecls = (cls * clssp).reshape([B, N, 1])
  74. final = seg * ecls
  75. final = final.reshape((B, N, H, W))
  76. return final
  77. def forward(self, inputs):
  78. hs = self.encoder(inputs)
  79. hds = self.decoder(hs)
  80. if self.is_deepsup:
  81. out = self.deepsup(hds)
  82. if self.is_CGM:
  83. # classification-guided module
  84. cls_branch = self.cls(hds[-1]).squeeze(3).squeeze(
  85. 2) # (B,N,1,1)->(B,N)
  86. cls_branch_max = cls_branch.argmax(axis=1)
  87. cls_branch_max = cls_branch_max.reshape((-1, 1)).astype('float')
  88. out = [self.dotProduct(d, cls_branch_max) for d in out]
  89. else:
  90. out = [self.outconv1(hds[0])] # d1->320*320*num_classes
  91. return out
  92. class Encoder(nn.Layer):
  93. def __init__(self, in_channels, filters, is_batchnorm):
  94. super(Encoder, self).__init__()
  95. self.conv1 = UnetConv2D(in_channels, filters[0], is_batchnorm)
  96. self.poolconv2 = MaxPoolConv2D(filters[0], filters[1], is_batchnorm)
  97. self.poolconv3 = MaxPoolConv2D(filters[1], filters[2], is_batchnorm)
  98. self.poolconv4 = MaxPoolConv2D(filters[2], filters[3], is_batchnorm)
  99. self.poolconv5 = MaxPoolConv2D(filters[3], filters[4], is_batchnorm)
  100. def forward(self, inputs):
  101. h1 = self.conv1(inputs) # h1->320*320*64
  102. h2 = self.poolconv2(h1) # h2->160*160*128
  103. h3 = self.poolconv3(h2) # h3->80*80*256
  104. h4 = self.poolconv4(h3) # h4->40*40*512
  105. hd5 = self.poolconv5(h4) # h5->20*20*1024
  106. return [h1, h2, h3, h4, hd5]
  107. class Decoder(nn.Layer):
  108. def __init__(self, filters, cat_channels, up_channels):
  109. super(Decoder, self).__init__()
  110. '''stage 4d'''
  111. # h1->320*320, hd4->40*40, Pooling 8 times
  112. self.h1_PT_hd4 = nn.MaxPool2D(8, 8, ceil_mode=True)
  113. self.h1_PT_hd4_cbr = ConvBnReLU2D(filters[0], cat_channels)
  114. # h2->160*160, hd4->40*40, Pooling 4 times
  115. self.h2_PT_hd4 = nn.MaxPool2D(4, 4, ceil_mode=True)
  116. self.h2_PT_hd4_cbr = ConvBnReLU2D(filters[1], cat_channels)
  117. # h3->80*80, hd4->40*40, Pooling 2 times
  118. self.h3_PT_hd4 = nn.MaxPool2D(2, 2, ceil_mode=True)
  119. self.h3_PT_hd4_cbr = ConvBnReLU2D(filters[2], cat_channels)
  120. # h4->40*40, hd4->40*40, Concatenation
  121. self.h4_Cat_hd4_cbr = ConvBnReLU2D(filters[3], cat_channels)
  122. # hd5->20*20, hd4->40*40, Upsample 2 times
  123. self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
  124. self.hd5_UT_hd4_cbr = ConvBnReLU2D(filters[4], cat_channels)
  125. # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)
  126. self.cbr4d_1 = ConvBnReLU2D(up_channels, up_channels) # 16
  127. '''stage 3d'''
  128. # h1->320*320, hd3->80*80, Pooling 4 times
  129. self.h1_PT_hd3 = nn.MaxPool2D(4, 4, ceil_mode=True)
  130. self.h1_PT_hd3_cbr = ConvBnReLU2D(filters[0], cat_channels)
  131. # h2->160*160, hd3->80*80, Pooling 2 times
  132. self.h2_PT_hd3 = nn.MaxPool2D(2, 2, ceil_mode=True)
  133. self.h2_PT_hd3_cbr = ConvBnReLU2D(filters[1], cat_channels)
  134. # h3->80*80, hd3->80*80, Concatenation
  135. self.h3_Cat_hd3_cbr = ConvBnReLU2D(filters[2], cat_channels)
  136. # hd4->40*40, hd4->80*80, Upsample 2 times
  137. self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
  138. self.hd4_UT_hd3_cbr = ConvBnReLU2D(up_channels, cat_channels)
  139. # hd5->20*20, hd4->80*80, Upsample 4 times
  140. self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14
  141. self.hd5_UT_hd3_cbr = ConvBnReLU2D(filters[4], cat_channels)
  142. # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)
  143. self.cbr3d_1 = ConvBnReLU2D(up_channels, up_channels) # 16
  144. '''stage 2d '''
  145. # h1->320*320, hd2->160*160, Pooling 2 times
  146. self.h1_PT_hd2 = nn.MaxPool2D(2, 2, ceil_mode=True)
  147. self.h1_PT_hd2_cbr = ConvBnReLU2D(filters[0], cat_channels)
  148. # h2->160*160, hd2->160*160, Concatenation
  149. self.h2_Cat_hd2_cbr = ConvBnReLU2D(filters[1], cat_channels)
  150. # hd3->80*80, hd2->160*160, Upsample 2 times
  151. self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
  152. self.hd3_UT_hd2_cbr = ConvBnReLU2D(up_channels, cat_channels)
  153. # hd4->40*40, hd2->160*160, Upsample 4 times
  154. self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14
  155. self.hd4_UT_hd2_cbr = ConvBnReLU2D(up_channels, cat_channels)
  156. # hd5->20*20, hd2->160*160, Upsample 8 times
  157. self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14
  158. self.hd5_UT_hd2_cbr = ConvBnReLU2D(filters[4], cat_channels)
  159. # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)
  160. self.cbr2d_1 = ConvBnReLU2D(up_channels, up_channels) # 16
  161. '''stage 1d'''
  162. # h1->320*320, hd1->320*320, Concatenation
  163. self.h1_Cat_hd1_cbr = ConvBnReLU2D(filters[0], cat_channels)
  164. # hd2->160*160, hd1->320*320, Upsample 2 times
  165. self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
  166. self.hd2_UT_hd1_cbr = ConvBnReLU2D(up_channels, cat_channels)
  167. # hd3->80*80, hd1->320*320, Upsample 4 times
  168. self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14
  169. self.hd3_UT_hd1_cbr = ConvBnReLU2D(up_channels, cat_channels)
  170. # hd4->40*40, hd1->320*320, Upsample 8 times
  171. self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14
  172. self.hd4_UT_hd1_cbr = ConvBnReLU2D(up_channels, cat_channels)
  173. # hd5->20*20, hd1->320*320, Upsample 16 times
  174. self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear') # 14*14
  175. self.hd5_UT_hd1_cbr = ConvBnReLU2D(filters[4], cat_channels)
  176. # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)
  177. self.cbr1d_1 = ConvBnReLU2D(up_channels, up_channels) # 16
  178. def forward(self, inputs):
  179. h1, h2, h3, h4, hd5 = inputs
  180. h1_PT_hd4 = self.h1_PT_hd4_cbr(self.h1_PT_hd4(h1))
  181. h2_PT_hd4 = self.h2_PT_hd4_cbr(self.h2_PT_hd4(h2))
  182. h3_PT_hd4 = self.h3_PT_hd4_cbr(self.h3_PT_hd4(h3))
  183. h4_Cat_hd4 = self.h4_Cat_hd4_cbr(h4)
  184. hd5_UT_hd4 = self.hd5_UT_hd4_cbr(self.hd5_UT_hd4(hd5))
  185. # hd4->40*40*up_channels
  186. hd4 = self.cbr4d_1(
  187. paddle.concat(
  188. [h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4], 1))
  189. h1_PT_hd3 = self.h1_PT_hd3_cbr(self.h1_PT_hd3(h1))
  190. h2_PT_hd3 = self.h2_PT_hd3_cbr(self.h2_PT_hd3(h2))
  191. h3_Cat_hd3 = self.h3_Cat_hd3_cbr(h3)
  192. hd4_UT_hd3 = self.hd4_UT_hd3_cbr(self.hd4_UT_hd3(hd4))
  193. hd5_UT_hd3 = self.hd5_UT_hd3_cbr(self.hd5_UT_hd3(hd5))
  194. # hd3->80*80*up_channels
  195. hd3 = self.cbr3d_1(
  196. paddle.concat(
  197. [h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3], 1))
  198. h1_PT_hd2 = self.h1_PT_hd2_cbr(self.h1_PT_hd2(h1))
  199. h2_Cat_hd2 = self.h2_Cat_hd2_cbr(h2)
  200. hd3_UT_hd2 = self.hd3_UT_hd2_cbr(self.hd3_UT_hd2(hd3))
  201. hd4_UT_hd2 = self.hd4_UT_hd2_cbr(self.hd4_UT_hd2(hd4))
  202. hd5_UT_hd2 = self.hd5_UT_hd2_cbr(self.hd5_UT_hd2(hd5))
  203. # hd2->160*160*up_channels
  204. hd2 = self.cbr2d_1(
  205. paddle.concat(
  206. [h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2], 1))
  207. h1_Cat_hd1 = self.h1_Cat_hd1_cbr(h1)
  208. hd2_UT_hd1 = self.hd2_UT_hd1_cbr(self.hd2_UT_hd1(hd2))
  209. hd3_UT_hd1 = self.hd3_UT_hd1_cbr(self.hd3_UT_hd1(hd3))
  210. hd4_UT_hd1 = self.hd4_UT_hd1_cbr(self.hd4_UT_hd1(hd4))
  211. hd5_UT_hd1 = self.hd5_UT_hd1_cbr(self.hd5_UT_hd1(hd5))
  212. # hd1->320*320*up_channels
  213. hd1 = self.cbr1d_1(
  214. paddle.concat(
  215. [h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1],
  216. 1))
  217. return [hd1, hd2, hd3, hd4, hd5]
  218. class DeepSup(nn.Layer):
  219. def __init__(self, up_channels, filters, num_classes):
  220. super(DeepSup, self).__init__()
  221. self.convup5 = ConvUp2D(filters[4], num_classes, 16)
  222. self.convup4 = ConvUp2D(up_channels, num_classes, 8)
  223. self.convup3 = ConvUp2D(up_channels, num_classes, 4)
  224. self.convup2 = ConvUp2D(up_channels, num_classes, 2)
  225. self.outconv1 = nn.Conv2D(up_channels, num_classes, 3, padding=1)
  226. def forward(self, inputs):
  227. hd1, hd2, hd3, hd4, hd5 = inputs
  228. d5 = self.convup5(hd5) # 16->256
  229. d4 = self.convup4(hd4) # 32->256
  230. d3 = self.convup3(hd3) # 64->256
  231. d2 = self.convup2(hd2) # 128->256
  232. d1 = self.outconv1(hd1) # 256
  233. return [d1, d2, d3, d4, d5]
  234. class ConvBnReLU2D(nn.Sequential):
  235. def __init__(self, in_channels, out_channels):
  236. super(ConvBnReLU2D, self).__init__(
  237. nn.Conv2D(in_channels, out_channels, 3, padding=1),
  238. nn.BatchNorm(out_channels), nn.ReLU())
  239. class ConvUp2D(nn.Sequential):
  240. def __init__(self, in_channels, out_channels, scale_factor):
  241. super(ConvUp2D, self).__init__(
  242. nn.Conv2D(in_channels, out_channels, 3, padding=1),
  243. nn.Upsample(scale_factor=scale_factor, mode='bilinear'))
  244. class MaxPoolConv2D(nn.Sequential):
  245. def __init__(self, in_channels, out_channels, is_batchnorm):
  246. super(MaxPoolConv2D, self).__init__(
  247. nn.MaxPool2D(kernel_size=2),
  248. UnetConv2D(in_channels, out_channels, is_batchnorm))
  249. class UnetConv2D(nn.Layer):
  250. def __init__(self,
  251. in_channels,
  252. out_channels,
  253. is_batchnorm,
  254. num_conv=2,
  255. kernel_size=3,
  256. stride=1,
  257. padding=1):
  258. super(UnetConv2D, self).__init__()
  259. self.num_conv = num_conv
  260. for i in range(num_conv):
  261. conv = (nn.Sequential(nn.Conv2D(in_channels, out_channels, kernel_size, stride, padding),
  262. nn.BatchNorm(out_channels),
  263. nn.ReLU()) \
  264. if is_batchnorm else \
  265. nn.Sequential(nn.Conv2D(in_channels, out_channels, kernel_size, stride, padding),
  266. nn.ReLU()))
  267. setattr(self, 'conv%d' % (i + 1), conv)
  268. in_channels = out_channels
  269. # initialise the blocks
  270. for children in self.children():
  271. children.weight_attr = paddle.framework.ParamAttr(
  272. initializer=paddle.nn.initializer.KaimingNormal)
  273. children.bias_attr = paddle.framework.ParamAttr(
  274. initializer=paddle.nn.initializer.KaimingNormal)
  275. def forward(self, inputs):
  276. x = inputs
  277. for i in range(self.num_conv):
  278. conv = getattr(self, 'conv%d' % (i + 1))
  279. x = conv(x)
  280. return x