train.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # coding: utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import paddlex as pdx
  17. from paddlex.seg import transforms
  18. MODEL_TYPE = ['HumanSegMobile', 'HumanSegServer']
  19. def parse_args():
  20. parser = argparse.ArgumentParser(description='HumanSeg training')
  21. parser.add_argument(
  22. '--model_type',
  23. dest='model_type',
  24. help="Model type for traing, which is one of ('HumanSegMobile', 'HumanSegServer')",
  25. type=str,
  26. default='HumanSegMobile')
  27. parser.add_argument(
  28. '--data_dir',
  29. dest='data_dir',
  30. help='The root directory of dataset',
  31. type=str)
  32. parser.add_argument(
  33. '--train_list',
  34. dest='train_list',
  35. help='Train list file of dataset',
  36. type=str)
  37. parser.add_argument(
  38. '--val_list',
  39. dest='val_list',
  40. help='Val list file of dataset',
  41. type=str,
  42. default=None)
  43. parser.add_argument(
  44. '--save_dir',
  45. dest='save_dir',
  46. help='The directory for saving the model snapshot',
  47. type=str,
  48. default='./output')
  49. parser.add_argument(
  50. '--num_classes',
  51. dest='num_classes',
  52. help='Number of classes',
  53. type=int,
  54. default=2)
  55. parser.add_argument(
  56. "--image_shape",
  57. dest="image_shape",
  58. help="The image shape for net inputs.",
  59. nargs=2,
  60. default=[192, 192],
  61. type=int)
  62. parser.add_argument(
  63. '--num_epochs',
  64. dest='num_epochs',
  65. help='Number epochs for training',
  66. type=int,
  67. default=100)
  68. parser.add_argument(
  69. '--batch_size',
  70. dest='batch_size',
  71. help='Mini batch size',
  72. type=int,
  73. default=128)
  74. parser.add_argument(
  75. '--learning_rate',
  76. dest='learning_rate',
  77. help='Learning rate',
  78. type=float,
  79. default=0.01)
  80. parser.add_argument(
  81. '--pretrain_weights',
  82. dest='pretrain_weights',
  83. help='The path of pretrianed weight',
  84. type=str,
  85. default=None)
  86. parser.add_argument(
  87. '--resume_checkpoint',
  88. dest='resume_checkpoint',
  89. help='The path of resume checkpoint',
  90. type=str,
  91. default=None)
  92. parser.add_argument(
  93. '--use_vdl',
  94. dest='use_vdl',
  95. help='Whether to use visualdl',
  96. action='store_true')
  97. parser.add_argument(
  98. '--save_interval_epochs',
  99. dest='save_interval_epochs',
  100. help='The interval epochs for save a model snapshot',
  101. type=int,
  102. default=5)
  103. return parser.parse_args()
  104. def train(args):
  105. train_transforms = transforms.Compose([
  106. transforms.Resize(args.image_shape), transforms.RandomHorizontalFlip(),
  107. transforms.Normalize()
  108. ])
  109. eval_transforms = transforms.Compose(
  110. [transforms.Resize(args.image_shape), transforms.Normalize()])
  111. train_dataset = pdx.datasets.SegDataset(
  112. data_dir=args.data_dir,
  113. file_list=args.train_list,
  114. transforms=train_transforms,
  115. shuffle=True)
  116. eval_dataset = pdx.datasets.SegDataset(
  117. data_dir=args.data_dir,
  118. file_list=args.val_list,
  119. transforms=eval_transforms)
  120. if args.model_type == 'HumanSegMobile':
  121. model = pdx.seg.HRNet(
  122. num_classes=args.num_classes, width='18_small_v1')
  123. elif args.model_type == 'HumanSegServer':
  124. model = pdx.seg.DeepLabv3p(
  125. num_classes=args.num_classes, backbone='Xception65')
  126. else:
  127. raise ValueError(
  128. "--model_type: {} is set wrong, it shold be one of ('HumanSegMobile', "
  129. "'HumanSegLite', 'HumanSegServer')".format(args.model_type))
  130. model.train(
  131. num_epochs=args.num_epochs,
  132. train_dataset=train_dataset,
  133. train_batch_size=args.batch_size,
  134. eval_dataset=eval_dataset,
  135. save_interval_epochs=args.save_interval_epochs,
  136. learning_rate=args.learning_rate,
  137. pretrain_weights=args.pretrain_weights,
  138. resume_checkpoint=args.resume_checkpoint,
  139. save_dir=args.save_dir,
  140. use_vdl=args.use_vdl)
  141. if __name__ == '__main__':
  142. args = parse_args()
  143. train(args)