소스 검색

polish code and docs

cuicheng01 1 년 전
부모
커밋
0080205d3c
60개의 변경된 파일1600개의 추가작업 그리고 1019개의 파일을 삭제
  1. 0 0
      dataset/.gitkeep
  2. 12 6
      docs/tutorials/INSTALL.md
  3. 0 174
      docs/tutorials/QUCK_STARTED.md
  4. 17 4
      docs/tutorials/base/README.md
  5. 0 0
      docs/tutorials/base/hyperparameters_introduction.md
  6. 1 0
      docs/tutorials/base/model_optimize.md
  7. 649 20
      docs/tutorials/data/dataset_check.md
  8. 1 0
      docs/tutorials/inference/model_inference_api.md
  9. 29 0
      docs/tutorials/inference/model_inference_tools.md
  10. 1 0
      docs/tutorials/inference/pipeline_inference_api.md
  11. 202 0
      docs/tutorials/inference/pipeline_inference_tools.md
  12. 0 174
      docs/tutorials/pipeline.md
  13. 0 63
      docs/tutorials/wheel.md
  14. 8 8
      paddlex/engine.py
  15. 9 0
      paddlex/modules/__init__.py
  16. 1 1
      paddlex/modules/base/build_model.py
  17. 1 150
      paddlex/modules/base/dataset_checker/__init__.py
  18. 166 0
      paddlex/modules/base/dataset_checker/dataset_checker.py
  19. 1 1
      paddlex/modules/base/evaluator.py
  20. 0 14
      paddlex/modules/base/predictor/io/readers.py
  21. 14 2
      paddlex/modules/base/predictor/kernel_option.py
  22. 54 36
      paddlex/modules/base/predictor/predictor.py
  23. 31 1
      paddlex/modules/base/predictor/transforms/image_common.py
  24. 3 3
      paddlex/modules/base/predictor/utils/paddle_inference_predictor.py
  25. 1 1
      paddlex/modules/base/trainer/train_deamon.py
  26. 1 1
      paddlex/modules/base/trainer/trainer.py
  27. 0 1
      paddlex/modules/image_classification/predictor/keys.py
  28. 13 25
      paddlex/modules/image_classification/predictor/predictor.py
  29. 102 36
      paddlex/modules/image_classification/predictor/transforms.py
  30. 18 2
      paddlex/modules/image_classification/predictor/utils.py
  31. 1 1
      paddlex/modules/instance_segmentation/dataset_checker/dataset_src/check_dataset.py
  32. 2 2
      paddlex/modules/instance_segmentation/predictor/keys.py
  33. 1 1
      paddlex/modules/instance_segmentation/predictor/predictor.py
  34. 3 4
      paddlex/modules/object_detection/predictor/keys.py
  35. 14 23
      paddlex/modules/object_detection/predictor/predictor.py
  36. 65 42
      paddlex/modules/object_detection/predictor/transforms.py
  37. 16 23
      paddlex/modules/semantic_segmentation/predictor/predictor.py
  38. 13 23
      paddlex/modules/table_recognition/predictor/predictor.py
  39. 11 24
      paddlex/modules/text_detection/predictor/predictor.py
  40. 5 0
      paddlex/modules/text_detection/predictor/transforms.py
  41. 10 18
      paddlex/modules/text_recognition/predictor/predictor.py
  42. 2 4
      paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/convert_dataset.py
  43. 5 8
      paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/split_dataset.py
  44. 6 24
      paddlex/modules/ts_classification/dataset_checker/dataset_src/analyse_dataset.py
  45. 2 4
      paddlex/modules/ts_classification/dataset_checker/dataset_src/convert_dataset.py
  46. 5 10
      paddlex/modules/ts_classification/dataset_checker/dataset_src/split_dataset.py
  47. 2 4
      paddlex/modules/ts_forecast/dataset_checker/dataset_src/convert_dataset.py
  48. 5 8
      paddlex/modules/ts_forecast/dataset_checker/dataset_src/split_dataset.py
  49. 3 3
      paddlex/modules/ts_forecast/predictor.py
  50. 3 3
      paddlex/paddlex_cli.py
  51. 22 13
      paddlex/pipelines/PPOCR/pipeline.py
  52. 6 0
      paddlex/pipelines/base/pipeline.py
  53. 15 12
      paddlex/pipelines/image_classification/pipeline.py
  54. 14 13
      paddlex/pipelines/instance_segmentation/pipeline.py
  55. 13 12
      paddlex/pipelines/object_detection/pipeline.py
  56. 13 12
      paddlex/pipelines/semantic_segmentation/pipeline.py
  57. 3 0
      paddlex/repo_apis/PaddleOCR_api/text_rec/config.py
  58. 1 1
      paddlex/utils/config.py
  59. 2 2
      paddlex/utils/device.py
  60. 2 2
      paddlex/utils/errors/others.py

+ 0 - 0
dataset/.gitkeep


+ 12 - 6
docs/tutorials/INSTALL.md

@@ -73,24 +73,30 @@ python -c "import paddle; print(paddle.__version__)"
 git clone https://github.com/PaddlePaddle/PaddleX.git
 ```
 
-<!-- #### 从 Gitee 下载
+#### 从 Gitee 下载
 
-如果访问 GitHub 网速较慢,可以从 Gitee 下载(Gitee 源码每日同步),命令如下:
+如果访问 GitHub 网速较慢,可以从 Gitee 下载,命令如下:
 
 ```shell
 git clone https://gitee.com/paddlepaddle/PaddleX.git
-``` -->
+```
 
 ### 2.2 安装配置及依赖
 
 参考下述命令,按提示操作,完成 PaddleX 依赖的安装。
 
-<!-- 这里需要指明安装成功的状态 -->
+<!-- 这里需要指明安装成功的状态, 廷权 -->
 ```bash
 cd PaddleX
 # 安装第三方依赖
 pip install -r requirements.txt
 
-# 获取并安装飞桨开发套件
-python install_pdx.py
+# 安装 PaddleX whl
+# -e:以可编辑模式安装,当前项目的代码更改,都会作用到 PaddleX Wheel
+pip install -e .
+
+# 安装 PaddleX 相关依赖
+paddlex --install
 ```
+
+**注 :** 在安装过程中,需要克隆 Paddle 官方模型套件,`--platform` 可以指定克隆源,可选 `github.com`,`gitee.com`,分别代表这些套件从 github 上和 gitee 上克隆,默认为 `github.com`。

+ 0 - 174
docs/tutorials/QUCK_STARTED.md

