| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- #Licensed under the Apache License, Version 2.0 (the "License");
- #you may not use this file except in compliance with the License.
- #You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- #Unless required by applicable law or agreed to in writing, software
- #distributed under the License is distributed on an "AS IS" BASIS,
- #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- #See the License for the specific language governing permissions and
- #limitations under the License.
- import argparse
- import os
- # 选择使用0号卡
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
- from paddlex.det import transforms
- import paddlex as pdx
- def train(model_dir, sensitivities_file, eval_metric_loss):
- # 下载和解压昆虫检测数据集
- insect_dataset = 'https://bj.bcebos.com/paddlex/datasets/insect_det.tar.gz'
- pdx.utils.download_and_decompress(insect_dataset, path='./')
- # 定义训练和验证时的transforms
- train_transforms = transforms.Compose([
- transforms.MixupImage(mixup_epoch=250), transforms.RandomDistort(),
- transforms.RandomExpand(), transforms.RandomCrop(), transforms.Resize(
- target_size=608, interp='RANDOM'),
- transforms.RandomHorizontalFlip(), transforms.Normalize()
- ])
- eval_transforms = transforms.Compose([
- transforms.Resize(
- target_size=608, interp='CUBIC'), transforms.Normalize()
- ])
- # 定义训练和验证所用的数据集
- train_dataset = pdx.datasets.VOCDetection(
- data_dir='insect_det',
- file_list='insect_det/train_list.txt',
- label_list='insect_det/labels.txt',
- transforms=train_transforms,
- shuffle=True)
- eval_dataset = pdx.datasets.VOCDetection(
- data_dir='insect_det',
- file_list='insect_det/val_list.txt',
- label_list='insect_det/labels.txt',
- transforms=eval_transforms)
- if model_dir is None:
- # 使用imagenet数据集上的预训练权重
- pretrain_weights = "IMAGENET"
- else:
- assert os.path.isdir(model_dir), "Path {} is not a directory".format(
- model_dir)
- pretrain_weights = model_dir
- save_dir = "output/yolov3_mobile"
- if sensitivities_file is not None:
- if sensitivities_file != 'DEFAULT':
- assert os.path.exists(
- sensitivities_file), "Path {} not exist".format(
- sensitivities_file)
- save_dir = "output/yolov3_mobile_prune"
- num_classes = len(train_dataset.labels)
- model = pdx.det.YOLOv3(num_classes=num_classes)
- model.train(
- num_epochs=270,
- train_dataset=train_dataset,
- train_batch_size=8,
- eval_dataset=eval_dataset,
- learning_rate=0.000125,
- lr_decay_epochs=[210, 240],
- pretrain_weights=pretrain_weights,
- save_dir=save_dir,
- use_vdl=True,
- sensitivities_file=sensitivities_file,
- eval_metric_loss=eval_metric_loss)
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description=__doc__)
- parser.add_argument(
- "--model_dir", default=None, type=str, help="The model path.")
- parser.add_argument(
- "--sensitivities_file",
- default=None,
- type=str,
- help="The sensitivities file path.")
- parser.add_argument(
- "--eval_metric_loss",
- default=0.05,
- type=float,
- help="The loss threshold.")
- args = parser.parse_args()
- train(args.model_dir, args.sensitivities_file, args.eval_metric_loss)
|