hrnet.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import paddle.fluid as fluid
  16. import paddlex
  17. from collections import OrderedDict
  18. from .deeplabv3p import DeepLabv3p
  19. class HRNet(DeepLabv3p):
  20. """实现HRNet网络的构建并进行训练、评估、预测和模型导出。
  21. Args:
  22. num_classes (int): 类别数。
  23. width (int|str): 高分辨率分支中特征层的通道数量。默认值为18。可选择取值为[18, 30, 32, 40, 44, 48, 60, 64, '18_small_v1']。
  24. '18_small_v1'是18的轻量级版本,默认18。
  25. use_bce_loss (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。默认False。
  26. use_dice_loss (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。
  27. 当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。默认False。
  28. class_weight (list|str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为
  29. num_classes。当class_weight为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重
  30. 自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
  31. 即平时使用的交叉熵损失函数。
  32. ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。
  33. input_channel (int): 输入图像通道数。默认值3。
  34. Raises:
  35. ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
  36. ValueError: class_weight为list, 但长度不等于num_class。
  37. class_weight为str, 但class_weight.low()不等于dynamic。
  38. TypeError: class_weight不为None时,其类型不是list或str。
  39. """
  40. def __init__(self,
  41. num_classes=2,
  42. width=18,
  43. use_bce_loss=False,
  44. use_dice_loss=False,
  45. class_weight=None,
  46. ignore_index=255,
  47. input_channel=3):
  48. self.init_params = locals()
  49. super(DeepLabv3p, self).__init__('segmenter')
  50. # dice_loss或bce_loss只适用两类分割中
  51. if num_classes > 2 and (use_bce_loss or use_dice_loss):
  52. raise ValueError(
  53. "dice loss and bce loss is only applicable to binary classfication"
  54. )
  55. if class_weight is not None:
  56. if isinstance(class_weight, list):
  57. if len(class_weight) != num_classes:
  58. raise ValueError(
  59. "Length of class_weight should be equal to number of classes"
  60. )
  61. elif isinstance(class_weight, str):
  62. if class_weight.lower() != 'dynamic':
  63. raise ValueError(
  64. "if class_weight is string, must be dynamic!")
  65. else:
  66. raise TypeError(
  67. 'Expect class_weight is a list or string but receive {}'.
  68. format(type(class_weight)))
  69. self.num_classes = num_classes
  70. self.width = width
  71. self.use_bce_loss = use_bce_loss
  72. self.use_dice_loss = use_dice_loss
  73. self.class_weight = class_weight
  74. self.ignore_index = ignore_index
  75. self.labels = None
  76. self.fixed_input_shape = None
  77. self.input_channel = input_channel
  78. def build_net(self, mode='train'):
  79. model = paddlex.cv.nets.segmentation.HRNet(
  80. self.num_classes,
  81. width=self.width,
  82. mode=mode,
  83. use_bce_loss=self.use_bce_loss,
  84. use_dice_loss=self.use_dice_loss,
  85. class_weight=self.class_weight,
  86. ignore_index=self.ignore_index,
  87. fixed_input_shape=self.fixed_input_shape,
  88. input_channel=self.input_channel)
  89. inputs = model.generate_inputs()
  90. model_out = model.build_net(inputs)
  91. outputs = OrderedDict()
  92. if mode == 'train':
  93. self.optimizer.minimize(model_out)
  94. outputs['loss'] = model_out
  95. else:
  96. outputs['pred'] = model_out[0]
  97. outputs['logit'] = model_out[1]
  98. return inputs, outputs
  99. def default_optimizer(self,
  100. learning_rate,
  101. num_epochs,
  102. num_steps_each_epoch,
  103. lr_decay_power=0.9):
  104. decay_step = num_epochs * num_steps_each_epoch
  105. lr_decay = fluid.layers.polynomial_decay(
  106. learning_rate,
  107. decay_step,
  108. end_learning_rate=0,
  109. power=lr_decay_power)
  110. optimizer = fluid.optimizer.Momentum(
  111. lr_decay,
  112. momentum=0.9,
  113. regularization=fluid.regularizer.L2Decay(
  114. regularization_coeff=5e-04))
  115. return optimizer
  116. def train(self,
  117. num_epochs,
  118. train_dataset,
  119. train_batch_size=2,
  120. eval_dataset=None,
  121. save_interval_epochs=1,
  122. log_interval_steps=2,
  123. save_dir='output',
  124. pretrain_weights='IMAGENET',
  125. optimizer=None,
  126. learning_rate=0.01,
  127. lr_decay_power=0.9,
  128. use_vdl=False,
  129. sensitivities_file=None,
  130. eval_metric_loss=0.05,
  131. early_stop=False,
  132. early_stop_patience=5,
  133. resume_checkpoint=None):
  134. """训练。
  135. Args:
  136. num_epochs (int): 训练迭代轮数。
  137. train_dataset (paddlex.datasets): 训练数据读取器。
  138. train_batch_size (int): 训练数据batch大小。同时作为验证数据batch大小。默认2。
  139. eval_dataset (paddlex.datasets): 评估数据读取器。
  140. save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为1。
  141. log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
  142. save_dir (str): 模型保存路径。默认'output'。
  143. pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',
  144. 则自动下载在IMAGENET图片数据上预训练的模型权重;若为字符串'CITYSCAPES'
  145. 则自动下载在CITYSCAPES图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认为'IMAGENET'。
  146. optimizer (paddle.fluid.optimizer): 优化器。当改参数为None时,使用默认的优化器:使用
  147. fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。
  148. learning_rate (float): 默认优化器的初始学习率。默认0.01。
  149. lr_decay_power (float): 默认优化器学习率多项式衰减系数。默认0.9。
  150. use_vdl (bool): 是否使用VisualDL进行可视化。默认False。
  151. sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
  152. 则自动下载在Cityscapes图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
  153. eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
  154. early_stop (bool): 是否使用提前终止训练策略。默认值为False。
  155. early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
  156. 连续下降或持平,则终止训练。默认值为5。
  157. resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
  158. Raises:
  159. ValueError: 模型从inference model进行加载。
  160. """
  161. return super(HRNet, self).train(
  162. num_epochs, train_dataset, train_batch_size, eval_dataset,
  163. save_interval_epochs, log_interval_steps, save_dir,
  164. pretrain_weights, optimizer, learning_rate, lr_decay_power,
  165. use_vdl, sensitivities_file, eval_metric_loss, early_stop,
  166. early_stop_patience, resume_checkpoint)