classification.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. def build_transforms(params):
  16. from paddlex.cls import transforms
  17. crop_size = params.image_shape[0]
  18. train_transforms = transforms.Compose([
  19. transforms.RandomCrop(
  20. crop_size=crop_size,
  21. lower_scale=0.88,
  22. lower_ratio=3. / 4,
  23. upper_ratio=4. / 3),
  24. transforms.RandomHorizontalFlip(prob=params.horizontal_flip_prob),
  25. transforms.RandomVerticalFlip(prob=params.vertical_flip_prob),
  26. transforms.RandomDistort(
  27. brightness_range=params.brightness_range,
  28. brightness_prob=params.brightness_prob,
  29. contrast_range=params.contrast_range,
  30. contrast_prob=params.contrast_prob,
  31. saturation_range=params.saturation_range,
  32. saturation_prob=params.saturation_prob,
  33. hue_range=params.hue_range,
  34. hue_prob=params.hue_prob), transforms.RandomRotate(
  35. rotate_range=params.rotate_range,
  36. prob=params.rotate_prob), transforms.Normalize(
  37. mean=params.image_mean, std=params.image_std)
  38. ])
  39. eval_transforms = transforms.Compose([
  40. transforms.ResizeByShort(short_size=int(crop_size * 1.143)),
  41. transforms.CenterCrop(crop_size=crop_size), transforms.Normalize(
  42. mean=params.image_mean, std=params.image_std)
  43. ])
  44. return train_transforms, eval_transforms
  45. def build_datasets(dataset_path, train_transforms, eval_transforms):
  46. import paddlex as pdx
  47. train_file_list = osp.join(dataset_path, 'train_list.txt')
  48. eval_file_list = osp.join(dataset_path, 'val_list.txt')
  49. label_list = osp.join(dataset_path, 'labels.txt')
  50. train_dataset = pdx.datasets.ImageNet(
  51. data_dir=dataset_path,
  52. file_list=train_file_list,
  53. label_list=label_list,
  54. transforms=train_transforms,
  55. shuffle=True)
  56. eval_dataset = pdx.datasets.ImageNet(
  57. data_dir=dataset_path,
  58. file_list=eval_file_list,
  59. label_list=label_list,
  60. transforms=eval_transforms)
  61. return train_dataset, eval_dataset
  62. def build_optimizer(step_each_epoch, params):
  63. import paddle.fluid as fluid
  64. from paddle.fluid.regularizer import L2Decay
  65. learning_rate = params.learning_rate
  66. num_epochs = params.num_epochs
  67. if params.lr_policy == 'Cosine':
  68. learning_rate = fluid.layers.cosine_decay(
  69. learning_rate=learning_rate,
  70. step_each_epoch=step_each_epoch,
  71. epochs=num_epochs)
  72. elif params.lr_policy == 'Linear':
  73. learning_rate = fluid.layers.polynomial_decay(
  74. learning_rate=learning_rate,
  75. decay_steps=step_each_epoch * num_epochs,
  76. end_learning_rate=0.0,
  77. power=1.0)
  78. elif params.lr_policy == 'Piecewise':
  79. lr_decay_epochs = params.lr_decay_epochs
  80. values = [
  81. learning_rate * (0.1**i) for i in range(len(lr_decay_epochs) + 1)
  82. ]
  83. boundaries = [b * step_each_epoch for b in lr_decay_epochs]
  84. learning_rate = fluid.layers.piecewise_decay(
  85. boundaries=boundaries, values=values)
  86. optimizer = fluid.optimizer.Momentum(
  87. learning_rate=learning_rate,
  88. momentum=0.9,
  89. regularization=L2Decay(1e-04))
  90. return optimizer
  91. def train(task_path, dataset_path, params):
  92. import paddlex as pdx
  93. pdx.log_level = 3
  94. train_transforms, eval_transforms = build_transforms(params)
  95. train_dataset, eval_dataset = build_datasets(
  96. dataset_path=dataset_path,
  97. train_transforms=train_transforms,
  98. eval_transforms=eval_transforms)
  99. step_each_epoch = train_dataset.num_samples // params.batch_size
  100. save_interval_epochs = params.save_interval_epochs
  101. save_dir = osp.join(task_path, 'output')
  102. pretrain_weights = params.pretrain_weights
  103. optimizer = build_optimizer(step_each_epoch, params)
  104. classifier = getattr(pdx.cv.models, params.model)
  105. sensitivities_path = params.sensitivities_path
  106. eval_metric_loss = params.eval_metric_loss
  107. if eval_metric_loss is None:
  108. eval_metric_loss = 0.05
  109. model = classifier(num_classes=len(train_dataset.labels))
  110. model.train(
  111. num_epochs=params.num_epochs,
  112. train_dataset=train_dataset,
  113. train_batch_size=params.batch_size,
  114. eval_dataset=eval_dataset,
  115. save_interval_epochs=save_interval_epochs,
  116. log_interval_steps=2,
  117. save_dir=save_dir,
  118. pretrain_weights=pretrain_weights,
  119. optimizer=optimizer,
  120. use_vdl=True,
  121. sensitivities_file=sensitivities_path,
  122. eval_metric_loss=eval_metric_loss,
  123. resume_checkpoint=params.resume_checkpoint)