train.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # coding: utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os.path as osp
  16. import argparse
  17. from paddlex.seg import transforms
  18. import paddlex as pdx
  19. def parse_args():
  20. parser = argparse.ArgumentParser(description='RemoteSensing training')
  21. parser.add_argument(
  22. '--data_dir',
  23. dest='data_dir',
  24. help='dataset directory',
  25. default=None,
  26. type=str)
  27. parser.add_argument(
  28. '--train_file_list',
  29. dest='train_file_list',
  30. help='train file_list',
  31. default=None,
  32. type=str)
  33. parser.add_argument(
  34. '--eval_file_list',
  35. dest='eval_file_list',
  36. help='eval file_list',
  37. default=None,
  38. type=str)
  39. parser.add_argument(
  40. '--label_list',
  41. dest='label_list',
  42. help='label_list file',
  43. default=None,
  44. type=str)
  45. parser.add_argument(
  46. '--save_dir',
  47. dest='save_dir',
  48. help='model save directory',
  49. default=None,
  50. type=str)
  51. parser.add_argument(
  52. '--num_classes',
  53. dest='num_classes',
  54. help='Number of classes',
  55. default=None,
  56. type=int)
  57. parser.add_argument(
  58. '--channel',
  59. dest='channel',
  60. help='number of data channel',
  61. default=3,
  62. type=int)
  63. parser.add_argument(
  64. '--clip_min_value',
  65. dest='clip_min_value',
  66. help='Min values for clipping data',
  67. nargs='+',
  68. default=None,
  69. type=int)
  70. parser.add_argument(
  71. '--clip_max_value',
  72. dest='clip_max_value',
  73. help='Max values for clipping data',
  74. nargs='+',
  75. default=None,
  76. type=int)
  77. parser.add_argument(
  78. '--mean',
  79. dest='mean',
  80. help='Data means',
  81. nargs='+',
  82. default=None,
  83. type=float)
  84. parser.add_argument(
  85. '--std',
  86. dest='std',
  87. help='Data standard deviation',
  88. nargs='+',
  89. default=None,
  90. type=float)
  91. parser.add_argument(
  92. '--num_epochs',
  93. dest='num_epochs',
  94. help='number of traing epochs',
  95. default=100,
  96. type=int)
  97. parser.add_argument(
  98. '--train_batch_size',
  99. dest='train_batch_size',
  100. help='training batch size',
  101. default=4,
  102. type=int)
  103. parser.add_argument(
  104. '--lr', dest='lr', help='learning rate', default=0.01, type=float)
  105. return parser.parse_args()
  106. args = parse_args()
  107. data_dir = args.data_dir
  108. train_list = args.train_file_list
  109. val_list = args.eval_file_list
  110. label_list = args.label_list
  111. save_dir = args.save_dir
  112. num_classes = args.num_classes
  113. channel = args.channel
  114. clip_min_value = args.clip_min_value
  115. clip_max_value = args.clip_max_value
  116. mean = args.mean
  117. std = args.std
  118. num_epochs = args.num_epochs
  119. train_batch_size = args.train_batch_size
  120. lr = args.lr
  121. # 定义训练和验证时的transforms
  122. train_transforms = transforms.Compose([
  123. transforms.RandomVerticalFlip(0.5),
  124. transforms.RandomHorizontalFlip(0.5),
  125. transforms.ResizeStepScaling(0.5, 2.0, 0.25),
  126. transforms.RandomPaddingCrop(im_padding_value=[1000] * channel),
  127. transforms.Clip(
  128. min_val=clip_min_value, max_val=clip_max_value),
  129. transforms.Normalize(
  130. min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
  131. ])
  132. eval_transforms = transforms.Compose([
  133. transforms.Clip(
  134. min_val=clip_min_value, max_val=clip_max_value),
  135. transforms.Normalize(
  136. min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
  137. ])
  138. train_dataset = pdx.datasets.SegDataset(
  139. data_dir=data_dir,
  140. file_list=train_list,
  141. label_list=label_list,
  142. transforms=train_transforms,
  143. shuffle=True)
  144. eval_dataset = pdx.datasets.SegDataset(
  145. data_dir=data_dir,
  146. file_list=val_list,
  147. label_list=label_list,
  148. transforms=eval_transforms)
  149. model = pdx.seg.UNet(num_classes=num_classes, input_channel=channel)
  150. model.train(
  151. num_epochs=num_epochs,
  152. train_dataset=train_dataset,
  153. train_batch_size=train_batch_size,
  154. eval_dataset=eval_dataset,
  155. save_interval_epochs=5,
  156. log_interval_steps=10,
  157. save_dir=save_dir,
  158. learning_rate=lr,
  159. use_vdl=True)