mobilenetv2.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import os
  16. # 选择使用0号卡
  17. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  18. from paddlex.cls import transforms
  19. import paddlex as pdx
  20. def train(model_dir=None, sensitivities_file=None, eval_metric_loss=0.05):
  21. # 下载和解压蔬菜分类数据集
  22. veg_dataset = 'https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz'
  23. pdx.utils.download_and_decompress(veg_dataset, path='./')
  24. # 定义训练和验证时的transforms
  25. train_transforms = transforms.Compose([
  26. transforms.RandomCrop(crop_size=224),
  27. transforms.RandomHorizontalFlip(), transforms.Normalize()
  28. ])
  29. eval_transforms = transforms.Compose([
  30. transforms.ResizeByShort(short_size=256),
  31. transforms.CenterCrop(crop_size=224), transforms.Normalize()
  32. ])
  33. # 定义训练和验证所用的数据集
  34. train_dataset = pdx.datasets.ImageNet(
  35. data_dir='vegetables_cls',
  36. file_list='vegetables_cls/train_list.txt',
  37. label_list='vegetables_cls/labels.txt',
  38. transforms=train_transforms,
  39. shuffle=True)
  40. eval_dataset = pdx.datasets.ImageNet(
  41. data_dir='vegetables_cls',
  42. file_list='vegetables_cls/val_list.txt',
  43. label_list='vegetables_cls/labels.txt',
  44. transforms=eval_transforms)
  45. num_classes = len(train_dataset.labels)
  46. model = pdx.cls.MobileNetV2(num_classes=num_classes)
  47. if model_dir is None:
  48. # 使用imagenet数据集预训练模型权重
  49. pretrain_weights = "IMAGENET"
  50. else:
  51. # 使用传入的model_dir作为预训练模型权重
  52. assert os.path.isdir(model_dir), "Path {} is not a directory".format(
  53. model_dir)
  54. pretrain_weights = model_dir
  55. save_dir = './output/mobilenetv2'
  56. if sensitivities_file is not None:
  57. # DEFAULT 指使用模型预置的参数敏感度信息作为裁剪依据
  58. if sensitivities_file != "DEFAULT":
  59. assert os.path.exists(
  60. sensitivities_file), "Path {} not exist".format(
  61. sensitivities_file)
  62. save_dir = './output/mobilenetv2_prune'
  63. model.train(
  64. num_epochs=10,
  65. train_dataset=train_dataset,
  66. train_batch_size=32,
  67. eval_dataset=eval_dataset,
  68. lr_decay_epochs=[4, 6, 8],
  69. learning_rate=0.025,
  70. pretrain_weights=pretrain_weights,
  71. save_dir=save_dir,
  72. use_vdl=True,
  73. sensitivities_file=sensitivities_file,
  74. eval_metric_loss=eval_metric_loss)
  75. if __name__ == '__main__':
  76. parser = argparse.ArgumentParser(description=__doc__)
  77. parser.add_argument(
  78. "--model_dir", default=None, type=str, help="The model path.")
  79. parser.add_argument(
  80. "--sensitivities_file",
  81. default=None,
  82. type=str,
  83. help="The sensitivities file path.")
  84. parser.add_argument(
  85. "--eval_metric_loss",
  86. default=0.05,
  87. type=float,
  88. help="The loss threshold.")
  89. args = parser.parse_args()
  90. train(args.model_dir, args.sensitivities_file, args.eval_metric_loss)