classification.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. from paddleslim import L1NormFilterPruner
  16. def build_transforms(params):
  17. from paddlex import transforms as T
  18. crop_size = params.image_shape[0]
  19. train_transforms = T.Compose([
  20. T.RandomCrop(
  21. crop_size=crop_size,
  22. scaling=[.88, 1.],
  23. aspect_ratio=[3. / 4, 4. / 3]),
  24. T.RandomHorizontalFlip(prob=params.horizontal_flip_prob),
  25. T.RandomVerticalFlip(prob=params.vertical_flip_prob), T.RandomDistort(
  26. brightness_range=params.brightness_range,
  27. brightness_prob=params.brightness_prob,
  28. contrast_range=params.contrast_range,
  29. contrast_prob=params.contrast_prob,
  30. saturation_range=params.saturation_range,
  31. saturation_prob=params.saturation_prob,
  32. hue_range=params.hue_range,
  33. hue_prob=params.hue_prob), T.Normalize(
  34. mean=params.image_mean, std=params.image_std)
  35. ])
  36. eval_transforms = T.Compose([
  37. T.ResizeByShort(short_size=int(crop_size * 1.143)),
  38. T.CenterCrop(crop_size=crop_size), T.Normalize(
  39. mean=params.image_mean, std=params.image_std)
  40. ])
  41. return train_transforms, eval_transforms
  42. def build_datasets(dataset_path, train_transforms, eval_transforms):
  43. import paddlex as pdx
  44. train_file_list = osp.join(dataset_path, 'train_list.txt')
  45. eval_file_list = osp.join(dataset_path, 'val_list.txt')
  46. label_list = osp.join(dataset_path, 'labels.txt')
  47. train_dataset = pdx.datasets.ImageNet(
  48. data_dir=dataset_path,
  49. file_list=train_file_list,
  50. label_list=label_list,
  51. transforms=train_transforms,
  52. shuffle=True)
  53. eval_dataset = pdx.datasets.ImageNet(
  54. data_dir=dataset_path,
  55. file_list=eval_file_list,
  56. label_list=label_list,
  57. transforms=eval_transforms)
  58. return train_dataset, eval_dataset
  59. def build_optimizer(parameters, step_each_epoch, params):
  60. import paddle
  61. from paddle.regularizer import L2Decay
  62. learning_rate = params.learning_rate
  63. num_epochs = params.num_epochs
  64. if params.lr_policy == 'Cosine':
  65. learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
  66. learning_rate=learning_rate, T_max=step_each_epoch * num_epochs)
  67. elif params.lr_policy == 'Linear':
  68. learning_rate = paddle.optimizer.lr.PolynomialDecay(
  69. learning_rate=learning_rate,
  70. decay_steps=step_each_epoch * num_epochs,
  71. end_lr=0.0,
  72. power=1.0)
  73. elif params.lr_policy == 'Piecewise':
  74. lr_decay_epochs = params.lr_decay_epochs
  75. gamma = 0.1
  76. boundaries = [step_each_epoch * e for e in lr_decay_epochs]
  77. values = [
  78. learning_rate * (gamma**i)
  79. for i in range(len(lr_decay_epochs) + 1)
  80. ]
  81. learning_rate = paddle.optimizer.lr.PiecewiseDecay(
  82. boundaries=boundaries, values=values)
  83. optimizer = paddle.optimizer.Momentum(
  84. learning_rate=learning_rate,
  85. momentum=.9,
  86. weight_decay=L2Decay(1e-04),
  87. parameters=parameters)
  88. return optimizer
  89. def train(task_path, dataset_path, params):
  90. import paddlex as pdx
  91. pdx.log_level = 3
  92. train_transforms, eval_transforms = build_transforms(params)
  93. train_dataset, eval_dataset = build_datasets(
  94. dataset_path=dataset_path,
  95. train_transforms=train_transforms,
  96. eval_transforms=eval_transforms)
  97. step_each_epoch = train_dataset.num_samples // params.batch_size
  98. save_interval_epochs = params.save_interval_epochs
  99. save_dir = osp.join(task_path, 'output')
  100. pretrain_weights = params.pretrain_weights
  101. if pretrain_weights is not None and osp.exists(pretrain_weights):
  102. pretrain_weights = osp.join(pretrain_weights, 'model.pdparams')
  103. classifier = getattr(pdx.cls, params.model)
  104. sensitivities_path = params.sensitivities_path
  105. pruned_flops = params.pruned_flops
  106. model = classifier(num_classes=len(train_dataset.labels))
  107. if sensitivities_path is not None:
  108. # load weights
  109. model.net_initialize(pretrain_weights=pretrain_weights)
  110. pretrain_weights = None
  111. # prune
  112. inputs = [1, 3] + list(eval_dataset[0]['image'].shape[:2])
  113. model.pruner = L1NormFilterPruner(
  114. model.net, inputs=inputs, sen_file=sensitivities_path)
  115. model.prune(pruned_flops=pruned_flops)
  116. optimizer = build_optimizer(model.net.parameters(), step_each_epoch,
  117. params)
  118. model.train(
  119. num_epochs=params.num_epochs,
  120. train_dataset=train_dataset,
  121. train_batch_size=params.batch_size,
  122. eval_dataset=eval_dataset,
  123. save_interval_epochs=save_interval_epochs,
  124. log_interval_steps=2,
  125. save_dir=save_dir,
  126. pretrain_weights=pretrain_weights,
  127. optimizer=optimizer,
  128. use_vdl=True,
  129. resume_checkpoint=params.resume_checkpoint)