train_demo.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # coding: utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os.path as osp
  16. import argparse
  17. from paddlex.seg import transforms
  18. import paddlex.RemoteSensing.transforms as custom_transforms
  19. import paddlex as pdx
  20. def parse_args():
  21. parser = argparse.ArgumentParser(description='RemoteSensing training')
  22. parser.add_argument(
  23. '--data_dir',
  24. dest='data_dir',
  25. help='dataset directory',
  26. default=None,
  27. type=str)
  28. parser.add_argument(
  29. '--save_dir',
  30. dest='save_dir',
  31. help='model save directory',
  32. default=None,
  33. type=str)
  34. parser.add_argument(
  35. '--num_classes',
  36. dest='num_classes',
  37. help='Number of classes',
  38. default=None,
  39. type=int)
  40. parser.add_argument(
  41. '--channel',
  42. dest='channel',
  43. help='number of data channel',
  44. default=3,
  45. type=int)
  46. parser.add_argument(
  47. '--clip_min_value',
  48. dest='clip_min_value',
  49. help='Min values for clipping data',
  50. nargs='+',
  51. default=None,
  52. type=int)
  53. parser.add_argument(
  54. '--clip_max_value',
  55. dest='clip_max_value',
  56. help='Max values for clipping data',
  57. nargs='+',
  58. default=None,
  59. type=int)
  60. parser.add_argument(
  61. '--mean',
  62. dest='mean',
  63. help='Data means',
  64. nargs='+',
  65. default=None,
  66. type=float)
  67. parser.add_argument(
  68. '--std',
  69. dest='std',
  70. help='Data standard deviation',
  71. nargs='+',
  72. default=None,
  73. type=float)
  74. parser.add_argument(
  75. '--num_epochs',
  76. dest='num_epochs',
  77. help='number of traing epochs',
  78. default=100,
  79. type=int)
  80. parser.add_argument(
  81. '--train_batch_size',
  82. dest='train_batch_size',
  83. help='training batch size',
  84. default=4,
  85. type=int)
  86. parser.add_argument(
  87. '--lr', dest='lr', help='learning rate', default=0.01, type=float)
  88. return parser.parse_args()
  89. args = parse_args()
  90. data_dir = args.data_dir
  91. save_dir = args.save_dir
  92. num_classes = args.num_classes
  93. channel = args.channel
  94. clip_min_value = args.clip_min_value
  95. clip_max_value = args.clip_max_value
  96. mean = args.mean
  97. std = args.std
  98. num_epochs = args.num_epochs
  99. train_batch_size = args.train_batch_size
  100. lr = args.lr
  101. # 定义训练和验证时的transforms
  102. train_transforms = transforms.Compose([
  103. transforms.RandomVerticalFlip(0.5),
  104. transforms.RandomHorizontalFlip(0.5),
  105. transforms.ResizeStepScaling(0.5, 2.0, 0.25),
  106. transforms.RandomPaddingCrop(im_padding_value=[1000] * channel),
  107. custom_transforms.Clip(
  108. min_val=clip_min_value, max_val=clip_max_value),
  109. transforms.Normalize(
  110. min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
  111. ])
  112. train_transforms.decode_image = custom_transforms.decode_image
  113. eval_transforms = transforms.Compose([
  114. custom_transforms.Clip(
  115. min_val=clip_min_value, max_val=clip_max_value),
  116. transforms.Normalize(
  117. min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
  118. ])
  119. eval_transforms.decode_image = custom_transforms.decode_image
  120. train_list = osp.join(data_dir, 'train.txt')
  121. val_list = osp.join(data_dir, 'val.txt')
  122. label_list = osp.join(data_dir, 'labels.txt')
  123. train_dataset = pdx.datasets.SegDataset(
  124. data_dir=data_dir,
  125. file_list=train_list,
  126. label_list=label_list,
  127. transforms=train_transforms,
  128. shuffle=True)
  129. eval_dataset = pdx.datasets.SegDataset(
  130. data_dir=data_dir,
  131. file_list=val_list,
  132. label_list=label_list,
  133. transforms=eval_transforms)
  134. model = pdx.seg.UNet(num_classes=num_classes, input_channel=channel)
  135. model.train(
  136. num_epochs=num_epochs,
  137. train_dataset=train_dataset,
  138. train_batch_size=train_batch_size,
  139. eval_dataset=eval_dataset,
  140. save_interval_epochs=5,
  141. log_interval_steps=10,
  142. save_dir=save_dir,
  143. learning_rate=lr,
  144. use_vdl=True)