Browse Source

latexocr_paddlex (#1911)

* latexocr_paddlex

* add_tokenizer in requirement

* delete_print

* add offical model of latexocr
liuhongen1234567 1 year ago
parent
commit
6948cc7466

+ 6 - 2
README.md

@@ -89,8 +89,8 @@ PaddleX 3.0 覆盖了 16 条产业级模型产线,其中 9 条基础产线可
     <td>Mask-RT-DETR-L<br/>Mask-RT-DETR-H</td>
     <td>Mask-RT-DETR-L<br/>Mask-RT-DETR-H</td>
   </tr>
   </tr>
   <tr>
   <tr>
-    <td rowspan="2">基础产线</td>
-    <td rowspan="2">通用OCR</td>
+    <td rowspan="3">基础产线</td>
+    <td rowspan="3">通用OCR</td>
     <td>文本检测</td>
     <td>文本检测</td>
     <td>PP-OCRv4_mobile_det<br/>PP-OCRv4_server_det</td>
     <td>PP-OCRv4_mobile_det<br/>PP-OCRv4_server_det</td>
   </tr>
   </tr>
@@ -99,6 +99,10 @@ PaddleX 3.0 覆盖了 16 条产业级模型产线,其中 9 条基础产线可
     <td>PP-OCRv4_mobile_rec<br/>PP-OCRv4_server_rec</td>
     <td>PP-OCRv4_mobile_rec<br/>PP-OCRv4_server_rec</td>
   </tr>
   </tr>
   <tr>
   <tr>
+    <td>公式识别</td>
+    <td>LaTeX_OCR_rec</td>
+  </tr>
+  <tr>
     <td rowspan="4">基础产线</td>
     <td rowspan="4">基础产线</td>
     <td rowspan="4">通用表格识别</td>
     <td rowspan="4">通用表格识别</td>
     <td>版面区域检测</td>
     <td>版面区域检测</td>

+ 110 - 15
docs/tutorials/data/dataset_check.md

@@ -474,10 +474,105 @@ python main.py -c paddlex/configs/text_recognition/PP-OCRv4_mobile_rec.yaml \
 
 
 数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
 数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
 
 
-## 7. 表格识别任务模块数据校验
+## 7. 公式识别任务模块数据校验
 
 
 ### 7.1 数据准备
 ### 7.1 数据准备
 
 
+您需要按照 PaddleX 支持的数据格式要求准备数据,关于数据标注,您可以参考[PaddleX 数据标注](./annotation/README.md),关于数据格式介绍,您可以参考[PaddleX 数据格式介绍](./dataset_format.md),此处我们准备了公式识别 Demo 数据供您使用。
+
+```bash
+cd /path/to/paddlex
+wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/ocr_rec_latexocr_dataset_example.tar -P ./dataset
+tar -xf ./dataset/ocr_rec_latexocr_dataset_example.tar -C ./dataset/
+```
+
+### 7.2 数据集校验
+
+在对数据集校验时,只需一行命令:
+
+```bash
+python main.py -c paddlex/configs/text_recognition/LaTeX_OCR_rec.yml \
+    -o Global.mode=check_dataset \
+    -o Global.dataset_dir=./dataset/ocr_rec_latexocr_dataset_example \
+    -o CheckDataset.convert.enable=True \
+    -o CheckDataset.convert.src_dataset_type=PKL
+```
+
+执行上述命令后,PaddleX 会对数据集进行校验,并统计数据集的基本信息。命令运行成功后会在log中打印出 `Check dataset passed !` 信息,同时相关产出会保存在当前目录的 `./output/check_dataset` 目录下,产出目录中包括可视化的示例样本图片和样本分布直方图。校验结果文件保存在 `./output/check_dataset_result.json`,校验结果文件具体内容为
+```
+{
+  "done_flag": true,
+  "check_pass": true,
+  "attributes": {
+    "train_samples": 10001,
+    "train_sample_paths": [
+      "../dataset/ocr_rec_latexocr_dataset_example/train/0077809.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/train/0161600.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/train/0002077.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/train/0178425.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/train/0010959.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/train/0079266.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/train/0142495.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/train/0196376.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/train/0185513.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/train/0217146.png"
+    ],
+    "val_samples": 501,
+    "val_sample_paths": [
+      "../dataset/ocr_rec_latexocr_dataset_example/val/0053264.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/val/0100521.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/val/0146333.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/val/0072788.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/val/0002022.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/val/0203664.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/val/0082217.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/val/0208199.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/val/0111236.png",
+      "../dataset/ocr_rec_latexocr_dataset_example/val/0204453.png"
+    ]
+  },
+  "analysis": {
+    "histogram": "check_dataset/histogram.png"
+  },
+  "dataset_path": "./dataset/ocr_rec_latexocr_dataset_example",
+  "show_type": "image",
+  "dataset_type": "MSTextRecDataset"
+}
+```
+上述校验结果中,check_pass 为 True 表示数据集格式符合要求,其他部分指标的说明如下:
+
+- attributes.train_samples:该数据集训练集样本数量为 10001;
+- attributes.val_samples:该数据集验证集样本数量为 501;
+- attributes.train_sample_paths:该数据集训练集样本可视化图片相对路径列表;
+- attributes.val_sample_paths:该数据集验证集样本可视化图片相对路径列表;
+
+另外,数据集校验还对数据集中所有字符长度占比的分布情况进行了分析,并绘制了分布直方图(histogram.png):
+![样本分布直方图](https://github.com/user-attachments/assets/256b1084-ef52-4cf7-87d3-9367f410235b)
+
+**注**:只有通过数据校验的数据才可以训练和评估。
+
+
+### 7.3 数据集格式转换/数据集划分(非必选)
+
+如需对数据集格式进行转换或是重新划分数据集,可通过修改配置文件或是追加超参数的方式进行设置。
+
+数据集校验相关的参数可以通过修改配置文件中 `CheckDataset` 下的字段进行设置,配置文件中部分参数的示例说明如下:
+
+* `CheckDataset`:
+    * `convert`:
+        * `enable`: 是否进行数据集格式转换,为 `True` 时进行数据集格式转换,默认为 `False`;
+        * `src_dataset_type`: 如果进行数据集格式转换,则需设置源数据集格式,数据可选源格式为 `PKL`, 默认为 `null`;
+    * `split`:
+        * `enable`: 是否进行重新划分数据集,公式识别不支持数据集重新划分,默认为 `False`;
+        * `train_percent`: 如果重新划分数据集,则需要设置训练集的百分比,类型为0-100之间的任意整数,需要保证和 `val_percent` 值加和为100;
+        * `val_percent`: 如果重新划分数据集,则需要设置验证集的百分比,类型为0-100之间的任意整数,需要保证和 `train_percent` 值加和为100;
+
+数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
+
+## 8. 表格识别任务模块数据校验
+
+### 8.1 数据准备
+
 您需要按照 PaddleX 支持的数据格式要求准备数据,关于数据标注,您可以参考[PaddleX 数据标注](./annotation/README.md),关于数据格式介绍,您可以参考[PaddleX 数据格式介绍](./dataset_format.md),此处我们准备了表格识别 Demo 数据供您使用。
 您需要按照 PaddleX 支持的数据格式要求准备数据,关于数据标注,您可以参考[PaddleX 数据标注](./annotation/README.md),关于数据格式介绍,您可以参考[PaddleX 数据格式介绍](./dataset_format.md),此处我们准备了表格识别 Demo 数据供您使用。
 
 
 ```bash
 ```bash
@@ -486,7 +581,7 @@ wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/table_rec_dataset_e
 tar -xf ./dataset/table_rec_dataset_examples.tar -C ./dataset/
 tar -xf ./dataset/table_rec_dataset_examples.tar -C ./dataset/
 ```
 ```
 
 
-### 7.2 数据集校验
+### 8.2 数据集校验
 
 
 在对数据集校验时,只需一行命令:
 在对数据集校验时,只需一行命令:
 
 
@@ -529,7 +624,7 @@ python main.py -c paddlex/configs/table_recognition/SLANet.yaml \
 **注**:只有通过数据校验的数据才可以训练和评估。
 **注**:只有通过数据校验的数据才可以训练和评估。
 
 
 
 
-### 7.3 数据集格式转换/数据集划分(非必选)
+### 8.3 数据集格式转换/数据集划分(非必选)
 
 
 如需对数据集格式进行转换或是重新划分数据集,可通过修改配置文件或是追加超参数的方式进行设置。
 如需对数据集格式进行转换或是重新划分数据集,可通过修改配置文件或是追加超参数的方式进行设置。
 
 
@@ -546,9 +641,9 @@ python main.py -c paddlex/configs/table_recognition/SLANet.yaml \
 
 
 数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
 数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
 
 
-## 8. 时序预测任务模块数据校验
+## 9. 时序预测任务模块数据校验
 
 
-### 8.1 数据准备
+### 9.1 数据准备
 
 
 您需要按照 PaddleX 支持的数据格式要求准备数据,关于数据格式介绍,您可以参考[PaddleX 数据格式介绍](./dataset_format.md),此处我们准备了时序预测 Demo 数据供您使用。
 您需要按照 PaddleX 支持的数据格式要求准备数据,关于数据格式介绍,您可以参考[PaddleX 数据格式介绍](./dataset_format.md),此处我们准备了时序预测 Demo 数据供您使用。
 
 
@@ -558,7 +653,7 @@ wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/ts_dataset_examples
 tar -xf ./dataset/ts_dataset_examples.tar -C ./dataset/
 tar -xf ./dataset/ts_dataset_examples.tar -C ./dataset/
 ```
 ```
 
 
-### 8.2 数据集校验
+### 9.2 数据集校验
 
 
 在对数据集校验时,只需一行命令:
 在对数据集校验时,只需一行命令:
 
 
@@ -659,7 +754,7 @@ python main.py -c paddlex/configs/ts_forecast/DLinear.yaml \
 **注**:只有通过数据校验的数据才可以训练和评估。
 **注**:只有通过数据校验的数据才可以训练和评估。
 
 
 
 
-### 8.3 数据集格式转换/数据集划分(非必选)
+### 9.3 数据集格式转换/数据集划分(非必选)
 
 
 如需对数据集格式进行转换或是重新划分数据集,可通过修改配置文件或是追加超参数的方式进行设置。
 如需对数据集格式进行转换或是重新划分数据集,可通过修改配置文件或是追加超参数的方式进行设置。
 
 
@@ -676,9 +771,9 @@ python main.py -c paddlex/configs/ts_forecast/DLinear.yaml \
 
 
 数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
 数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
 
 
-## 9. 时序异常检测任务模块数据校验
+## 10. 时序异常检测任务模块数据校验
 
 
-### 9.1 数据准备
+### 10.1 数据准备
 
 
 您需要按照 PaddleX 支持的数据格式要求准备数据,关于数据格式介绍,您可以参考[PaddleX 数据格式介绍](./dataset_format.md),此处我们准备了时序异常检测 Demo 数据供您使用。
 您需要按照 PaddleX 支持的数据格式要求准备数据,关于数据格式介绍,您可以参考[PaddleX 数据格式介绍](./dataset_format.md),此处我们准备了时序异常检测 Demo 数据供您使用。
 
 
@@ -688,7 +783,7 @@ wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/ts_anomaly_examples
 tar -xf ./dataset/ts_anomaly_examples.tar -C ./dataset/
 tar -xf ./dataset/ts_anomaly_examples.tar -C ./dataset/
 ```
 ```
 
 
-### 9.2 数据集校验
+### 10.2 数据集校验
 
 
 在对数据集校验时,只需一行命令:
 在对数据集校验时,只需一行命令:
 
 
@@ -757,7 +852,7 @@ python main.py -c paddlex/configs/ts_anomaly_detection/DLinear_ad.yaml \
 **注**:只有通过数据校验的数据才可以训练和评估。
 **注**:只有通过数据校验的数据才可以训练和评估。
 
 
 
 
-### 9.3 数据集格式转换/数据集划分(非必选)
+### 10.3 数据集格式转换/数据集划分(非必选)
 
 
 如需对数据集格式进行转换或是重新划分数据集,可通过修改配置文件或是追加超参数的方式进行设置。
 如需对数据集格式进行转换或是重新划分数据集,可通过修改配置文件或是追加超参数的方式进行设置。
 
 
@@ -774,9 +869,9 @@ python main.py -c paddlex/configs/ts_anomaly_detection/DLinear_ad.yaml \
 
 
 数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
 数据转换和数据划分支持同时开启,对于数据划分原有标注文件会被在原路径下重命名为 `xxx.bak`,以上参数同样支持通过追加命令行参数的方式进行设置,例如重新划分数据集并设置训练集与验证集比例:`-o CheckDataset.split.enable=True -o CheckDataset.split.train_percent=80 -o CheckDataset.split.val_percent=20`。
 
 
-## 10. 时序分类任务模块数据校验
+## 11. 时序分类任务模块数据校验
 
 
-### 10.1 数据准备
+### 11.1 数据准备
 
 
 您需要按照 PaddleX 支持的数据格式要求准备数据,关于数据格式介绍,您可以参考[PaddleX 数据格式介绍](./dataset_format.md),此处我们准备了时序分类 Demo 数据供您使用。
 您需要按照 PaddleX 支持的数据格式要求准备数据,关于数据格式介绍,您可以参考[PaddleX 数据格式介绍](./dataset_format.md),此处我们准备了时序分类 Demo 数据供您使用。
 
 
@@ -786,7 +881,7 @@ wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/ts_classify_example
 tar -xf ./dataset/ts_classify_examples.tar -C ./dataset/
 tar -xf ./dataset/ts_classify_examples.tar -C ./dataset/
 ```
 ```
 
 
-### 10.2 数据集校验
+### 11.2 数据集校验
 
 
 在对数据集校验时,只需一行命令:
 在对数据集校验时,只需一行命令:
 
 
@@ -867,7 +962,7 @@ python main.py -c paddlex/configs/ts_classify_examples/DLinear_ad.yaml \
 **注**:只有通过数据校验的数据才可以训练和评估。
 **注**:只有通过数据校验的数据才可以训练和评估。
 
 
 
 
-### 10.3 数据集格式转换/数据集划分(非必选)
+### 11.3 数据集格式转换/数据集划分(非必选)
 
 
 如需对数据集格式进行转换或是重新划分数据集,可通过修改配置文件或是追加超参数的方式进行设置。
 如需对数据集格式进行转换或是重新划分数据集,可通过修改配置文件或是追加超参数的方式进行设置。
 
 

+ 8 - 4
docs/tutorials/models/support_model_list.md

@@ -187,11 +187,15 @@
 | :--- | :---: |
 | :--- | :---: |
 | PP-OCRv4_server_rec | [PP-OCRv4_server_rec.yaml](../../../paddlex/configs/text_recognition/PP-OCRv4_server_rec.yaml)|
 | PP-OCRv4_server_rec | [PP-OCRv4_server_rec.yaml](../../../paddlex/configs/text_recognition/PP-OCRv4_server_rec.yaml)|
 | PP-OCRv4_mobile_rec | [PP-OCRv4_mobile_rec.yaml](../../../paddlex/configs/text_recognition/PP-OCRv4_mobile_rec.yaml)|
 | PP-OCRv4_mobile_rec | [PP-OCRv4_mobile_rec.yaml](../../../paddlex/configs/text_recognition/PP-OCRv4_mobile_rec.yaml)|
-## 八、版面分析
+## 八、公式识别
+| 模型名称 | config |
+| :--- | :---: |
+| LaTeX_OCR_rec | [LaTeX_OCR_rec.yml](../../../paddlex/configs/text_recognition/LaTeX_OCR_rec.yml)|
+## 九、版面分析
 | 模型名称 | config |
 | 模型名称 | config |
 | :--- | :---: |
 | :--- | :---: |
 | PicoDet_layout_1x | [PicoDet_layout_1x.yaml](../../../paddlex/configs/structure_analysis/PicoDet_layout_1x.yaml)|
 | PicoDet_layout_1x | [PicoDet_layout_1x.yaml](../../../paddlex/configs/structure_analysis/PicoDet_layout_1x.yaml)|
-## 九、时序异常检测
+## 、时序异常检测
 | 模型名称 | config |
 | 模型名称 | config |
 | :--- | :---: |
 | :--- | :---: |
 | DLinear_ad | [DLinear_ad.yaml](../../../paddlex/configs/ts_anomaly_detection/DLinear_ad.yaml)|
 | DLinear_ad | [DLinear_ad.yaml](../../../paddlex/configs/ts_anomaly_detection/DLinear_ad.yaml)|
@@ -199,11 +203,11 @@
 | TimesNet_ad | [TimesNet_ad.yaml](../../../paddlex/configs/ts_anomaly_detection/TimesNet_ad.yaml)|
 | TimesNet_ad | [TimesNet_ad.yaml](../../../paddlex/configs/ts_anomaly_detection/TimesNet_ad.yaml)|
 | AutoEncoder_ad | [AutoEncoder_ad.yaml](../../../paddlex/configs/ts_anomaly_detection/AutoEncoder_ad.yaml)|
 | AutoEncoder_ad | [AutoEncoder_ad.yaml](../../../paddlex/configs/ts_anomaly_detection/AutoEncoder_ad.yaml)|
 | Nonstationary_ad | [Nonstationary_ad.yaml](../../../paddlex/configs/ts_anomaly_detection/Nonstationary_ad.yaml)|
 | Nonstationary_ad | [Nonstationary_ad.yaml](../../../paddlex/configs/ts_anomaly_detection/Nonstationary_ad.yaml)|
-## 十、时序分类
+## 十、时序分类
 | 模型名称 | config |
 | 模型名称 | config |
 | :--- | :---: |
 | :--- | :---: |
 | TimesNet_cls | [TimesNet_cls.yaml](../../../paddlex/configs/ts_classification/TimesNet_cls.yaml)|
 | TimesNet_cls | [TimesNet_cls.yaml](../../../paddlex/configs/ts_classification/TimesNet_cls.yaml)|
-## 十、时序预测
+## 十、时序预测
 | 模型名称 | config |
 | 模型名称 | config |
 | :--- | :---: |
 | :--- | :---: |
 | DLinear | [DLinear.yaml](../../../paddlex/configs/ts_forecast/DLinear.yaml)|
 | DLinear | [DLinear.yaml](../../../paddlex/configs/ts_forecast/DLinear.yaml)|

+ 37 - 0
paddlex/configs/text_recognition/LaTeX_OCR_rec.yml

@@ -0,0 +1,37 @@
+Global:
+  model: LaTeX_OCR_rec
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "./dataset/ocr_rec_latexocr_dataset_example"
+  device: gpu:0
+  output: "output"
+
+CheckDataset:
+  convert: 
+    enable: False
+    src_dataset_type: null
+  split: 
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 2
+  batch_size_train: 40
+  batch_size_val: 10
+  learning_rate: 0.0001
+  pretrain_weight_path: null
+  resume_path: null
+  log_interval: 20
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_accuracy.pdparams"
+  log_interval: 1
+
+Predict:
+  model_dir: "output/best_accuracy"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_formula_rec_001.png"
+  kernel_option:
+    run_mode: paddle
+    batch_size: 1

+ 0 - 1
paddlex/modules/base/dataset_checker/dataset_checker.py

@@ -74,7 +74,6 @@ class BaseDatasetChecker(ABC, metaclass=AutoRegisterABCMetaClass):
 
 
         attrs = self.check_dataset(dataset_dir)
         attrs = self.check_dataset(dataset_dir)
         analysis = self.analyse(dataset_dir)
         analysis = self.analyse(dataset_dir)
-
         check_result = build_res_dict(True)
         check_result = build_res_dict(True)
         check_result["attributes"] = attrs
         check_result["attributes"] = attrs
         check_result["analysis"] = analysis
         check_result["analysis"] = analysis

+ 2 - 0
paddlex/modules/base/predictor/utils/official_models.py

@@ -276,6 +276,8 @@ openatom_rec_svtrv2_ch_infer.tar",
     "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/PicoDet-L_layout_infer.tar",
     "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/PicoDet-L_layout_infer.tar",
     "SLANet":
     "SLANet":
     "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/SLANet_infer.tar",
     "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/SLANet_infer.tar",
+    "LaTeX_OCR_rec":
+    "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/LaTeX_OCR_rec_infer.tar",
 }
 }
 
 
 
 

