yolov3_mobilenet.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. import argparse
  15. import os
  16. # 选择使用0号卡
  17. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  18. from paddlex.det import transforms
  19. import paddlex as pdx
  20. def train(model_dir, sensitivities_file, eval_metric_loss):
  21. # 下载和解压昆虫检测数据集
  22. insect_dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
  23. pdx.utils.download_and_decompress(insect_dataset, path='./')
  24. # 定义训练和验证时的transforms
  25. train_transforms = transforms.Compose([
  26. transforms.MixupImage(mixup_epoch=250), transforms.RandomDistort(),
  27. transforms.RandomExpand(), transforms.RandomCrop(), transforms.Resize(
  28. target_size=608, interp='RANDOM'),
  29. transforms.RandomHorizontalFlip(), transforms.Normalize()
  30. ])
  31. eval_transforms = transforms.Compose([
  32. transforms.Resize(
  33. target_size=608, interp='CUBIC'), transforms.Normalize()
  34. ])
  35. # 定义训练和验证所用的数据集
  36. train_dataset = pdx.datasets.VOCDetection(
  37. data_dir='insect_det',
  38. file_list='insect_det/train_list.txt',
  39. label_list='insect_det/labels.txt',
  40. transforms=train_transforms,
  41. shuffle=True)
  42. eval_dataset = pdx.datasets.VOCDetection(
  43. data_dir='insect_det',
  44. file_list='insect_det/val_list.txt',
  45. label_list='insect_det/labels.txt',
  46. transforms=eval_transforms)
  47. if model_dir is None:
  48. # 使用imagenet数据集上的预训练权重
  49. pretrain_weights = "IMAGENET"
  50. else:
  51. assert os.path.isdir(model_dir), "Path {} is not a directory".format(
  52. model_dir)
  53. pretrain_weights = model_dir
  54. save_dir = "output/yolov3_mobile"
  55. if sensitivities_file is not None:
  56. if sensitivities_file != 'DEFAULT':
  57. assert os.path.exists(
  58. sensitivities_file), "Path {} not exist".format(
  59. sensitivities_file)
  60. save_dir = "output/yolov3_mobile_prune"
  61. num_classes = len(train_dataset.labels)
  62. model = pdx.det.YOLOv3(num_classes=num_classes)
  63. model.train(
  64. num_epochs=270,
  65. train_dataset=train_dataset,
  66. train_batch_size=8,
  67. eval_dataset=eval_dataset,
  68. learning_rate=0.000125,
  69. lr_decay_epochs=[210, 240],
  70. pretrain_weights=pretrain_weights,
  71. save_dir=save_dir,
  72. use_vdl=True,
  73. sensitivities_file=sensitivities_file,
  74. eval_metric_loss=eval_metric_loss)
  75. if __name__ == '__main__':
  76. parser = argparse.ArgumentParser(description=__doc__)
  77. parser.add_argument(
  78. "--model_dir", default=None, type=str, help="The model path.")
  79. parser.add_argument(
  80. "--sensitivities_file",
  81. default=None,
  82. type=str,
  83. help="The sensitivities file path.")
  84. parser.add_argument(
  85. "--eval_metric_loss",
  86. default=0.05,
  87. type=float,
  88. help="The loss threshold.")
  89. args = parser.parse_args()
  90. train(args.model_dir, args.sensitivities_file, args.eval_metric_loss)