|
@@ -17,16 +17,16 @@ tar xzvf vegetables_cls.tar.gz
|
|
|
|
|
|
|
|
## 3. 训练代码开发
|
|
## 3. 训练代码开发
|
|
|
PaddleX的所有模型训练和预测均只涉及到5个API接口,分别是
|
|
PaddleX的所有模型训练和预测均只涉及到5个API接口,分别是
|
|
|
-> - [transforms](apis/transforms/index) 图像数据处理
|
|
|
|
|
-> - [datasets](apis/datasets.md) 数据集加载
|
|
|
|
|
-> - [models](apis/models.md) 模型类型定义
|
|
|
|
|
-> - [train](apis/models.md) 开始训练
|
|
|
|
|
-> - [predict](apis/models.md) 模型预测
|
|
|
|
|
|
|
+> - [transforms](apis/transforms/index.html) 图像数据处理
|
|
|
|
|
+> - [datasets](apis/datasets/classification.md) 数据集加载
|
|
|
|
|
+> - [models](apis/models/classification.md) 模型类型定义
|
|
|
|
|
+> - [train](apis/models/classification.html#train) 开始训练
|
|
|
|
|
+> - [predict](apis/models/classification.html#predict) 模型预测
|
|
|
|
|
|
|
|
在本示例,通过如下`train.py`代码进行训练, 训练环境为1张Tesla P40 GPU卡。
|
|
在本示例,通过如下`train.py`代码进行训练, 训练环境为1张Tesla P40 GPU卡。
|
|
|
|
|
|
|
|
### 3.1 定义`transforms`数据处理流程
|
|
### 3.1 定义`transforms`数据处理流程
|
|
|
-由于训练时数据增强操作的加入,因此模型在训练和验证过程中,数据处理流程需要分别进行定义。如下所示,代码在`train_transforms`中加入了[RandomCrop](apis/transforms/cls_transforms.html#RandomCrop)和[RandomHorizontalFlip](apis/transforms/cls_transforms.html#RandomHorizontalFlip)两种数据增强方式
|
|
|
|
|
|
|
+由于训练时数据增强操作的加入,因此模型在训练和验证过程中,数据处理流程需要分别进行定义。如下所示,代码在`train_transforms`中加入了[RandomCrop](apis/transforms/cls_transforms.html#RandomCrop)和[RandomHorizontalFlip](apis/transforms/cls_transforms.html#RandomHorizontalFlip)两种数据增强方式, 更多方法可以参考[数据增强文档](apis/transforms/augment.md)。
|
|
|
```
|
|
```
|
|
|
from paddlex.cls import transforms
|
|
from paddlex.cls import transforms
|
|
|
train_transforms = transforms.Compose([
|
|
train_transforms = transforms.Compose([
|
|
@@ -41,7 +41,8 @@ eval_transforms = transforms.Compose([
|
|
|
])
|
|
])
|
|
|
```
|
|
```
|
|
|
|
|
|
|
|
-> 定义数据集,`pdx.datasets.ImageNet`表示读取ImageNet格式的分类数据集
|
|
|
|
|
|
|
+### 3.2 定义`dataset`加载数据集
|
|
|
|
|
+定义数据集,`pdx.datasets.ImageNet`表示读取ImageNet格式的分类数据集, 更多数据集细节可以查阅[数据集格式说明](datasets.md)和[ImageNet接口文档](apis/datasets/classification.md)
|
|
|
```
|
|
```
|
|
|
train_dataset = pdx.datasets.ImageNet(
|
|
train_dataset = pdx.datasets.ImageNet(
|
|
|
data_dir='vegetables_cls',
|
|
data_dir='vegetables_cls',
|
|
@@ -55,11 +56,17 @@ eval_dataset = pdx.datasets.ImageNet(
|
|
|
label_list='vegetables_cls/labels.txt',
|
|
label_list='vegetables_cls/labels.txt',
|
|
|
transforms=eval_transforms)
|
|
transforms=eval_transforms)
|
|
|
```
|
|
```
|
|
|
-> 模型训练
|
|
|
|
|
|
|
|
|
|
|
|
+### 3.3 定义分类模型
|
|
|
|
|
+本文档中使用百度基于蒸馏方法得到的MobileNetV3预训练模型,模型结构与MobileNetV3一致,但精度更高。PaddleX内置了20多种分类模型,查阅[PaddleX模型库](appendix/model_zoo.md)了解更多分类模型。
|
|
|
```
|
|
```
|
|
|
num_classes = len(train_dataset.labels)
|
|
num_classes = len(train_dataset.labels)
|
|
|
-model = pdx.cls.MobileNetV2(num_classes=num_classes)
|
|
|
|
|
|
|
+model.pdx.cls.MobileNetV3_small_ssld(num_classes=num_classes)
|
|
|
|
|
+```
|
|
|
|
|
+
|
|
|
|
|
+### 3.4 定义训练参数
|
|
|
|
|
+定义好模型后,即可直接调用`train`接口,定义训练时的参数,分类模型内置了`piecewise_decay`学习率衰减策略,相关参数见[分类train接口文档](apis/models/classification.html#train)。
|
|
|
|
|
+```
|
|
|
model.train(num_epochs=10,
|
|
model.train(num_epochs=10,
|
|
|
train_dataset=train_dataset,
|
|
train_dataset=train_dataset,
|
|
|
train_batch_size=32,
|
|
train_batch_size=32,
|
|
@@ -70,19 +77,21 @@ model.train(num_epochs=10,
|
|
|
use_vdl=True)
|
|
use_vdl=True)
|
|
|
```
|
|
```
|
|
|
|
|
|
|
|
-## 3. 模型训练
|
|
|
|
|
-> `train.py`与解压后的数据集目录`vegetables_cls`放在同一目录下,在此目录下运行`train.py`即可开始训练。如果您的电脑上有GPU,这将会在10分钟内训练完成,如果为CPU也大概会在30分钟内训练完毕。
|
|
|
|
|
|
|
+## 4. 模型开始训练
|
|
|
|
|
+`train.py`与解压后的数据集目录`vegetables_cls`放在同一目录下,在此目录下运行`train.py`即可开始训练。如果您的电脑上有GPU,这将会在10分钟内训练完成,如果为CPU也大概会在30分钟内训练完毕。
|
|
|
```
|
|
```
|
|
|
python train.py
|
|
python train.py
|
|
|
```
|
|
```
|
|
|
-## 4. 训练过程中查看训练指标
|
|
|
|
|
-> 模型在训练过程中,所有的迭代信息将以标注输出流的形式,输出到命令执行的终端上,用户也可通过visualdl以可视化的方式查看训练指标的变化,通过如下方式启动visualdl后,在浏览器打开https://0.0.0.0:8001即可。
|
|
|
|
|
|
|
+
|
|
|
|
|
+## 5. 训练过程中查看训练指标
|
|
|
|
|
+模型在训练过程中,所有的迭代信息将以标注输出流的形式,输出到命令执行的终端上,用户也可通过visualdl以可视化的方式查看训练指标的变化,通过如下方式启动visualdl后,在浏览器打开https://0.0.0.0:8001 (或 https://localhost:8001)即可。
|
|
|
```
|
|
```
|
|
|
visualdl --logdir output/mobilenetv2/vdl_log --port 8000
|
|
visualdl --logdir output/mobilenetv2/vdl_log --port 8000
|
|
|
```
|
|
```
|
|
|

|
|

|
|
|
-## 5. 训练完成使用模型进行测试
|
|
|
|
|
-> 如使用训练过程中第8轮保存的模型进行测试
|
|
|
|
|
|
|
+
|
|
|
|
|
+## 6. 训练完成使用模型进行测试
|
|
|
|
|
+如下代码使用训练过程中第8轮保存的模型进行测试。
|
|
|
```
|
|
```
|
|
|
import paddlex as pdx
|
|
import paddlex as pdx
|
|
|
model = pdx.load_model('output/mobilenetv2/epoch_8')
|
|
model = pdx.load_model('output/mobilenetv2/epoch_8')
|