+ 13 - 5
paddlex/modules/text_recognition/dataset_checker/__init__.py

@@ -21,7 +21,7 @@ from PIL import Image
 import json
 import json
 
 
 from ...base import BaseDatasetChecker
 from ...base import BaseDatasetChecker
-from .dataset_src import check, split_dataset, deep_analyse
+from .dataset_src import check, split_dataset, deep_analyse, convert
 
 
 from ..model_list import MODELS
 from ..model_list import MODELS
 
 
@@ -41,7 +41,8 @@ class TextRecDatasetChecker(BaseDatasetChecker):
         Returns:
         Returns:
             str: the root directory of converted dataset.
             str: the root directory of converted dataset.
         """
         """
-        return src_dataset_dir
+        return convert(self.check_dataset_config.convert.src_dataset_type,
+                       src_dataset_dir)
 
 
     def split_dataset(self, src_dataset_dir: str) -> str:
     def split_dataset(self, src_dataset_dir: str) -> str:
         """repartition the train and validation dataset
         """repartition the train and validation dataset
@@ -66,7 +67,7 @@ class TextRecDatasetChecker(BaseDatasetChecker):
         Returns:
         Returns:
             dict: dataset summary.
             dict: dataset summary.
         """
         """
-        return check(dataset_dir, self.global_config.output, sample_num=10)
+        return check(dataset_dir, self.global_config.output, sample_num=10, dataset_type=self.get_dataset_type())
 
 
     def analyse(self, dataset_dir: str) -> dict:
     def analyse(self, dataset_dir: str) -> dict:
         """deep analyse dataset
         """deep analyse dataset
