train.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # coding: utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. # 选择使用0号卡
  17. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  18. # 使用CPU
  19. #os.environ['CUDA_VISIBLE_DEVICES'] = ''
  20. import argparse
  21. import paddlex as pdx
  22. from paddlex.seg import transforms
  23. MODEL_TYPE = ['HumanSegMobile', 'HumanSegServer']
  24. def parse_args():
  25. parser = argparse.ArgumentParser(description='HumanSeg training')
  26. parser.add_argument(
  27. '--model_type',
  28. dest='model_type',
  29. help="Model type for traing, which is one of ('HumanSegMobile', 'HumanSegServer')",
  30. type=str,
  31. default='HumanSegMobile')
  32. parser.add_argument(
  33. '--data_dir',
  34. dest='data_dir',
  35. help='The root directory of dataset',
  36. type=str)
  37. parser.add_argument(
  38. '--train_list',
  39. dest='train_list',
  40. help='Train list file of dataset',
  41. type=str)
  42. parser.add_argument(
  43. '--val_list',
  44. dest='val_list',
  45. help='Val list file of dataset',
  46. type=str,
  47. default=None)
  48. parser.add_argument(
  49. '--save_dir',
  50. dest='save_dir',
  51. help='The directory for saving the model snapshot',
  52. type=str,
  53. default='./output')
  54. parser.add_argument(
  55. '--num_classes',
  56. dest='num_classes',
  57. help='Number of classes',
  58. type=int,
  59. default=2)
  60. parser.add_argument(
  61. "--image_shape",
  62. dest="image_shape",
  63. help="The image shape for net inputs.",
  64. nargs=2,
  65. default=[192, 192],
  66. type=int)
  67. parser.add_argument(
  68. '--num_epochs',
  69. dest='num_epochs',
  70. help='Number epochs for training',
  71. type=int,
  72. default=100)
  73. parser.add_argument(
  74. '--batch_size',
  75. dest='batch_size',
  76. help='Mini batch size',
  77. type=int,
  78. default=128)
  79. parser.add_argument(
  80. '--learning_rate',
  81. dest='learning_rate',
  82. help='Learning rate',
  83. type=float,
  84. default=0.01)
  85. parser.add_argument(
  86. '--pretrain_weights',
  87. dest='pretrain_weights',
  88. help='The path of pretrianed weight',
  89. type=str,
  90. default=None)
  91. parser.add_argument(
  92. '--resume_checkpoint',
  93. dest='resume_checkpoint',
  94. help='The path of resume checkpoint',
  95. type=str,
  96. default=None)
  97. parser.add_argument(
  98. '--use_vdl',
  99. dest='use_vdl',
  100. help='Whether to use visualdl',
  101. action='store_true')
  102. parser.add_argument(
  103. '--save_interval_epochs',
  104. dest='save_interval_epochs',
  105. help='The interval epochs for save a model snapshot',
  106. type=int,
  107. default=5)
  108. return parser.parse_args()
  109. def train(args):
  110. train_transforms = transforms.Compose([
  111. transforms.Resize(args.image_shape), transforms.RandomHorizontalFlip(),
  112. transforms.Normalize()
  113. ])
  114. eval_transforms = transforms.Compose(
  115. [transforms.Resize(args.image_shape), transforms.Normalize()])
  116. train_dataset = pdx.datasets.SegDataset(
  117. data_dir=args.data_dir,
  118. file_list=args.train_list,
  119. transforms=train_transforms,
  120. shuffle=True)
  121. eval_dataset = pdx.datasets.SegDataset(
  122. data_dir=args.data_dir,
  123. file_list=args.val_list,
  124. transforms=eval_transforms)
  125. if args.model_type == 'HumanSegMobile':
  126. model = pdx.seg.HRNet(
  127. num_classes=args.num_classes, width='18_small_v1')
  128. elif args.model_type == 'HumanSegServer':
  129. model = pdx.seg.DeepLabv3p(
  130. num_classes=args.num_classes, backbone='Xception65')
  131. else:
  132. raise ValueError(
  133. "--model_type: {} is set wrong, it shold be one of ('HumanSegMobile', "
  134. "'HumanSegLite', 'HumanSegServer')".format(args.model_type))
  135. model.train(
  136. num_epochs=args.num_epochs,
  137. train_dataset=train_dataset,
  138. train_batch_size=args.batch_size,
  139. eval_dataset=eval_dataset,
  140. save_interval_epochs=args.save_interval_epochs,
  141. learning_rate=args.learning_rate,
  142. pretrain_weights=args.pretrain_weights,
  143. resume_checkpoint=args.resume_checkpoint,
  144. save_dir=args.save_dir,
  145. use_vdl=args.use_vdl)
  146. if __name__ == '__main__':
  147. args = parse_args()
  148. train(args)