@@ -1,174 +0,0 @@
-# 快速开始
-
-参考本教程内容,快速体验 PaddleX,轻松完成深度学习模型开发全流程。本教程以『图像分类』任务为例,训练图像分类模型,解决花图像分类问题,数据集为常用的 `Flowers102`。
-
-`Flowers102` 数据集中包含上千张花的图像,共涵盖 102 个花的品种,其中训练集有 1020 张图像,验证集有 1020 张图像。模型选择 `PP-LCNet_x1_0`,`PP-LCNet_x1_0` 是一个超轻量级的图像分类模型,训练和推理速度较快,PaddleX 中内置了 `PP-LCNet_x1_0` 模型的配置文件(`paddlex/configs/image_classification/PP-LCNet_x1_0.yaml`)。
-
-接下来,就从环境配置开始,完成模型训练开发,最终得到能解决该问题的模型。
-
-## 1. 环境配置与安装
-
-参考 [文档](./INSTALL.md) 完成环境配置与安装。
-
-## 2. 准备数据
-
-下载数据集压缩包 [Flowers102数据集](https://paddle-model-ecology.bj.bcebos.com/paddlex/data/cls_flowers_examples.tar),并解压到 `PaddleX/dataset/` 目录下。对于 Linux、MacOS 用户,也可参考以下命令完成:
-
-```bash
-cd PaddleX/dataset/
-
-wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/cls_flowers_examples.tar
-
-tar xf cls_flowers_examples
-```
-
-准备完成后,数据集目录结构应为如下格式:
-
-```
-PaddleX
-└── dataset
-    └── cls_flowers_examples
-        ├── images
-        ├── label.txt
-        ├── train.txt
-        └── val.txt
-```
-
-## 3. 数据集校验
-
-PaddleX 提供了数据集校验功能,能够对所用数据集的内容进行检查分析,确认数据集格式是否符合 PaddleX 要求,并分析数据集的概要信息。请参考下述命令完成。
-
-```bash
-python main.py -c paddlex/configs/image_classification/PP-LCNet_x1_0.yaml -o Global.mode=check_dataset -o Global.dataset_dir=dataset/cls_flowers_examples
-```
-
-在完成数据集校验后,会生成校验结果文件`output/check_dataset_result.json`,具体内容为
-
-```json
-{
-  "done_flag": true,
-  "check_pass": true,
-  "attributes": {
-    "label_file": "dataset/label.txt",
-    "num_classes": 102,
-    "train_samples": 1020,
-    "train_sample_paths": [
-      "tmp/image_01904.jpg",
-      "tmp/image_06940.jpg"
-    ],
-    "val_samples": 1020,
-    "val_sample_paths": [
-      "tmp/image_01937.jpg",
-      "tmp/image_06958.jpg"
-    ]
-  },
-  "analysis": {
-    "histogram": "histogram.png"
-  },
-  "dataset_path": "dataset",
-  "show_type": "image",
-  "dataset_type": "ClsDataset"
-}
-```
-
-上述校验结果中,`check_pass` 为 `True` 表示数据集格式符合要求,其他部分指标的说明如下:
-* attributes.num_classes:该数据集类别数为 102;
-* attributes.train_samples:该数据集训练集样本数量为 1020;
-* attributes.val_samples:该数据集验证集样本数量为 1020;
-
-另外,数据集校验还对数据集中所有类别的样本数量分布情况进行了分析,并绘制了分布直方图(histogram.png):
-
-![样本分布直方图](https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/doc_images/open_source/quick_start/histogram.png)
-
-
-### 4. 模型训练
-
-在完成数据集校验并通过后,即可使用该数据集训练模型。使用 PaddleX 训练模型仅需一条命令,
-
-```bash
-python main.py -c paddlex/configs/image_classification/PP-LCNet_x1_0.yaml -o Global.mode=train -o Global.dataset_dir=dataset/cls_flowers_examples
-```
-
-在完成模型训练后,会生成训练结果文件`output/train_result.json`,具体内容为
-
-```json
-{
-  "model_name": "PP-LCNet_x1_0",
-  "done_flag": true,
-  "config": "config.yaml",
-  "label_dict": "label_dict.txt",
-  "train_log": "train.log",
-  "visualdl_log": "vdlrecords.1717143354.log",
-  "models": {
-    "last_1": {
-      "score": 0.6137255430221558,
-      "pdparams": "epoch_20.pdparams",
-      "pdema": "",
-      "pdopt": "epoch_20.pdopt",
-      "pdstates": "epoch_20.pdstates",
-      "inference_config": "epoch_20/inference.yml",
-      "pdmodel": "epoch_20/inference.pdmodel",
-      "pdiparams": "epoch_20/inference.pdiparams",
-      "pdiparams.info": "epoch_20/inference.pdiparams.info"
-    },
-    "best": {
-      "score": 0.6137255430221558,
-      "pdparams": "best_model.pdparams",
-      "pdema": "",
-      "pdopt": "best_model.pdopt",
-      "pdstates": "best_model.pdstates",
-      "inference_config": "best_model/inference.yml",
-      "pdmodel": "best_model/inference.pdmodel",
-      "pdiparams": "best_model/inference.pdiparams",
-      "pdiparams.info": "best_model/inference.pdiparams.info"
-    }
-  }
-}
-```
-
-训练结果文件中的部分内容:
-* train_log:训练日志文件的路径为 `output/train.log`;
-* models:训练产出的部分模型文件,其中:
-  * last_1:训练过程中,最后一轮 epoch 产出的模型;
-  * best:训练过程中产出的最佳模型,其在验证集上的精度最高,一般作为最终的模型用于后续处理;
-
-在完成模型训练后,可以对训练得到的模型进行评估:
-
-```bash
-python main.py -c paddlex/configs/image_classification/PP-LCNet_x1_0.yaml -o Global.mode=evaluate -o Global.dataset_dir=dataset/cls_flowers_examples
-```
-
-在完成模型评估后,会生成评估结果文件`output/evaluate_result.json`,具体内容为
-
-```json
-{
-  "done_flag": true,
-  "metrics": {
-    "val.top1": 0.62059,
-    "val.top5": 0.84118
-  }
-}
-```
-
-评估结果文件中,展示了所评估的模型在验证集上的精度:
-* val.top1:验证集上 Top1 的分类准确率;
-* val.top5:验证集上 Top5 的分类准确率;
-
-### 5. 模型推理
-
-在训练得到满意的模型后,可以使用训练好的模型进行推理预测:
-
-```bash
-python main.py -c paddlex/configs/image_classification/PP-LCNet_x1_0.yaml -o Global.mode=predict -o Predict.model_dir="output/best_model" -o Predict.input_path="/paddle/dataset/paddlex/cls/cls_flowers_examples/images/image_00002.jpg"
-```
-
-上述命令中,可以通过修改配置文件(`paddlex/configs/image_classification/PP-LCNet_x1_0.yaml`)或`-o`追加参数的方式设置模型推理相关参数:
-
-* `Predict.model_dir`:使用的推理模型文件所在目录,在完成模型训练后,最佳模型的推理文件默认保存在`output/best_model`中,推理模型文件为`inference.pdparams`、`inference.pdmodel`等;
-* `Predict.input_path`:待预测图像路径;
-
-在执行上述命令进行推理后,可以在控制台输出预测结果,如下所示:
-
-```bash
-[{'class_ids': [76], 'scores': [0.66833], 'label_names': ['西番莲']}]
-```

+ 17 - 4
docs/tutorials/train/README.md → docs/tutorials/base/README.md

@@ -1,4 +1,4 @@
-# PaddleX 模型训练和评估
+# PaddleX 模型训练、评估和推理
 
 在训练之前,请确保您的数据集已经经过了[数据校验](../data/README.md)。经过数据校验的数据集才可以进行训练。PaddleX 提供了很多不同的任务模块,不同的模块下又内置了很多被广泛验证的高精度、高效率、精度效率均衡的模型。训练模型时,您只需要一行命令,即可发起相应任务的训练。本文档提供了图像分类任务模块的 `PP-LCNet_x1_0` 模型的训练和评估示例,其他任务模块的训练与图像分类类似。当您按照[PaddleX 数据集标注](../data/annotation/README.md)和 [PaddleX 数据集校验](../data/dataset_check.md)准备好训练数据后,即可参考本文档完成所有 PaddleX 支持的模型训练。
 
@@ -44,10 +44,23 @@ python main.py -c paddlex/configs/image_classification/PP-LCNet_x1_0.yaml \
 
 **注:** 在模型评估时,需要指定模型权重文件路径,每个配置文件中都内置了默认的权重保存路径,如需要改变,只需要通过追加命令行参数的形式进行设置即可,如`-o Evaluate.weight_path=./output/best_model.pdparams`。
 
-## 3. 须知事项
-### 3.1 训练须知事项
+# 3. 模型推理
+
+在完成后,即可使用训练好的模型权重进行推理预测。使用 PaddleX 模型,通过命令行的方式进行推理预测,只需如下一条命令:
+
+```bash
+python main.py -c paddlex/configs/image_classification/PP-LCNet_x1_0.yaml \
+    -o Global.mode=predict \
+    -o Predict.model_dir="/output/best_model" \
+    -o Predict.input_path="https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_image_classification_001.jpg"
+```
+
+**注:** PaddleX 允许使用 wheel 包进行推理,在此处,当您验证好自己的模型之后,即使用 PaddleX 的 wheel 包进行推理,方便地将模型集成到您自己的项目中。模型推理方法请参考 [PaddleX 单模型开发工具](../tools/model_tools.md)。
+
+## 4. 须知事项
+### 4.1 训练须知事项
 - 训练其他模型时,需要的指定相应的配置文件,模型和配置的文件的对应关系,可以详情[模型库](../models/support_model_list.md)。
 - PaddleX 对您屏蔽了动态图权重和静态图权重的概念,在模型训练的过程中,会同时产出动态图和静态图的权重,在模型推理时,默认选择静态图权重推理。
 <!-- 这里需要补充说明,廷权 -->
-### 3.2 训练产出解释
+### 4.2 训练产出解释
 <!-- 这里需要补充说明,廷权 -->

+ 0 - 0
docs/tutorials/train/hyperparameters_introduction.md → docs/tutorials/base/hyperparameters_introduction.md


+ 1 - 0
docs/tutorials/base/model_optimize.md

@@ -0,0 +1 @@
+# 模型优化方法

+ 649 - 20
docs/tutorials/data/dataset_check.md

@@ -73,14 +73,14 @@ python main.py -c paddlex/configs/image_classification/PP-LCNet_x1_0.yaml \
 
 * `CheckDataset`:
     * `convert`:
-        * `enable`: 是否进行数据集格式转换,为 `True` 时进行数据集格式转换,默认为 `False`;
-        * `src_dataset_type`: 如果进行数据集格式转换,则需设置源数据集格式;
+        * `enable`: 是否进行数据集格式转换,图像分类不支持格式转换,默认为 `False`;
+        * `src_dataset_type`: 如果进行数据集格式转换,则需设置源数据集格式,图像分类不支持数据转换,默认为 `null`
     * `split`:
         * `enable`: 是否进行重新划分数据集,为 `True` 时进行数据集格式转换,默认为 `False`;
-        * `train_percent`: 如果重新划分数据集,则需要设置训练集的百分比,类型为0-1之间的任意小数,需要保证和 `val_percent` 值加和为1
-        * `val_percent`: 如果重新划分数据集,则需要设置验证集的百分比,类型为0-1之间的任意小数,需要保证和 `train_percent` 值加和为1;
+        * `train_percent`: 如果重新划分数据集,则需要设置训练集的百分比,类型为0-100之间的任意整数,需要保证和 `val_percent` 值加和为100
+        * `val_percent`: 如果重新划分数据集,则需要设置验证集的百分比,类型为0-100之间的任意整数,需要保证和 `train_percent` 值加和为100
 
-数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split=True -o CheckDataset.train_percent=0.8 -o CheckDataset.val_percent=0.2`。
+数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
 
 
 ## 2.目标检测任务模块数据校验
@@ -154,13 +154,13 @@ python main.py -c paddlex/configs/object_detection/PicoDet-S.yaml \
 * `CheckDataset`:
     * `convert`:
         * `enable`: 是否进行数据集格式转换,为 `True` 时进行数据集格式转换,默认为 `False`;
-        * `src_dataset_type`: 如果进行数据集格式转换,则需设置源数据集格式;
+        * `src_dataset_type`: 如果进行数据集格式转换,则需设置源数据集格式,数据可选源格式为 `LabelMe`、`LabelMeWithUnlabeled`、`VOC` 和 `VOCWithUnlabeled`
     * `split`:
         * `enable`: 是否进行重新划分数据集,为 `True` 时进行数据集格式转换,默认为 `False`;
-        * `train_percent`: 如果重新划分数据集,则需要设置训练集的百分比,类型为0-1之间的任意小数,需要保证和 `val_percent` 值加和为1
-        * `val_percent`: 如果重新划分数据集,则需要设置验证集的百分比,类型为0-1之间的任意小数,需要保证和 `train_percent` 值加和为1;
+        * `train_percent`: 如果重新划分数据集,则需要设置训练集的百分比,类型为0-100之间的任意整数,需要保证和 `val_percent` 值加和为100
+        * `val_percent`: 如果重新划分数据集,则需要设置验证集的百分比,类型为0-100之间的任意整数,需要保证和 `train_percent` 值加和为100
 
-数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split=True -o CheckDataset.train_percent=0.8 -o CheckDataset.val_percent=0.2`。
+数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
 
 ## 3.语义分割任务模块数据校验
 
@@ -233,24 +233,653 @@ python main.py -c paddlex/configs/semantic_segmentation/PP-LiteSeg-T.yaml \
 * `CheckDataset`:
     * `convert`:
         * `enable`: 是否进行数据集格式转换,为 `True` 时进行数据集格式转换,默认为 `False`;
-        * `src_dataset_type`: 如果进行数据集格式转换,则需设置源数据集格式;
+        * `src_dataset_type`: 如果进行数据集格式转换,则需设置源数据集格式,数据可选源格式为 `LabelMe`;
+    * `split`:
+        * `enable`: 是否进行重新划分数据集,为 `True` 时进行数据集格式转换,默认为 `False`;
+        * `train_percent`: 如果重新划分数据集,则需要设置训练集的百分比,类型为0-100之间的任意整数,需要保证和 `val_percent` 值加和为100;
+        * `val_percent`: 如果重新划分数据集,则需要设置验证集的百分比,类型为0-100之间的任意整数,需要保证和 `train_percent` 值加和为100;
+
+数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
+
+## 4. 实例分割任务模块数据校验
+
+### 4.1 数据准备
+
+您需要按照 PaddleX 支持的数据格式要求准备数据,关于数据标注,您可以参考[PaddleX 数据标注](./annotation/README.md),关于数据格式介绍,您可以参考[PaddleX 数据格式介绍](./dataset_format.md),此处我们准备了实例分割 Demo 数据供您使用。
+
+```bash
+cd /path/to/paddlex
+wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/instance_seg_coco_examples.tar -P ./dataset
+tar -xf ./dataset/instance_seg_coco_examples.tar -C ./dataset/
+```
+
+### 4.2 数据集校验
+
+在对数据集校验时,只需一行命令:
+
+```bash
+python main.py -c paddlex/configs/instance_segmentation/Mask-RT-DETR-L.yaml \
+    -o Global.mode=check_dataset \
+    -o Global.dataset_dir=./dataset/instance_seg_coco_examples
+```
+
+执行上述命令后,PaddleX 会对数据集进行校验,并统计数据集的基本信息。命令运行成功后会在log中打印出 `Check dataset passed !` 信息,同时相关产出会保存在当前目录的 `./output/check_dataset` 目录下,产出目录中包括可视化的示例样本图片和样本分布直方图。校验结果文件保存在 `./output/check_dataset_result.json`,校验结果文件具体内容为
+```
+{
+  "done_flag": true,
+  "check_pass": true,
+  "attributes": {
+    "num_classes": 2,
+    "train_samples": 79,
+    "train_sample_paths": [
+      "check_dataset/demo_img/pexels-photo-634007.jpeg",
+      "check_dataset/demo_img/pexels-photo-59576.png"
+    ],
+    "val_samples": 19,
+    "val_sample_paths": [
+      "check_dataset/demo_img/peasant-farmer-farmer-romania-botiza-47862.jpeg",
+      "check_dataset/demo_img/pexels-photo-715546.png"
+    ]
+  },
+  "analysis": {
+    "histogram": "check_dataset/histogram.png"
+  },
+  "dataset_path": "./dataset/instance_seg_coco_examples",
+  "show_type": "image",
+  "dataset_type": "COCOInstSegDataset"
+}
+```
+上述校验结果中,check_pass 为 True 表示数据集格式符合要求,其他部分指标的说明如下:
+
+- attributes.num_classes:该数据集类别数为 2;
+- attributes.train_samples:该数据集训练集样本数量为 79;
+- attributes.val_samples:该数据集验证集样本数量为 19;
+- attributes.train_sample_paths:该数据集训练集样本可视化图片相对路径列表;
+- attributes.val_sample_paths:该数据集验证集样本可视化图片相对路径列表;
+
+另外,数据集校验还对数据集中所有类别的样本数量分布情况进行了分析,并绘制了分布直方图(histogram.png):
+![样本分布直方图](https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/doc_images/open_source/tutorials/data/dataset_check/instance_segmentation/histogram.png)
+
+**注**:只有通过数据校验的数据才可以训练和评估。
+
+
+### 4.3 数据集格式转换/数据集划分(非必选)
+
+如需对数据集格式进行转换或是重新划分数据集,可通过修改配置文件或是追加超参数的方式进行设置。
+
+数据集校验相关的参数可以通过修改配置文件中 `CheckDataset` 下的字段进行设置,配置文件中部分参数的示例说明如下:
+
+* `CheckDataset`:
+    * `convert`:
+        * `enable`: 是否进行数据集格式转换,为 `True` 时进行数据集格式转换,默认为 `False`;
+        * `src_dataset_type`: 如果进行数据集格式转换,则需设置源数据集格式,数据可选源格式为 `LabelMe`;
+    * `split`:
+        * `enable`: 是否进行重新划分数据集,为 `True` 时进行数据集格式转换,默认为 `False`;
+        * `train_percent`: 如果重新划分数据集,则需要设置训练集的百分比,类型为0-100之间的任意整数,需要保证和 `val_percent` 值加和为100;
+        * `val_percent`: 如果重新划分数据集,则需要设置验证集的百分比,类型为0-100之间的任意整数,需要保证和 `train_percent` 值加和为100;
+
+数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
+
+## 5. 文本检测任务模块数据校验
+
+### 5.1 数据准备
+
+您需要按照 PaddleX 支持的数据格式要求准备数据,关于数据标注,您可以参考[PaddleX 数据标注](./annotation/README.md),关于数据格式介绍,您可以参考[PaddleX 数据格式介绍](./dataset_format.md),此处我们准备了文本检测 Demo 数据供您使用。
+
+```bash
+cd /path/to/paddlex
+wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/ocr_det_dataset_examples.tar -P ./dataset
+tar -xf ./dataset/ocr_det_dataset_examples.tar -C ./dataset/
+```
+
+### 5.2 数据集校验
+
+在对数据集校验时,只需一行命令:
+
+```bash
+python main.py -c paddlex/configs/text_detection/PP-OCRv4_mobile_det.yaml \
+    -o Global.mode=check_dataset \
+    -o Global.dataset_dir=./dataset/ocr_det_dataset_examples
+```
+
+执行上述命令后,PaddleX 会对数据集进行校验,并统计数据集的基本信息。命令运行成功后会在log中打印出 `Check dataset passed !` 信息,同时相关产出会保存在当前目录的 `./output/check_dataset` 目录下,产出目录中包括可视化的示例样本图片和样本分布直方图。校验结果文件保存在 `./output/check_dataset_result.json`,校验结果文件具体内容为
+```
+{
+  "done_flag": true,
+  "check_pass": true,
+  "attributes": {
+    "train_samples": 200,
+    "train_sample_paths": [
+      "../dataset/ocr_det_dataset_examples/images/train_img_61.jpg",
+      "../dataset/ocr_det_dataset_examples/images/train_img_289.jpg"
+    ],
+    "val_samples": 50,
+    "val_sample_paths": [
+      "../dataset/ocr_det_dataset_examples/images/val_img_61.jpg",
+      "../dataset/ocr_det_dataset_examples/images/val_img_137.jpg"
+    ]
+  },
+  "analysis": {
+    "histogram": "check_dataset/histogram.png"
+  },
+  "dataset_path": "./dataset/ocr_det_dataset_examples",
+  "show_type": "image",
+  "dataset_type": "TextDetDataset"
+}
+```
+上述校验结果中,check_pass 为 True 表示数据集格式符合要求,其他部分指标的说明如下:
+
+- attributes.train_samples:该数据集训练集样本数量为 200;
+- attributes.val_samples:该数据集验证集样本数量为 50;
+- attributes.train_sample_paths:该数据集训练集样本可视化图片相对路径列表;
+- attributes.val_sample_paths:该数据集验证集样本可视化图片相对路径列表;
+
+另外,数据集校验还对数据集中所有类别的样本数量分布情况进行了分析,并绘制了分布直方图(histogram.png):
+![样本分布直方图](https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/doc_images/open_source/tutorials/data/dataset_check/text_detection/histogram.png)
+
+**注**:只有通过数据校验的数据才可以训练和评估。
+
+
+### 4.3 数据集格式转换/数据集划分(非必选)
+
+如需对数据集格式进行转换或是重新划分数据集,可通过修改配置文件或是追加超参数的方式进行设置。
+
+数据集校验相关的参数可以通过修改配置文件中 `CheckDataset` 下的字段进行设置,配置文件中部分参数的示例说明如下:
+
+* `CheckDataset`:
+    * `convert`:
+        * `enable`: 是否进行数据集格式转换,文本检测不支持格式转换,默认为 `False`;
+        * `src_dataset_type`: 如果进行数据集格式转换,则需设置源数据集格式,文本检测不支持格式转换,默认为 `null`;
+    * `split`:
+        * `enable`: 是否进行重新划分数据集,为 `True` 时进行数据集格式转换,默认为 `False`;
+        * `train_percent`: 如果重新划分数据集,则需要设置训练集的百分比,类型为0-100之间的任意整数,需要保证和 `val_percent` 值加和为100;
+        * `val_percent`: 如果重新划分数据集,则需要设置验证集的百分比,类型为0-100之间的任意整数,需要保证和 `train_percent` 值加和为100;
+
+数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
+
+## 6. 文本识别任务模块数据校验
+
+### 6.1 数据准备
+
+您需要按照 PaddleX 支持的数据格式要求准备数据,关于数据标注,您可以参考[PaddleX 数据标注](./annotation/README.md),关于数据格式介绍,您可以参考[PaddleX 数据格式介绍](./dataset_format.md),此处我们准备了文本识别 Demo 数据供您使用。
+
+```bash
+cd /path/to/paddlex
+wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/ocr_rec_dataset_examples.tar -P ./dataset
+tar -xf ./dataset/ocr_rec_dataset_examples.tar -C ./dataset/
+```
+
+### 6.2 数据集校验
+
+在对数据集校验时,只需一行命令:
+
+```bash
+python main.py -c paddlex/configs/text_recognition/PP-OCRv4_mobile_rec.yaml \
+    -o Global.mode=check_dataset \
+    -o Global.dataset_dir=./dataset/ocr_rec_dataset_examples
+```
+
+执行上述命令后,PaddleX 会对数据集进行校验,并统计数据集的基本信息。命令运行成功后会在log中打印出 `Check dataset passed !` 信息,同时相关产出会保存在当前目录的 `./output/check_dataset` 目录下,产出目录中包括可视化的示例样本图片和样本分布直方图。校验结果文件保存在 `./output/check_dataset_result.json`,校验结果文件具体内容为
+```
+{
+  "done_flag": true,
+  "check_pass": true,
+  "attributes": {
+    "train_samples": 4468,
+    "train_sample_paths": [
+      "../dataset/ocr_rec_dataset_examples/images/train_word_1.png",
+      "../dataset/ocr_rec_dataset_examples/images/train_word_10.png"
+    ],
+    "val_samples": 2077,
+    "val_sample_paths": [
+      "../dataset/ocr_rec_dataset_examples/images/val_word_1.png",
+      "../dataset/ocr_rec_dataset_examples/images/val_word_10.png"
+    ]
+  },
+  "analysis": {
+    "histogram": "check_dataset/histogram.png"
+  },
+  "dataset_path": "./dataset/ocr_rec_dataset_examples",
+  "show_type": "image",
+  "dataset_type": "MSTextRecDataset"
+}
+```
+上述校验结果中,check_pass 为 True 表示数据集格式符合要求,其他部分指标的说明如下:
+
+- attributes.train_samples:该数据集训练集样本数量为 4468;
+- attributes.val_samples:该数据集验证集样本数量为 2077;
+- attributes.train_sample_paths:该数据集训练集样本可视化图片相对路径列表;
+- attributes.val_sample_paths:该数据集验证集样本可视化图片相对路径列表;
+
+另外,数据集校验还对数据集中所有类别的样本数量分布情况进行了分析,并绘制了分布直方图(histogram.png):
+![样本分布直方图](https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/doc_images/open_source/tutorials/data/dataset_check/text_recognition/histogram.png)
+
+**注**:只有通过数据校验的数据才可以训练和评估。
+
+
+### 6.3 数据集格式转换/数据集划分(非必选)
+
+如需对数据集格式进行转换或是重新划分数据集,可通过修改配置文件或是追加超参数的方式进行设置。
+
+数据集校验相关的参数可以通过修改配置文件中 `CheckDataset` 下的字段进行设置,配置文件中部分参数的示例说明如下:
+
+* `CheckDataset`:
+    * `convert`:
+        * `enable`: 是否进行数据集格式转换,文本识别不支持格式转换,默认为 `False`;
+        * `src_dataset_type`: 如果进行数据集格式转换,则需设置源数据集格式,文本识别不支持格式转换,默认为 `null`;
+    * `split`:
+        * `enable`: 是否进行重新划分数据集,为 `True` 时进行数据集格式转换,默认为 `False`;
+        * `train_percent`: 如果重新划分数据集,则需要设置训练集的百分比,类型为0-100之间的任意整数,需要保证和 `val_percent` 值加和为100;
+        * `val_percent`: 如果重新划分数据集,则需要设置验证集的百分比,类型为0-100之间的任意整数,需要保证和 `train_percent` 值加和为100;
+
+数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
+
+## 7. 表格识别任务模块数据校验
+
+### 7.1 数据准备
+
+您需要按照 PaddleX 支持的数据格式要求准备数据,关于数据标注,您可以参考[PaddleX 数据标注](./annotation/README.md),关于数据格式介绍,您可以参考[PaddleX 数据格式介绍](./dataset_format.md),此处我们准备了表格识别 Demo 数据供您使用。
+
+```bash
+cd /path/to/paddlex
+wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/table_rec_dataset_examples.tar -P ./dataset
+tar -xf ./dataset/table_rec_dataset_examples.tar -C ./dataset/
+```
+
+### 7.2 数据集校验
+
+在对数据集校验时,只需一行命令:
+
+```bash
+python main.py -c paddlex/configs/table_recognition/SLANet.yaml \
+    -o Global.mode=check_dataset \
+    -o Global.dataset_dir=./dataset/table_rec_dataset_examples
+```
+
+执行上述命令后,PaddleX 会对数据集进行校验,并统计数据集的基本信息。命令运行成功后会在log中打印出 `Check dataset passed !` 信息,同时相关产出会保存在当前目录的 `./output/check_dataset` 目录下,产出目录中包括可视化的示例样本图片和样本分布直方图。校验结果文件保存在 `./output/check_dataset_result.json`,校验结果文件具体内容为
+```
+{
+  "done_flag": true,
+  "check_pass": true,
+  "attributes": {
+    "train_samples": 2000,
+    "train_sample_paths": [
+      "../dataset/table_rec_dataset_examples/images/border_right_7384_X9UFEPKVMLALY7DDB11A.jpg",
+      "../dataset/table_rec_dataset_examples/images/no_border_5223_HLG406UK35UD5EUYC2AV.jpg"
+    ],
+    "val_samples": 100,
+    "val_sample_paths": [
+      "../dataset/table_rec_dataset_examples/images/border_2945_L7MSRHBZRW6Y347G39O6.jpg",
+      "../dataset/table_rec_dataset_examples/images/no_border_288_6LK683JUCMOQ38V5BV29.jpg"
+    ]
+  },
+  "analysis": {},
+  "dataset_path": "./dataset/table_rec_dataset_examples",
+  "show_type": "image",
+  "dataset_type": "PubTabTableRecDataset"
+}
+```
+上述校验结果中,check_pass 为 True 表示数据集格式符合要求,其他部分指标的说明如下:
+
+- attributes.train_samples:该数据集训练集样本数量为 2000;
+- attributes.val_samples:该数据集验证集样本数量为 100;
+- attributes.train_sample_paths:该数据集训练集样本可视化图片相对路径列表;
+- attributes.val_sample_paths:该数据集验证集样本可视化图片相对路径列表;
+
+**注**:只有通过数据校验的数据才可以训练和评估。
+
+
+### 7.3 数据集格式转换/数据集划分(非必选)
+
+如需对数据集格式进行转换或是重新划分数据集,可通过修改配置文件或是追加超参数的方式进行设置。
+
+数据集校验相关的参数可以通过修改配置文件中 `CheckDataset` 下的字段进行设置,配置文件中部分参数的示例说明如下:
+
+* `CheckDataset`:
+    * `convert`:
+        * `enable`: 是否进行数据集格式转换,表格识别不支持格式转换,默认为 `False`;
+        * `src_dataset_type`: 如果进行数据集格式转换,则需设置源数据集格式,表格识别不支持格式转换,默认为 `null`;
+    * `split`:
+        * `enable`: 是否进行重新划分数据集,为 `True` 时进行数据集格式转换,默认为 `False`;
+        * `train_percent`: 如果重新划分数据集,则需要设置训练集的百分比,类型为0-100之间的任意整数,需要保证和 `val_percent` 值加和为100;
+        * `val_percent`: 如果重新划分数据集,则需要设置验证集的百分比,类型为0-100之间的任意整数,需要保证和 `train_percent` 值加和为100;
+
+数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
+
+## 8. 时序预测任务模块数据校验
+
+### 8.1 数据准备
+
+您需要按照 PaddleX 支持的数据格式要求准备数据,关于数据标注,您可以参考[PaddleX 数据标注](./annotation/README.md),关于数据格式介绍,您可以参考[PaddleX 数据格式介绍](./dataset_format.md),此处我们准备了时序预测 Demo 数据供您使用。
+
+```bash
+cd /path/to/paddlex
+wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/ts_dataset_examples.tar -P ./dataset
+tar -xf ./dataset/ts_dataset_examples.tar -C ./dataset/
+```
+
+### 8.2 数据集校验
+
+在对数据集校验时,只需一行命令:
+
+```bash
+python main.py -c paddlex/configs/ts_forecast/DLinear.yaml \
+    -o Global.mode=check_dataset \
+    -o Global.dataset_dir=./dataset/ts_dataset_examples
+```
+
+执行上述命令后,PaddleX 会对数据集进行校验,并统计数据集的基本信息。命令运行成功后会在log中打印出 `Check dataset passed !` 信息,同时相关产出会保存在当前目录的 `./output/check_dataset` 目录下,产出目录中包括可视化的示例样本图片和样本分布直方图。校验结果文件保存在 `./output/check_dataset_result.json`,校验结果文件具体内容为
+```
+{
+  "done_flag": true,
+  "check_pass": true,
+  "attributes": {
+    "train_samples": 12194,
+    "train_table": [
+      [
+        "date",
+        "HUFL",
+        "HULL",
+        "MUFL",
+        "MULL",
+        "LUFL",
+        "LULL",
+        "OT"
+      ],
+      [
+        "2016-07-01 00:00:00",
+        5.827000141143799,
+        2.009000062942505,
+        1.5989999771118164,
+        0.4620000123977661,
+        4.203000068664552,
+        1.3400000333786009,
+        30.5310001373291
+      ],
+      [
+        "2016-07-01 01:00:00",
+        5.692999839782715,
+        2.075999975204468,
+        1.4919999837875366,
+        0.4259999990463257,
+        4.142000198364259,
+        1.371000051498413,
+        27.78700065612793
+      ]
+    ],
+    "val_samples": 3484,
+    "val_table": [
+      [
+        "date",
+        "HUFL",
+        "HULL",
+        "MUFL",
+        "MULL",
+        "LUFL",
+        "LULL",
+        "OT"
+      ],
+      [
+        "2017-11-21 02:00:00",
+        12.994000434875488,
+        4.889999866485597,
+        10.055999755859377,
+        2.878000020980835,
+        2.559000015258789,
+        1.2489999532699585,
+        4.7129998207092285
+      ],
+      [
+        "2017-11-21 03:00:00",
+        11.92199993133545,
+        4.554999828338623,
+        9.097000122070312,
+        3.0920000076293945,
+        2.559000015258789,
+        1.2790000438690186,
+        4.8540000915527335
+      ]
+    ]
+  },
+  "analysis": {
+    "histogram": ""
+  },
+  "dataset_path": ".\/dataset\/ts_dataset_examples",
+  "show_type": "csv",
+  "dataset_type": "TSDataset"
+}
+```
+上述校验结果中,check_pass 为 True 表示数据集格式符合要求,其他部分指标的说明如下:
+
+- attributes.train_samples:该数据集训练集样本数量为 12194;
+- attributes.val_samples:该数据集验证集样本数量为 3484;
+- attributes.train_table:该数据集训练集样本示例数据表格信息;
+- attributes.val_table:该数据集验证集样本示例数据表格信息;
+
+**注**:只有通过数据校验的数据才可以训练和评估。
+
+
+### 8.3 数据集格式转换/数据集划分(非必选)
+
+如需对数据集格式进行转换或是重新划分数据集,可通过修改配置文件或是追加超参数的方式进行设置。
+
+数据集校验相关的参数可以通过修改配置文件中 `CheckDataset` 下的字段进行设置,配置文件中部分参数的示例说明如下:
+
+* `CheckDataset`:
+    * `convert`:
+        * `enable`: 是否进行数据集格式转换,为 `True` 时进行数据集格式转换,默认为 `False`;
+        * `src_dataset_type`: 如果进行数据集格式转换,时序预测仅支持将xlsx标注文件转换为xls,无需设置源数据集格式,默认为 `null`;
     * `split`:
         * `enable`: 是否进行重新划分数据集,为 `True` 时进行数据集格式转换,默认为 `False`;
-        * `train_percent`: 如果重新划分数据集,则需要设置训练集的百分比,类型为0-1之间的任意小数,需要保证和 `val_percent` 值加和为1;
-        * `val_percent`: 如果重新划分数据集,则需要设置验证集的百分比,类型为0-1之间的任意小数,需要保证和 `train_percent` 值加和为1;
+        * `train_percent`: 如果重新划分数据集,则需要设置训练集的百分比,类型为0-100之间的任意整数,需要保证和 `val_percent` 值加和为100
+        * `val_percent`: 如果重新划分数据集,则需要设置验证集的百分比,类型为0-100之间的任意整数,需要保证和 `train_percent` 值加和为100
 
-数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split=True -o CheckDataset.train_percent=0.8 -o CheckDataset.val_percent=0.2`。
+数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
 
-## 实例分割任务模块数据校验
+## 9. 时序异常检测任务模块数据校验
 
-## 文本检测任务模块数据校验
+### 9.1 数据准备
 
-## 文本识别任务模块数据校验
+您需要按照 PaddleX 支持的数据格式要求准备数据,关于数据标注,您可以参考[PaddleX 数据标注](./annotation/README.md),关于数据格式介绍,您可以参考[PaddleX 数据格式介绍](./dataset_format.md),此处我们准备了时序异常检测 Demo 数据供您使用。
 
-## 表格识别任务模块数据校验
+```bash
+cd /path/to/paddlex
+wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/ts_anomaly_examples.tar -P ./dataset
+tar -xf ./dataset/ts_anomaly_examples.tar -C ./dataset/
+```
 
-## 时序预测任务模块数据校验
+### 9.2 数据集校验
 
-## 时序异常检测任务模块数据校验
+在对数据集校验时,只需一行命令:
+
+```bash
+python main.py -c paddlex/configs/ts_anomaly_detection/DLinear_ad.yaml \
+    -o Global.mode=check_dataset \
+    -o Global.dataset_dir=./dataset/ts_anomaly_examples
+```
+
+执行上述命令后,PaddleX 会对数据集进行校验,并统计数据集的基本信息。命令运行成功后会在log中打印出 `Check dataset passed !` 信息,同时相关产出会保存在当前目录的 `./output/check_dataset` 目录下,产出目录中包括可视化的示例样本图片和样本分布直方图。校验结果文件保存在 `./output/check_dataset_result.json`,校验结果文件具体内容为
+```
+{
+  "done_flag": true,
+  "check_pass": true,
+  "attributes": {
+    "train_samples": 22032,
+    "train_table": [
+      [
+        "timestamp",
+        "feature_0",
+        "...",
+        "feature_24",
+        "label"
+      ],
+      [
+        0.0,
+        0.7326893750079723,
+        "...",
+        0.1382488479262673,
+        0.0
+      ]
+    ],
+    "val_samples": 198290,
+    "val_table": [
+      [
+        "timestamp",
+        "feature_0",
+        "...",
+        "feature_24",
+        "label"
+      ],
+      [
+        22032.0,
+        0.8604795809835284,
+        "...",
+        0.1428571428571428,
+        0.0
+      ]
+    ]
+  },
+  "analysis": {
+    "histogram": ""
+  },
+  "dataset_path": "./dataset/ts_anomaly_examples",
+  "show_type": "csv",
+  "dataset_type": "TSADDataset"
+}
+```
+上述校验结果中,check_pass 为 True 表示数据集格式符合要求,其他部分指标的说明如下:
+
+- attributes.train_samples:该数据集训练集样本数量为 22032;
+- attributes.val_samples:该数据集验证集样本数量为 198290;
+- attributes.train_table:该数据集训练集样本示例数据表格信息;
+- attributes.val_table:该数据集验证集样本示例数据表格信息;
+
+**注**:只有通过数据校验的数据才可以训练和评估。
+
+
+### 9.3 数据集格式转换/数据集划分(非必选)
+
+如需对数据集格式进行转换或是重新划分数据集,可通过修改配置文件或是追加超参数的方式进行设置。
+
+数据集校验相关的参数可以通过修改配置文件中 `CheckDataset` 下的字段进行设置,配置文件中部分参数的示例说明如下:
+
+* `CheckDataset`:
+    * `convert`:
+        * `enable`: 是否进行数据集格式转换,为 `True` 时进行数据集格式转换,默认为 `False`;
+        * `src_dataset_type`: 如果进行数据集格式转换,时序异常检测仅支持将xlsx标注文件转换为xls,无需设置源数据集格式,默认为 `null`;
+    * `split`:
+        * `enable`: 是否进行重新划分数据集,为 `True` 时进行数据集格式转换,默认为 `False`;
+        * `train_percent`: 如果重新划分数据集,则需要设置训练集的百分比,类型为0-100之间的任意整数,需要保证和 `val_percent` 值加和为100;
+        * `val_percent`: 如果重新划分数据集,则需要设置验证集的百分比,类型为0-100之间的任意整数,需要保证和 `train_percent` 值加和为100;
+
+数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
+
+## 10. 时序分类任务模块数据校验
+
+### 10.1 数据准备
+
+您需要按照 PaddleX 支持的数据格式要求准备数据,关于数据标注,您可以参考[PaddleX 数据标注](./annotation/README.md),关于数据格式介绍,您可以参考[PaddleX 数据格式介绍](./dataset_format.md),此处我们准备了时序分类 Demo 数据供您使用。
+
+```bash
+cd /path/to/paddlex
+wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/ts_classify_examples.tar -P ./dataset
+tar -xf ./dataset/ts_classify_examples.tar -C ./dataset/
+```
+
+### 10.2 数据集校验
+
+在对数据集校验时,只需一行命令:
+
+```bash
+python main.py -c paddlex/configs/ts_classify_examples/DLinear_ad.yaml \
+    -o Global.mode=check_dataset \
+    -o Global.dataset_dir=./dataset/ts_classify_examples
+```
+
+执行上述命令后,PaddleX 会对数据集进行校验,并统计数据集的基本信息。命令运行成功后会在log中打印出 `Check dataset passed !` 信息,同时相关产出会保存在当前目录的 `./output/check_dataset` 目录下,产出目录中包括可视化的示例样本图片和样本分布直方图。校验结果文件保存在 `./output/check_dataset_result.json`,校验结果文件具体内容为
+```
+{
+  "done_flag": true,
+  "check_pass": true,
+  "attributes": {
+    "train_samples": 82620,
+    "train_table": [
+      [
+        "Unnamed: 0",
+        "group_id",
+        "dim_0",
+        ...,
+        "dim_60",
+        "label",
+        "time"
+      ],
+      [
+        0.0,
+        0.0,
+        0.000949,
+        ...,
+        0.12107,
+        1.0,
+        0.0
+      ]
+    ],
+    "val_samples": 83025,
+    "val_table": [
+      [
+        "Unnamed: 0",
+        "group_id",
+        "dim_0",
+        ...,
+        "dim_60",
+        "label",
+        "time"
+      ],
+      [
+        0.0,
+        0.0,
+        0.004578,
+        ...,
+        0.15728,
+        1.0,
+        0.0
+      ]
+    ]
+  },
+  "analysis": {
+    "histogram": "check_dataset/histogram.png"
+  },
+  "dataset_path": "./dataset/ts_classify_examples",
+  "show_type": "csv",
+  "dataset_type": "TSCLSDataset"
+}
+```
+上述校验结果中,check_pass 为 True 表示数据集格式符合要求,其他部分指标的说明如下:
+
+- attributes.train_samples:该数据集训练集样本数量为 82620;
+- attributes.val_samples:该数据集验证集样本数量为 83025;
+- attributes.train_table:该数据集训练集样本示例数据表格信息;
+- attributes.val_table:该数据集验证集样本示例数据表格信息;
+
+
+另外,数据集校验还对数据集中所有类别的样本数量分布情况进行了分析,并绘制了分布直方图(histogram.png):
+![样本分布直方图](https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/doc_images/open_source/tutorials/data/dataset_check/ts_classify_examples/histogram.png)
+
+**注**:只有通过数据校验的数据才可以训练和评估。
+
+
+### 10.3 数据集格式转换/数据集划分(非必选)
+
+如需对数据集格式进行转换或是重新划分数据集,可通过修改配置文件或是追加超参数的方式进行设置。
+
+数据集校验相关的参数可以通过修改配置文件中 `CheckDataset` 下的字段进行设置,配置文件中部分参数的示例说明如下:
+
+* `CheckDataset`:
+    * `convert`:
+        * `enable`: 是否进行数据集格式转换,为 `True` 时进行数据集格式转换,默认为 `False`;
+        * `src_dataset_type`: 如果进行数据集格式转换,时序异常检测仅支持将xlsx标注文件转换为xls,无需设置源数据集格式,默认为 `null`;
+    * `split`:
+        * `enable`: 是否进行重新划分数据集,为 `True` 时进行数据集格式转换,默认为 `False`;
+        * `train_percent`: 如果重新划分数据集,则需要设置训练集的百分比,类型为0-100之间的任意整数,需要保证和 `val_percent` 值加和为100;
+        * `val_percent`: 如果重新划分数据集,则需要设置验证集的百分比,类型为0-100之间的任意整数,需要保证和 `train_percent` 值加和为100;
 
-## 时序分类任务模块数据校验
+数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。

+ 1 - 0
docs/tutorials/inference/model_inference_api.md

@@ -0,0 +1 @@
+# 单模型 Python API 推理文档

+ 29 - 0
docs/tutorials/inference/model_inference_tools.md

@@ -0,0 +1,29 @@
+# PaddleX 单模型开发工具推理预测
+
+PaddleX 提供了丰富的单模型,其是完成某一类任务的子模块的最小单元,模型开发完后,可以方便地集成到各类系统中。PaddleX 中的每个模型提供了官方权重,支持通过命令行方式直接推理预测和调用 Python API 预测。命令行方式直接推理预测可以快速体验模型推理效果,而 Python API 预测可以方便地集成到自己的项目中进行预测。
+
+## 1.安装 PaddleX
+在使用单模型开发工具之前,首先需要安装 PaddleX 的 wheel 包,安装方式请参考 [PaddleX 安装文档](./INSTALL.md)。
+
+## 2.PaddleX 单模型开发工具使用方式
+
+### 2.1 推理预测
+
+PaddleX 支持单模型的统一推理 Python API,基于 Python API,您可以修改更多设置,实现多模型串联,自定义产线任务。使用 Python API 仅需几行代码,如下所示:
+
+```python
+from paddlex import PaddleInferenceOption, create_model
+
+model_name = "PP-LCNet_x1_0"
+
+# 实例化 PaddleInferenceOption 设置推理配置
+kernel_option = PaddleInferenceOption()
+kernel_option.set_device("gpu")
+
+model = create_model(model_name=model_name, kernel_option=kernel_option)
+
+# 预测
+result = model.predict({'input_path': "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_image_classification_001.jpg"})
+```
+
+PaddleX 提供的所有模型均支持以上 Python API 的调用,关于模型列表,您可以参考 [PaddleX 模型列表](../models/support_model_list.md),关于 Python API 的更多介绍,您可以参考 [PaddleX 模型推理 API](../API.md)。

+ 1 - 0
docs/tutorials/inference/pipeline_inference_api.md

@@ -0,0 +1 @@
+# 模型产线 Python API 推理文档

+ 202 - 0
docs/tutorials/inference/pipeline_inference_tools.md

@@ -0,0 +1,202 @@
+# PaddleX 模型产线开发工具推理预测
+
+模型产线指的是可以独立完成某类任务且具备落地能力的模型或模型组合,其可以是单模型产线,也可以是多模型组合的产线,PaddleX 提供了丰富的模型产线,可以方便地完成 AI 任务的推理和部署。PaddleX 中的每个模型产线有多个模型可供选择,并均提供了官方权重,支持通过命令行方式直接推理预测和调用 Python API 预测。命令行方式直接推理预测可以快速体验模型推理效果,而 Python API 预测可以方便地集成到自己的项目中进行预测。
+
+## 1.安装 PaddleX
+在使用单模型开发工具之前,首先需要安装 PaddleX 的 wheel 包,安装方式请参考 [PaddleX 安装文档](./INSTALL.md)。
+
+
+## 2.PaddleX 模型产线开发工具使用方式
+
+### 2.1 图像分类产线
+图像分类产线内置了多个图像分类的单模型,包含 `ResNet` 系列、`PP-LCNet` 系列、`MobileNetV2` 系列、`MobileNetV3` 系列、`ConvNeXt` 系列、`SwinTransformer` 系列、`PP-HGNet` 系列、`PP-HGNetV2` 系列、`CLIP` 系列等模型。具体支持的分类模型列表,您可以参考[模型库](./models/support_model_list.md),您可以使用以下两种方式进行推理预测,如果在您的场景中,上述模型不能满足您的需求,您可以参考 [PaddleX 模型训练文档](./train/README.md) 进行训练,训练后的模型可以非常方便地集成到该产线中。
+
+<details>
+<summary><b> 命令行使用方式 </b></summary>
+您可以使用命令行将图片的类别分出来,命令行使用方式如下:
+
+```
+paddlex --pipeline image_classification --model PP-LCNet_x1_0 --input https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_image_classification_001.jpg
+```
+参数解释:
+- `pipeline`: 产线名称,当前支持的产线名称有 `image_classification`、`object_detection`、`semantic_segmentation`、`instance_segmentation`、`ocr`。
+- `model`: 模型名称,每个产线支持的模型名称不同,请参考 [PaddleX 模型产线文档](./models/support_model_list.md)。对于多模型组合的产线,需要指定多个模型名称,以空格分隔。
+- `input`: 输入图片路径或 URL。
+</details>
+
+<details>
+<summary><b> Python API 使用方式</b></summary>
+
+
+```python
+from paddlex import ClsPipeline
+from paddlex import PaddleInferenceOption
+
+model_name = "PP-LCNet_x1_0"
+pipeline = ClsPipeline(model_name, kernel_option=PaddleInferenceOption())
+result = pipeline.predict(
+        {'input_path': "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_image_classification_001.jpg"}
+    )
+print(result["cls_result"])
+```    
+</details>
+
+### 2.2 目标检测产线
+
+
+目标检测产线内置了多个目标检测的单模型,包含 `PicoDet` 系列、`RT-DETR` 系列、`PP-YOLO-E` 系列等模型。具体支持的目标检测模型列表,您可以参考[模型库](./models/support_model_list.md),您可以使用以下两种方式进行推理预测,如果在您的场景中,上述模型不能满足您的需求,您可以参考 [PaddleX 模型训练文档](./train/README.md) 进行训练,训练后的模型可以非常方便地集成到该产线中。
+
+<details>
+<summary><b> 命令行使用方式 </b></summary>
+您可以使用命令行将图片中的目标检测出来,命令行使用方式如下:
+
+```
+paddlex --pipeline object_detection --model RT-DETR-L --input https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_object_detection_002.png
+
+```
+参数解释:
+- `pipeline`: 产线名称,当前支持的产线名称有 `image_classification`、`object_detection`、`semantic_segmentation`、`instance_segmentation`、`ocr`。
+- `model`: 模型名称,每个产线支持的模型名称不同,请参考 [PaddleX 模型产线文档](./models/support_model_list.md)。对于多模型组合的产线,需要指定多个模型名称,以空格分隔。
+- `input`: 输入图片路径或 URL。
+</details>
+
+<details>
+<summary><b> Python API 使用方式</b></summary>
+
+```python
+from pathlib import Path
+from paddlex import DetPipeline
+from paddlex import PaddleInferenceOption
+
+model_name =  "RT-DETR-L"
+output_base = Path("output")
+
+output_dir = output_base / model_name
+pipeline = DetPipeline(model_name, output_dir=output_dir, kernel_option=PaddleInferenceOption())
+result = pipeline.predict(
+        {"input_path": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_object_detection_002.png"})
+print(result["boxes"])
+
+```
+</details>
+
+
+### 2.3 语义分割产线
+
+
+语义分割产线内置了多个语义分割的单模型,包含 `PP-LiteSeg` 系列、`OCRNet` 系列、`DeepLabv3` 系列等模型。具体支持的语义分割模型列表,您可以参考[模型库](./models/support_model_list.md),您可以使用以下两种方式进行推理预测,如果在您的场景中,上述模型不能满足您的需求,您可以参考 [PaddleX 模型训练文档](./train/README.md) 进行训练,训练后的模型可以非常方便地集成到该产线中。
+
+<details>
+<summary><b> 命令行使用方式 </b></summary>
+您可以使用命令行将图片的语义信息分割出来,命令行使用方式如下:
+
+```
+paddlex --pipeline semantic_segmentation --model PP-LiteSeg-T --input https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_semantic_segmentation_002.png
+
+```
+参数解释:
+- `pipeline`: 产线名称,当前支持的产线名称有 `image_classification`、`object_detection`、`semantic_segmentation`、`instance_segmentation`、`ocr`。
+- `model`: 模型名称,每个产线支持的模型名称不同,请参考 [PaddleX 模型产线文档](./models/support_model_list.md)。对于多模型组合的产线,需要指定多个模型名称,以空格分隔。
+- `input`: 输入图片路径或 URL。
+</details>
+
+<details>
+<summary><b> Python API 使用方式</b></summary>
+
+```python
+from pathlib import Path
+from paddlex import SegPipeline
+from paddlex import PaddleInferenceOption
+
+
+model_name = "PP-LiteSeg-T",
+output_base = Path("output")
+output_dir = output_base / model_name
+pipeline = SegPipeline(model_name, output_dir=output_dir, kernel_option=PaddleInferenceOption())
+result = pipeline.predict(
+    {"input_path": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_semantic_segmentation_002.png"}
+)
+print(result["seg_map"])
+
+```
+</details>
+
+
+### 2.4 实例分割产线
+
+
+实例分割产线内置了两个目前 SOTA 的单模型,分别是 `Mask-RT-DETR-L` 和 `Mask-DT-DETR-H`。您可以使用以下两种方式进行推理预测,如果在您的场景中,上述模型不能满足您的需求,您可以参考 [PaddleX 模型训练文档](./train/README.md) 进行训练,训练后的模型可以非常方便地集成到该产线中。
+
+<details>
+<summary><b> 命令行使用方式 </b></summary>
+您可以使用命令行将图片中的实例分割出来,命令行使用方式如下:
+
+```
+paddlex --pipeline instance_segmentation --model RT-DETR-L --input https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_instance_segmentation_004.png
+
+```
+参数解释:
+- `pipeline`: 产线名称,当前支持的产线名称有 `image_classification`、`object_detection`、`semantic_segmentation`、`instance_segmentation`、`ocr`。
+- `model`: 模型名称,每个产线支持的模型名称不同,请参考 [PaddleX 模型产线文档](./models/support_model_list.md)。对于多模型组合的产线,需要指定多个模型名称,以空格分隔。
+- `input`: 输入图片路径或 URL。
+</details>
+
+<details>
+<summary><b> Python API 使用方式</b></summary>
+
+```python
+from pathlib import Path
+from paddlex import DetPipeline
+from paddlex import PaddleInferenceOption
+
+model_name =  "Mask-RT-DETR-L"
+output_base = Path("output")
+
+output_dir = output_base / model_name
+pipeline = DetPipeline(model_name, output_dir=output_dir, kernel_option=PaddleInferenceOption())
+result = pipeline.predict(
+    {"input_path": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_instance_segmentation_004.png"})
+print(result["boxes"])
+
+```
+</details>
+
+### 2.5 OCR 产线
+OCR 产线内置了 PP-OCRv4 模型,包括文本检测和文本识别两个部分。文本检测支持的模型有 `PP-OCRv4_mobile_det`、`PP-OCRv4_server_det`,文本识别支持的模型有 `PP-OCRv4_mobile_rec`、`PP-OCRv4_server_rec`。您可以使用以下两种方式进行推理预测,如果在您的场景中,上述模型不能满足您的需求,您可以参考 [PaddleX 模型训练文档](./train/README.md) 进行训练,训练后的模型可以非常方便地集成到该产线中。
+
+<details>
+<summary><b> 命令行使用方式 </b></summary>
+您可以使用命令行将图片的文字识别出来,命令行使用方式如下:
+
+```
+paddlex --pipeline ocr --model PP-OCRv4_mobile_det PP-OCRv4_mobile_rec --input https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_002.png --output ./
+```
+参数解释:
+- `pipeline`: 产线名称,当前支持的产线名称有 `image_classification`、`object_detection`、`semantic_segmentation`、`instance_segmentation`、`ocr`。
+- `model`: 模型名称,每个产线支持的模型名称不同,请参考 [PaddleX 模型产线文档](./models/support_model_list.md)。对于多模型组合的产线,需要指定多个模型名称,以空格分隔。
+- `input`: 输入图片路径或 URL。
+- `output`: 输出可视化图片的路径。
+</details>
+
+<details>
+<summary><b> Python API 使用方式</b></summary>
+
+```python
+import cv2
+from paddlex import OCRPipeline
+from paddlex import PaddleInferenceOption
+from paddle.pipelines.PPOCR.utils import draw_ocr_box_txt
+
+pipeline = OCRPipeline(
+    'PP-OCRv4_mobile_det',
+    'PP-OCRv4_mobile_rec',
+    text_det_kernel_option=PaddleInferenceOption(),
+    text_rec_kernel_option=PaddleInferenceOption(),)
+result = pipeline.predict(
+    {"input_path": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_002.png"},
+)
+
+draw_img = draw_ocr_box_txt(result['original_image'],result['dt_polys'], result["rec_text"])
+cv2.imwrite("ocr_result.jpg", draw_img[:, :, ::-1])
+```
+</details>

+ 0 - 174
docs/tutorials/pipeline.md

@@ -1,174 +0,0 @@
-# PaddleX 模型产线开发工具
-
-PaddleX 中提供了多个模型产线,包括:OCR、图像分类、目标检测、实例分割、语义分割等,每个模型产线有多个模型可供选择,并均提供了官方权重,支持通过命令行方式直接推理预测和 Python API 预测。命令行方式直接推理预测可以快速体验模型推理效果,而 Python API 预测可以方便地集成到自己的项目中进行预测。
-
-## 1.安装 PaddleX
-在使用模型产线工具之前,首先需要安装 PaddleX,安装方式请参考 [PaddleX 安装文档](xxx)。
-
-
-## 2.PaddleX 模型产线工具使用方式
-### 2.1 OCR 产线
-OCR 产线内置了 PP-OCRv4 模型,包括文字检测和文字识别两个部分。文字检测支持的模型有`PP-OCRv4_mobile_det`、`PP-OCRv4_server_det`,文字识别支持的模型有`PP-OCRv4_mobile_rec`、`PP-OCRv4_server_rec`。您可以使用以下两种方式进行推理预测,如果在您的场景中,上述模型不能满足您的需求,您可以参考 [PaddleX 模型训练文档](./train/README.md) 进行训练,训练后的模型可以非常方便地集成到该产线中。
-
-<details>
-<summary><b> 命令行使用方式 </b></summary>
-您可以使用命令行将图片的文字识别出来,命令行使用方式如下:
-
-```
-paddlex --task ocrdet --model PP-OCRv4_mobile_det --image /paddle/dataset/paddlex/ocr_det/ocr_det_dataset/xxx
-```
-参数解释:
-- `task`: 任务类型,当前支持 `ocrdet`
-- `model`: 模型名称,当前支持 `PP-OCRv4_mobile_det` 和 `PP-OCRv4_mobile_rec`。
-</details>
-
-<details>
-<summary><b> Python API 使用方式</b></summary>
-
-```python
-import cv2
-from paddlex import OCRPipeline
-from paddlex import PaddleInferenceOption
-from paddle.pipelines.PPOCR.utils import draw_ocr_box_txt
-
-pipeline = OCRPipeline(
-    'PP-OCRv4_mobile_det',
-    'PP-OCRv4_mobile_rec',
-    text_det_kernel_option=PaddleInferenceOption(),
-    text_rec_kernel_option=PaddleInferenceOption(),)
-result = pipeline(
-    "/paddle/dataset/paddlex/ocr_det/ocr_det_dataset_examples/images/train_img_100.jpg",
-)
-
-draw_img = draw_ocr_box_txt(result['original_image'],result['dt_polys'], result["rec_text"])
-cv2.imwrite("ocr_result.jpg", draw_img[:, :, ::-1], )
-```
-
-参数解释:
-- `task`: 任务类型,当前支持 `ocrdet`
-- `model`: 模型名称,当前支持 `PP-OCRv4_mobile_det` 和 `PP-OCRv4_mobile_rec`。
-</details>
-
-
-## 2.2 图像分类产线
-图像分类产线内置了多个图像分类的单模型,包含 `ResNet` 系列、`PP-LCNet` 系列、`MobileNetV2` 系列、`MobileNetV3` 系列、`ConvNeXt` 系列、`SwinTransformer` 系列、`PP-HGNet` 系列、`PP-HGNetV2` 系列、`CLIP` 系列等模型。具体支持的分类模型列表,您可以参考[模型库](./models/support_model_list.md),您可以使用以下两种方式进行推理预测,如果在您的场景中,上述模型不能满足您的需求,您可以参考 [PaddleX 模型训练文档](./train/README.md) 进行训练,训练后的模型可以非常方便地集成到该产线中。
-
-<details>
-<summary><b> 命令行使用方式 </b></summary>
-您可以使用命令行将图片的文字识别出来,命令行使用方式如下:
-
-```
-paddlex --task ocrdet --model PP-OCRv4_mobile_det --image /paddle/dataset/paddlex/ocr_det/ocr_det_dataset/xxx
-```
-参数解释:
-- `task`: 任务类型,当前支持 `ocrdet`
-- `model`: 模型名称,当前支持 `PP-OCRv4_mobile_det` 和 `PP-OCRv4_mobile_rec`。
-</details>
-
-
-<details>
-<summary><b> Python API 使用方式</b></summary>
-
-```python
-from paddlex import ClsPipeline
-from paddlex import PaddleInferenceOption
-
-pipeline = ClsPipeline(model_name, kernel_option=PaddleInferenceOption())
-    result = pipeline(
-        "/paddle/dataset/paddlex/cls/cls_flowers_examples/images/image_00006.jpg"
-    )
-    print(result["cls_result"])
-
-</details>
-
-
-## 目标检测
-
-```python
-from pathlib import Path
-from paddlex import DetPipeline
-from paddlex import PaddleInferenceOption
-
-models = [
-    "PicoDet-L",
-    "PicoDet-S",
-    "PP-YOLOE_plus-L",
-    "PP-YOLOE_plus-M",
-    "PP-YOLOE_plus-S",
-    "PP-YOLOE_plus-X",
-    "RT-DETR-H",
-    "RT-DETR-L",
-    "RT-DETR-R18",
-    "RT-DETR-R50",
-    "RT-DETR-X",
-]
-output_base = Path("output")
-
-for model_name in models:
-    output_dir = output_base / model_name
-    try:
-        pipeline = DetPipeline(model_name, output_dir=output_dir, kernel_option=PaddleInferenceOption())
-        result = pipeline(
-            "/paddle/dataset/paddlex/det/det_coco_examples/images/road0.png")
-        print(result["boxes"])
-    except Exception as e:
-        print(f"[ERROR] model: {model_name}; err: {e}")
-    print(f"[INFO] model: {model_name} done!")
-```
-
-
-## 实例分割
-
-```python
-from pathlib import Path
-from paddlex import InstanceSegPipeline
-from paddlex import PaddleInferenceOption
-
-models = ["Mask-RT-DETR-H", "Mask-RT-DETR-L"]
-output_base = Path("output")
-
-for model_name in models:
-    output_dir = output_base / model_name
-    try:
-        pipeline = InstanceSegPipeline(model_name, output_dir=output_dir, kernel_option=PaddleInferenceOption())
-        result = pipeline(
-            "/paddle/dataset/paddlex/instance_seg/instance_seg_coco_examples/images/aircraft-women-fashion-pilot-48797.png"
-        )
-        print(result["masks"])
-    except Exception as e:
-        print(f"[ERROR] model: {model_name}; err: {e}")
-    print(f"[INFO] model: {model_name} done!")
-```
-
-## 语义分割
-
-
-```python
-from pathlib import Path
-from paddlex import SegPipeline
-from paddlex import PaddleInferenceOption
-
-
-models = [
-    "Deeplabv3-R50",
-    "Deeplabv3-R101",
-    "Deeplabv3_Plus-R50",
-    "Deeplabv3_Plus-R101",
-    "PP-LiteSeg-T",
-    "OCRNet_HRNet-W48",
-]
-
-output_base = Path("output")
-
-for model_name in models:
-    output_dir = output_base / model_name
-    try:
-        pipeline = SegPipeline(model_name, output_dir=output_dir, kernel_option=PaddleInferenceOption())
-        result = pipeline(
-            "/paddle/dataset/paddlex/seg/seg_optic_examples/images/H0002.jpg"
-        )
-        print(result["seg_map"])
-    except Exception as e:
-        print(f"[ERROR] model: {model_name}; err: {e}")
-    print(f"[INFO] model: {model_name} done!")
-```

+ 0 - 63
docs/tutorials/wheel.md

@@ -1,63 +0,0 @@
-# 使用 PaddleX wheel 进行推理预测
-
-## 1. 安装
-
-### 1.1 安装 PaddleX whl
-
-1. 安装官方版本
-
-```bash
-pip install paddlex
-```
-
-2. 从源码编译安装
-
-```bash
-cd PaddleX
-pip install .
-```
-
-### 1.2 安装 PaddleX 相关依赖
-
-```bash
-paddlex --install
-```
-
-## 2. 推理预测
-
-### 2.1 使用 CLI 进行推理预测
-
-以图像分类模型 `PP-LCNet_x1_0` 为例,使用 PaddleX 预置的官方模型对图像(`/paddle/dataset/paddlex/cls/cls_flowers_examples/images/image_00002.jpg`)进行预测,命令如下:
-
-```bash
-paddlex --pipeline image_classification --model PP-LCNet_x1_0 --input /paddle/dataset/paddlex/cls/cls_flowers_examples/images/image_00006.jpg
-```
-
-可以得到预测结果:
-
-```
-[{'class_ids': [309], 'scores': [0.19514], 'label_names': ['bee']}]
-```
-
-以 OCR 为例,使用PaddleX 预置的 `PP-OCRv4_mobile_det` 和 `PP-OCRv4_mobile_rec` 官方模型,对图像(`/paddle/dataset/paddlex/ocr_det/ocr_det_dataset_examples/images/train_img_100.jpg`)进行预测,命令如下:
-
-```bash
-paddlex --pipeline ocr --model PP-OCRv4_mobile_det PP-OCRv4_mobile_rec --input /paddle/dataset/paddlex/ocr_det/ocr_det_dataset_examples/images/train_img_100.jpg  --output ./
-```
-
-可以在当前目录下得到预测结果示例图 `ocr_result.jpg`。
-
-
-### 2.2 使用 Python 进行推理预测
-
-```python
-import paddlex
-
-model_name = "PP-LCNet_x1_0"
-
-kernel_option = paddlex.PaddleInferenceOption()
-kernel_option.set_device("gpu")
-
-model = paddlex.create_model(model_name, kernel_option=kernel_option)
-model.predict("/paddle/dataset/paddlex/cls/cls_flowers_examples/images/image_00002.jpg")
-```

+ 8 - 8
paddlex/engine.py

@@ -35,19 +35,19 @@ class Engine(object):
     def run(self):
         """ the main function """
         if self.config.Global.mode == "check_dataset":
-            check_dataset = build_dataset_checker(self.config)
-            return check_dataset()
+            dataset_checker = build_dataset_checker(self.config)
+            return dataset_checker.check_dataset()
         elif self.config.Global.mode == "train":
-            train = build_trainer(self.config)
-            train()
+            trainer = build_trainer(self.config)
+            trainer.train()
         elif self.config.Global.mode == "evaluate":
-            evaluate = build_evaluater(self.config)
-            return evaluate()
+            evaluator = build_evaluater(self.config)
+            return evaluator.evaluate()
         elif self.config.Global.mode == "export":
             raise_unsupported_api_error("export", self.__class__)
         elif self.config.Global.mode == "predict":
-            predict = build_predictor(self.config)
-            return predict()
+            predictor = build_predictor(self.config)
+            return predictor.predict()
         else:
             raise_unsupported_api_error(f"{self.config.Global.mode}",
                                         self.__class__)

+ 9 - 0
paddlex/modules/__init__.py

@@ -26,3 +26,12 @@ InstanceSegPredictor
 from .ts_anomaly_detection import TSADDatasetChecker, TSADTrainer, TSADEvaluator, TSADPredictor
 from .ts_classification import TSCLSDatasetChecker, TSCLSTrainer, TSCLSEvaluator, TSCLSPredictor
 from .ts_forecast import TSFCDatasetChecker, TSFCTrainer, TSFCEvaluator, TSFCPredictor
+
+from .base.predictor.transforms import image_common
+from .image_classification import transforms as cls_transforms
+from .object_detection import transforms as det_transforms
+from .text_detection import transforms as text_det_transforms
+from .text_recognition import transforms as text_rec_transforms
+from .table_recognition import transforms as table_rec_transforms
+from .semantic_segmentation import transforms as seg_transforms
+from .instance_segmentation import transforms as instance_seg_transforms

+ 1 - 1
paddlex/modules/base/build_model.py

@@ -34,6 +34,6 @@ def build_model(model_name: str, device: str=None,
     config = Config(model_name, config_path)
 
     if device:
-        config.update_device(get_device(device).split(":")[0])
+        config.update_device(get_device(device))
     model = PaddleModel(config=config)
     return config, model

+ 1 - 150
paddlex/modules/base/dataset_checker/__init__.py

@@ -13,154 +13,5 @@
 # limitations under the License.
 
 
-import os
-from abc import ABC, abstractmethod
 
-from .utils import build_res_dict
-from ....utils.misc import AutoRegisterABCMetaClass
-from ....utils.config import AttrDict
-from ....utils.logging import info
-
-
-def build_dataset_checker(config: AttrDict) -> "BaseDatasetChecker":
-    """build dataset checker
-
-    Args:
-        config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
-
-    Returns:
-        BaseDatasetChecker: the dataset checker, which is subclass of BaseDatasetChecker.
-    """
-    model_name = config.Global.model
-    return BaseDatasetChecker.get(model_name)(config)
-
-
-class BaseDatasetChecker(ABC, metaclass=AutoRegisterABCMetaClass):
-    """ Base Dataset Checker """
-
-    __is_base = True
-
-    def __init__(self, config):
-        """Initialize the instance.
-
-        Args:
-            config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
-        """
-        super().__init__()
-        self.global_config = config.Global
-        self.check_dataset_config = config.CheckDataset
-        self.output_dir = os.path.join(self.global_config.output,
-                                       "check_dataset")
-
-    def __call__(self) -> dict:
-        """execute dataset checking
-
-        Returns:
-            dict: the dataset checking result.
-        """
-        dataset_dir = self.get_dataset_root(self.global_config.dataset_dir)
-
-        if not os.path.exists(self.output_dir):
-            os.makedirs(self.output_dir)
-
-        if self.check_dataset_config.get("convert", None):
-            if self.check_dataset_config.convert.get("enable", False):
-                self.convert_dataset(dataset_dir)
-                info("Convert dataset successfully !")
-
-        if self.check_dataset_config.get("split", None):
-            if self.check_dataset_config.split.get("enable", False):
-                self.split_dataset(dataset_dir)
-                info("Split dataset successfully !")
-
-        attrs = self.check_dataset(dataset_dir)
-        analysis = self.analyse(dataset_dir)
-
-        check_result = build_res_dict(True)
-        check_result["attributes"] = attrs
-        check_result["analysis"] = analysis
-        check_result["dataset_path"] = self.global_config.dataset_dir
-        check_result["show_type"] = self.get_show_type()
-        check_result["dataset_type"] = self.get_dataset_type()
-        info("Check dataset passed !")
-        return check_result
-
-    def get_dataset_root(self, dataset_dir: str) -> str:
-        """find the dataset root dir
-
-        Args:
-            dataset_dir (str): the directory that contain dataset.
-
-        Returns:
-            str: the root directory of dataset.
-        """
-        # XXX: forward compatible
-        # dataset_dir = [d for d in Path(dataset_dir).iterdir() if d.is_dir()]
-        # assert len(dataset_dir) == 1
-        # return dataset_dir[0].as_posix()
-        return dataset_dir
-
-    @abstractmethod
-    def check_dataset(self, dataset_dir: str):
-        """check if the dataset meets the specifications and get dataset summary
-
-        Args:
-            dataset_dir (str): the root directory of dataset.
-
-        Raises:
-            NotImplementedError
-        """
-        raise NotImplementedError
-
-    def convert_dataset(self, src_dataset_dir: str) -> str:
-        """convert the dataset from other type to specified type
-
-        Args:
-            src_dataset_dir (str): the root directory of dataset.
-
-        Returns:
-            str: the root directory of converted dataset.
-        """
-        dst_dataset_dir = src_dataset_dir
-        return dst_dataset_dir
-
-    def split_dataset(self, src_dataset_dir: str) -> str:
-        """repartition the train and validation dataset
-
-        Args:
-            src_dataset_dir (str): the root directory of dataset.
-
-        Returns:
-            str: the root directory of splited dataset.
-        """
-        dst_dataset_dir = src_dataset_dir
-        return dst_dataset_dir
-
-    def analyse(self, dataset_dir: str) -> dict:
-        """deep analyse dataset
-
-        Args:
-            dataset_dir (str): the root directory of dataset.
-
-        Returns:
-            dict: the deep analysis results.
-        """
-        return {}
-
-    @abstractmethod
-    def get_show_type(self):
-        """return the dataset show type
-
-        Raises:
-            NotImplementedError
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_dataset_type(self):
-        """ return the dataset type
-
-        Raises:
-            NotImplementedError
-        """
-        raise NotImplementedError
+from .dataset_checker import build_dataset_checker, BaseDatasetChecker

+ 166 - 0
paddlex/modules/base/dataset_checker/dataset_checker.py

@@ -0,0 +1,166 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# 
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+from abc import ABC, abstractmethod
+
+from .utils import build_res_dict
+from ....utils.misc import AutoRegisterABCMetaClass
+from ....utils.config import AttrDict
+from ....utils.logging import info
+
+
+def build_dataset_checker(config: AttrDict) -> "BaseDatasetChecker":
+    """build dataset checker
+
+    Args:
+        config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
+
+    Returns:
+        BaseDatasetChecker: the dataset checker, which is subclass of BaseDatasetChecker.
+    """
+    model_name = config.Global.model
+    return BaseDatasetChecker.get(model_name)(config)
+
+
+class BaseDatasetChecker(ABC, metaclass=AutoRegisterABCMetaClass):
+    """ Base Dataset Checker """
+
+    __is_base = True
+
+    def __init__(self, config):
+        """Initialize the instance.
+
+        Args:
+            config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
+        """
+        super().__init__()
+        self.global_config = config.Global
+        self.check_dataset_config = config.CheckDataset
+        self.output_dir = os.path.join(self.global_config.output,
+                                       "check_dataset")
+
+    def check_dataset(self) -> dict:
+        """execute dataset checking
+
+        Returns:
+            dict: the dataset checking result.
+        """
+        dataset_dir = self.get_dataset_root(self.global_config.dataset_dir)
+
+        if not os.path.exists(self.output_dir):
+            os.makedirs(self.output_dir)
+
+        if self.check_dataset_config.get("convert", None):
+            if self.check_dataset_config.convert.get("enable", False):
+                self.convert_dataset(dataset_dir)
+                info("Convert dataset successfully !")
+
+        if self.check_dataset_config.get("split", None):
+            if self.check_dataset_config.split.get("enable", False):
+                self.split_dataset(dataset_dir)
+                info("Split dataset successfully !")
+
+        attrs = self.check_dataset(dataset_dir)
+        analysis = self.analyse(dataset_dir)
+
+        check_result = build_res_dict(True)
+        check_result["attributes"] = attrs
+        check_result["analysis"] = analysis
+        check_result["dataset_path"] = self.global_config.dataset_dir
+        check_result["show_type"] = self.get_show_type()
+        check_result["dataset_type"] = self.get_dataset_type()
+        info("Check dataset passed !")
+        return check_result
+
+    def get_dataset_root(self, dataset_dir: str) -> str:
+        """find the dataset root dir
+
+        Args:
+            dataset_dir (str): the directory that contain dataset.
+
+        Returns:
+            str: the root directory of dataset.
+        """
+        # XXX: forward compatible
+        # dataset_dir = [d for d in Path(dataset_dir).iterdir() if d.is_dir()]
+        # assert len(dataset_dir) == 1
+        # return dataset_dir[0].as_posix()
+        return dataset_dir
+
+    @abstractmethod
+    def check_dataset(self, dataset_dir: str):
+        """check if the dataset meets the specifications and get dataset summary
+
+        Args:
+            dataset_dir (str): the root directory of dataset.
+
+        Raises:
+            NotImplementedError
+        """
+        raise NotImplementedError
+
+    def convert_dataset(self, src_dataset_dir: str) -> str:
+        """convert the dataset from other type to specified type
+
+        Args:
+            src_dataset_dir (str): the root directory of dataset.
+
+        Returns:
+            str: the root directory of converted dataset.
+        """
+        dst_dataset_dir = src_dataset_dir
+        return dst_dataset_dir
+
+    def split_dataset(self, src_dataset_dir: str) -> str:
+        """repartition the train and validation dataset
+
+        Args:
+            src_dataset_dir (str): the root directory of dataset.
+
+        Returns:
+            str: the root directory of splited dataset.
+        """
+        dst_dataset_dir = src_dataset_dir
+        return dst_dataset_dir
+
+    def analyse(self, dataset_dir: str) -> dict:
+        """deep analyse dataset
+
+        Args:
+            dataset_dir (str): the root directory of dataset.
+
+        Returns:
+            dict: the deep analysis results.
+        """
+        return {}
+
+    @abstractmethod
+    def get_show_type(self):
+        """return the dataset show type
+
+        Raises:
+            NotImplementedError
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_dataset_type(self):
+        """ return the dataset type
+
+        Raises:
+            NotImplementedError
+        """
+        raise NotImplementedError

+ 1 - 1
paddlex/modules/base/evaluator.py

@@ -94,7 +94,7 @@ class BaseEvaluator(ABC, metaclass=AutoRegisterABCMetaClass):
                 return False
         return True
 
-    def __call__(self) -> dict:
+    def evaluate(self) -> dict:
         """execute model training
 
         Returns:

+ 0 - 14
paddlex/modules/base/predictor/io/readers.py

@@ -16,12 +16,8 @@
 
 import enum
 import itertools
-from pathlib import Path
 import cv2
 
-from .....utils.download import download
-from .....utils.cache import CACHE_DIR
-
 __all__ = ['ImageReader', 'VideoReader', 'ReaderType']
 
 
@@ -65,14 +61,6 @@ class _BaseReader(object):
         """ get default backend arguments """
         return {}
 
-    def _download_from_url(self, in_path):
-        if in_path.startswith("http"):
-            file_name = Path(in_path).name
-            save_path = Path(CACHE_DIR) / "predict_input" / file_name
-            download(in_path, save_path, overwrite=True)
-            return save_path.as_posix()
-        return in_path
-
 
 class ImageReader(_BaseReader):
     """ ImageReader """
@@ -82,8 +70,6 @@ class ImageReader(_BaseReader):
 
     def read(self, in_path):
         """ read the image file from path """
-        # XXX: auto download for url
-        in_path = self._download_from_url(in_path)
         arr = self._backend.read_file(in_path)
         return arr
 

+ 14 - 2
paddlex/modules/base/predictor/kernel_option.py

@@ -16,6 +16,8 @@
 
 from functools import wraps, partial
 
+from ....utils import logging
+
 
 def register(register_map, key):
     """register the option setting func
@@ -64,6 +66,7 @@ class PaddleInferenceOption(object):
             'run_mode': 'paddle',
             'batch_size': 1,
             'device': 'gpu',
+            'device_id': 0,
             'min_subgraph_size': 3,
             'shape_info_filename': None,
             'trt_calib_mode': False,
@@ -91,16 +94,25 @@ class PaddleInferenceOption(object):
         self._cfg['batch_size'] = batch_size
 
     @register2self('device')
-    def set_device(self, device: str):
+    def set_device(self, device_setting: str):
         """set device
         """
-        device = device.split(":")[0]
+        if len(device_setting.split(":")) == 1:
+            device = device_setting.split(":")[0]
+            device_id = 0
+        else:
+            assert len(device_setting.split(":")) == 2
+            device = device_setting.split(":")[0]
+            device_id = device_setting.split(":")[1].split(",")[0]
+            logging.warning(f"The device id has been set to {device_id}.")
+
         if device.lower() not in self.SUPPORT_DEVICE:
             support_run_mode_str = ", ".join(self.SUPPORT_DEVICE)
             raise ValueError(
                 f"`device` must be {support_run_mode_str}, but received {repr(device)}."
             )
         self._cfg['device'] = device.lower()
+        self._cfg['device_id'] = int(device_id)
 
     @register2self('min_subgraph_size')
     def set_min_subgraph_size(self, min_subgraph_size: int):

+ 54 - 36
paddlex/modules/base/predictor/predictor.py

@@ -38,20 +38,37 @@ class BasePredictor(ABC, FromDictMixin, Node):
     def __init__(self,
                  model_dir,
                  kernel_option,
+                 output_dir,
                  pre_transforms=None,
                  post_transforms=None):
         super().__init__()
         self.model_dir = model_dir
-        self.pre_transforms = pre_transforms
-        self.post_transforms = post_transforms
         self.kernel_option = kernel_option
+        self.output_dir = output_dir
+        self.other_src = self.load_other_src()
+
+        logging.debug(
+            f"-------------------- {self.__class__.__name__} --------------------\n\
+Model: {self.model_dir}\n\
+Env: {self.kernel_option}")
+        self.pre_tfs, self.post_tfs = self.build_transforms(pre_transforms,
+                                                            post_transforms)
 
         param_path = os.path.join(model_dir, f"{self.MODEL_FILE_TAG}.pdiparams")
         model_path = os.path.join(model_dir, f"{self.MODEL_FILE_TAG}.pdmodel")
         self._predictor = _PaddleInferencePredictor(
             param_path=param_path, model_path=model_path, option=kernel_option)
 
-        self.other_src = self.load_other_src()
+    def build_transforms(self, pre_transforms, post_transforms):
+        """ build pre-transforms and post-transforms
+        """
+        pre_tfs = pre_transforms if pre_transforms is not None else self._get_pre_transforms_from_config(
+        )
+        logging.debug(f"Preprocess Ops: {self._format_transforms(pre_tfs)}")
+        post_tfs = post_transforms if post_transforms is not None else self._get_post_transforms_from_config(
+        )
+        logging.debug(f"Postprocessing: {self._format_transforms(post_tfs)}")
+        return pre_tfs, post_tfs
 
     def predict(self, input, batch_size=1):
         """ predict """
@@ -63,24 +80,10 @@ class BasePredictor(ABC, FromDictMixin, Node):
         if isinstance(input, dict):
             input = [input]
 
-        logging.debug(
-            f"-------------------- {self.__class__.__name__} --------------------\n\
-Model: {self.model_dir}\nEnv: {self.kernel_option}")
-        data = input[0]
-        if self.pre_transforms is not None:
-            pre_tfs = self.pre_transforms
-        else:
-            pre_tfs = self._get_pre_transforms_for_data(data)
-        logging.debug(f"Preprocess Ops: {self._format_transforms(pre_tfs)}")
-        if self.post_transforms is not None:
-            post_tfs = self.post_transforms
-        else:
-            post_tfs = self._get_post_transforms_for_data(data)
-        logging.debug(f"Postprocessing: {self._format_transforms(post_tfs)}")
-
         output = []
         for mini_batch in Batcher(input, batch_size=batch_size):
-            mini_batch = self._preprocess(mini_batch, pre_transforms=pre_tfs)
+            mini_batch = self._preprocess(
+                mini_batch, pre_transforms=self.pre_tfs)
 
             for data in mini_batch:
                 self.check_input_keys(data)
@@ -90,7 +93,8 @@ Model: {self.model_dir}\nEnv: {self.kernel_option}")
             for data in mini_batch:
                 self.check_output_keys(data)
 
-            mini_batch = self._postprocess(mini_batch, post_transforms=post_tfs)
+            mini_batch = self._postprocess(
+                mini_batch, post_transforms=self.post_tfs)
 
             output.extend(mini_batch)
 
@@ -104,12 +108,12 @@ Model: {self.model_dir}\nEnv: {self.kernel_option}")
         raise NotImplementedError
 
     @abstractmethod
-    def _get_pre_transforms_for_data(self, data):
+    def _get_pre_transforms_from_config(self):
         """ get preprocess transforms """
         raise NotImplementedError
 
     @abstractmethod
-    def _get_post_transforms_for_data(self, data):
+    def _get_post_transforms_from_config(self):
         """ get postprocess transforms """
         raise NotImplementedError
 
@@ -137,6 +141,11 @@ Model: {self.model_dir}\nEnv: {self.kernel_option}")
         """
         return None
 
+    def get_input_keys(self):
+        """get keys of input dict
+        """
+        return self.pre_tfs[0].get_input_keys()
+
 
 class PredictorBuilderByConfig(object):
     """build model predictor
@@ -149,7 +158,7 @@ class PredictorBuilderByConfig(object):
         """
         model_name = config.Global.model
 
-        device = config.Global.device.split(':')[0]
+        device = config.Global.device
 
         predict_config = deepcopy(config.Predict)
         model_dir = predict_config.pop('model_dir')
@@ -159,17 +168,16 @@ class PredictorBuilderByConfig(object):
 
         self.input_path = predict_config.pop('input_path')
 
-        self.predictor = BasePredictor.get(model_name)(model_dir, kernel_option,
-                                                       **predict_config)
-        self.output = config.Global.output
+        self.predictor = BasePredictor.get(model_name)(
+            model_dir=model_dir,
+            kernel_option=kernel_option,
+            output_dir=config.Global.output_dir,
+            **predict_config)
 
-    def __call__(self):
-        data = {
-            "input_path": self.input_path,
-            "cli_flag": True,
-            "output_dir": self.output
-        }
-        self.predictor.predict(data)
+    def predict(self):
+        """predict
+        """
+        self.predictor.predict({'input_path': self.input_path})
 
 
 def build_predictor(*args, **kwargs):
@@ -181,6 +189,7 @@ def build_predictor(*args, **kwargs):
 def create_model(model_name,
                  model_dir=None,
                  kernel_option=None,
+                 output_dir=None,
                  pre_transforms=None,
                  post_transforms=None,
                  *args,
@@ -189,7 +198,16 @@ def create_model(model_name,
     """
     kernel_option = PaddleInferenceOption(
     ) if kernel_option is None else kernel_option
-    model_dir = official_models[model_name] if model_dir is None else model_dir
-    return BasePredictor.get(model_name)(model_dir, kernel_option,
-                                         pre_transforms, post_transforms, *args,
+    if model_dir is None:
+        if model_name in official_models:
+            model_dir = official_models[model_name]
+        else:
+            # model name is invalid
+            BasePredictor.get(model_name)
+    return BasePredictor.get(model_name)(model_dir=model_dir,
+                                         kernel_option=kernel_option,
+                                         output_dir=output_dir,
+                                         pre_transforms=pre_transforms,
+                                         post_transforms=post_transforms,
+                                         *args,
                                          **kwargs)

+ 31 - 1
paddlex/modules/base/predictor/transforms/image_common.py

@@ -15,12 +15,16 @@
 
 
 import math
+from pathlib import Path
 
 import numpy as np
 import cv2
 
+from .....utils.download import download
+from .....utils.cache import CACHE_DIR
 from ..transform import BaseTransform
 from ..io.readers import ImageReader
+from ..io.writers import ImageWriter
 from . import image_functions as F
 
 __all__ = [
@@ -57,26 +61,52 @@ class ReadImage(BaseTransform):
         self.format = format
         flags = self._FLAGS_DICT[self.format]
         self._reader = ImageReader(backend='opencv', flags=flags)
+        self._writer = ImageWriter(backend='opencv')
 
     def apply(self, data):
         """ apply """
+        if 'image' in data:
+            img = data['image']
+            img_path = (Path(CACHE_DIR) / "predict_input" /
+                        "tmp_img.jpg").as_posix()
+            self._writer.write(img_path, img)
+            data['input_path'] = img_path
+            data['original_image'] = img
+            data['original_image_size'] = [img.shape[1], img.shape[0]]
+            return data
+
+        elif 'input_path' not in data:
+            raise KeyError(
+                f"Key {repr('input_path')} is required, but not found.")
+
         im_path = data['input_path']
+        # XXX: auto download for url
+        im_path = self._download_from_url(im_path)
         blob = self._reader.read(im_path)
         if self.format == 'RGB':
             if blob.ndim != 3:
                 raise RuntimeError("Array is not 3-dimensional.")
             # BGR to RGB
             blob = blob[..., ::-1]
+        data['input_path'] = im_path
         data['image'] = blob
         data['original_image'] = blob
         data['original_image_size'] = [blob.shape[1], blob.shape[0]]
         return data
 
+    def _download_from_url(self, in_path):
+        if in_path.startswith("http"):
+            file_name = Path(in_path).name
+            save_path = Path(CACHE_DIR) / "predict_input" / file_name
+            download(in_path, save_path, overwrite=True)
+            return save_path.as_posix()
+        return in_path
+
     @classmethod
     def get_input_keys(cls):
         """ get input keys """
         # input_path: Path of the image.
-        return ['input_path']
+        return [['input_path'], ['image']]
 
     @classmethod
     def get_output_keys(cls):

+ 3 - 3
paddlex/modules/base/predictor/utils/paddle_inference_predictor.py

@@ -41,11 +41,11 @@ self._create(param_path, model_path, option, delete_pass=delete_pass)
                                 len(param_buffer))
 
         if option.device == 'gpu':
-            config.enable_use_gpu(200, 0)
+            config.enable_use_gpu(200, option.device_id)
         elif option.device == 'npu':
             config.enable_custom_device('npu')
-            os.environ["FLAGS_npu_jit_compile"] = 0
-            os.environ["FLAGS_use_stride_kernel"] = 0
+            os.environ["FLAGS_npu_jit_compile"] = "0"
+            os.environ["FLAGS_use_stride_kernel"] = "0"
             os.environ["FLAGS_allocator_strategy"] = "auto_growth"
         elif option.device == 'xpu':
             config.enable_custom_device('npu')

+ 1 - 1
paddlex/modules/base/trainer/train_deamon.py

@@ -35,7 +35,7 @@ def try_except_decorator(func):
         except Exception as e:
             exc_type, exc_value, exc_tb = sys.exc_info()
             self.save_json()
-            traceback.logging.info_exception(exc_type, exc_value, exc_tb)
+            traceback.print_exception(exc_type, exc_value, exc_tb)
         finally:
             self.processing = False
 

+ 1 - 1
paddlex/modules/base/trainer/trainer.py

@@ -52,7 +52,7 @@ class BaseTrainer(ABC, metaclass=AutoRegisterABCMetaClass):
         self.deamon = self.build_deamon(self.global_config)
         self.pdx_config, self.pdx_model = build_model(self.global_config.model)
 
-    def __call__(self, *args, **kwargs):
+    def train(self, *args, **kwargs):
         """execute model training
         """
         os.makedirs(self.global_config.output, exist_ok=True)

+ 0 - 1
paddlex/modules/image_classification/predictor/keys.py

@@ -29,4 +29,3 @@ class ClsKeys(object):
     # Suite-specific keys
     CLS_PRED = 'cls_pred'
     CLS_RESULT = 'cls_result'
-    LABELS = 'labels'

+ 13 - 25
paddlex/modules/image_classification/predictor/predictor.py

@@ -64,32 +64,20 @@ class ClsPredictor(BasePredictor):
             dict_[K.CLS_PRED] = cls_out
         return pred
 
-    def _get_pre_transforms_for_data(self, data):
+    def _get_pre_transforms_from_config(self):
         """ get preprocess transforms """
-        if K.IMAGE not in data:
-            if K.IM_PATH not in data:
-                raise KeyError(
-                    f"Key {repr(K.IM_PATH)} is required, but not found.")
-            logging.info(
-                f"Transformation operators for data preprocessing will be inferred from config file."
-            )
-            pre_transforms = self.other_src.pre_transforms
-            pre_transforms.insert(0, image_common.ReadImage(format='RGB'))
-        else:
-            raise RuntimeError(
-                f"`{self.__class__.__name__}` does not have default transformation operators to preprocess the input. "
-                f"Please set `pre_transforms` when using the {repr(K.IMAGE)} key in input dict."
-            )
-        pre_transforms.insert(0, T.LoadLabels(self.other_src.labels))
+        logging.info(
+            f"Transformation operators for data preprocessing will be inferred from config file."
+        )
+        pre_transforms = self.other_src.pre_transforms
+        pre_transforms.insert(0, image_common.ReadImage(format='RGB'))
         return pre_transforms
 
-    def _get_post_transforms_for_data(self, data):
+    def _get_post_transforms_from_config(self):
         """ get postprocess transforms """
-        if data.get('cli_flag', False):
-            output_dir = data.get("output_dir", "./")
-            return [
-                # T.SaveDetResults(output_dir, labels=self.other_src.labels),
-                T.Topk(topk=1),
-                T.PrintResult()
-            ]
-        return []
+        post_transforms = self.other_src.post_transforms
+        post_transforms.extend([
+            T.PrintResult(), T.SaveClsResults(self.output_dir,
+                                              self.other_src.labels)
+        ])
+        return post_transforms

+ 102 - 36
paddlex/modules/image_classification/predictor/transforms.py

@@ -14,50 +14,42 @@
 
 
 import os
+import json
+from PIL import Image, ImageDraw, ImageFont
+from pathlib import Path
 
 import numpy as np
 
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH
 from ...base import BaseTransform
+from ...base.predictor.io.writers import ImageWriter
 from .keys import ClsKeys as K
 from ....utils import logging
 
-__all__ = ["Topk", "NormalizeFeatures"]
+__all__ = ["Topk", "NormalizeFeatures", "PrintResult", "SaveClsResults"]
+
+
+def _parse_class_id_map(class_ids):
+    """ parse class id to label map file """
+    if class_ids is None:
+        return None
+    class_id_map = {id: str(lb) for id, lb in enumerate(class_ids)}
+    return class_id_map
 
 
 class Topk(BaseTransform):
     """ Topk Transform """
 
-    def __init__(self, topk, class_id_map_file=None, delimiter=None):
+    def __init__(self, topk, class_ids=None):
         super().__init__()
         assert isinstance(topk, (int, ))
         self.topk = topk
-        self.delimiter = delimiter if delimiter is not None else " "
-        self.class_id_map = self._parse_class_id_map(class_id_map_file)
-
-    def _parse_class_id_map(self, class_id_map_file):
-        """ parse class id to label map file """
-        if class_id_map_file is None:
-            return None
-        if not os.path.exists(class_id_map_file):
-            logging.warning(
-                "Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!"
-            )
-            return None
-
-        class_id_map = {}
-        with open(class_id_map_file, 'r', encoding='utf-8') as fin:
-            lines = fin.readlines()
-            for line in lines:
-                partition = line.split("\n")[0].partition(self.delimiter)
-                class_id_map[int(partition[0])] = str(partition[-1])
-        return class_id_map
+        self.class_id_map = _parse_class_id_map(class_ids)
 
     def apply(self, data):
         """ apply """
         x = data[K.CLS_PRED]
-        class_id_map = data[
-            K.
-            LABELS] if self.class_id_map is None and K.LABELS in data else self.class_id_map
+        class_id_map = self.class_id_map
         y = []
         index = x.argsort(axis=0)[-self.topk:][::-1].astype("int32")
         clas_id_list = []
@@ -132,26 +124,100 @@ class PrintResult(BaseTransform):
         return []
 
 
-class LoadLabels(BaseTransform):
-    """load label to data
-    """
-
-    def __init__(self, labels=None):
+class SaveClsResults(BaseTransform):
+    def __init__(self, save_dir, class_ids=None):
         super().__init__()
-        self.labels = labels
+        self.save_dir = save_dir
+        self.class_id_map = _parse_class_id_map(class_ids)
+        self._writer = ImageWriter(backend='pillow')
+
+    def _get_colormap(self, rgb=False):
+        """
+        Get colormap
+        """
+        color_list = np.array([
+            0xFF, 0x00, 0x00, 0xCC, 0xFF, 0x00, 0x00, 0xFF, 0x66, 0x00, 0x66,
+            0xFF, 0xCC, 0x00, 0xFF, 0xFF, 0x4D, 0x00, 0x80, 0xff, 0x00, 0x00,
+            0xFF, 0xB2, 0x00, 0x1A, 0xFF, 0xFF, 0x00, 0xE5, 0xFF, 0x99, 0x00,
+            0x33, 0xFF, 0x00, 0x00, 0xFF, 0xFF, 0x33, 0x00, 0xFF, 0xff, 0x00,
+            0x99, 0xFF, 0xE5, 0x00, 0x00, 0xFF, 0x1A, 0x00, 0xB2, 0xFF, 0x80,
+            0x00, 0xFF, 0xFF, 0x00, 0x4D
+        ]).astype(np.float32)
+        color_list = (color_list.reshape((-1, 3)))
+        if not rgb:
+            color_list = color_list[:, ::-1]
+        return color_list.astype('int32')
+
+    def _get_font_colormap(self, color_index):
+        """
+        Get font colormap
+        """
+        dark = np.array([0x14, 0x0E, 0x35])
+        light = np.array([0xFF, 0xFF, 0xFF])
+        light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
+        if color_index in light_indexs:
+            return light.astype('int32')
+        else:
+            return dark.astype('int32')
 
     def apply(self, data):
-        """ apply """
-        if self.labels:
-            data[K.LABELS] = self.labels
+        """ Draw label on image """
+        ori_path = data[K.IM_PATH]
+        pred = data[K.CLS_PRED]
+        index = pred.argsort(axis=0)[-1].astype("int32")
+        score = pred[index].item()
+        label = self.class_id_map[int(index)]
+        label_str = f"{label} {score:.2f}"
+        file_name = os.path.basename(ori_path)
+        save_path = os.path.join(self.save_dir, file_name)
+
+        image = Image.open(ori_path)
+        image = image.convert('RGB')
+        image_size = image.size
+        draw = ImageDraw.Draw(image)
+        min_font_size = int(image_size[0] * 0.02)
+        max_font_size = int(image_size[0] * 0.05)
+        for font_size in range(max_font_size, min_font_size - 1, -1):
+            font = ImageFont.truetype(
+                PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
+            text_width_tmp, text_height_tmp = draw.textsize(label_str, font)
+            if text_width_tmp <= image_size[0]:
+                break
+            else:
+                font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH,
+                                          min_font_size)
+        color_list = self._get_colormap(rgb=True)
+        color = tuple(color_list[0])
+        font_color = tuple(self._get_font_colormap(3))
+        text_width, text_height = draw.textsize(label_str, font)
+
+        rect_left = 3
+        rect_top = 3
+        rect_right = rect_left + text_width + 3
+        rect_bottom = rect_top + text_height + 6
+
+        draw.rectangle(
+            [(rect_left, rect_top), (rect_right, rect_bottom)], fill=color)
+
+        text_x = rect_left + 3
+        text_y = rect_top
+        draw.text((text_x, text_y), label_str, fill=font_color, font=font)
+        self._write_image(save_path, image)
+
         return data
 
+    def _write_image(self, path, image):
+        """ write image """
+        if os.path.exists(path):
+            logging.warning(f"{path} already exists. Overwriting it.")
+        self._writer.write(path, image)
+
     @classmethod
     def get_input_keys(cls):
         """ get input keys """
-        return []
+        return [K.IM_PATH, K.CLS_PRED]
 
     @classmethod
     def get_output_keys(cls):
         """ get output keys """
-        return [K.LABELS]
+        return []

+ 18 - 2
paddlex/modules/image_classification/predictor/utils.py

@@ -19,6 +19,7 @@ import codecs
 import yaml
 
 from ...base.predictor.transforms import image_common
+from . import transforms as T
 
 
 class InnerConfig(object):
@@ -36,7 +37,7 @@ class InnerConfig(object):
 
     @property
     def pre_transforms(self):
-        """ read preprocess transforms from  config file """
+        """ read preprocess transforms from config file """
         if "RecPreProcess" in list(self.inner_cfg.keys()):
             tfs_cfg = self.inner_cfg['RecPreProcess']['transform_ops']
         else:
@@ -53,7 +54,7 @@ class InnerConfig(object):
                 if "resize_short" in list(cfg[tf_key].keys()):
                     tf = image_common.ResizeByShort(
                         target_short_edge=cfg['ResizeImage'].get("resize_short",
-                                                                 (224, 224)),
+                                                                 224),
                         size_divisor=None,
                         interp='LINEAR')
                 else:
@@ -70,6 +71,21 @@ class InnerConfig(object):
         return tfs
 
     @property
+    def post_transforms(self):
+        """ read postprocess transforms from config file """
+        tfs_cfg = self.inner_cfg['PostProcess']
+        tfs = []
+        for tf_key in tfs_cfg:
+            if tf_key == 'Topk':
+                tf = T.Topk(
+                    topk=tfs_cfg['Topk']['topk'],
+                    class_ids=tfs_cfg['Topk']['label_list'])
+            else:
+                raise RuntimeError(f"Unsupported type: {tf_key}")
+            tfs.append(tf)
+        return tfs
+
+    @property
     def labels(self):
         """ the labels in inner config """
         return self.inner_cfg["PostProcess"]["Topk"]["label_list"]

+ 1 - 1
paddlex/modules/instance_segmentation/dataset_checker/dataset_src/check_dataset.py

@@ -58,7 +58,7 @@ def check(dataset_dir, output_dir, sample_num=10):
             coco = COCO(file_list)
             num_class = len(coco.getCatIds())
 
-            vis_save_dir = osp.join(output_dir, 'tmp')
+            vis_save_dir = osp.join(output_dir, 'demo_img')
 
             image_info = jsondata['images']
             for i in range(sample_num):

+ 2 - 2
paddlex/modules/instance_segmentation/predictor/keys.py

@@ -24,9 +24,9 @@ class InstanceSegKeys(object):
     """
 
     # Common keys
-    IMAGE_PATH = 'input_path'
     IMAGE = 'image'
-    IMAGE_SHAPE = 'image_size'
+    IM_PATH = 'input_path'
+    IM_SIZE = 'image_size'
     SCALE_FACTOR = 'scale_factors'
     # Suite-specific keys
     BOXES = 'boxes'

+ 1 - 1
paddlex/modules/instance_segmentation/predictor/predictor.py

@@ -34,7 +34,7 @@ class InstanceSegPredictor(DetPredictor):
             axis=0).astype(
                 dtype=np.float32, copy=False)
         input_dict["im_shape"] = np.stack(
-            [data[K.IMAGE_SHAPE][::-1] for data in batch_input], axis=0).astype(
+            [data[K.IM_SIZE][::-1] for data in batch_input], axis=0).astype(
                 dtype=np.float32, copy=False)
 
         input_ = [input_dict[i] for i in self._predictor.get_input_names()]

+ 3 - 4
paddlex/modules/object_detection/predictor/keys.py

@@ -24,11 +24,10 @@ class DetKeys(object):
     """
 
     # Common keys
-    IMAGE_PATH = 'input_path'
     IMAGE = 'image'
-    IMAGE_SHAPE = 'image_size'
+    IM_PATH = 'input_path'
+    IM_SIZE = 'image_size'
     SCALE_FACTOR = 'scale_factors'
     # Suite-specific keys
     BOXES = 'boxes'
-    MASKS = 'masks'
-    LABELS = 'labels'
+    MASKS = 'masks'

+ 14 - 23
paddlex/modules/object_detection/predictor/predictor.py

@@ -42,7 +42,7 @@ class DetPredictor(BasePredictor):
     @classmethod
     def get_input_keys(cls):
         """ get input keys """
-        return [[K.IMAGE], [K.IMAGE_PATH]]
+        return [[K.IMAGE], [K.IM_PATH]]
 
     @classmethod
     def get_output_keys(cls):
@@ -60,7 +60,7 @@ class DetPredictor(BasePredictor):
             axis=0).astype(
                 dtype=np.float32, copy=False)
         input_dict["im_shape"] = np.stack(
-            [data[K.IMAGE_SHAPE][::-1] for data in batch_input], axis=0).astype(
+            [data[K.IM_SIZE][::-1] for data in batch_input], axis=0).astype(
                 dtype=np.float32, copy=False)
 
         input_ = [input_dict[i] for i in self._predictor.get_input_names()]
@@ -78,28 +78,19 @@ class DetPredictor(BasePredictor):
             batch_input[idx][K.BOXES] = np_boxes
         return pred
 
-    def _get_pre_transforms_for_data(self, data):
+    def _get_pre_transforms_from_config(self):
         """ get preprocess transforms """
-        if K.IMAGE not in data:
-            if K.IMAGE_PATH not in data:
-                raise KeyError(
-                    f"Key {repr(K.IMAGE_PATH)} is required, but not found.")
-            logging.info(
-                f"Transformation operators for data preprocessing will be inferred from config file."
-            )
-            pre_transforms = self.other_src.pre_transforms
-            pre_transforms.insert(0, image_common.ReadImage(format='RGB'))
-        else:
-            raise RuntimeError(
-                f"`{self.__class__.__name__}` does not have default transformation operators to preprocess the input. "
-                f"Please set `pre_transforms` when using the {repr(K.IMAGE)} key in input dict."
-            )
-        pre_transforms.insert(0, T.LoadLabels(self.other_src.labels))
+        logging.info(
+            f"Transformation operators for data preprocessing will be inferred from config file."
+        )
+        pre_transforms = self.other_src.pre_transforms
+        pre_transforms.insert(0, image_common.ReadImage(format='RGB'))
         return pre_transforms
 
-    def _get_post_transforms_for_data(self, data):
+    def _get_post_transforms_from_config(self):
         """ get postprocess transforms """
-        if data.get('cli_flag', False):
-            output_dir = data.get("output_dir", "./")
-            return [T.SaveDetResults(output_dir), T.PrintResult()]
-        return []
+        return [
+            T.SaveDetResults(
+                save_dir=self.output_dir, labels=self.other_src.labels),
+            T.PrintResult()
+        ]

+ 65 - 42
paddlex/modules/object_detection/predictor/transforms.py

@@ -17,18 +17,17 @@ import os
 
 import numpy as np
 import math
-from PIL import Image, ImageDraw, ImageFile
+from PIL import Image, ImageDraw, ImageFont
 
-from ....utils import logging
-from ...base import BaseTransform
 from .keys import DetKeys as K
+from ...base import BaseTransform
 from ...base.predictor.io.writers import ImageWriter
 from ...base.predictor.transforms import image_functions as F
 from ...base.predictor.transforms.image_common import _BaseResize, _check_image_size
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH
+from ....utils import logging
 
-__all__ = [
-    'SaveDetResults', 'PadStride', 'DetResize', 'PrintResult', 'LoadLabels'
-]
+__all__ = ['SaveDetResults', 'PadStride', 'DetResize', 'PrintResult']
 
 
 def get_color_map_list(num_classes):
@@ -52,6 +51,39 @@ def get_color_map_list(num_classes):
     return color_map
 
 
+def colormap(rgb=False):
+    """
+    Get colormap
+
+    The code of this function is copied from https://github.com/facebookresearch/Detectron/blob/main/detectron/\
+utils/colormap.py
+    """
+    color_list = np.array([
+        0xFF, 0x00, 0x00, 0xCC, 0xFF, 0x00, 0x00, 0xFF, 0x66, 0x00, 0x66, 0xFF,
+        0xCC, 0x00, 0xFF, 0xFF, 0x4D, 0x00, 0x80, 0xff, 0x00, 0x00, 0xFF, 0xB2,
+        0x00, 0x1A, 0xFF, 0xFF, 0x00, 0xE5, 0xFF, 0x99, 0x00, 0x33, 0xFF, 0x00,
+        0x00, 0xFF, 0xFF, 0x33, 0x00, 0xFF, 0xff, 0x00, 0x99, 0xFF, 0xE5, 0x00,
+        0x00, 0xFF, 0x1A, 0x00, 0xB2, 0xFF, 0x80, 0x00, 0xFF, 0xFF, 0x00, 0x4D
+    ]).astype(np.float32)
+    color_list = (color_list.reshape((-1, 3)))
+    if not rgb:
+        color_list = color_list[:, ::-1]
+    return color_list.astype('int32')
+
+
+def font_colormap(color_index):
+    """
+    Get font color according to the index of colormap
+    """
+    dark = np.array([0x14, 0x0E, 0x35])
+    light = np.array([0xFF, 0xFF, 0xFF])
+    light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
+    if color_index in light_indexs:
+        return light.astype('int32')
+    else:
+        return dark.astype('int32')
+
+
 def draw_box(img, np_boxes, labels, threshold=0.5):
     """
     Args:
@@ -63,18 +95,26 @@ def draw_box(img, np_boxes, labels, threshold=0.5):
     Returns:
         img (PIL.Image.Image): visualized image
     """
-    draw_thickness = min(img.size) // 320
+    font_size = int(0.024 * int(img.width)) + 2
+    font = ImageFont.truetype(
+        PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
+
+    draw_thickness = int(max(img.size) * 0.005)
     draw = ImageDraw.Draw(img)
     clsid2color = {}
-    color_list = get_color_map_list(len(labels))
+    catid2fontcolor = {}
+    color_list = colormap(rgb=True)
     expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
     np_boxes = np_boxes[expect_boxes, :]
 
-    for dt in np_boxes:
+    for i, dt in enumerate(np_boxes):
         clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
         if clsid not in clsid2color:
-            clsid2color[clsid] = color_list[clsid]
+            color_index = i % len(color_list)
+            clsid2color[clsid] = color_list[color_index]
+            catid2fontcolor[clsid] = font_colormap(color_index)
         color = tuple(clsid2color[clsid])
+        font_color = tuple(catid2fontcolor[clsid])
 
         xmin, ymin, xmax, ymax = bbox
         # draw bbox
@@ -85,11 +125,18 @@ def draw_box(img, np_boxes, labels, threshold=0.5):
             fill=color)
 
         # draw label
-        text = "{} {:.4f}".format(labels[clsid], score)
-        tw, th = draw.textsize(text)
-        draw.rectangle(
-            [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
-        draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
+        text = "{} {:.2f}".format(labels[clsid], score)
+        tw, th = draw.textsize(text, font=font)
+        if ymin < th:
+            draw.rectangle(
+                [(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)
+            draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font)
+        else:
+            draw.rectangle(
+                [(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color)
+            draw.text(
+                (xmin + 2, ymin - th - 2), text, fill=font_color, font=font)
+
     return img
 
 
@@ -144,13 +191,11 @@ class SaveDetResults(BaseTransform):
 
     def apply(self, data):
         """ apply """
-        ori_path = data[K.IMAGE_PATH]
+        ori_path = data[K.IM_PATH]
         file_name = os.path.basename(ori_path)
         save_path = os.path.join(self.save_dir, file_name)
 
-        labels = data[
-            K.
-            LABELS] if self.labels is None and K.LABELS in data else self.labels
+        labels = self.labels
         image = Image.open(ori_path)
         if K.MASKS in data:
             image = draw_mask(
@@ -174,7 +219,7 @@ class SaveDetResults(BaseTransform):
     @classmethod
     def get_input_keys(cls):
         """ get input keys """
-        return [K.IMAGE_PATH, K.BOXES]
+        return [K.IM_PATH, K.BOXES]
 
     @classmethod
     def get_output_keys(cls):
@@ -308,25 +353,3 @@ class PrintResult(BaseTransform):
     def get_output_keys(cls):
         """ get output keys """
         return []
-
-
-class LoadLabels(BaseTransform):
-    def __init__(self, labels=None):
-        super().__init__()
-        self.labels = labels
-
-    def apply(self, data):
-        """ apply """
-        if self.labels:
-            data[K.LABELS] = self.labels
-        return data
-
-    @classmethod
-    def get_input_keys(cls):
-        """ get input keys """
-        return []
-
-    @classmethod
-    def get_output_keys(cls):
-        """ get output keys """
-        return [K.LABELS]

+ 16 - 23
paddlex/modules/semantic_segmentation/predictor/predictor.py

@@ -33,12 +33,14 @@ class SegPredictor(BasePredictor):
     def __init__(self,
                  model_dir,
                  kernel_option,
+                 output_dir,
                  pre_transforms=None,
                  post_transforms=None,
                  has_prob_map=False):
         super().__init__(
             model_dir=model_dir,
             kernel_option=kernel_option,
+            output_dir=output_dir,
             pre_transforms=pre_transforms,
             post_transforms=post_transforms)
         self.has_prob_map = has_prob_map
@@ -82,8 +84,8 @@ class SegPredictor(BasePredictor):
                 dict_[K.SEG_MAP] = out_map
         return pred
 
-    def _get_pre_transforms_for_data(self, data):
-        """ _get_pre_transforms_for_data """
+    def _get_pre_transforms_from_config(self):
+        """ _get_pre_transforms_from_config """
         # If `K.IMAGE` (the decoded image) is found, return a default list of
         # transformation operators for the input (if possible).
         # If `K.IMAGE` (the decoded image) is not found, `K.IM_PATH` (the image
@@ -91,26 +93,17 @@ class SegPredictor(BasePredictor):
         # transformation operators from the config file.
         # In cases where the input contains both `K.IMAGE` and `K.IM_PATH`,
         # `K.IMAGE` takes precedence over `K.IM_PATH`.
-        if K.IMAGE not in data:
-            if K.IM_PATH not in data:
-                raise KeyError(
-                    f"Key {repr(K.IM_PATH)} is required, but not found.")
-            logging.info(
-                f"Transformation operators for data preprocessing will be inferred from config file."
-            )
-            pre_transforms = self.other_src.pre_transforms
-            pre_transforms.insert(0, image_common.ReadImage())
-            pre_transforms.append(image_common.ToCHWImage())
-        else:
-            raise RuntimeError(
-                f"`{self.__class__.__name__}` does not have default transformation operators to preprocess the input. "
-                f"Please set `pre_transforms` when using the {repr(K.IMAGE)} key in input dict."
-            )
+        logging.info(
+            f"Transformation operators for data preprocessing will be inferred from config file."
+        )
+        pre_transforms = self.other_src.pre_transforms
+        pre_transforms.insert(0, image_common.ReadImage())
+        pre_transforms.append(image_common.ToCHWImage())
         return pre_transforms
 
-    def _get_post_transforms_for_data(self, data):
-        """ _get_post_transforms_for_data """
-        if data.get('cli_flag', False):
-            output_dir = data.get("output_dir", "./")
-            return [T.SaveSegResults(output_dir), T.PrintResult()]
-        return []
+    def _get_post_transforms_from_config(self):
+        """ _get_post_transforms_from_config """
+        return [
+            T.GeneratePCMap(), T.SaveSegResults(self.output_dir),
+            T.PrintResult()
+        ]

+ 13 - 23
paddlex/modules/table_recognition/predictor/predictor.py

@@ -33,12 +33,14 @@ class TableRecPredictor(BasePredictor):
     def __init__(self,
                  model_dir,
                  kernel_option,
+                 output_dir,
                  pre_transforms=None,
                  post_transforms=None,
                  table_max_len=488):
         super().__init__(
             model_dir=model_dir,
             kernel_option=kernel_option,
+            output_dir=output_dir,
             pre_transforms=pre_transforms,
             post_transforms=post_transforms)
         self.table_max_len = table_max_len
@@ -72,29 +74,17 @@ class TableRecPredictor(BasePredictor):
             dict_[K.LOC_PROB] = bbox_prob[np.newaxis, :]
         return pred
 
-    def _get_pre_transforms_for_data(self, data):
-        """ _get_pre_transforms_for_data """
-        if K.IMAGE not in data and K.IM_PATH not in data:
-            raise KeyError(
-                f"Key {repr(K.IMAGE)} or {repr(K.IM_PATH)} is required, but not found."
-            )
-        pre_transforms = []
-        if K.IMAGE not in data:
-            pre_transforms.append(image_common.ReadImage())
-        pre_transforms.append(
-            image_common.ResizeByLong(target_long_edge=self.table_max_len))
-        pre_transforms.append(
+    def _get_pre_transforms_from_config(self):
+        """ _get_pre_transforms_from_config """
+        return [
+            image_common.ReadImage(),
+            image_common.ResizeByLong(target_long_edge=self.table_max_len),
             image_common.Normalize(
-                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
-        pre_transforms.append(
-            image_common.Pad(target_size=self.table_max_len, val=0.0))
-        pre_transforms.append(image_common.ToCHWImage())
-        return pre_transforms
+                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+            image_common.Pad(target_size=self.table_max_len, val=0.0),
+            image_common.ToCHWImage()
+        ]
 
-    def _get_post_transforms_for_data(self, data):
+    def _get_post_transforms_from_config(self):
         """ get postprocess transforms """
-        post_transforms = [T.TableLabelDecode()]
-        if data.get('cli_flag', False):
-            output_dir = data.get("output_dir", "./")
-            post_transforms.append(T.SaveTableResults(output_dir))
-        return post_transforms
+        return [T.TableLabelDecode(), T.SaveTableResults(self.output_dir)]

+ 11 - 24
paddlex/modules/text_detection/predictor/predictor.py

@@ -58,28 +58,18 @@ class TextDetPredictor(BasePredictor):
 
         return pred
 
-    def _get_pre_transforms_for_data(self, data):
+    def _get_pre_transforms_from_config(self):
         """ get preprocess transforms """
-        if K.IMAGE not in data and K.IM_PATH not in data:
-            raise KeyError(
-                f"Key {repr(K.IMAGE)} or {repr(K.IM_PATH)} is required, but not found."
-            )
-        pre_transforms = []
-        if K.IMAGE not in data:
-            pre_transforms.append(image_common.ReadImage())
-        pre_transforms.append(
-            T.DetResizeForTest(
-                limit_side_len=960, limit_type="max"))
-        pre_transforms.append(
-            T.NormalizeImage(
-                mean=[0.485, 0.456, 0.406],
-                std=[0.229, 0.224, 0.225],
-                scale=1. / 255,
-                order='hwc'))
-        pre_transforms.append(image_common.ToCHWImage())
-        return pre_transforms
+        return [
+            image_common.ReadImage(), T.DetResizeForTest(
+                limit_side_len=960, limit_type="max"), T.NormalizeImage(
+                    mean=[0.485, 0.456, 0.406],
+                    std=[0.229, 0.224, 0.225],
+                    scale=1. / 255,
+                    order='hwc'), image_common.ToCHWImage()
+        ]
 
-    def _get_post_transforms_for_data(self, data):
+    def _get_post_transforms_from_config(self):
         """ get postprocess transforms """
         post_transforms = [
             T.DBPostProcess(
@@ -89,9 +79,6 @@ class TextDetPredictor(BasePredictor):
                 unclip_ratio=1.5,
                 use_dilation=False,
                 score_mode='fast',
-                box_type='quad'),
+                box_type='quad'), T.SaveTextDetResults(self.output_dir)
         ]
-        if data.get('cli_flag', False):
-            output_dir = data.get("output_dir", "./")
-            post_transforms.append(T.SaveTextDetResults(output_dir))
         return post_transforms

+ 5 - 0
paddlex/modules/text_detection/predictor/transforms.py

@@ -583,6 +583,11 @@ class SaveTextDetResults(BaseTransform):
 
     def apply(self, data):
         """ apply """
+        if self.save_dir is None:
+            logging.warning(
+                "The `save_dir` has been set to None, so the text detection result won't to be saved."
+            )
+            return data
         save_path = os.path.join(self.save_dir, self.file_name)
         bbox_res = data[K.DT_POLYS]
         vis_img = self.draw_rectangle(data[K.IM_PATH], bbox_res)

+ 10 - 18
paddlex/modules/text_recognition/predictor/predictor.py

@@ -66,24 +66,16 @@ class TextRecPredictor(BasePredictor):
             dict_[K.REC_PROBS] = probs[np.newaxis, :]
         return pred
 
-    def _get_pre_transforms_for_data(self, data):
-        """ _get_pre_transforms_for_data """
-        if K.IMAGE not in data and K.IM_PATH not in data:
-            raise KeyError(
-                f"Key {repr(K.IMAGE)} or {repr(K.IM_PATH)} is required, but not found."
-            )
-        pre_transforms = []
-        if K.IMAGE not in data:
-            pre_transforms.append(image_common.ReadImage())
-        else:
-            pre_transforms.append(image_common.GetImageInfo())
-        pre_transforms.append(T.OCRReisizeNormImg())
-        return pre_transforms
+    def _get_pre_transforms_from_config(self):
+        """ _get_pre_transforms_from_config """
+        return [
+            image_common.ReadImage(), image_common.GetImageInfo(),
+            T.OCRReisizeNormImg()
+        ]
 
-    def _get_post_transforms_for_data(self, data):
+    def _get_post_transforms_from_config(self):
         """ get postprocess transforms """
-        post_transforms = [T.CTCLabelDecode(self.other_src.PostProcess)]
-        if data.get('cli_flag', False):
-            output_dir = data.get("output_dir", "./")
-            post_transforms.append(T.PrintResult())
+        post_transforms = [
+            T.CTCLabelDecode(self.other_src.PostProcess), T.PrintResult()
+        ]
         return post_transforms

+ 2 - 4
paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/convert_dataset.py

@@ -31,8 +31,7 @@ def check_src_dataset(root_dir):
     err_msg_prefix = f"Convert Failed! Only support '.xlsx/.xls' format files."
 
     for dst_anno, src_anno in [("train.xlsx", "train.xls"),
-                               ("val.xlsx", "val.xls"),
-                               ("test.xlsx", "test.xls")]:
+                               ("val.xlsx", "val.xls")]:
         src_anno_path = os.path.join(root_dir, src_anno)
         dst_anno_path = os.path.join(root_dir, dst_anno)
         if not os.path.exists(src_anno_path) and not os.path.exists(
@@ -60,8 +59,7 @@ def convert_excel_dataset(input_dir):
     """
     # read excel file
     for dst_anno, src_anno in [("train.xlsx", "train.xls"),
-                               ("val.xlsx", "val.xls"),
-                               ("test.xlsx", "test.xls")]:
+                               ("val.xlsx", "val.xls")]:
         src_anno_path = os.path.join(input_dir, src_anno)
         dst_anno_path = os.path.join(input_dir, dst_anno)
 

+ 5 - 8
paddlex/modules/ts_anomaly_detection/dataset_checker/dataset_src/split_dataset.py

@@ -23,14 +23,14 @@ from tqdm import tqdm
 from .....utils.logging import info
 
 
-def split_dataset(root_dir, train_rate, val_rate, test_rate=0):
+def split_dataset(root_dir, train_rate, val_rate):
     """ split dataset """
-    assert train_rate + val_rate + test_rate == 100, \
-    f"The sum of train_rate({train_rate}), val_rate({val_rate}) and test_rate({test_rate}) should equal 100!"
+    assert train_rate + val_rate == 100, \
+    f"The sum of train_rate({train_rate}) and val_rate({val_rate}) should equal 100!"
     assert train_rate > 0 and val_rate > 0, \
     f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!"
 
-    tags = ['train.csv', 'val.csv', 'test.csv']
+    tags = ['train.csv', 'val.csv']
     df = pd.DataFrame()
     for tag in tags:
         if os.path.exists(osp.join(root_dir, tag)):
@@ -43,14 +43,11 @@ def split_dataset(root_dir, train_rate, val_rate, test_rate=0):
     df_len = df.shape[0]
     point_train = math.floor((df_len * train_rate / 100))
     point_val = math.floor((df_len * (train_rate + val_rate) / 100))
-    point_test = math.floor(
-        (df_len * (train_rate + val_rate + test_rate) / 100))
 
     train_df = df.iloc[:point_train, :]
     val_df = df.iloc[point_train:point_val, :]
-    test_df = df.iloc[point_val:, :]
 
-    df_dict = {'train.csv': train_df, 'val.csv': val_df, 'test.csv': test_df}
+    df_dict = {'train.csv': train_df, 'val.csv': val_df}
     for tag in df_dict.keys():
         save_path = osp.join(root_dir, tag)
         if os.path.exists(save_path):

+ 6 - 24
paddlex/modules/ts_classification/dataset_checker/dataset_src/analyse_dataset.py

@@ -28,18 +28,12 @@ from .....utils.fonts import PINGFANG_FONT_FILE_PATH
 
 def deep_analyse(dataset_dir, output_dir, label_col='label'):
     """class analysis for dataset"""
-    tags = ['train', 'val', 'test']
+    tags = ['train', 'val']
     label_unique = None
     for tag in tags:
         csv_path = os.path.abspath(os.path.join(dataset_dir, tag + '.csv'))
-        if tag == 'test' and not os.path.exists(csv_path):
-            cls_test = None
-            continue
         df = pd.read_csv(csv_path)
         if label_col not in df.columns:
-            if tag == 'test':
-                cls_test = None
-                continue
             raise ValueError(
                 f"default label_col: {label_col} not in {tag} dataset")
         if label_unique is None:
@@ -52,17 +46,13 @@ def deep_analyse(dataset_dir, output_dir, label_col='label'):
             cls_train = [label_num for label_col, label_num in cls_dict.items()]
         elif tag == 'val':
             cls_val = [label_num for label_col, label_num in cls_dict.items()]
-        else:
-            cls_test = [label_num for label_col, label_num in cls_dict.items()]
     sorted_id = sorted(
         range(len(cls_train)), key=lambda k: cls_train[k], reverse=True)
     cls_train_sorted = sorted(cls_train, reverse=True)
     cls_val_sorted = [cls_val[index] for index in sorted_id]
-    if cls_test:
-        cls_test_sorted = [cls_test[index] for index in sorted_id]
     classes_sorted = [label_unique[index] for index in sorted_id]
     x = np.arange(len(label_unique))
-    width = 0.5 if not cls_test else 0.333
+    width = 0.5
 
     # bar
     os_system = platform.system().lower()
@@ -72,18 +62,10 @@ def deep_analyse(dataset_dir, output_dir, label_col='label'):
         font = font_manager.FontProperties(fname=PINGFANG_FONT_FILE_PATH)
     fig, ax = plt.subplots(
         figsize=(max(8, int(len(label_unique) / 5)), 5), dpi=120)
-    ax.bar(x,
-           cls_train_sorted,
-           width=0.5 if not cls_test else 0.333,
-           label='train')
-    ax.bar(x + width,
-           cls_val_sorted,
-           width=0.5 if not cls_test else 0.333,
-           label='val')
-    if cls_test:
-        ax.bar(x + 2 * width, cls_test_sorted, width=0.333, label='test')
+    ax.bar(x, cls_train_sorted, width=0.5, label='train')
+    ax.bar(x + width, cls_val_sorted, width=0.5, label='val')
     plt.xticks(
-        x + width / 2 if not cls_test else x + width,
+        x + width / 2,
         classes_sorted,
         rotation=90,
         fontproperties=None if os_system == "windows" else font)
@@ -92,4 +74,4 @@ def deep_analyse(dataset_dir, output_dir, label_col='label'):
     fig.tight_layout()
     fig_path = os.path.join(output_dir, "histogram.png")
     fig.savefig(fig_path)
-    return {"histogram": "histogram.png"}
+    return {"histogram": os.path.join("check_dataset", "histogram.png")}

+ 2 - 4
paddlex/modules/ts_classification/dataset_checker/dataset_src/convert_dataset.py

@@ -30,8 +30,7 @@ def check_src_dataset(root_dir):
     err_msg_prefix = f"数据格式转换失败!当前仅支持后续为'.xlsx/.xls'格式的数据转换。"
 
     for dst_anno, src_anno in [("train.xlsx", "train.xls"),
-                               ("val.xlsx", "val.xls"),
-                               ("test.xlsx", "test.xls")]:
+                               ("val.xlsx", "val.xls")]:
         src_anno_path = os.path.join(root_dir, src_anno)
         dst_anno_path = os.path.join(root_dir, dst_anno)
         if not os.path.exists(src_anno_path) and not os.path.exists(
@@ -60,8 +59,7 @@ def convert_excel_dataset(input_dir):
 
     # read excel file
     for dst_anno, src_anno in [("train.xlsx", "train.xls"),
-                               ("val.xlsx", "val.xls"),
-                               ("test.xlsx", "test.xls")]:
+                               ("val.xlsx", "val.xls")]:
         src_anno_path = os.path.join(input_dir, src_anno)
         dst_anno_path = os.path.join(input_dir, dst_anno)
 

+ 5 - 10
paddlex/modules/ts_classification/dataset_checker/dataset_src/split_dataset.py

@@ -23,18 +23,14 @@ from tqdm import tqdm
 from .....utils.logging import info
 
 
-def split_dataset(root_dir,
-                  train_rate,
-                  val_rate,
-                  group_id='group_id',
-                  test_rate=0):
+def split_dataset(root_dir, train_rate, val_rate, group_id='group_id'):
     """ split dataset """
-    assert train_rate + val_rate + test_rate == 100, \
-    f"The sum of train_rate({train_rate}), val_rate({val_rate}) and test_rate({test_rate}) should equal 100!"
+    assert train_rate + val_rate == 100, \
+    f"The sum of train_rate({train_rate}) and val_rate({val_rate}) should equal 100!"
     assert train_rate > 0 and val_rate > 0, \
     f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!"
 
-    tags = ['train.csv', 'val.csv', 'test.csv']
+    tags = ['train.csv', 'val.csv']
     df = pd.DataFrame()
     group_unique = None
     for tag in tags:
@@ -63,8 +59,7 @@ def split_dataset(root_dir,
     group_len = len(dfs)
     point_train = math.floor((group_len * train_rate / 100))
     point_val = math.floor((group_len * (train_rate + val_rate) / 100))
-    point_test = math.floor(
-        (group_len * (train_rate + val_rate + test_rate) / 100))
+
     assert point_train > 0, f"The train_len is 0, the train_percent should be greater ."
     assert point_val - point_train > 0, f"The train_len is 0, the val_percent should be greater ."
 

+ 2 - 4
paddlex/modules/ts_forecast/dataset_checker/dataset_src/convert_dataset.py

@@ -30,8 +30,7 @@ def check_src_dataset(root_dir):
     err_msg_prefix = f"数据格式转换失败!当前仅支持后续为'.xlsx/.xls'格式的数据转换。"
 
     for dst_anno, src_anno in [("train.xlsx", "train.xls"),
-                               ("val.xlsx", "val.xls"),
-                               ("test.xlsx", "test.xls")]:
+                               ("val.xlsx", "val.xls")]:
         src_anno_path = os.path.join(root_dir, src_anno)
         dst_anno_path = os.path.join(root_dir, dst_anno)
         if not os.path.exists(src_anno_path) and not os.path.exists(
@@ -59,8 +58,7 @@ def convert_excel_dataset(input_dir):
     """
     # read excel file
     for dst_anno, src_anno in [("train.xlsx", "train.xls"),
-                               ("val.xlsx", "val.xls"),
-                               ("test.xlsx", "test.xls")]:
+                               ("val.xlsx", "val.xls")]:
         src_anno_path = os.path.join(input_dir, src_anno)
         dst_anno_path = os.path.join(input_dir, dst_anno)
 

+ 5 - 8
paddlex/modules/ts_forecast/dataset_checker/dataset_src/split_dataset.py

@@ -23,14 +23,14 @@ from tqdm import tqdm
 from .....utils.logging import info
 
 
-def split_dataset(root_dir, train_rate, val_rate, test_rate=0):
+def split_dataset(root_dir, train_rate, val_rate):
     """ split dataset """
-    assert train_rate + val_rate + test_rate == 100, \
-    f"The sum of train_rate({train_rate}), val_rate({val_rate}) and test_rate({test_rate}) should equal 100!"
+    assert train_rate + val_rate == 100, \
+    f"The sum of train_rate({train_rate}) and val_rate({val_rate}) should equal 100!"
     assert train_rate > 0 and val_rate > 0, \
     f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!"
 
-    tags = ['train.csv', 'val.csv', 'test.csv']
+    tags = ['train.csv', 'val.csv']
     df = pd.DataFrame()
     for tag in tags:
         if os.path.exists(osp.join(root_dir, tag)):
@@ -43,14 +43,11 @@ def split_dataset(root_dir, train_rate, val_rate, test_rate=0):
     df_len = df.shape[0]
     point_train = math.floor((df_len * train_rate / 100))
     point_val = math.floor((df_len * (train_rate + val_rate) / 100))
-    point_test = math.floor(
-        (df_len * (train_rate + val_rate + test_rate) / 100))
 
     train_df = df.iloc[:point_train, :]
     val_df = df.iloc[point_train:point_val, :]
-    test_df = df.iloc[point_val:, :]
 
-    df_dict = {'train.csv': train_df, 'val.csv': val_df, 'test.csv': test_df}
+    df_dict = {'train.csv': train_df, 'val.csv': val_df}
     for tag in df_dict.keys():
         save_path = osp.join(root_dir, tag)
         if os.path.exists(save_path):

+ 3 - 3
paddlex/modules/ts_forecast/predictor.py

@@ -61,7 +61,7 @@ is not exist, use default instead.")
             raise_model_not_found_error(model_dir)
         return None
 
-    def __call__(self, input=None, batch_size=1):
+    def predict(self, input=None, batch_size=1):
         """execute model predict
 
         Returns:
@@ -89,10 +89,10 @@ is not exist, use default instead.")
             "save_dir": self.global_config.output
         }
 
-    def _get_post_transforms_for_data(self):
+    def _get_post_transforms_from_config(self):
         pass
 
-    def _get_pre_transforms_for_data(self):
+    def _get_pre_transforms_from_config(self):
         pass
 
     def _run(self):

+ 3 - 3
paddlex/paddlex_cli.py

@@ -54,7 +54,7 @@ def args_cfg():
     parser.add_argument('--pipeline', type=str, help="")
     parser.add_argument('--model', nargs='+', help="")
     parser.add_argument('--input', type=str, help="")
-    parser.add_argument('--output', type=str, help="")
+    parser.add_argument('--output', type=str, default="./", help="")
     parser.add_argument('--device', type=str, default='gpu:0', help="")
 
     return parser.parse_args()
@@ -110,7 +110,7 @@ def install(args):
 
 def pipeline_predict(pipeline, model_name_list, input_path, output_dir, device):
     pipeline = build_pipeline(pipeline, model_name_list, output_dir, device)
-    pipeline.predict(input_path)
+    pipeline.predict({"input_path": input_path})
 
 
 # for CLI
@@ -121,6 +121,6 @@ def main():
     if args.install:
         install(args)
     else:
-        print_info()
+        # print_info()
         return pipeline_predict(args.pipeline, args.model, args.input,
                                 args.output, args.device)

+ 22 - 13
paddlex/pipelines/PPOCR/pipeline.py

@@ -25,7 +25,7 @@ from .utils import draw_ocr_box_txt
 class OCRPipeline(BasePipeline):
     """OCR Pipeline
     """
-    support_models = "ocr"
+    support_models = "PP-OCRv4"
 
     def __init__(self,
                  text_det_model_name=None,
@@ -35,18 +35,27 @@ class OCRPipeline(BasePipeline):
                  text_det_kernel_option=None,
                  text_rec_kernel_option=None,
                  output_dir=None,
+                 device="gpu",
                  **kwargs):
         self.text_det_model_name = text_det_model_name
         self.text_rec_model_name = text_rec_model_name
         self.text_det_model_dir = text_det_model_dir
         self.text_rec_model_dir = text_rec_model_dir
         self.output_dir = output_dir
+        self.device = device
         self.text_det_kernel_option = self.get_kernel_option(
         ) if text_det_kernel_option is None else text_det_kernel_option
         self.text_rec_kernel_option = self.get_kernel_option(
         ) if text_rec_kernel_option is None else text_rec_kernel_option
 
-        self.text_det_post_transforms = [
+        if self.text_det_model_name is not None and self.text_rec_model_name is not None:
+            self.load_model()
+
+    def load_model(self):
+        """load model predictor
+        """
+        assert self.text_det_model_name is not None and self.text_rec_model_name is not None
+        text_det_post_transforms = [
             text_det_T.DBPostProcess(
                 thresh=0.3,
                 box_thresh=0.6,
@@ -58,27 +67,21 @@ class OCRPipeline(BasePipeline):
             # TODO
             text_det_T.CropByPolys(det_box_type="foo")
         ]
-        if self.text_det_model_name is not None and self.text_rec_model_name is not None:
-            self.load_model()
 
-    def load_model(self):
-        """load model predictor
-        """
-        assert self.text_det_model_name is not None and self.text_rec_model_name is not None
         self.text_det_model = create_model(
             self.text_det_model_name,
             self.text_det_model_dir,
             kernel_option=self.text_det_kernel_option,
-            post_transforms=self.text_det_post_transforms)
+            post_transforms=text_det_post_transforms)
         self.text_rec_model = create_model(
             self.text_rec_model_name,
             self.text_rec_model_dir,
             kernel_option=self.text_rec_kernel_option)
 
-    def predict(self, input_path):
+    def predict(self, input):
         """predict
         """
-        result = self.text_det_model.predict({"input_path": input_path})
+        result = self.text_det_model.predict(input)
         all_rec_result = []
         for i, img in enumerate(result["sub_imgs"]):
             rec_result = self.text_rec_model.predict({"image": img})
@@ -88,8 +91,9 @@ class OCRPipeline(BasePipeline):
         if self.output_dir is not None:
             draw_img = draw_ocr_box_txt(result['original_image'],
                                         result['dt_polys'], result["rec_text"])
+            fn = os.path.basename(result['input_path'])
             cv2.imwrite(
-                os.path.join(self.output_dir, "ocr_result.jpg"),
+                os.path.join(self.output_dir, fn),
                 draw_img[:, :, ::-1], )
 
         return result
@@ -108,5 +112,10 @@ class OCRPipeline(BasePipeline):
         """get kernel option
         """
         kernel_option = PaddleInferenceOption()
-        kernel_option.set_device("gpu")
+        kernel_option.set_device(self.device)
         return kernel_option
+
+    def get_input_keys(self):
+        """get dict keys of input argument input
+        """
+        return self.text_det_model.get_input_keys()

+ 6 - 0
paddlex/pipelines/base/pipeline.py

@@ -60,3 +60,9 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
             model_list (list): list of model name.
         """
         raise NotImplementedError
+
+    @abstractmethod
+    def get_input_keys(self):
+        """get dict keys of input argument input
+        """
+        raise NotImplementedError

+ 15 - 12
paddlex/pipelines/image_classification/pipeline.py

@@ -26,42 +26,40 @@ class ClsPipeline(BasePipeline):
     def __init__(self,
                  model_name=None,
                  model_dir=None,
+                 output_dir=None,
                  kernel_option=None,
+                 device="gpu",
                  **kwargs):
         super().__init__()
         self.model_name = model_name
         self.model_dir = model_dir
+        self.output_dir = output_dir
+        self.device = device
         self.kernel_option = self.get_kernel_option(
         ) if kernel_option is None else kernel_option
-        self.post_transforms = self.get_post_transforms()
         if self.model_name is not None:
             self.load_model()
 
-    def predict(self, input_path):
+    def predict(self, input):
         """predict
         """
-        return self.model.predict({"input_path": input_path})
+        return self.model.predict(input)
 
     def load_model(self):
         """load model predictor
         """
         assert self.model_name is not None
         self.model = create_model(
-            self.model_name,
+            model_name=self.model_name,
             model_dir=self.model_dir,
-            kernel_option=self.kernel_option,
-            post_transforms=self.post_transforms)
-
-    def get_post_transforms(self):
-        """get post transform ops
-        """
-        return [T.Topk(topk=1), T.PrintResult()]
+            output_dir=self.output_dir,
+            kernel_option=self.kernel_option)
 
     def get_kernel_option(self):
         """get kernel option
         """
         kernel_option = PaddleInferenceOption()
-        kernel_option.set_device("gpu")
+        kernel_option.set_device(self.device)
         return kernel_option
 
     def update_model_name(self, model_name_list):
@@ -72,3 +70,8 @@ class ClsPipeline(BasePipeline):
         """
         assert len(model_name_list) == 1
         self.model_name = model_name_list[0]
+
+    def get_input_keys(self):
+        """get dict keys of input argument input
+        """
+        return self.model.get_input_keys()

+ 14 - 13
paddlex/pipelines/instance_segmentation/pipeline.py

@@ -28,11 +28,12 @@ class InstanceSegPipeline(BasePipeline):
                  model_dir=None,
                  output_dir="./output",
                  kernel_option=None,
+                 device="gpu",
                  **kwargs):
         self.model_name = model_name
         self.model_dir = model_dir
         self.output_dir = output_dir
-        self.post_transforms = self.get_post_transforms(model_dir)
+        self.device = device
         self.kernel_option = self.get_kernel_option(
         ) if kernel_option is None else kernel_option
         if self.model_name is not None:
@@ -43,26 +44,21 @@ class InstanceSegPipeline(BasePipeline):
         """
         assert self.model_name is not None
         self.model = create_model(
-            self.model_name,
+            model_name=self.model_name,
             model_dir=self.model_dir,
-            kernel_option=self.kernel_option,
-            post_transforms=self.post_transforms)
+            output_dir=self.output_dir,
+            kernel_option=self.kernel_option)
 
-    def predict(self, input_path):
+    def predict(self, input):
         """predict
         """
-        return self.model.predict({"input_path": input_path})
-
-    def get_post_transforms(self, model_dir):
-        """get post transform ops
-        """
-        return [T.SaveDetResults(self.output_dir), T.PrintResult()]
+        return self.model.predict(input)
 
     def get_kernel_option(self):
         """get kernel option
         """
         kernel_option = PaddleInferenceOption()
-        kernel_option.set_device("gpu")
+        kernel_option.set_device(self.device)
         return kernel_option
 
     def update_model_name(self, model_name_list):
@@ -72,4 +68,9 @@ class InstanceSegPipeline(BasePipeline):
             model_list (list): list of model name.
         """
         assert len(model_name_list) == 1
-        self.model_name = model_name_list[0]
+        self.model_name = model_name_list[0]
+
+    def get_input_keys(self):
+        """get dict keys of input argument input
+        """
+        return self.model.get_input_keys()

+ 13 - 12
paddlex/pipelines/object_detection/pipeline.py

@@ -28,11 +28,12 @@ class DetPipeline(BasePipeline):
                  model_dir=None,
                  output_dir="./output",
                  kernel_option=None,
+                 device="gpu",
                  **kwargs):
         self.model_name = model_name
         self.model_dir = model_dir
         self.output_dir = output_dir
-        self.post_transforms = self.get_post_transforms(model_dir)
+        self.device = device
         self.kernel_option = self.get_kernel_option(
         ) if kernel_option is None else kernel_option
         if self.model_name is not None:
@@ -43,26 +44,21 @@ class DetPipeline(BasePipeline):
         """
         assert self.model_name is not None
         self.model = create_model(
-            self.model_name,
+            model_name=self.model_name,
             model_dir=self.model_dir,
-            kernel_option=self.kernel_option,
-            post_transforms=self.post_transforms)
+            output_dir=self.output_dir,
+            kernel_option=self.kernel_option)
 
-    def predict(self, input_path):
+    def predict(self, input):
         """predict
         """
-        return self.model.predict({"input_path": input_path})
-
-    def get_post_transforms(self, model_dir):
-        """get post transform ops
-        """
-        return [T.SaveDetResults(self.output_dir), T.PrintResult()]
+        return self.model.predict(input)
 
     def get_kernel_option(self):
         """get kernel option
         """
         kernel_option = PaddleInferenceOption()
-        kernel_option.set_device("gpu")
+        kernel_option.set_device(self.device)
 
     def update_model_name(self, model_name_list):
         """update model name and re
@@ -72,3 +68,8 @@ class DetPipeline(BasePipeline):
         """
         assert len(model_name_list) == 1
         self.model_name = model_name_list[0]
+
+    def get_input_keys(self):
+        """get dict keys of input argument input
+        """
+        return self.model.get_input_keys()

+ 13 - 12
paddlex/pipelines/semantic_segmentation/pipeline.py

@@ -28,11 +28,12 @@ class SegPipeline(BasePipeline):
                  model_dir=None,
                  output_dir="./output",
                  kernel_option=None,
+                 device="gpu",
                  **kwargs):
         self.model_name = model_name
         self.model_dir = model_dir
         self.output_dir = output_dir
-        self.post_transforms = self.get_post_transforms()
+        self.device = device
         self.kernel_option = self.get_kernel_option(
         ) if kernel_option is None else kernel_option
         if self.model_name is not None:
@@ -43,26 +44,21 @@ class SegPipeline(BasePipeline):
         """
         assert self.model_name is not None
         self.model = create_model(
-            self.model_name,
+            model_name=self.model_name,
             model_dir=self.model_dir,
-            kernel_option=self.kernel_option,
-            post_transforms=self.post_transforms)
+            output_dir=self.output_dir,
+            kernel_option=self.kernel_option)
 
-    def predict(self, input_path):
+    def predict(self, input):
         """predict
         """
-        return self.model.predict({"input_path": input_path})
-
-    def get_post_transforms(self):
-        """get post transform ops
-        """
-        return [T.SaveSegResults(self.output_dir), T.PrintResult()]
+        return self.model.predict(input)
 
     def get_kernel_option(self):
         """get kernel option
         """
         kernel_option = PaddleInferenceOption()
-        kernel_option.set_device("gpu")
+        kernel_option.set_device(self.device)
         return kernel_option
 
     def update_model_name(self, model_name_list):
@@ -73,3 +69,8 @@ class SegPipeline(BasePipeline):
         """
         assert len(model_name_list) == 1
         self.model_name = model_name_list[0]
+
+    def get_input_keys(self):
+        """get dict keys of input argument input
+        """
+        return self.model.get_input_keys()

+ 3 - 0
paddlex/repo_apis/PaddleOCR_api/text_rec/config.py

@@ -202,6 +202,9 @@ class TextRecConfig(BaseConfig):
             },
             'mlu': {
                 'Global.use_mlu': True
+            },
+            'npu': {
+                'Global.use_npu': True
             }
         }
         default_cfg.update(device_cfg[device])

+ 1 - 1
paddlex/utils/config.py

@@ -198,7 +198,7 @@ def get_config(fname, overrides=None, show=False):
 
 def parse_args():
     """ parse args """
-    parser = argparse.ArgumentParser("generic-image-rec train script")
+    parser = argparse.ArgumentParser("PaddleX script")
     parser.add_argument(
         '-c',
         '--config',

+ 2 - 2
paddlex/utils/device.py

@@ -27,8 +27,8 @@ def get_device(device_cfg, using_device_number=None):
     assert device.lower() in SUPPORTED_DEVICE_TYPE
     if device.lower() in ["gpu", "xpu", "npu", "mlu"]:
         if device.lower() == "npu":
-            os.environ["FLAGS_npu_jit_compile"] = 0
-            os.environ["FLAGS_use_stride_kernel"] = 0
+            os.environ["FLAGS_npu_jit_compile"] = "0"
+            os.environ["FLAGS_use_stride_kernel"] = "0"
             os.environ["FLAGS_allocator_strategy"] = "auto_growth"
         elif device.lower() == "mlu":
             os.environ["CUSTOM_DEVICE_BLACK_LIST"] = "set_value"

+ 2 - 2
paddlex/utils/errors/others.py

@@ -112,8 +112,8 @@ def raise_class_not_found_error(cls_name, base_cls, all_entities=None):
     base_cls_name = base_cls.__name__
     msg = f"`{cls_name}` is not registered on {base_cls_name}."
     if all_entities is not None:
-        all_entities_str = ",".join(all_entities)
-        msg += f"\nThe registied entities:`[{all_entities_str}]`"
+        all_entities_str = ",  ".join(all_entities)
+        msg += f"\nThe registied entities: [{all_entities_str}]"
     raise ClassNotFoundException(msg)