@@ -77,7 +78,11 @@ class TextRecDatasetChecker(BaseDatasetChecker):
         Returns:
         Returns:
             dict: the deep analysis results.
             dict: the deep analysis results.
         """
         """
-        return deep_analyse(dataset_dir, self.output)
+        if self.global_config['model'] in ['LaTeX_OCR_rec']:
+            datatype = "LaTeXOCRDataset"
+        else:
+            datatype = "MSTextRecDataset"
+        return deep_analyse(dataset_dir, self.output, datatype=datatype)
 
 
     def get_show_type(self) -> str:
     def get_show_type(self) -> str:
         """get the show type of dataset
         """get the show type of dataset
@@ -93,4 +98,7 @@ class TextRecDatasetChecker(BaseDatasetChecker):
         Returns:
         Returns:
             str: dataset type
             str: dataset type
         """
         """
-        return "MSTextRecDataset"
+        if self.global_config['model'] in ['LaTeX_OCR_rec']:
+            return "LaTeXOCRDataset"
+        else:
+            return "MSTextRecDataset"

+ 1 - 0
paddlex/modules/text_recognition/dataset_checker/dataset_src/__init__.py

@@ -14,5 +14,6 @@
 
 
 
 
 from .check_dataset import check
 from .check_dataset import check
+from .convert_dataset import convert
 from .split_dataset import split_dataset
 from .split_dataset import split_dataset
 from .analyse_dataset import deep_analyse
 from .analyse_dataset import deep_analyse

