|
|
@@ -25,7 +25,7 @@ paddlex.cls.ResNet50(num_classes=1000)
|
|
|
### <h3 id="11">train</h3>
|
|
|
|
|
|
```python
|
|
|
-train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, optimizer=None, save_interval_epochs=1, log_interval_steps=10, save_dir='output', pretrain_weights='IMAGENET', learning_rate=.025, warmup_steps=0, warmup_start_lr=0.0, lr_decay_epochs=(30, 60, 90), lr_decay_gamma=0.1, early_stop=False, early_stop_patience=5, use_vdl=True)
|
|
|
+train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, optimizer=None, save_interval_epochs=1, log_interval_steps=10, save_dir='output', pretrain_weights='IMAGENET', learning_rate=.025, warmup_steps=0, warmup_start_lr=0.0, lr_decay_epochs=(30, 60, 90), lr_decay_gamma=0.1, label_smoothing=None, early_stop=False, early_stop_patience=5, use_vdl=True)
|
|
|
```
|
|
|
>
|
|
|
> **参数**
|
|
|
@@ -44,6 +44,7 @@ train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, o
|
|
|
- **warmup_start_lr**(float): 默认优化器的warmup起始学习率,默认为0.0。
|
|
|
- **lr_decay_epochs** (list): 默认优化器的学习率衰减轮数。默认为[30, 60, 90]。
|
|
|
- **lr_decay_gamma** (float): 默认优化器的学习率衰减率。默认为0.1。
|
|
|
+- **label_smoothing** (float, bool or None): 是否使用标签平滑。若为float,表示标签平滑系数。若为True,使用系数为0.1的标签平滑。若为None或False,则不采用标签平滑。默认为None。
|
|
|
- **early_stop** (bool): 是否使用提前终止训练策略。默认为False。
|
|
|
- **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认为5。
|
|
|
- **use_vdl** (bool): 是否使用VisualDL进行可视化。默认为True。
|
|
|
@@ -171,43 +172,45 @@ quant_aware_train(self, num_epochs, train_dataset, train_batch_size=64, eval_dat
|
|
|
|
|
|
PaddleX提供了共计38种分类模型,所有分类模型均提供同`ResNet50`相同的训练`train`,评估`evaluate`,预测`predict`,敏感度分析`analyze_sensitivity`,剪裁`prune`和在线量化`quant_aware_train`接口,各模型效果可参考[模型库](../../appendix/model_zoo.md)。
|
|
|
|
|
|
-| 模型 | 接口 |
|
|
|
-| :---------------- | :---------------------- |
|
|
|
-| ResNet18 | paddlex.cls.ResNet18(num_classes=1000) |
|
|
|
-| ResNet18_vd | paddlex.cls.ResNet18_vd(num_classes=1000) |
|
|
|
-| ResNet34 | paddlex.cls.ResNet34(num_classes=1000) |
|
|
|
-| ResNet34_vd | paddlex.cls.ResNet34_vd(num_classes=1000) |
|
|
|
-| ResNet50 | paddlex.cls.ResNet50(num_classes=1000) |
|
|
|
-| ResNet50_vd | paddlex.cls.ResNet50_vd(num_classes=1000) |
|
|
|
-| ResNet50_vd_ssld | paddlex.cls.ResNet50_vd_ssld(num_classes=1000) |
|
|
|
-| ResNet101 | paddlex.cls.ResNet101(num_classes=1000) |
|
|
|
-| ResNet101_vd | paddlex.cls.ResNet101_vd(num_classes=1000) |
|
|
|
-| ResNet101_vd_ssld | paddlex.cls.ResNet101_vd_ssld(num_classes=1000) |
|
|
|
-| ResNet152 | paddlex.cls.ResNet152(num_classes=1000) |
|
|
|
-| ResNet152_vd | paddlex.cls.ResNet152_vd(num_classes=1000) |
|
|
|
-| ResNet200_vd | paddlex.cls.ResNet200_vd(num_classes=1000) |
|
|
|
-| DarkNet53 | paddlex.cls.DarkNet53(num_classes=1000) |
|
|
|
-| MobileNetV1 | paddlex.cls.MobileNetV1(num_classes=1000, scale=1.0) |
|
|
|
-| MobileNetV2 | paddlex.cls.MobileNetV2(num_classes=1000, scale=1.0) |
|
|
|
-| MobileNetV3_small | paddlex.cls.MobileNetV3_small(num_classes=1000, scale=1.0) |
|
|
|
-| MobileNetV3_small_ssld | paddlex.cls.MobileNetV3_small_ssld(num_classes=1000, scale=1.0) |
|
|
|
-| MobileNetV3_large | paddlex.cls.MobileNetV3_large(num_classes=1000, scale=1.0) |
|
|
|
-| MobileNetV3_large_ssld | paddlex.cls.MobileNetV3_large_ssld(num_classes=1000) |
|
|
|
-| Xception41 | paddlex.cls.Xception41(num_classes=1000) |
|
|
|
-| Xception65 | paddlex.cls.Xception65(num_classes=1000) |
|
|
|
-| Xception71 | paddlex.cls.Xception71(num_classes=1000) |
|
|
|
-| ShuffleNetV2 | paddlex.cls.ShuffleNetV2(num_classes=1000, scale=1.0) |
|
|
|
-| ShuffleNetV2_swish | paddlex.cls.ShuffleNetV2_swish(num_classes=1000) |
|
|
|
-| DenseNet121 | paddlex.cls.DenseNet121(num_classes=1000) |
|
|
|
-| DenseNet161 | paddlex.cls.DenseNet161(num_classes=1000) |
|
|
|
-| DenseNet169 | paddlex.cls.DenseNet169(num_classes=1000) |
|
|
|
-| DenseNet201 | paddlex.cls.DenseNet201(num_classes=1000) |
|
|
|
-| DenseNet264 | paddlex.cls.DenseNet264(num_classes=1000) |
|
|
|
-| HRNet_W18_C | paddlex.cls.HRNet_W18_C(num_classes=1000) |
|
|
|
-| HRNet_W30_C | paddlex.cls.HRNet_W30_C(num_classes=1000) |
|
|
|
-| HRNet_W32_C | paddlex.cls.HRNet_W32_C(num_classes=1000) |
|
|
|
-| HRNet_W40_C | paddlex.cls.HRNet_W40_C(num_classes=1000) |
|
|
|
-| HRNet_W44_C | paddlex.cls.HRNet_W44_C(num_classes=1000) |
|
|
|
-| HRNet_W48_C | paddlex.cls.HRNet_W48_C(num_classes=1000) |
|
|
|
-| HRNet_W64_C | paddlex.cls.HRNet_W64_C(num_classes=1000) |
|
|
|
-| AlexNet | paddlex.cls.AlexNet(num_classes=1000) |
|
|
|
+| 模型 | 接口 |
|
|
|
+|:-----------------------|:----------------------------------------------------------------|
|
|
|
+| PPLCNet | paddlex.cls.PPLCNet(num_classes=1000) |
|
|
|
+| PPLCNet_ssld | paddlex.cls.PPLCNet_ssld(num_classes=1000) |
|
|
|
+| ResNet18 | paddlex.cls.ResNet18(num_classes=1000) |
|
|
|
+| ResNet18_vd | paddlex.cls.ResNet18_vd(num_classes=1000) |
|
|
|
+| ResNet34 | paddlex.cls.ResNet34(num_classes=1000) |
|
|
|
+| ResNet34_vd | paddlex.cls.ResNet34_vd(num_classes=1000) |
|
|
|
+| ResNet50 | paddlex.cls.ResNet50(num_classes=1000) |
|
|
|
+| ResNet50_vd | paddlex.cls.ResNet50_vd(num_classes=1000) |
|
|
|
+| ResNet50_vd_ssld | paddlex.cls.ResNet50_vd_ssld(num_classes=1000) |
|
|
|
+| ResNet101 | paddlex.cls.ResNet101(num_classes=1000) |
|
|
|
+| ResNet101_vd | paddlex.cls.ResNet101_vd(num_classes=1000) |
|
|
|
+| ResNet101_vd_ssld | paddlex.cls.ResNet101_vd_ssld(num_classes=1000) |
|
|
|
+| ResNet152 | paddlex.cls.ResNet152(num_classes=1000) |
|
|
|
+| ResNet152_vd | paddlex.cls.ResNet152_vd(num_classes=1000) |
|
|
|
+| ResNet200_vd | paddlex.cls.ResNet200_vd(num_classes=1000) |
|
|
|
+| DarkNet53 | paddlex.cls.DarkNet53(num_classes=1000) |
|
|
|
+| MobileNetV1 | paddlex.cls.MobileNetV1(num_classes=1000, scale=1.0) |
|
|
|
+| MobileNetV2 | paddlex.cls.MobileNetV2(num_classes=1000, scale=1.0) |
|
|
|
+| MobileNetV3_small | paddlex.cls.MobileNetV3_small(num_classes=1000, scale=1.0) |
|
|
|
+| MobileNetV3_small_ssld | paddlex.cls.MobileNetV3_small_ssld(num_classes=1000, scale=1.0) |
|
|
|
+| MobileNetV3_large | paddlex.cls.MobileNetV3_large(num_classes=1000, scale=1.0) |
|
|
|
+| MobileNetV3_large_ssld | paddlex.cls.MobileNetV3_large_ssld(num_classes=1000) |
|
|
|
+| Xception41 | paddlex.cls.Xception41(num_classes=1000) |
|
|
|
+| Xception65 | paddlex.cls.Xception65(num_classes=1000) |
|
|
|
+| Xception71 | paddlex.cls.Xception71(num_classes=1000) |
|
|
|
+| ShuffleNetV2 | paddlex.cls.ShuffleNetV2(num_classes=1000, scale=1.0) |
|
|
|
+| ShuffleNetV2_swish | paddlex.cls.ShuffleNetV2_swish(num_classes=1000) |
|
|
|
+| DenseNet121 | paddlex.cls.DenseNet121(num_classes=1000) |
|
|
|
+| DenseNet161 | paddlex.cls.DenseNet161(num_classes=1000) |
|
|
|
+| DenseNet169 | paddlex.cls.DenseNet169(num_classes=1000) |
|
|
|
+| DenseNet201 | paddlex.cls.DenseNet201(num_classes=1000) |
|
|
|
+| DenseNet264 | paddlex.cls.DenseNet264(num_classes=1000) |
|
|
|
+| HRNet_W18_C | paddlex.cls.HRNet_W18_C(num_classes=1000) |
|
|
|
+| HRNet_W30_C | paddlex.cls.HRNet_W30_C(num_classes=1000) |
|
|
|
+| HRNet_W32_C | paddlex.cls.HRNet_W32_C(num_classes=1000) |
|
|
|
+| HRNet_W40_C | paddlex.cls.HRNet_W40_C(num_classes=1000) |
|
|
|
+| HRNet_W44_C | paddlex.cls.HRNet_W44_C(num_classes=1000) |
|
|
|
+| HRNet_W48_C | paddlex.cls.HRNet_W48_C(num_classes=1000) |
|
|
|
+| HRNet_W64_C | paddlex.cls.HRNet_W64_C(num_classes=1000) |
|
|
|
+| AlexNet | paddlex.cls.AlexNet(num_classes=1000) |
|