unet.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. import argparse
  15. import os
  16. # 选择使用0号卡
  17. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  18. from paddlex.seg import transforms
  19. import paddlex as pdx
  20. def train(model_dir, sensitivities_file, eval_metric_loss):
  21. # 下载和解压视盘分割数据集
  22. optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
  23. pdx.utils.download_and_decompress(optic_dataset, path='./')
  24. # 定义训练和验证时的transforms
  25. train_transforms = transforms.Compose([
  26. transforms.RandomHorizontalFlip(), transforms.ResizeRangeScaling(),
  27. transforms.RandomPaddingCrop(crop_size=512), transforms.Normalize()
  28. ])
  29. eval_transforms = transforms.Compose([
  30. transforms.ResizeByLong(long_size=512),
  31. transforms.Padding(target_size=512), transforms.Normalize()
  32. ])
  33. # 定义训练和验证所用的数据集
  34. train_dataset = pdx.datasets.SegDataset(
  35. data_dir='optic_disc_seg',
  36. file_list='optic_disc_seg/train_list.txt',
  37. label_list='optic_disc_seg/labels.txt',
  38. transforms=train_transforms,
  39. shuffle=True)
  40. eval_dataset = pdx.datasets.SegDataset(
  41. data_dir='optic_disc_seg',
  42. file_list='optic_disc_seg/val_list.txt',
  43. label_list='optic_disc_seg/labels.txt',
  44. transforms=eval_transforms)
  45. if model_dir is None:
  46. # 使用coco数据集上的预训练权重
  47. pretrain_weights = "COCO"
  48. else:
  49. assert os.path.isdir(model_dir), "Path {} is not a directory".format(
  50. model_dir)
  51. pretrain_weights = model_dir
  52. save_dir = "output/unet"
  53. if sensitivities_file is not None:
  54. if sensitivities_file != 'DEFAULT':
  55. assert os.path.exists(
  56. sensitivities_file), "Path {} not exist".format(
  57. sensitivities_file)
  58. save_dir = "output/unet_prune"
  59. num_classes = len(train_dataset.labels)
  60. model = pdx.seg.UNet(num_classes=num_classes)
  61. model.train(
  62. num_epochs=20,
  63. train_dataset=train_dataset,
  64. train_batch_size=4,
  65. eval_dataset=eval_dataset,
  66. learning_rate=0.01,
  67. pretrain_weights=pretrain_weights,
  68. save_dir=save_dir,
  69. use_vdl=True,
  70. sensitivities_file=sensitivities_file,
  71. eval_metric_loss=eval_metric_loss)
  72. if __name__ == '__main__':
  73. parser = argparse.ArgumentParser(description=__doc__)
  74. parser.add_argument(
  75. "--model_dir", default=None, type=str, help="The model path.")
  76. parser.add_argument(
  77. "--sensitivities_file",
  78. default=None,
  79. type=str,
  80. help="The sensitivities file path.")
  81. parser.add_argument(
  82. "--eval_metric_loss",
  83. default=0.05,
  84. type=float,
  85. help="The loss threshold.")
  86. args = parser.parse_args()
  87. train(args.model_dir, args.sensitivities_file, args.eval_metric_loss)