+ 19 - 6
paddlex/modules/text_recognition/dataset_checker/dataset_src/analyse_dataset.py

@@ -72,7 +72,7 @@ def simple_analyse(dataset_path, images_dict):
             img_paths[tags[2]])
             img_paths[tags[2]])
 
 
 
 
-def deep_analyse(dataset_path, output):
+def deep_analyse(dataset_path, output, datatype = "MSTextRecDataset"):
     """class analysis for dataset"""
     """class analysis for dataset"""
     tags = ['train', 'val']
     tags = ['train', 'val']
     all_instances = 0
     all_instances = 0
@@ -90,10 +90,15 @@ def deep_analyse(dataset_path, output):
                 warning(f"Error in {line}.")
                 warning(f"Error in {line}.")
                 continue
                 continue
             str_nums.append(len(line[1]))
             str_nums.append(len(line[1]))
-        max_length = min(100, max(str_nums))
+        if datatype == "LaTeXOCRDataset":
+            max_length = min(768, max(str_nums))
+            interval = 20
+        else:
+            max_length = min(100, max(str_nums))
+            interval = 5
         start = 0
         start = 0
-        for i in range(1, math.ceil((max_length / 5))):
-            stop = i * 5
+        for i in range(1, math.ceil((max_length / interval))):
+            stop = i * interval
             num_str = sum(start < i <= stop for i in str_nums)
             num_str = sum(start < i <= stop for i in str_nums)
             labels_cnt[f'{start}-{stop}'] = num_str
             labels_cnt[f'{start}-{stop}'] = num_str
             start = stop
             start = stop
@@ -126,12 +131,18 @@ def deep_analyse(dataset_path, output):
     else:
     else:
         font = font_manager.FontProperties(
         font = font_manager.FontProperties(
             fname=PINGFANG_FONT_FILE_PATH, size=15)
             fname=PINGFANG_FONT_FILE_PATH, size=15)
-    fig, ax = plt.subplots(figsize=(10, 5), dpi=120)
+    if datatype == "LaTeXOCRDataset":
+        fig, ax = plt.subplots(figsize=(15, 9), dpi=120)
+        xlabel_name = '公式长度区间'
+    else:
+        fig, ax = plt.subplots(figsize=(10, 5), dpi=120)
+        xlabel_name = '文本字长度区间'    
     ax.bar(x_train, cnts_train, width=0.3, label='train')
     ax.bar(x_train, cnts_train, width=0.3, label='train')
     ax.bar(x_val + width, cnts_val, width=0.3, label='val')
     ax.bar(x_val + width, cnts_val, width=0.3, label='val')
     plt.xticks(x_max + width / 2, classes_max, rotation=90)
     plt.xticks(x_max + width / 2, classes_max, rotation=90)
+    plt.legend(prop = {'size':18})
     ax.set_xlabel(
     ax.set_xlabel(
-        '文本字长度区间',
+        xlabel_name,
         fontproperties=None if os_system == "windows" else font,
         fontproperties=None if os_system == "windows" else font,
         fontsize=12)
         fontsize=12)
     ax.set_ylabel(
     ax.set_ylabel(
@@ -149,3 +160,5 @@ def deep_analyse(dataset_path, output):
     cv2.imwrite(fig1_path, pie_array)
     cv2.imwrite(fig1_path, pie_array)
 
 
     return {"histogram": os.path.join("check_dataset", "histogram.png")}
     return {"histogram": os.path.join("check_dataset", "histogram.png")}
+
+

+ 18 - 10
paddlex/modules/text_recognition/dataset_checker/dataset_src/check_dataset.py

@@ -31,24 +31,28 @@ def check(dataset_dir,
           mode='fast',
           mode='fast',
           sample_num=10):
           sample_num=10):
     """ check dataset """
     """ check dataset """
-    # dataset_dir = osp.abspath(dataset_dir)
-    if dataset_type == 'SimpleDataSet' or 'MSTextRecDataset':
+    if dataset_type == 'SimpleDataSet' or 'MSTextRecDataset' or 'LaTeXOCRDataset':
         # Custom dataset
         # Custom dataset
         if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
         if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
             raise DatasetFileNotFoundError(file_path=dataset_dir)
             raise DatasetFileNotFoundError(file_path=dataset_dir)
-
         tags = ['train', 'val']
         tags = ['train', 'val']
         delim = '\t'
         delim = '\t'
         valid_num_parts = 2
         valid_num_parts = 2
         max_recorded_sample_cnts = 50
         max_recorded_sample_cnts = 50
         sample_cnts = dict()
         sample_cnts = dict()
         sample_paths = defaultdict(list)
         sample_paths = defaultdict(list)
-
-        dict_file = osp.join(dataset_dir, 'dict.txt')
-        if not osp.exists(dict_file):
-            raise DatasetFileNotFoundError(
-                file_path=dict_file,
-                solution=f"Ensure that `dict.txt` exist in {dataset_dir}")
+        if dataset_type == 'LaTeXOCRDataset':
+            dict_file = osp.join(dataset_dir, 'latex_ocr_tokenizer.json')
+            if not osp.exists(dict_file):
+                raise DatasetFileNotFoundError(
+                    file_path=dict_file,
+                    solution=f"Ensure that `latex_ocr_tokenizer.json` exist in {dataset_dir}")
+        else:
+            dict_file = osp.join(dataset_dir, 'dict.txt')
+            if not osp.exists(dict_file):
+                raise DatasetFileNotFoundError(
+                    file_path=dict_file,
+                    solution=f"Ensure that `dict.txt` exist in {dataset_dir}")
         for tag in tags:
         for tag in tags:
             file_list = osp.join(dataset_dir, f'{tag}.txt')
             file_list = osp.join(dataset_dir, f'{tag}.txt')
             if not osp.exists(file_list):
             if not osp.exists(file_list):
@@ -76,7 +80,10 @@ def check(dataset_dir,
                                 "in {file_list} should be {valid_num_parts} (current delimiter is '{delim}')."
                                 "in {file_list} should be {valid_num_parts} (current delimiter is '{delim}')."
                             )
                             )
                         file_name = substr[0]
                         file_name = substr[0]
