segmentation.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. def build_transforms(params):
  16. from paddlex.seg import transforms
  17. seg_list = []
  18. min_value = max(params.image_shape) * 4 // 5
  19. max_value = max(params.image_shape) * 6 // 5
  20. seg_list.extend([
  21. transforms.ResizeRangeScaling(
  22. min_value=min_value, max_value=max_value),
  23. transforms.RandomBlur(prob=params.blur_prob)
  24. ])
  25. if params.rotate:
  26. seg_list.append(
  27. transforms.RandomRotate(rotate_range=params.max_rotation))
  28. if params.scale_aspect:
  29. seg_list.append(
  30. transforms.RandomScaleAspect(
  31. min_scale=params.min_ratio, aspect_ratio=params.aspect_ratio))
  32. seg_list.extend([
  33. transforms.RandomDistort(
  34. brightness_range=params.brightness_range,
  35. brightness_prob=params.brightness_prob,
  36. contrast_range=params.contrast_range,
  37. contrast_prob=params.contrast_prob,
  38. saturation_range=params.saturation_range,
  39. saturation_prob=params.saturation_prob,
  40. hue_range=params.hue_range,
  41. hue_prob=params.hue_prob),
  42. transforms.RandomVerticalFlip(prob=params.vertical_flip_prob),
  43. transforms.RandomHorizontalFlip(prob=params.horizontal_flip_prob),
  44. transforms.RandomPaddingCrop(crop_size=max(params.image_shape)),
  45. transforms.Normalize(
  46. mean=params.image_mean, std=params.image_std)
  47. ])
  48. train_transforms = transforms.Compose(seg_list)
  49. eval_transforms = transforms.Compose([
  50. transforms.ResizeByLong(long_size=max(params.image_shape)),
  51. transforms.Padding(target_size=max(params.image_shape)),
  52. transforms.Normalize(
  53. mean=params.image_mean, std=params.image_std)
  54. ])
  55. return train_transforms, eval_transforms
  56. def build_datasets(dataset_path, train_transforms, eval_transforms):
  57. import paddlex as pdx
  58. train_file_list = osp.join(dataset_path, 'train_list.txt')
  59. eval_file_list = osp.join(dataset_path, 'val_list.txt')
  60. label_list = osp.join(dataset_path, 'labels.txt')
  61. train_dataset = pdx.datasets.SegDataset(
  62. data_dir=dataset_path,
  63. file_list=train_file_list,
  64. label_list=label_list,
  65. transforms=train_transforms,
  66. shuffle=True)
  67. eval_dataset = pdx.datasets.SegDataset(
  68. data_dir=dataset_path,
  69. file_list=eval_file_list,
  70. label_list=label_list,
  71. transforms=eval_transforms)
  72. return train_dataset, eval_dataset
  73. def build_optimizer(step_each_epoch, params):
  74. import paddle.fluid as fluid
  75. if params.lr_policy == 'Piecewise':
  76. gamma = 0.1
  77. bd = [step_each_epoch * e for e in params.lr_decay_epochs]
  78. lr = [params.learning_rate * (gamma**i) for i in range(len(bd) + 1)]
  79. decayed_lr = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
  80. elif params.lr_policy == 'Polynomial':
  81. decay_step = params.num_epochs * step_each_epoch
  82. decayed_lr = fluid.layers.polynomial_decay(
  83. params.learning_rate, decay_step, end_learning_rate=0, power=0.9)
  84. elif params.lr_policy == 'Cosine':
  85. decayed_lr = fluid.layers.cosine_decay(
  86. params.learning_rate, step_each_epoch, params.num_epochs)
  87. else:
  88. raise Exception(
  89. 'lr_policy only support Polynomial or Piecewise, but you set {}'.
  90. format(params.lr_policy))
  91. if params.optimizer.lower() == 'sgd':
  92. momentum = 0.9
  93. regularize_coef = 1e-4
  94. optimizer = fluid.optimizer.Momentum(
  95. learning_rate=decayed_lr,
  96. momentum=momentum,
  97. regularization=fluid.regularizer.L2Decay(
  98. regularization_coeff=regularize_coef), )
  99. elif params.optimizer.lower() == 'adam':
  100. momentum = 0.9
  101. momentum2 = 0.999
  102. regularize_coef = 1e-4
  103. optimizer = fluid.optimizer.Adam(
  104. learning_rate=decayed_lr,
  105. beta1=momentum,
  106. beta2=momentum2,
  107. regularization=fluid.regularizer.L2Decay(
  108. regularization_coeff=regularize_coef), )
  109. return optimizer
  110. def train(task_path, dataset_path, params):
  111. import paddlex as pdx
  112. pdx.log_level = 3
  113. train_transforms, eval_transforms = build_transforms(params)
  114. train_dataset, eval_dataset = build_datasets(
  115. dataset_path=dataset_path,
  116. train_transforms=train_transforms,
  117. eval_transforms=eval_transforms)
  118. step_each_epoch = train_dataset.num_samples // params.batch_size
  119. save_interval_epochs = params.save_interval_epochs
  120. save_dir = osp.join(task_path, 'output')
  121. pretrain_weights = params.pretrain_weights
  122. optimizer = build_optimizer(step_each_epoch, params)
  123. segmenter = getattr(pdx.cv.models, 'HRNet'
  124. if params.model.startswith('HRNet') else params.model)
  125. use_dice_loss, use_bce_loss = params.loss_type
  126. backbone = params.backbone
  127. sensitivities_path = params.sensitivities_path
  128. eval_metric_loss = params.eval_metric_loss
  129. if eval_metric_loss is None:
  130. eval_metric_loss = 0.05
  131. if params.model in ['UNet', 'HRNet_W18', 'FastSCNN']:
  132. model = segmenter(
  133. num_classes=len(train_dataset.labels),
  134. use_bce_loss=use_bce_loss,
  135. use_dice_loss=use_dice_loss)
  136. elif params.model == 'DeepLabv3p':
  137. model = segmenter(
  138. num_classes=len(train_dataset.labels),
  139. backbone=backbone,
  140. use_bce_loss=use_bce_loss,
  141. use_dice_loss=use_dice_loss)
  142. if backbone == 'MobileNetV3_large_x1_0_ssld':
  143. model.pooling_crop_size = params.image_shape
  144. model.train(
  145. num_epochs=params.num_epochs,
  146. train_dataset=train_dataset,
  147. train_batch_size=params.batch_size,
  148. eval_dataset=eval_dataset,
  149. save_interval_epochs=save_interval_epochs,
  150. log_interval_steps=2,
  151. save_dir=save_dir,
  152. pretrain_weights=pretrain_weights,
  153. optimizer=optimizer,
  154. use_vdl=True,
  155. sensitivities_file=sensitivities_path,
  156. eval_metric_loss=eval_metric_loss,
  157. resume_checkpoint=params.resume_checkpoint)