yolov3_mobilenet.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. import argparse
  15. import os
  16. # 选择使用0号卡
  17. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  18. from paddlex.det import transforms
  19. import paddlex as pdx
  20. def train(model_dir, sensitivities_file, eval_metric_loss):
  21. # 下载和解压昆虫检测数据集
  22. insect_dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
  23. pdx.utils.download_and_decompress(insect_dataset, path='./')
  24. # 定义训练和验证时的transforms
  25. train_transforms = transforms.Compose([
  26. transforms.MixupImage(mixup_epoch=250),
  27. transforms.RandomDistort(),
  28. transforms.RandomExpand(),
  29. transforms.RandomCrop(),
  30. transforms.Resize(target_size=608, interp='RANDOM'),
  31. transforms.RandomHorizontalFlip(),
  32. transforms.Normalize()
  33. ])
  34. eval_transforms = transforms.Compose([
  35. transforms.Resize(target_size=608, interp='CUBIC'),
  36. transforms.Normalize()
  37. ])
  38. # 定义训练和验证所用的数据集
  39. train_dataset = pdx.datasets.VOCDetection(
  40. data_dir='insect_det',
  41. file_list='insect_det/train_list.txt',
  42. label_list='insect_det/labels.txt',
  43. transforms=train_transforms,
  44. shuffle=True)
  45. eval_dataset = pdx.datasets.VOCDetection(
  46. data_dir='insect_det',
  47. file_list='insect_det/val_list.txt',
  48. label_list='insect_det/labels.txt',
  49. transforms=eval_transforms)
  50. if model_dir is None:
  51. # 使用imagenet数据集上的预训练权重
  52. pretrain_weights = "IMAGENET"
  53. else:
  54. assert os.path.isdir(model_dir), "Path {} is not a directory".format(
  55. model_dir)
  56. pretrain_weights = model_dir
  57. save_dir = "output/yolov3_mobile"
  58. if sensitivities_file is not None:
  59. if sensitivities_file != 'DEFAULT':
  60. assert os.path.exists(
  61. sensitivities_file), "Path {} not exist".format(
  62. sensitivities_file)
  63. save_dir = "output/yolov3_mobile_prune"
  64. num_classes = len(train_dataset.labels)
  65. model = pdx.det.YOLOv3(num_classes=num_classes)
  66. model.train(
  67. num_epochs=270,
  68. train_dataset=train_dataset,
  69. train_batch_size=8,
  70. eval_dataset=eval_dataset,
  71. learning_rate=0.000125,
  72. lr_decay_epochs=[210, 240],
  73. pretrain_weights=pretrain_weights,
  74. save_dir=save_dir,
  75. use_vdl=True,
  76. sensitivities_file=sensitivities_file,
  77. eval_metric_loss=eval_metric_loss)
  78. if __name__ == '__main__':
  79. parser = argparse.ArgumentParser(description=__doc__)
  80. parser.add_argument(
  81. "--model_dir", default=None, type=str, help="The model path.")
  82. parser.add_argument(
  83. "--sensitivities_file",
  84. default=None,
  85. type=str,
  86. help="The sensitivities file path.")
  87. parser.add_argument(
  88. "--eval_metric_loss",
  89. default=0.05,
  90. type=float,
  91. help="The loss threshold.")
  92. args = parser.parse_args()
  93. train(args.model_dir, args.sensitivities_file, args.eval_metric_loss)