-                        img_path = osp.join(dataset_dir, file_name)
+                        if dataset_type == 'LaTeXOCRDataset':
+                           img_path = osp.join(dataset_dir, tag, file_name)                        
+                        else:
+                            img_path = osp.join(dataset_dir, file_name)
                         if len(sample_paths[tag]) < max_recorded_sample_cnts:
                         if len(sample_paths[tag]) < max_recorded_sample_cnts:
                             sample_paths[tag].append(
                             sample_paths[tag].append(
                                 os.path.relpath(img_path, output))
                                 os.path.relpath(img_path, output))
@@ -94,3 +101,4 @@ def check(dataset_dir,
         # meta['dict_file'] = dict_file
         # meta['dict_file'] = dict_file
 
 
         return meta
         return meta
+

+ 88 - 0
paddlex/modules/text_recognition/dataset_checker/dataset_src/convert_dataset.py

@@ -0,0 +1,88 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# 
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import shutil
+import json
+import random
+import math
+import pickle
+from tqdm import tqdm
+from collections import defaultdict
+from paddle.utils import try_import
+from .....utils.errors import ConvertFailedError
+from .....utils.logging import info, warning
+
+
+def check_src_dataset(root_dir, dataset_type):
+    """ check src dataset format validity """
+    if dataset_type in ("PKL"):
+        anno_suffix = ".pkl"
+    else:
+        raise ConvertFailedError(
+            message=f"数据格式转换失败!不支持{dataset_type}格式数据集。当前仅支持 PKL 格式。"
+        )
+
+    err_msg_prefix = f"数据格式转换失败!请参考上述`{dataset_type}格式数据集示例`检查待转换数据集格式。"
+
+    for anno in ['train.txt','val.txt','latex_ocr_tokenizer.json']:
+        src_anno_path = os.path.join(root_dir, anno)
+        if not os.path.exists(src_anno_path):
+            raise ConvertFailedError(
+                    message=f"{err_msg_prefix}保证{src_anno_path}文件存在。")
+    return None
+
+def convert(dataset_type, input_dir):
+    """ convert dataset to pkl format """
+    # check format validity
+    check_src_dataset(input_dir, dataset_type)
+    if dataset_type in ("PKL"):
+        convert_pkl_dataset(input_dir) 
+    else:
+        raise ConvertFailedError(message=f"数据格式转换失败!不支持{dataset_type}格式数据集。当前仅支持 PKL 格式。")
+
+def convert_pkl_dataset(root_dir):
+    for anno in ['train.txt','val.txt']:
+        src_img_dir = os.path.join(root_dir, anno.replace(".txt",""))
+        src_anno_path = os.path.join(root_dir, anno)
+        txt2pickle(src_img_dir, src_anno_path, root_dir)
+
+def txt2pickle(images, equations, save_dir):
+    imagesize = try_import("imagesize")
+    save_p = os.path.join(save_dir, "latexocr_{}.pkl".format(images.split("/")[-1]))
+    min_dimensions = (32, 32)
+    max_dimensions = (672, 192)
+    max_length = 512
+    data = defaultdict(lambda: [])
+    pic_num = 0
+    if images is not None and equations is not None:
+        with open(equations, "r") as f:
+            lines = f.readlines()
+            for l in tqdm(lines, total = len(lines)):
+                l = l.strip()
+                img_name, equation = l.split("\t")
+                img_path = os.path.join( os.path.abspath(images), img_name)
+                width, height = imagesize.get(img_path)
+                if (
+                    min_dimensions[0] <= width <= max_dimensions[0]
+                    and min_dimensions[1] <= height <= max_dimensions[1]
+                ):
+                    divide_h = math.ceil(height / 16) * 16
+                    divide_w = math.ceil(width / 16) * 16
+                    data[(divide_w, divide_h)].append((equation, img_path))
+                    pic_num +=1
+        data = dict(data)
+        with open(save_p, "wb") as file:
+            pickle.dump(data, file)

+ 6 - 3
paddlex/modules/text_recognition/evaluator.py

@@ -28,9 +28,12 @@ class TextRecEvaluator(BaseEvaluator):
         """
         """
         if self.eval_config.log_interval:
         if self.eval_config.log_interval:
             self.pdx_config.update_log_interval(self.eval_config.log_interval)
             self.pdx_config.update_log_interval(self.eval_config.log_interval)
-
-        self.pdx_config.update_dataset(self.global_config.dataset_dir,
-                                       "MSTextRecDataset")
+        if self.global_config['model']=='LaTeX_OCR_rec':   
+            self.pdx_config.update_dataset(self.global_config.dataset_dir,
+                                        "LaTeXOCRDataSet")
+        else:
+            self.pdx_config.update_dataset(self.global_config.dataset_dir,
+                                        "MSTextRecDataset")
         label_dict_path = None
         label_dict_path = None
         if self.eval_config.get("label_dict_path"):
         if self.eval_config.get("label_dict_path"):
             label_dict_path = self.eval_config.label_dict_path
             label_dict_path = self.eval_config.label_dict_path

+ 1 - 0
paddlex/modules/text_recognition/model_list.py

@@ -17,4 +17,5 @@ MODELS = [
     'PP-OCRv4_server_rec',
     'PP-OCRv4_server_rec',
     'SVTRv2_server_rec',
     'SVTRv2_server_rec',
     'RepSVTR_mobile_rec',
     'RepSVTR_mobile_rec',
+    'LaTeX_OCR_rec',
 ]
 ]

+ 18 - 7
paddlex/modules/text_recognition/predictor/predictor.py

