seg.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from . import cv
  15. from .cv.models.utils.visualize import visualize_segmentation
  16. from paddlex.cv.transforms import seg_transforms
  17. import paddlex.utils.logging as logging
  18. transforms = seg_transforms
  19. visualize = visualize_segmentation
  20. class UNet(cv.models.UNet):
  21. def __init__(self,
  22. num_classes=2,
  23. upsample_mode='bilinear',
  24. use_bce_loss=False,
  25. use_dice_loss=False,
  26. class_weight=None,
  27. ignore_index=None,
  28. input_channel=None):
  29. if num_classes > 2 and (use_bce_loss or use_dice_loss):
  30. raise ValueError(
  31. "dice loss and bce loss is only applicable to binary classification"
  32. )
  33. elif num_classes == 2:
  34. if use_bce_loss and use_dice_loss:
  35. use_mixed_loss = [('CrossEntropyLoss', 1), ('DiceLoss', 1)]
  36. elif use_bce_loss:
  37. use_mixed_loss = [('CrossEntropyLoss', 1)]
  38. elif use_dice_loss:
  39. use_mixed_loss = [('DiceLoss', 1)]
  40. else:
  41. use_mixed_loss = False
  42. else:
  43. use_mixed_loss = False
  44. if class_weight is not None:
  45. logging.warning(
  46. "`class_weight` is not supported in PaddleX 2.0 currently and is forcibly set to None."
  47. )
  48. if ignore_index is not None:
  49. logging.warning(
  50. "`ignore_index` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 255."
  51. )
  52. if input_channel is not None:
  53. logging.warning(
  54. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  55. )
  56. if upsample_mode == 'bilinear':
  57. use_deconv = False
  58. else:
  59. use_deconv = True
  60. super(UNet, self).__init__(
  61. num_classes=num_classes,
  62. use_mixed_loss=use_mixed_loss,
  63. use_deconv=use_deconv)
  64. class DeepLabv3p(cv.models.DeepLabV3P):
  65. def __init__(self,
  66. num_classes=2,
  67. backbone='ResNet50_vd',
  68. output_stride=8,
  69. aspp_with_sep_conv=None,
  70. decoder_use_sep_conv=None,
  71. encoder_with_aspp=None,
  72. enable_decoder=None,
  73. use_bce_loss=False,
  74. use_dice_loss=False,
  75. class_weight=None,
  76. ignore_index=None,
  77. pooling_crop_size=None,
  78. input_channel=None):
  79. if num_classes > 2 and (use_bce_loss or use_dice_loss):
  80. raise ValueError(
  81. "dice loss and bce loss is only applicable to binary classification"
  82. )
  83. elif num_classes == 2:
  84. if use_bce_loss and use_dice_loss:
  85. use_mixed_loss = [('CrossEntropyLoss', 1), ('DiceLoss', 1)]
  86. elif use_bce_loss:
  87. use_mixed_loss = [('CrossEntropyLoss', 1)]
  88. elif use_dice_loss:
  89. use_mixed_loss = [('DiceLoss', 1)]
  90. else:
  91. use_mixed_loss = False
  92. else:
  93. use_mixed_loss = False
  94. if aspp_with_sep_conv is not None:
  95. logging.warning(
  96. "`aspp_with_sep_conv` is deprecated in PaddleX 2.0 and will not take effect. "
  97. "Defaults to True")
  98. if decoder_use_sep_conv is not None:
  99. logging.warning(
  100. "`decoder_use_sep_conv` is deprecated in PaddleX 2.0 and will not take effect. "
  101. "Defaults to True")
  102. if encoder_with_aspp is not None:
  103. logging.warning(
  104. "`encoder_with_aspp` is deprecated in PaddleX 2.0 and will not take effect. "
  105. "Defaults to True")
  106. if enable_decoder is not None:
  107. logging.warning(
  108. "`enable_decoder` is deprecated in PaddleX 2.0 and will not take effect. "
  109. "Defaults to True")
  110. if class_weight is not None:
  111. logging.warning(
  112. "`class_weight` is not supported in PaddleX 2.0 currently and is forcibly set to None."
  113. )
  114. if ignore_index is not None:
  115. logging.warning(
  116. "`ignore_index` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 255."
  117. )
  118. if pooling_crop_size is not None:
  119. logging.warning(
  120. "Backbone 'MobileNetV3_large_x1_0_ssld' is currently not supported in PaddleX 2.0. "
  121. "`pooling_crop_size` will not take effect. Defaults to None")
  122. if input_channel is not None:
  123. logging.warning(
  124. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  125. )
  126. super(DeepLabv3p, self).__init__(
  127. num_classes=num_classes,
  128. backbone=backbone,
  129. use_mixed_loss=use_mixed_loss,
  130. output_stride=output_stride)
  131. class HRNet(cv.models.HRNet):
  132. def __init__(self,
  133. num_classes=2,
  134. width=18,
  135. use_bce_loss=False,
  136. use_dice_loss=False,
  137. class_weight=None,
  138. ignore_index=None,
  139. input_channel=None):
  140. if num_classes > 2 and (use_bce_loss or use_dice_loss):
  141. raise ValueError(
  142. "dice loss and bce loss is only applicable to binary classification"
  143. )
  144. elif num_classes == 2:
  145. if use_bce_loss and use_dice_loss:
  146. use_mixed_loss = [('CrossEntropyLoss', 1), ('DiceLoss', 1)]
  147. elif use_bce_loss:
  148. use_mixed_loss = [('CrossEntropyLoss', 1)]
  149. elif use_dice_loss:
  150. use_mixed_loss = [('DiceLoss', 1)]
  151. else:
  152. use_mixed_loss = False
  153. else:
  154. use_mixed_loss = False
  155. if class_weight is not None:
  156. logging.warning(
  157. "`class_weight` is not supported in PaddleX 2.0 currently and is forcibly set to None."
  158. )
  159. if ignore_index is not None:
  160. logging.warning(
  161. "`ignore_index` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 255."
  162. )
  163. if input_channel is not None:
  164. logging.warning(
  165. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  166. )
  167. super(HRNet, self).__init__(
  168. num_classes=num_classes,
  169. width=width,
  170. use_mixed_loss=use_mixed_loss)
  171. class FastSCNN(cv.models.FastSCNN):
  172. def __init__(self,
  173. num_classes=2,
  174. use_bce_loss=False,
  175. use_dice_loss=False,
  176. class_weight=None,
  177. ignore_index=255,
  178. multi_loss_weight=None,
  179. input_channel=3):
  180. if num_classes > 2 and (use_bce_loss or use_dice_loss):
  181. raise ValueError(
  182. "dice loss and bce loss is only applicable to binary classification"
  183. )
  184. elif num_classes == 2:
  185. if use_bce_loss and use_dice_loss:
  186. use_mixed_loss = [('CrossEntropyLoss', 1), ('DiceLoss', 1)]
  187. elif use_bce_loss:
  188. use_mixed_loss = [('CrossEntropyLoss', 1)]
  189. elif use_dice_loss:
  190. use_mixed_loss = [('DiceLoss', 1)]
  191. else:
  192. use_mixed_loss = False
  193. else:
  194. use_mixed_loss = False
  195. if class_weight is not None:
  196. logging.warning(
  197. "`class_weight` is not supported in PaddleX 2.0 currently and is forcibly set to None."
  198. )
  199. if ignore_index is not None:
  200. logging.warning(
  201. "`ignore_index` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 255."
  202. )
  203. if multi_loss_weight is not None:
  204. logging.warning(
  205. "`multi_loss_weight` is deprecated in PaddleX 2.0 and will not take effect. "
  206. "Defaults to [1.0, 0.4]")
  207. if input_channel is not None:
  208. logging.warning(
  209. "`input_channel` is deprecated in PaddleX 2.0 and won't take effect. Defaults to 3."
  210. )
  211. super(FastSCNN, self).__init__(
  212. num_classes=num_classes, use_mixed_loss=use_mixed_loss)