@@ -68,14 +68,25 @@ class TextRecPredictor(BasePredictor):
 
 
     def _get_pre_transforms_from_config(self):
     def _get_pre_transforms_from_config(self):
         """ _get_pre_transforms_from_config """
         """ _get_pre_transforms_from_config """
-        return [
-            image_common.ReadImage(), image_common.GetImageInfo(),
-            T.OCRReisizeNormImg()
-        ]
+        if self.model_name == 'LaTeX_OCR_rec': 
+            return [
+                image_common.ReadImage(), image_common.GetImageInfo(),
+                T.LaTeXOCRReisizeNormImg()
+            ]
+        else:
+            return [
+                image_common.ReadImage(), image_common.GetImageInfo(),
+                T.OCRReisizeNormImg()
+            ]
 
 
     def _get_post_transforms_from_config(self):
     def _get_post_transforms_from_config(self):
         """ get postprocess transforms """
         """ get postprocess transforms """
-        post_transforms = [
-            T.CTCLabelDecode(self.other_src.PostProcess), T.PrintResult()
-        ]
+        if self.model_name =='LaTeX_OCR_rec': 
+            post_transforms = [
+                T.LaTeXOCRDecode(self.other_src.PostProcess), T.PrintResult()
+            ]
+        else:
+            post_transforms = [
+                T.CTCLabelDecode(self.other_src.PostProcess), T.PrintResult()
+            ]
         return post_transforms
         return post_transforms

+ 167 - 2
paddlex/modules/text_recognition/predictor/transforms.py

@@ -22,13 +22,15 @@ from PIL import Image
 import cv2
 import cv2
 import math
 import math
 import paddle
 import paddle
-
+import json
+import tempfile
+from tokenizers import Tokenizer as TokenizerFast
 from ....utils import logging
 from ....utils import logging
 from ...base.predictor import BaseTransform
 from ...base.predictor import BaseTransform
 from ...base.predictor.io.writers import TextWriter
 from ...base.predictor.io.writers import TextWriter
 from .keys import TextRecKeys as K
 from .keys import TextRecKeys as K
 
 
-__all__ = ['OCRReisizeNormImg', 'CTCLabelDecode', 'SaveTextRecResults']
+__all__ = ['OCRReisizeNormImg', 'LaTeXOCRReisizeNormImg', 'CTCLabelDecode', 'LaTeXOCRDecode', 'SaveTextRecResults']
 
 
 
 
 class OCRReisizeNormImg(BaseTransform):
 class OCRReisizeNormImg(BaseTransform):
@@ -80,6 +82,113 @@ class OCRReisizeNormImg(BaseTransform):
         return [K.IMAGE]
         return [K.IMAGE]
 
 
 
 
+class LaTeXOCRReisizeNormImg(BaseTransform):
+    """ for ocr image resize and normalization """
+
+    def __init__(self, rec_image_shape=[3, 48, 320]):
+        super().__init__()
+        self.rec_image_shape = rec_image_shape
+
+    def pad_(self, img, divable=32):
+        threshold = 128
+        data = np.array(img.convert("LA"))
+        if data[..., -1].var() == 0:
+            data = (data[..., 0]).astype(np.uint8)
+        else:
+            data = (255 - data[..., -1]).astype(np.uint8)
+        data = (data - data.min()) / (data.max() - data.min()) * 255
+        if data.mean() > threshold:
+            # To invert the text to white
+            gray = 255 * (data < threshold).astype(np.uint8)
+        else:
+            gray = 255 * (data > threshold).astype(np.uint8)
+            data = 255 - data
+
+        coords = cv2.findNonZero(gray)  # Find all non-zero points (text)
+        a, b, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
+        rect = data[b: b + h, a: a + w]
+        im = Image.fromarray(rect).convert("L")
+        dims = []
+        for x in [w, h]:
+            div, mod = divmod(x, divable)
+            dims.append(divable * (div + (1 if mod > 0 else 0)))
+        padded = Image.new("L", dims, 255)
+        padded.paste(im, (0, 0, im.size[0], im.size[1]))
+        return padded
+
+    def minmax_size_(
+            self,
+            img,
+            max_dimensions,
+            min_dimensions,
+    ):
+        if max_dimensions is not None:
+            ratios = [a / b for a, b in zip(img.size, max_dimensions)]
+            if any([r > 1 for r in ratios]):
+                size = np.array(img.size) // max(ratios)
+                img = img.resize(tuple(size.astype(int)), Image.BILINEAR)
+        if min_dimensions is not None:
+            # hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions
+            padded_size = [
+                max(img_dim, min_dim)
+                for img_dim, min_dim in zip(img.size, min_dimensions)
+            ]
+            if padded_size != list(img.size):  # assert hypothesis
+                padded_im = Image.new("L", padded_size, 255)
+                padded_im.paste(img, img.getbbox())
+                img = padded_im
+        return img
+
+    def norm_img_latexocr(self, img):
+        # CAN only predict gray scale image
+        shape = (1, 1, 3)
+        mean = [0.7931, 0.7931, 0.7931]
+        std = [0.1738, 0.1738, 0.1738]
+        scale = 255.0
+        min_dimensions = [32, 32]
+        max_dimensions = [672, 192]
+        mean = np.array(mean).reshape(shape).astype("float32")
+        std = np.array(std).reshape(shape).astype("float32")
+
+        im_h, im_w = img.shape[:2]
+        if (
+                min_dimensions[0] <= im_w <= max_dimensions[0]
+                and min_dimensions[1] <= im_h <= max_dimensions[1]
+        ):
+            pass
+        else:
+            img = Image.fromarray(np.uint8(img))
+            img = self.minmax_size_(self.pad_(img), max_dimensions, min_dimensions)
+            img = np.array(img)
+            im_h, im_w = img.shape[:2]
+            img = np.dstack([img, img, img])
+        img = (img.astype("float32") * scale - mean) / std
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+        divide_h = math.ceil(im_h / 16) * 16
+        divide_w = math.ceil(im_w / 16) * 16
+        img = np.pad(
+            img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
+        )
+        img = img[:, :, np.newaxis].transpose(2, 0, 1)
+        img = img.astype("float32")
+        return img
+
+    def apply(self, data):
+        """ apply """
+        data[K.IMAGE] = self.norm_img_latexocr(data[K.IMAGE])
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """ get input keys """
+        return [K.IMAGE, K.ORI_IM_SIZE]
+
+    @classmethod
+    def get_output_keys(cls):
+        """ get output keys """
+        return [K.IMAGE]
+
+
 class BaseRecLabelDecode(BaseTransform):
 class BaseRecLabelDecode(BaseTransform):
     """ Convert between text-label and text-index """
     """ Convert between text-label and text-index """
 
 
@@ -219,6 +328,62 @@ class CTCLabelDecode(BaseRecLabelDecode):
         return [K.REC_TEXT, K.REC_SCORE]
         return [K.REC_TEXT, K.REC_SCORE]
 
 
 
 
+class LaTeXOCRDecode(object):
+    """Convert between latex-symbol and symbol-index"""
+
+    def __init__(self, post_process_cfg=None, **kwargs):
+        assert post_process_cfg['name'] == 'LaTeXOCRDecode'
+
+        super(LaTeXOCRDecode, self).__init__()
+        character_list = post_process_cfg['character_dict']
+        temp_path = tempfile.gettempdir()
+        rec_char_dict_path = os.path.join(temp_path, "latexocr_tokenizer.json")
+        try:
+            with open(rec_char_dict_path, "w") as f:
+                json.dump(character_list, f)
+        except Exception as e:
+            print(f'创建 latexocr_tokenizer.json 文件失败, 原因{str(e)}')
+        self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
+
+    def post_process(self, s):
+        text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
+        letter = "[a-zA-Z]"
+        noletter = "[\W_^\d]"
+        names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
+        s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
+        news = s
+        while True:
+            s = news
+            news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
+            news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
+            news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
+            if news == s:
+                break
+        return s
+
+    def decode(self, tokens):
+        if len(tokens.shape) == 1:
+            tokens = tokens[None, :]
+
+        dec = [self.tokenizer.decode(tok) for tok in tokens]
+        dec_str_list = [
+            "".join(detok.split(" "))
+            .replace("Ġ", " ")
+            .replace("[EOS]", "")
+            .replace("[BOS]", "")
+            .replace("[PAD]", "")
+            .strip()
+            for detok in dec
+        ]
+        return [str(self.post_process(dec_str)) for dec_str in dec_str_list]
+
+    def __call__(self, data):
+        preds = data[K.REC_PROBS]
+        text = self.decode(preds)
+        data[K.REC_TEXT] = text[0]
+        return data
+
+
 class SaveTextRecResults(BaseTransform):
 class SaveTextRecResults(BaseTransform):
     """ SaveTextRecResults """
     """ SaveTextRecResults """
     _TEXT_REC_RES_SUFFIX = '_text_rec'
     _TEXT_REC_RES_SUFFIX = '_text_rec'

+ 20 - 6
paddlex/modules/text_recognition/trainer.py

@@ -59,9 +59,14 @@ class TextRecTrainer(BaseTrainer):
         if self.train_config.save_interval:
         if self.train_config.save_interval:
             self.pdx_config.update_save_interval(
             self.pdx_config.update_save_interval(
                 self.train_config.save_interval)
                 self.train_config.save_interval)
-
-        self.pdx_config.update_dataset(self.global_config.dataset_dir,
-                                       "MSTextRecDataset")
+        
+        if self.global_config['model']=='LaTeX_OCR_rec':        
+            self.pdx_config.update_dataset(self.global_config.dataset_dir,
+                                        "LaTeXOCRDataSet")
+        else:
+            self.pdx_config.update_dataset(self.global_config.dataset_dir,
+                                        "MSTextRecDataset")
+        
         label_dict_path = Path(self.global_config.dataset_dir).joinpath(
         label_dict_path = Path(self.global_config.dataset_dir).joinpath(
             "dict.txt")
             "dict.txt")
         if label_dict_path.exists():
         if label_dict_path.exists():
@@ -71,8 +76,14 @@ class TextRecTrainer(BaseTrainer):
         if self.train_config.pretrain_weight_path:
         if self.train_config.pretrain_weight_path:
             self.pdx_config.update_pretrained_weights(
             self.pdx_config.update_pretrained_weights(
                 self.train_config.pretrain_weight_path)
                 self.train_config.pretrain_weight_path)
-        if self.train_config.batch_size is not None:
-            self.pdx_config.update_batch_size(self.train_config.batch_size)
+        
+        if self.global_config['model']=='LaTeX_OCR_rec':
+            if self.train_config.batch_size_train is not None and self.train_config.batch_size_val:
+                self.pdx_config.update_batch_size_pair(self.train_config.batch_size_train, self.train_config.batch_size_val)
+        else:
+            if self.train_config.batch_size is not None:
+                self.pdx_config.update_batch_size(self.train_config.batch_size)
+        
         if self.train_config.learning_rate is not None:
         if self.train_config.learning_rate is not None:
             self.pdx_config.update_learning_rate(
             self.pdx_config.update_learning_rate(
                 self.train_config.learning_rate)
                 self.train_config.learning_rate)
@@ -126,7 +137,10 @@ class TextRecTrainDeamon(BaseTrainDeamon):
         """ get the score by pdstates file """
         """ get the score by pdstates file """
         if not Path(pdstates_path).exists():
         if not Path(pdstates_path).exists():
             return 0
             return 0
-        return paddle.load(pdstates_path)['best_model_dict']['acc']
+        if self.global_config['model'] == 'LaTeX_OCR_rec': 
+            return paddle.load(pdstates_path)['best_model_dict']['exp_rate']       
+        else:
+            return paddle.load(pdstates_path)['best_model_dict']['acc']
 
 
     def get_epoch_id_by_pdparams_prefix(self, pdparams_prefix):
     def get_epoch_id_by_pdparams_prefix(self, pdparams_prefix):
         """ get the epoch_id by pdparams file """
         """ get the epoch_id by pdparams file """

+ 126 - 0
paddlex/repo_apis/PaddleOCR_api/configs/LaTeX_OCR_rec.yml

@@ -0,0 +1,126 @@
+Global:
+  use_gpu: True
+  epoch_num: 500
+  log_smooth_window: 20
+  print_batch_step: 100
+  save_model_dir: ./output/rec/latex_ocr/
+  save_epoch_step: 5
+  max_seq_len: 512
+  # evaluation is run every 60000 iterations (22 epoch)(batch_size = 56)
+  eval_batch_step: [0, 60000]
+  cal_metric_during_train: True
+  pretrained_model: https://paddle-model-ecology.bj.bcebos.com/pretrained/rec_latex_ocr_trained.pdparams
+  checkpoints:
+  save_inference_dir:
+  use_visualdl: False
+  infer_img: doc/datasets/pme_demo/0000013.png
+  infer_mode: False
+  use_space_char: False
+  rec_char_dict_path:  ppocr/utils/dict/latex_ocr_tokenizer.json
+  save_res_path: ./output/rec/predicts_latexocr.txt
+
+Optimizer:
+  name: AdamW
+  beta1: 0.9
+  beta2: 0.999
+  lr:
+    name: Const
+    learning_rate: 0.0001
+
+Architecture:
+  model_type: rec
+  algorithm: LaTeXOCR
+  in_channels: 1
+  Transform:
+  Backbone:
+    name: HybridTransformer
+    img_size: [192, 672]
+    patch_size: 16
+    num_classes: 0
+    embed_dim: 256
+    depth: 4
+    num_heads: 8
+    input_channel: 1
+    is_predict: False
+    is_export: False
+  Head:
+    name: LaTeXOCRHead
+    pad_value: 0
+    is_export: False
+    decoder_args:
+      attn_on_attn: True
+      cross_attend: True
+      ff_glu: True
+      rel_pos_bias: False
+      use_scalenorm: False
+
+Loss:
+  name: LaTeXOCRLoss
+
+PostProcess:
+  name: LaTeXOCRDecode
+  rec_char_dict_path: ppocr/utils/dict/latex_ocr_tokenizer.json
+
+Metric:
+  name: LaTeXOCRMetric
+  main_indicator:  exp_rate
+  cal_blue_score: False
+
+Train:
+  dataset:
+    name: LaTeXOCRDataSet
+    data: ./train_data/LaTeXOCR/latexocr_train.pkl
+    min_dimensions: [32, 32]
+    max_dimensions: [672, 192]
+    batch_size_per_pair: 40
+    keep_smaller_batches: False
+    transforms:
+      - DecodeImage:
+          channel_first: False
+      - MinMaxResize:
+          min_dimensions: [32, 32]
+          max_dimensions: [672, 192]        
+      - LatexTrainTransform:
+          bitmap_prob: .04
+      - NormalizeImage:
+          mean: [0.7931, 0.7931, 0.7931]
+          std: [0.1738, 0.1738, 0.1738]
+          order: 'hwc'
+      - LatexImageFormat:
+      - KeepKeys:
+          keep_keys: ['image']
+  loader:
+    shuffle: True
+    batch_size_per_card: 1
+    drop_last: False
+    num_workers: 0
+    collate_fn: LaTeXOCRCollator
+
+Eval:
+  dataset:
+    name: LaTeXOCRDataSet
+    data: ./train_data/LaTeXOCR/latexocr_val.pkl
+    min_dimensions: [32, 32]
+    max_dimensions: [672, 192]
+    batch_size_per_pair: 10
+    keep_smaller_batches: True
+    transforms:
+      - DecodeImage:
+          channel_first: False
+      - MinMaxResize:
+          min_dimensions: [32, 32]
+          max_dimensions: [672, 192]  
+      - LatexTestTransform:
+      - NormalizeImage:
+          mean: [0.7931, 0.7931, 0.7931]
+          std: [0.1738, 0.1738, 0.1738]
+          order: 'hwc'
+      - LatexImageFormat:
+      - KeepKeys:
+          keep_keys: ['image']
+  loader:
+    shuffle: False
+    drop_last: False
+    batch_size_per_card: 1
+    num_workers: 0
+    collate_fn: LaTeXOCRCollator

+ 33 - 1
paddlex/repo_apis/PaddleOCR_api/text_rec/config.py

@@ -79,7 +79,7 @@ class TextRecConfig(BaseConfig):
             train_list_path = f"{train_list_path}"
             train_list_path = f"{train_list_path}"
         else:
         else:
             train_list_path = os.path.join(dataset_path, 'train.txt')
             train_list_path = os.path.join(dataset_path, 'train.txt')
-        if dataset_type == 'TextRecDataset' or "MSTextRecDataset":
+        if (dataset_type == 'TextRecDataset') or (dataset_type=="MSTextRecDataset"):
             _cfg = {
             _cfg = {
                 'Train.dataset.name': dataset_type,
                 'Train.dataset.name': dataset_type,
                 'Train.dataset.data_dir': dataset_path,
                 'Train.dataset.data_dir': dataset_path,
@@ -92,6 +92,21 @@ class TextRecConfig(BaseConfig):
                 os.path.join(dataset_path, 'dict.txt')
                 os.path.join(dataset_path, 'dict.txt')
             }
             }
             self.update(_cfg)
             self.update(_cfg)
+        elif dataset_type == "LaTeXOCRDataSet":
+            _cfg = {
+                    'Train.dataset.name': dataset_type,
+                    'Train.dataset.data_dir': dataset_path,
+                    'Train.dataset.data': os.path.join(dataset_path, "latexocr_train.pkl"),
+                    'Train.dataset.label_file_list': [train_list_path],
+                    'Eval.dataset.name': dataset_type,
+                    'Eval.dataset.data_dir': dataset_path,
+                    'Eval.dataset.data': os.path.join(dataset_path, "latexocr_val.pkl"),
+                    'Eval.dataset.label_file_list':
+                    [os.path.join(dataset_path, 'val.txt')],
+                    'Global.character_dict_path':
+                    os.path.join(dataset_path, 'dict.txt')
+                }
+            self.update(_cfg)
         else:
         else:
             raise ValueError(f"{repr(dataset_type)} is not supported.")
             raise ValueError(f"{repr(dataset_type)} is not supported.")
 
 
@@ -114,6 +129,23 @@ class TextRecConfig(BaseConfig):
             _cfg['Train.sampler.first_bs'] = batch_size
             _cfg['Train.sampler.first_bs'] = batch_size
         self.update(_cfg)
         self.update(_cfg)
 
 
+    def update_batch_size_pair(self, batch_size_train: int, batch_size_val: int, mode: str='train'):
+        """update batch size setting
+        Args:
+            batch_size (int): the batch size number to set.
+            mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
+                Defaults to 'train'.
+        Raises:
+            ValueError: mode error.
+        """
+        _cfg = {
+            'Train.dataset.batch_size_per_pair': batch_size_train,
+            'Eval.dataset.batch_size_per_pair': batch_size_val,
+        }
+        # if "sampler" in self.dict['Train']:
+        #     _cfg['Train.sampler.first_bs'] = 1
+        self.update(_cfg)
+
     def update_learning_rate(self, learning_rate: float):
     def update_learning_rate(self, learning_rate: float):
         """update learning rate
         """update learning rate
 
 

+ 7 - 0
paddlex/repo_apis/PaddleOCR_api/text_rec/register.py

@@ -58,3 +58,10 @@ register_model_info({
     'config_path': osp.join(PDX_CONFIG_DIR, 'RepSVTR_mobile_rec.yaml'),
     'config_path': osp.join(PDX_CONFIG_DIR, 'RepSVTR_mobile_rec.yaml'),
     'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
     'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
 })
 })
+
+register_model_info({
+    'model_name': 'LaTeX_OCR_rec',
+    'suite': 'TextRec',
+    'config_path': osp.join(PDX_CONFIG_DIR, 'LaTeX_OCR_rec.yml'),
+    'supported_apis': ['train', 'evaluate', 'predict', 'export', 'infer']
+})

+ 1 - 1
paddlex/repo_manager/meta.py

@@ -63,7 +63,7 @@ REPO_META = {
         'lib_name': 'paddleocr',
         'lib_name': 'paddleocr',
         'pdx_pkg_name': 'PaddleOCR_api',
         'pdx_pkg_name': 'PaddleOCR_api',
         'editable': False,
         'editable': False,
-        'extra_req_files': ['ppstructure/kie/requirements.txt'],
+        'extra_req_files': ['ppstructure/kie/requirements.txt', 'docs/algorithm/formula_recognition/requirements.txt'],
         'path_env': 'PADDLE_PDX_PADDLEOCR_PATH',
         'path_env': 'PADDLE_PDX_PADDLEOCR_PATH',
         'requires': ['PaddleNLP'],
         'requires': ['PaddleNLP'],
     },
     },

+ 2 - 1
requirements.txt

@@ -12,4 +12,5 @@ pyclipper
 shapely
 shapely
 pandas
 pandas
 parsley
 parsley
-requests
+requests
+tokenizers==0.19.1