Bladeren bron

Add formula model1 (#2729)

* add formula model

* rename yaml

* repair formula visual bug
liuhongen1234567 10 maanden geleden
bovenliggende
commit
f3e97f130f
36 gewijzigde bestanden met toevoegingen van 3952 en 23 verwijderingen
  1. 1 1
      paddlex/configs/modules/formula_recognition/LaTeX_OCR_rec.yaml
  2. 40 0
      paddlex/configs/modules/formula_recognition/PP-FormulaNet-L.yaml
  3. 40 0
      paddlex/configs/modules/formula_recognition/PP-FormulaNet-S.yaml
  4. 40 0
      paddlex/configs/modules/formula_recognition/UniMERNet.yaml
  5. 2 1
      paddlex/inference/models_new/__init__.py
  6. 15 0
      paddlex/inference/models_new/formula_recognition/__init__.py
  7. 158 0
      paddlex/inference/models_new/formula_recognition/predictor.py
  8. 986 0
      paddlex/inference/models_new/formula_recognition/processors.py
  9. 317 0
      paddlex/inference/models_new/formula_recognition/result.py
  10. 3 0
      paddlex/inference/utils/official_models.py
  11. 5 0
      paddlex/modules/formula_recognition/__init__.py
  12. 98 0
      paddlex/modules/formula_recognition/dataset_checker/__init__.py
  13. 19 0
      paddlex/modules/formula_recognition/dataset_checker/dataset_src/__init__.py
  14. 157 0
      paddlex/modules/formula_recognition/dataset_checker/dataset_src/analyse_dataset.py
  15. 80 0
      paddlex/modules/formula_recognition/dataset_checker/dataset_src/check_dataset.py
  16. 94 0
      paddlex/modules/formula_recognition/dataset_checker/dataset_src/convert_dataset.py
  17. 81 0
      paddlex/modules/formula_recognition/dataset_checker/dataset_src/split_dataset.py
  18. 64 0
      paddlex/modules/formula_recognition/evaluator.py
  19. 22 0
      paddlex/modules/formula_recognition/exportor.py
  20. 3 0
      paddlex/modules/formula_recognition/model_list.py
  21. 111 0
      paddlex/modules/formula_recognition/trainer.py
  22. 0 3
      paddlex/modules/text_recognition/dataset_checker/__init__.py
  23. 0 3
      paddlex/modules/text_recognition/evaluator.py
  24. 0 3
      paddlex/modules/text_recognition/exportor.py
  25. 0 3
      paddlex/modules/text_recognition/trainer.py
  26. 1 0
      paddlex/repo_apis/PaddleOCR_api/__init__.py
  27. 1 0
      paddlex/repo_apis/PaddleOCR_api/configs/LaTeX_OCR_rec.yml
  28. 117 0
      paddlex/repo_apis/PaddleOCR_api/configs/PP-FormulaNet-L.yaml
  29. 115 0
      paddlex/repo_apis/PaddleOCR_api/configs/PP-FormulaNet-S.yaml
  30. 113 0
      paddlex/repo_apis/PaddleOCR_api/configs/UniMERNet.yaml
  31. 16 0
      paddlex/repo_apis/PaddleOCR_api/formula_rec/__init__.py
  32. 544 0
      paddlex/repo_apis/PaddleOCR_api/formula_rec/config.py
  33. 396 0
      paddlex/repo_apis/PaddleOCR_api/formula_rec/model.py
  34. 73 0
      paddlex/repo_apis/PaddleOCR_api/formula_rec/register.py
  35. 240 0
      paddlex/repo_apis/PaddleOCR_api/formula_rec/runner.py
  36. 0 9
      paddlex/repo_apis/PaddleOCR_api/text_rec/register.py

+ 1 - 1
paddlex/configs/modules/formula_recognition/LaTeX_OCR_rec.yaml

@@ -8,7 +8,7 @@ Global:
 CheckDataset:
   convert: 
     enable: True
-    src_dataset_type: MSTextRecDataset
+    src_dataset_type: FormulaRecDataset
   split: 
     enable: False
     train_percent: null

+ 40 - 0
paddlex/configs/modules/formula_recognition/PP-FormulaNet-L.yaml

@@ -0,0 +1,40 @@
+Global:
+  model: PP-FormulaNet-L
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "./dataset/ocr_rec_latexocr_dataset_example"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert: 
+    enable: False
+    src_dataset_type: FormulaRecDataset
+  split: 
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 20
+  batch_size_train: 5
+  batch_size_val: 5
+  learning_rate: 0.0001
+  pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-FormulaNet-L_pretrained.pdparams
+  resume_path: null
+  log_interval: 20
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: output/best_accuracy/best_accuracy.pdparams
+  log_interval: 1
+
+Export:
+  weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-FormulaNet-L_pretrained.pdparams
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_accuracy/inference_pir"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_formula_rec_001.png"
+  kernel_option:
+    run_mode: paddle

+ 40 - 0
paddlex/configs/modules/formula_recognition/PP-FormulaNet-S.yaml

@@ -0,0 +1,40 @@
+Global:
+  model: PP-FormulaNet-S
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "./dataset/ocr_rec_latexocr_dataset_example"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert: 
+    enable: False
+    src_dataset_type: FormulaRecDataset
+  split: 
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 20
+  batch_size_train: 30
+  batch_size_val: 10
+  learning_rate: 0.0001
+  pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-FormulaNet-S_pretrained.pdparams
+  resume_path: null
+  log_interval: 20
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: output/best_accuracy/best_accuracy.pdparams
+  log_interval: 1
+
+Export:
+  weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-FormulaNet-S_pretrained.pdparams
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_accuracy/inference_pir"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_formula_rec_001.png"
+  kernel_option:
+    run_mode: paddle

+ 40 - 0
paddlex/configs/modules/formula_recognition/UniMERNet.yaml

@@ -0,0 +1,40 @@
+Global:
+  model: UniMERNet
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "./dataset/ocr_rec_latexocr_dataset_example"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert: 
+    enable: False
+    src_dataset_type: FormulaRecDataset
+  split: 
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 20
+  batch_size_train: 7
+  batch_size_val: 20
+  learning_rate: 0.0001
+  pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/UniMERNet_pretrained.pdparams
+  resume_path: null
+  log_interval: 20
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: output/best_accuracy/best_accuracy.pdparams
+  log_interval: 1
+
+Export:
+  weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/UniMERNet_pretrained.pdparams
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_accuracy/inference_pir"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_formula_rec_001.png"
+  kernel_option:
+    run_mode: paddle

+ 2 - 1
paddlex/inference/models_new/__init__.py

@@ -24,6 +24,7 @@ from .image_classification import ClasPredictor
 from .object_detection import DetPredictor
 from .text_detection import TextDetPredictor
 from .text_recognition import TextRecPredictor
+from .formula_recognition import FormulaRecPredictor
 
 # from .table_recognition import TablePredictor
 # from .object_detection import DetPredictor
@@ -40,7 +41,7 @@ from .image_unwarping import WarpPredictor
 
 # from .multilabel_classification import MLClasPredictor
 # from .anomaly_detection import UadPredictor
-# from .formula_recognition import LaTeXOCRPredictor
+
 # from .face_recognition import FaceRecPredictor
 
 

+ 15 - 0
paddlex/inference/models_new/formula_recognition/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import FormulaRecPredictor

+ 158 - 0
paddlex/inference/models_new/formula_recognition/predictor.py

@@ -0,0 +1,158 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ....utils.func_register import FuncRegister
+from ....modules.formula_recognition.model_list import MODELS
+from ...common.batch_sampler import ImageBatchSampler
+from ...common.reader import ReadImage
+from ..common import (
+    StaticInfer,
+)
+from ..base import BasicPredictor
+from .processors import (
+    MinMaxResize,
+    LatexTestTransform,
+    LatexImageFormat,
+    LaTeXOCRDecode,
+    NormalizeImage,
+    ToBatch,
+    UniMERNetImgDecode,
+    UniMERNetDecode,
+    UniMERNetTestTransform,
+    UniMERNetImageFormat,
+)
+
+from .result import FormulaRecResult
+
+
+class FormulaRecPredictor(BasicPredictor):
+
+    entities = MODELS
+
+    _FUNC_MAP = {}
+    register = FuncRegister(_FUNC_MAP)
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.pre_tfs, self.infer, self.post_op = self._build()
+
+    def _build_batch_sampler(self):
+        return ImageBatchSampler()
+
+    def _get_result_class(self):
+        return FormulaRecResult
+
+    def _build(self):
+        pre_tfs = {"Read": ReadImage(format="RGB")}
+        for cfg in self.config["PreProcess"]["transform_ops"]:
+            tf_key = list(cfg.keys())[0]
+            assert tf_key in self._FUNC_MAP
+            func = self._FUNC_MAP[tf_key]
+            args = cfg.get(tf_key, {})
+            name, op = func(self, **args) if args else func(self)
+            if op:
+                pre_tfs[name] = op
+        pre_tfs["ToBatch"] = ToBatch()
+
+        infer = StaticInfer(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+
+        post_op = self.build_postprocess(**self.config["PostProcess"])
+        return pre_tfs, infer, post_op
+
+    def process(self, batch_data):
+        batch_raw_imgs = self.pre_tfs["Read"](imgs=batch_data)
+        if self.model_name in ("LaTeX_OCR_rec"):
+            batch_imgs = self.pre_tfs["MinMaxResize"](imgs=batch_raw_imgs)
+            batch_imgs = self.pre_tfs["LatexTestTransform"](imgs=batch_imgs)
+            batch_imgs = self.pre_tfs["NormalizeImage"](imgs=batch_imgs)
+            batch_imgs = self.pre_tfs["LatexImageFormat"](imgs=batch_imgs)
+        elif self.model_name in ("UniMERNet"):
+            batch_imgs = self.pre_tfs["UniMERNetImgDecode"](imgs=batch_raw_imgs)
+            batch_imgs = self.pre_tfs["UniMERNetTestTransform"](imgs=batch_imgs)
+            batch_imgs = self.pre_tfs["UniMERNetImageFormat"](imgs=batch_imgs)
+        elif self.model_name in ("PP-FormulaNet-S", "PP-FormulaNet-L"):
+            batch_imgs = self.pre_tfs["UniMERNetImgDecode"](imgs=batch_raw_imgs)
+            batch_imgs = self.pre_tfs["UniMERNetTestTransform"](imgs=batch_imgs)
+            batch_imgs = self.pre_tfs["LatexImageFormat"](imgs=batch_imgs)
+
+        x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
+        batch_preds = self.infer(x=x)
+        batch_preds = [p.reshape([-1]) for p in batch_preds[0]]
+        rec_formula = self.post_op(batch_preds)
+        return {
+            "input_path": batch_data,
+            "input_img": batch_raw_imgs,
+            "rec_formula": rec_formula,
+        }
+
+    @register("DecodeImage")
+    def build_readimg(self, channel_first, img_mode):
+        assert channel_first == False
+        return "Read", ReadImage(format=img_mode)
+
+    @register("MinMaxResize")
+    def build_min_max_resize(self, min_dimensions, max_dimensions):
+        return "MinMaxResize", MinMaxResize(
+            min_dimensions=min_dimensions, max_dimensions=max_dimensions
+        )
+
+    @register("LatexTestTransform")
+    def build_latex_test_transform(
+        self,
+    ):
+        return "LatexTestTransform", LatexTestTransform()
+
+    @register("NormalizeImage")
+    def build_normalize(self, mean, std, order="chw"):
+        return "NormalizeImage", NormalizeImage(mean=mean, std=std, order=order)
+
+    @register("LatexImageFormat")
+    def build_latexocr_imageformat(self):
+        return "LatexImageFormat", LatexImageFormat()
+
+    @register("UniMERNetImgDecode")
+    def build_unimernet_decode(self, input_size):
+        return "UniMERNetImgDecode", UniMERNetImgDecode(input_size)
+
+    def build_postprocess(self, **kwargs):
+        if kwargs.get("name") == "LaTeXOCRDecode":
+            return LaTeXOCRDecode(
+                character_list=kwargs.get("character_dict"),
+            )
+        elif kwargs.get("name") == "UniMERNetDecode":
+            return UniMERNetDecode(
+                character_list=kwargs.get("character_dict"),
+            )
+        else:
+            raise Exception()
+
+    @register("UniMERNetTestTransform")
+    def build_unimernet_imageformat(self):
+        return "UniMERNetTestTransform", UniMERNetTestTransform()
+
+    @register("UniMERNetImageFormat")
+    def build_unimernet_imageformat(self):
+        return "UniMERNetImageFormat", UniMERNetImageFormat()
+
+    @register("UniMERNetLabelEncode")
+    def foo(self, *args, **kwargs):
+        return None, None
+
+    @register("KeepKeys")
+    def foo(self, *args, **kwargs):
+        return None, None

+ 986 - 0
paddlex/inference/models_new/formula_recognition/processors.py

@@ -0,0 +1,986 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import os.path as osp
+
+import re
+import numpy as np
+from PIL import Image, ImageOps, ImageDraw
+import cv2
+import math
+import json
+import tempfile
+from tokenizers import Tokenizer as TokenizerFast
+from tokenizers import AddedToken
+from typing import List, Tuple, Optional, Any, Dict, Union
+
+from ....utils import logging
+
+
+class MinMaxResize:
+    """Class for resizing images to be within specified minimum and maximum dimensions, with padding and normalization."""
+
+    def __init__(
+        self,
+        min_dimensions: Optional[List[int]] = [32, 32],
+        max_dimensions: Optional[List[int]] = [672, 192],
+        **kwargs,
+    ) -> None:
+        """Initializes the MinMaxResize class with minimum and maximum dimensions.
+
+        Args:
+            min_dimensions (list of int, optional): Minimum dimensions (width, height). Defaults to [32, 32].
+            max_dimensions (list of int, optional): Maximum dimensions (width, height). Defaults to [672, 192].
+            **kwargs: Additional keyword arguments for future expansion.
+        """
+        self.min_dimensions = min_dimensions
+        self.max_dimensions = max_dimensions
+
+    def pad_(self, img: Image.Image, divable: int = 32) -> Image.Image:
+        """Pads the image to ensure its dimensions are divisible by a specified value.
+
+        Args:
+            img (PIL.Image.Image): The input image.
+            divable (int, optional): The value by which the dimensions should be divisible. Defaults to 32.
+
+        Returns:
+            PIL.Image.Image: The padded image.
+        """
+        threshold = 128
+        data = np.array(img.convert("LA"))
+        if data[..., -1].var() == 0:
+            data = (data[..., 0]).astype(np.uint8)
+        else:
+            data = (255 - data[..., -1]).astype(np.uint8)
+        data = (data - data.min()) / (data.max() - data.min()) * 255
+        if data.mean() > threshold:
+            # To invert the text to white
+            gray = 255 * (data < threshold).astype(np.uint8)
+        else:
+            gray = 255 * (data > threshold).astype(np.uint8)
+            data = 255 - data
+
+        coords = cv2.findNonZero(gray)  # Find all non-zero points (text)
+        a, b, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
+        rect = data[b : b + h, a : a + w]
+        im = Image.fromarray(rect).convert("L")
+        dims = []
+        for x in [w, h]:
+            div, mod = divmod(x, divable)
+            dims.append(divable * (div + (1 if mod > 0 else 0)))
+        padded = Image.new("L", dims, 255)
+        padded.paste(im, (0, 0, im.size[0], im.size[1]))
+        return padded
+
+    def minmax_size_(
+        self,
+        img: Image.Image,
+        max_dimensions: Optional[List[int]],
+        min_dimensions: Optional[List[int]],
+    ) -> Image.Image:
+        """Resizes the image to be within the specified minimum and maximum dimensions.
+
+        Args:
+            img (PIL.Image.Image): The input image.
+            max_dimensions (list of int or None): Maximum dimensions (width, height).
+            min_dimensions (list of int or None): Minimum dimensions (width, height).
+
+        Returns:
+            PIL.Image.Image: The resized image.
+        """
+        if max_dimensions is not None:
+            ratios = [a / b for a, b in zip(img.size, max_dimensions)]
+            if any([r > 1 for r in ratios]):
+                size = np.array(img.size) // max(ratios)
+                img = img.resize(tuple(size.astype(int)), Image.BILINEAR)
+        if min_dimensions is not None:
+            # hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions
+            padded_size = [
+                max(img_dim, min_dim)
+                for img_dim, min_dim in zip(img.size, min_dimensions)
+            ]
+            if padded_size != list(img.size):  # assert hypothesis
+                padded_im = Image.new("L", padded_size, 255)
+                padded_im.paste(img, img.getbbox())
+                img = padded_im
+        return img
+
+    def resize(self, img: np.ndarray) -> np.ndarray:
+        """Resizes the input image according to the specified minimum and maximum dimensions.
+
+        Args:
+            img (np.ndarray): The input image as a numpy array.
+
+        Returns:
+            np.ndarray: The resized image as a numpy array with three channels.
+        """
+        h, w = img.shape[:2]
+        if (
+            self.min_dimensions[0] <= w <= self.max_dimensions[0]
+            and self.min_dimensions[1] <= h <= self.max_dimensions[1]
+        ):
+            return img
+        else:
+            img = Image.fromarray(np.uint8(img))
+            img = self.minmax_size_(
+                self.pad_(img), self.max_dimensions, self.min_dimensions
+            )
+            img = np.array(img)
+            img = np.dstack((img, img, img))
+            return img
+
+    def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
+        """Applies the resize method to a list of images.
+
+        Args:
+            imgs (list of np.ndarray): The list of input images as numpy arrays.
+
+        Returns:
+            list of np.ndarray: The list of resized images as numpy arrays with three channels.
+        """
+        return [self.resize(img) for img in imgs]
+
+
+class LatexTestTransform:
+    """
+    A transform class for processing images according to Latex test requirements.
+    """
+
+    def __init__(self, **kwargs) -> None:
+        """
+        Initialize the transform with default number of output channels.
+        """
+        super().__init__()
+        self.num_output_channels = 3
+
+    def transform(self, img: np.ndarray) -> np.ndarray:
+        """
+        Convert the input image to grayscale, squeeze it, and merge to create an output
+        image with the specified number of output channels.
+
+        Parameters:
+            img (np.array): The input image.
+
+        Returns:
+            np.array: The transformed image.
+        """
+        grayscale_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+        squeezed = np.squeeze(grayscale_image)
+        return cv2.merge([squeezed] * self.num_output_channels)
+
+    def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
+        """
+        Apply the transform to a list of images.
+
+        Parameters:
+            imgs (list of np.array): The list of input images.
+
+        Returns:
+            list of np.array: The list of transformed images.
+        """
+        return [self.transform(img) for img in imgs]
+
+
+class LatexImageFormat:
+    """Class for formatting images to a specific format suitable for LaTeX."""
+
+    def __init__(self, **kwargs) -> None:
+        """Initializes the LatexImageFormat class with optional keyword arguments."""
+        super().__init__()
+
+    def format(self, img: np.ndarray) -> np.ndarray:
+        """Formats a single image to the LaTeX-compatible format.
+
+        Args:
+            img (numpy.ndarray): The input image as a numpy array.
+
+        Returns:
+            numpy.ndarray: The formatted image as a numpy array with an added dimension for color.
+        """
+        im_h, im_w = img.shape[:2]
+        divide_h = math.ceil(im_h / 16) * 16
+        divide_w = math.ceil(im_w / 16) * 16
+        img = img[:, :, 0]
+        img = np.pad(
+            img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
+        )
+        img_expanded = img[:, :, np.newaxis].transpose(2, 0, 1)
+        return img_expanded[np.newaxis, :]
+
+    def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
+        """Applies the format method to a list of images.
+
+        Args:
+            imgs (list of numpy.ndarray): A list of input images as numpy arrays.
+
+        Returns:
+            list of numpy.ndarray: A list of formatted images as numpy arrays.
+        """
+        return [self.format(img) for img in imgs]
+
+
+class NormalizeImage(object):
+    """Normalize an image by subtracting the mean and dividing by the standard deviation.
+
+    Args:
+        scale (float or str): The scale factor to apply to the image. If a string is provided, it will be evaluated as a Python expression.
+        mean (list of float): The mean values to subtract from each channel. Defaults to [0.485, 0.456, 0.406].
+        std (list of float): The standard deviation values to divide by for each channel. Defaults to [0.229, 0.224, 0.225].
+        order (str): The order of dimensions for the mean and std. 'chw' for channels-height-width, 'hwc' for height-width-channels. Defaults to 'chw'.
+        **kwargs: Additional keyword arguments that may be used by subclasses.
+
+    Attributes:
+        scale (float): The scale factor applied to the image.
+        mean (numpy.ndarray): The mean values reshaped according to the specified order.
+        std (numpy.ndarray): The standard deviation values reshaped according to the specified order.
+    """
+
+    def __init__(
+        self,
+        scale: Optional[Union[float, str]] = None,
+        mean: Optional[List[float]] = None,
+        std: Optional[List[float]] = None,
+        order: str = "chw",
+        **kwargs,
+    ) -> None:
+        if isinstance(scale, str):
+            scale = eval(scale)
+        self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
+        mean = mean if mean is not None else [0.485, 0.456, 0.406]
+        std = std if std is not None else [0.229, 0.224, 0.225]
+
+        shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
+        self.mean = np.array(mean).reshape(shape).astype("float32")
+        self.std = np.array(std).reshape(shape).astype("float32")
+
+    def normalize(self, img: Union[np.ndarray, Image.Image]) -> np.ndarray:
+        from PIL import Image
+
+        if isinstance(img, Image.Image):
+            img = np.array(img)
+        assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
+        img = (img.astype("float32") * self.scale - self.mean) / self.std
+        return img
+
+    def __call__(self, imgs: List[Union[np.ndarray, Image.Image]]) -> List[np.ndarray]:
+        """Apply normalization to a list of images."""
+        return [self.normalize(img) for img in imgs]
+
+
+class ToBatch(object):
+    """A class for batching images."""
+
+    def __init__(self, **kwargs) -> None:
+        """Initializes the ToBatch object."""
+        super(ToBatch, self).__init__()
+
+    def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
+        """Concatenates a list of images into a single batch.
+
+        Args:
+            imgs (list): A list of image arrays to be concatenated.
+
+        Returns:
+            list: A list containing the concatenated batch of images wrapped in another list (to comply with common batch processing formats).
+        """
+        batch_imgs = np.concatenate(imgs)
+        batch_imgs = batch_imgs.copy()
+        x = [batch_imgs]
+        return x
+
+
+class LaTeXOCRDecode(object):
+    """Class for decoding LaTeX OCR tokens based on a provided character list."""
+
+    def __init__(self, character_list: List[str], **kwargs) -> None:
+        """Initializes the LaTeXOCRDecode object.
+
+        Args:
+            character_list (list): The list of characters to use for tokenization.
+            **kwargs: Additional keyword arguments for initialization.
+        """
+        from tokenizers import Tokenizer as TokenizerFast
+
+        super(LaTeXOCRDecode, self).__init__()
+        temp_path = tempfile.gettempdir()
+        rec_char_dict_path = os.path.join(temp_path, "latexocr_tokenizer.json")
+        try:
+            with open(rec_char_dict_path, "w") as f:
+                json.dump(character_list, f)
+        except Exception as e:
+            print(f"创建 latexocr_tokenizer.json 文件失败, 原因{str(e)}")
+        self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
+
+    def post_process(self, s: str) -> str:
+        """Post-processes the decoded LaTeX string.
+
+        Args:
+            s (str): The decoded LaTeX string to post-process.
+
+        Returns:
+            str: The post-processed LaTeX string.
+        """
+        text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
+        letter = "[a-zA-Z]"
+        noletter = "[\W_^\d]"
+        names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
+        s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
+        news = s
+        while True:
+            s = news
+            news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
+            news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
+            news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
+            if news == s:
+                break
+        return s
+
+    def decode(self, tokens: np.ndarray) -> List[str]:
+        """Decodes the provided tokens into LaTeX strings.
+
+        Args:
+            tokens (np.array): The tokens to decode.
+
+        Returns:
+            list: The decoded LaTeX strings.
+        """
+        if len(tokens.shape) == 1:
+            tokens = tokens[None, :]
+        dec = [self.tokenizer.decode(tok) for tok in tokens]
+        dec_str_list = [
+            "".join(detok.split(" "))
+            .replace("Ġ", " ")
+            .replace("[EOS]", "")
+            .replace("[BOS]", "")
+            .replace("[PAD]", "")
+            .strip()
+            for detok in dec
+        ]
+        return [self.post_process(dec_str) for dec_str in dec_str_list]
+
+    def __call__(
+        self,
+        preds: np.ndarray,
+        label: Optional[np.ndarray] = None,
+        mode: str = "eval",
+        *args,
+        **kwargs,
+    ) -> Tuple[List[str], List[str]]:
+        """Calls the object with the provided predictions and label.
+
+        Args:
+            preds (np.array): The predictions to decode.
+            label (np.array, optional): The labels to decode. Defaults to None.
+            mode (str): The mode to run in, either 'train' or 'eval'. Defaults to 'eval'.
+            *args: Positional arguments to pass.
+            **kwargs: Keyword arguments to pass.
+
+        Returns:
+            tuple or list: The decoded text and optionally the decoded label.
+        """
+        if mode == "train":
+            preds_idx = np.array(preds.argmax(axis=2))
+            text = self.decode(preds_idx)
+        else:
+            text = self.decode(np.array(preds))
+        if label is None:
+            return text
+        label = self.decode(np.array(label))
+        return text, label
+
+
+class UniMERNetImgDecode(object):
+    """Class for decoding images for UniMERNet, including cropping margins, resizing, and padding."""
+
+    def __init__(
+        self, input_size: Tuple[int, int], random_padding: bool = False, **kwargs
+    ) -> None:
+        """Initializes the UniMERNetImgDecode class with input size and random padding options.
+
+        Args:
+            input_size (tuple): The desired input size for the images (height, width).
+            random_padding (bool): Whether to use random padding for resizing.
+            **kwargs: Additional keyword arguments."""
+        self.input_size = input_size
+        self.random_padding = random_padding
+
+    def crop_margin(self, img: Image.Image) -> Image.Image:
+        """Crops the margin of the image based on grayscale thresholding.
+
+        Args:
+            img (PIL.Image.Image): The input image.
+
+        Returns:
+            PIL.Image.Image: The cropped image."""
+        data = np.array(img.convert("L"))
+        data = data.astype(np.uint8)
+        max_val = data.max()
+        min_val = data.min()
+        if max_val == min_val:
+            return img
+        data = (data - min_val) / (max_val - min_val) * 255
+        gray = 255 * (data < 200).astype(np.uint8)
+        coords = cv2.findNonZero(gray)  # Find all non-zero points (text)
+        a, b, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
+        return img.crop((a, b, w + a, h + b))
+
+    def get_dimensions(self, img: Union[Image.Image, np.ndarray]) -> List[int]:
+        """Gets the dimensions of the image.
+
+        Args:
+            img (PIL.Image.Image or numpy.ndarray): The input image.
+
+        Returns:
+            list: A list containing the number of channels, height, and width."""
+        if hasattr(img, "getbands"):
+            channels = len(img.getbands())
+        else:
+            channels = img.channels
+        width, height = img.size
+        return [channels, height, width]
+
+    def _compute_resized_output_size(
+        self,
+        image_size: Tuple[int, int],
+        size: Union[int, Tuple[int, int]],
+        max_size: Optional[int] = None,
+    ) -> List[int]:
+        """Computes the resized output size of the image.
+
+        Args:
+            image_size (tuple): The original size of the image (height, width).
+            size (int or tuple): The desired size for the smallest edge or both height and width.
+            max_size (int, optional): The maximum allowed size for the longer edge.
+
+        Returns:
+            list: A list containing the new height and width."""
+        if len(size) == 1:  # specified size only for the smallest edge
+            h, w = image_size
+            short, long = (w, h) if w <= h else (h, w)
+            requested_new_short = size if isinstance(size, int) else size[0]
+
+            new_short, new_long = requested_new_short, int(
+                requested_new_short * long / short
+            )
+
+            if max_size is not None:
+                if max_size <= requested_new_short:
+                    raise ValueError(
+                        f"max_size = {max_size} must be strictly greater than the requested "
+                        f"size for the smaller edge size = {size}"
+                    )
+                if new_long > max_size:
+                    new_short, new_long = int(max_size * new_short / new_long), max_size
+
+            new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
+        else:  # specified both h and w
+            new_w, new_h = size[1], size[0]
+        return [new_h, new_w]
+
+    def resize(
+        self, img: Image.Image, size: Union[int, Tuple[int, int]]
+    ) -> Image.Image:
+        """Resizes the image to the specified size.
+
+        Args:
+            img (PIL.Image.Image): The input image.
+            size (int or tuple): The desired size for the smallest edge or both height and width.
+
+        Returns:
+            PIL.Image.Image: The resized image."""
+        _, image_height, image_width = self.get_dimensions(img)
+        if isinstance(size, int):
+            size = [size]
+        max_size = None
+        output_size = self._compute_resized_output_size(
+            (image_height, image_width), size, max_size
+        )
+        img = img.resize(tuple(output_size[::-1]), resample=2)
+        return img
+
+    def img_decode(self, img: np.ndarray) -> Optional[np.ndarray]:
+        """Decodes the image by cropping margins, resizing, and adding padding.
+
+        Args:
+            img (numpy.ndarray): The input image array.
+
+        Returns:
+            numpy.ndarray: The decoded image array."""
+        try:
+            img = self.crop_margin(Image.fromarray(img).convert("RGB"))
+        except OSError:
+            return
+        if img.height == 0 or img.width == 0:
+            return
+        img = self.resize(img, min(self.input_size))
+        img.thumbnail((self.input_size[1], self.input_size[0]))
+        delta_width = self.input_size[1] - img.width
+        delta_height = self.input_size[0] - img.height
+        if self.random_padding:
+            pad_width = np.random.randint(low=0, high=delta_width + 1)
+            pad_height = np.random.randint(low=0, high=delta_height + 1)
+        else:
+            pad_width = delta_width // 2
+            pad_height = delta_height // 2
+        padding = (
+            pad_width,
+            pad_height,
+            delta_width - pad_width,
+            delta_height - pad_height,
+        )
+        return np.array(ImageOps.expand(img, padding))
+
+    def __call__(self, imgs: List[np.ndarray]) -> List[Optional[np.ndarray]]:
+        """Calls the img_decode method on a list of images.
+
+        Args:
+            imgs (list of numpy.ndarray): The list of input image arrays.
+
+        Returns:
+            list of numpy.ndarray: The list of decoded image arrays."""
+        return [self.img_decode(img) for img in imgs]
+
+
+class UniMERNetDecode(object):
+    """Class for decoding tokenized inputs using UniMERNet tokenizer.
+
+    Attributes:
+        SPECIAL_TOKENS_ATTRIBUTES (List[str]): List of special token attributes.
+        model_input_names (List[str]): List of model input names.
+        max_seq_len (int): Maximum sequence length.
+        pad_token_id (int): ID for the padding token.
+        bos_token_id (int): ID for the beginning-of-sequence token.
+        eos_token_id (int): ID for the end-of-sequence token.
+        padding_side (str): Padding side, either 'left' or 'right'.
+        pad_token (str): Padding token.
+        pad_token_type_id (int): Type ID for the padding token.
+        pad_to_multiple_of (Optional[int]): If set, pad to a multiple of this value.
+        tokenizer (TokenizerFast): Fast tokenizer instance.
+
+    Args:
+        character_list (Dict[str, Any]): Dictionary containing tokenizer configuration.
+        **kwargs: Additional keyword arguments.
+    """
+
+    SPECIAL_TOKENS_ATTRIBUTES = [
+        "bos_token",
+        "eos_token",
+        "unk_token",
+        "sep_token",
+        "pad_token",
+        "cls_token",
+        "mask_token",
+        "additional_special_tokens",
+    ]
+
+    def __init__(
+        self,
+        character_list: Dict[str, Any],
+        **kwargs,
+    ) -> None:
+        """Initializes the UniMERNetDecode class.
+
+        Args:
+            character_list (Dict[str, Any]): Dictionary containing tokenizer configuration.
+            **kwargs: Additional keyword arguments.
+        """
+
+        self._unk_token = "<unk>"
+        self._bos_token = "<s>"
+        self._eos_token = "</s>"
+        self._pad_token = "<pad>"
+        self._sep_token = None
+        self._cls_token = None
+        self._mask_token = None
+        self._additional_special_tokens = []
+        self.model_input_names = ["input_ids", "token_type_ids", "attention_mask"]
+        self.max_seq_len = 2048
+        self.pad_token_id = 1
+        self.bos_token_id = 0
+        self.eos_token_id = 2
+        self.padding_side = "right"
+        self.pad_token_id = 1
+        self.pad_token = "<pad>"
+        self.pad_token_type_id = 0
+        self.pad_to_multiple_of = None
+
+        temp_path = tempfile.gettempdir()
+        fast_tokenizer_file = os.path.join(temp_path, "tokenizer.json")
+        tokenizer_config_file = os.path.join(temp_path, "tokenizer_config.json")
+        try:
+            with open(fast_tokenizer_file, "w") as f:
+                json.dump(character_list["fast_tokenizer_file"], f)
+            with open(tokenizer_config_file, "w") as f:
+                json.dump(character_list["tokenizer_config_file"], f)
+        except Exception as e:
+            print(
+                f"创建 tokenizer.json 和 tokenizer_config.json 文件失败, 原因{str(e)}"
+            )
+
+        self.tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
+        added_tokens_decoder = {}
+        added_tokens_map = {}
+        if tokenizer_config_file is not None:
+            with open(
+                tokenizer_config_file, encoding="utf-8"
+            ) as tokenizer_config_handle:
+                init_kwargs = json.load(tokenizer_config_handle)
+                if "added_tokens_decoder" in init_kwargs:
+                    for idx, token in init_kwargs["added_tokens_decoder"].items():
+                        if isinstance(token, dict):
+                            token = AddedToken(**token)
+                        if isinstance(token, AddedToken):
+                            added_tokens_decoder[int(idx)] = token
+                            added_tokens_map[str(token)] = token
+                        else:
+                            raise ValueError(
+                                f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance"
+                            )
+                init_kwargs["added_tokens_decoder"] = added_tokens_decoder
+                added_tokens_decoder = init_kwargs.pop("added_tokens_decoder", {})
+                tokens_to_add = [
+                    token
+                    for index, token in sorted(
+                        added_tokens_decoder.items(), key=lambda x: x[0]
+                    )
+                    if token not in added_tokens_decoder
+                ]
+                added_tokens_encoder = self.added_tokens_encoder(added_tokens_decoder)
+                encoder = list(added_tokens_encoder.keys()) + [
+                    str(token) for token in tokens_to_add
+                ]
+                tokens_to_add += [
+                    token
+                    for token in self.all_special_tokens_extended
+                    if token not in encoder and token not in tokens_to_add
+                ]
+                if len(tokens_to_add) > 0:
+                    is_last_special = None
+                    tokens = []
+                    special_tokens = self.all_special_tokens
+                    for token in tokens_to_add:
+                        is_special = (
+                            (token.special or str(token) in special_tokens)
+                            if isinstance(token, AddedToken)
+                            else str(token) in special_tokens
+                        )
+                        if is_last_special is None or is_last_special == is_special:
+                            tokens.append(token)
+                        else:
+                            self._add_tokens(tokens, special_tokens=is_last_special)
+                            tokens = [token]
+                        is_last_special = is_special
+                    if tokens:
+                        self._add_tokens(tokens, special_tokens=is_last_special)
+
+    def _add_tokens(
+        self, new_tokens: List[Union[AddedToken, str]], special_tokens: bool = False
+    ) -> List[Union[AddedToken, str]]:
+        """Adds new tokens to the tokenizer.
+
+        Args:
+            new_tokens (List[Union[AddedToken, str]]): Tokens to be added.
+            special_tokens (bool): Indicates whether the tokens are special tokens.
+
+        Returns:
+            List[Union[AddedToken, str]]: added tokens.
+        """
+        if special_tokens:
+            return self.tokenizer.add_special_tokens(new_tokens)
+
+        return self.tokenizer.add_tokens(new_tokens)
+
+    def added_tokens_encoder(
+        self, added_tokens_decoder: Dict[int, AddedToken]
+    ) -> Dict[str, int]:
+        """Creates an encoder dictionary from added tokens.
+
+        Args:
+            added_tokens_decoder (Dict[int, AddedToken]): Dictionary mapping token IDs to tokens.
+
+        Returns:
+            Dict[str, int]: Dictionary mapping token strings to IDs.
+        """
+        return {
+            k.content: v
+            for v, k in sorted(added_tokens_decoder.items(), key=lambda item: item[0])
+        }
+
+    @property
+    def all_special_tokens(self) -> List[str]:
+        """Retrieves all special tokens.
+
+        Returns:
+            List[str]: List of all special tokens as strings.
+        """
+        all_toks = [str(s) for s in self.all_special_tokens_extended]
+        return all_toks
+
+    @property
+    def all_special_tokens_extended(self) -> List[Union[str, AddedToken]]:
+        """Retrieves all special tokens, including extended ones.
+
+        Returns:
+            List[Union[str, AddedToken]]: List of all special tokens.
+        """
+        all_tokens = []
+        seen = set()
+        for value in self.special_tokens_map_extended.values():
+            if isinstance(value, (list, tuple)):
+                tokens_to_add = [token for token in value if str(token) not in seen]
+            else:
+                tokens_to_add = [value] if str(value) not in seen else []
+            seen.update(map(str, tokens_to_add))
+            all_tokens.extend(tokens_to_add)
+        return all_tokens
+
+    @property
+    def special_tokens_map_extended(self) -> Dict[str, Union[str, List[str]]]:
+        """Retrieves the extended map of special tokens.
+
+        Returns:
+            Dict[str, Union[str, List[str]]]: Dictionary mapping special token attributes to their values.
+        """
+        set_attr = {}
+        for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
+            attr_value = getattr(self, "_" + attr)
+            if attr_value:
+                set_attr[attr] = attr_value
+        return set_attr
+
+    def convert_ids_to_tokens(
+        self, ids: Union[int, List[int]], skip_special_tokens: bool = False
+    ) -> Union[str, List[str]]:
+        """Converts token IDs to token strings.
+
+        Args:
+            ids (Union[int, List[int]]): Token ID(s) to convert.
+            skip_special_tokens (bool): Whether to skip special tokens during conversion.
+
+        Returns:
+            Union[str, List[str]]: Converted token string(s).
+        """
+        if isinstance(ids, int):
+            return self.tokenizer.id_to_token(ids)
+        tokens = []
+        for index in ids:
+            index = int(index)
+            if skip_special_tokens and index in self.all_special_ids:
+                continue
+            tokens.append(self.tokenizer.id_to_token(index))
+        return tokens
+
+    def detokenize(self, tokens: List[List[int]]) -> List[List[str]]:
+        """Detokenizes a list of token IDs back into strings.
+
+        Args:
+            tokens (List[List[int]]): List of token ID lists.
+
+        Returns:
+            List[List[str]]: List of detokenized strings.
+        """
+        self.tokenizer.bos_token = "<s>"
+        self.tokenizer.eos_token = "</s>"
+        self.tokenizer.pad_token = "<pad>"
+        toks = [self.convert_ids_to_tokens(tok) for tok in tokens]
+        for b in range(len(toks)):
+            for i in reversed(range(len(toks[b]))):
+                if toks[b][i] is None:
+                    toks[b][i] = ""
+                toks[b][i] = toks[b][i].replace("Ġ", " ").strip()
+                if toks[b][i] in (
+                    [
+                        self.tokenizer.bos_token,
+                        self.tokenizer.eos_token,
+                        self.tokenizer.pad_token,
+                    ]
+                ):
+                    del toks[b][i]
+        return toks
+
+    def token2str(self, token_ids: List[List[int]]) -> List[str]:
+        """Converts a list of token IDs to strings.
+
+        Args:
+            token_ids (List[List[int]]): List of token ID lists.
+
+        Returns:
+            List[str]: List of converted strings.
+        """
+        generated_text = []
+        for tok_id in token_ids:
+            end_idx = np.argwhere(tok_id == 2)
+            if len(end_idx) > 0:
+                end_idx = int(end_idx[0][0])
+                tok_id = tok_id[: end_idx + 1]
+            generated_text.append(
+                self.tokenizer.decode(tok_id, skip_special_tokens=True)
+            )
+        generated_text = [self.post_process(text) for text in generated_text]
+        return generated_text
+
+    def normalize(self, s: str) -> str:
+        """Normalizes a string by removing unnecessary spaces.
+
+        Args:
+            s (str): String to normalize.
+
+        Returns:
+            str: Normalized string.
+        """
+        text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
+        letter = "[a-zA-Z]"
+        noletter = "[\W_^\d]"
+        names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
+        s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
+        news = s
+        while True:
+            s = news
+            news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
+            news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
+            news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
+            if news == s:
+                break
+        return s
+
+    def post_process(self, text: str) -> str:
+        """Post-processes a string by fixing text and normalizing it.
+
+        Args:
+            text (str): String to post-process.
+
+        Returns:
+            str: Post-processed string.
+        """
+        from ftfy import fix_text
+
+        text = fix_text(text)
+        text = self.normalize(text)
+        return text
+
+    def __call__(
+        self,
+        preds: np.ndarray,
+        label: Optional[np.ndarray] = None,
+        mode: str = "eval",
+        *args,
+        **kwargs,
+    ) -> Union[List[str], tuple]:
+        """Processes predictions and optionally labels, returning the decoded text.
+
+        Args:
+            preds (np.ndarray): Model predictions.
+            label (Optional[np.ndarray]): True labels, if available.
+            mode (str): Mode of operation, either 'train' or 'eval'.
+
+        Returns:
+            Union[List[str], tuple]: Decoded text, optionally with labels.
+        """
+        if mode == "train":
+            preds_idx = np.array(preds.argmax(axis=2))
+            text = self.token2str(preds_idx)
+        else:
+            text = self.token2str(np.array(preds))
+        if label is None:
+            return text
+        label = self.token2str(np.array(label))
+        return text, label
+
+
+class UniMERNetTestTransform:
+    """
+    A class for transforming images according to UniMERNet test specifications.
+    """
+
+    def __init__(self, **kwargs) -> None:
+        """
+        Initializes the UniMERNetTestTransform class.
+        """
+        super().__init__()
+        self.num_output_channels = 3
+
+    def transform(self, img: np.ndarray) -> np.ndarray:
+        """
+        Transforms a single image for UniMERNet testing.
+
+        Args:
+            img (numpy.ndarray): The input image.
+
+        Returns:
+            numpy.ndarray: The transformed image.
+        """
+        mean = [0.7931, 0.7931, 0.7931]
+        std = [0.1738, 0.1738, 0.1738]
+        scale = float(1 / 255.0)
+        shape = (1, 1, 3)
+        mean = np.array(mean).reshape(shape).astype("float32")
+        std = np.array(std).reshape(shape).astype("float32")
+        img = (img.astype("float32") * scale - mean) / std
+        grayscale_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+        squeezed = np.squeeze(grayscale_image)
+        img = cv2.merge([squeezed] * self.num_output_channels)
+        return img
+
+    def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
+        """
+        Applies the transform to a list of images.
+
+        Args:
+            imgs (list of numpy.ndarray): The list of input images.
+
+        Returns:
+            list of numpy.ndarray: The list of transformed images.
+        """
+        return [self.transform(img) for img in imgs]
+
+
+class UniMERNetImageFormat:
+    """Class for formatting images to UniMERNet's required format."""
+
+    def __init__(self, **kwargs) -> None:
+        """Initializes the UniMERNetImageFormat instance."""
+        # your init code
+        pass
+
+    def format(self, img: np.ndarray) -> np.ndarray:
+        """Formats a single image to UniMERNet's required format.
+
+        Args:
+            img (numpy.ndarray): The input image to be formatted.
+
+        Returns:
+            numpy.ndarray: The formatted image.
+        """
+        im_h, im_w = img.shape[:2]
+        divide_h = math.ceil(im_h / 32) * 32
+        divide_w = math.ceil(im_w / 32) * 32
+        img = img[:, :, 0]
+        img = np.pad(
+            img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
+        )
+        img_expanded = img[:, :, np.newaxis].transpose(2, 0, 1)
+        return img_expanded[np.newaxis, :]
+
+    def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
+        """Applies the format method to a list of images.
+
+        Args:
+            imgs (list of numpy.ndarray): The list of input images to be formatted.
+
+        Returns:
+            list of numpy.ndarray: The list of formatted images.
+        """
+        return [self.format(img) for img in imgs]

+ 317 - 0
paddlex/inference/models_new/formula_recognition/result.py

@@ -0,0 +1,317 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import cv2
+import PIL
+import math
+import random
+import tempfile
+import subprocess
+import numpy as np
+from pathlib import Path
+from PIL import Image, ImageDraw, ImageFont
+
+from ...common.result import BaseCVResult
+from ....utils import logging
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH
+
+
+class FormulaRecResult(BaseCVResult):
+    def _to_str(self, *args, **kwargs):
+        return super()._to_str(*args, **kwargs).replace("\\\\", "\\")
+
+    def _to_img(
+        self,
+    ):
+        """Draw formula on image"""
+        image = Image.fromarray(self._input_img)
+        try:
+            env_valid()
+        except subprocess.CalledProcessError as e:
+            logging.warning(
+                "Please refer to 2.3 Formula Recognition Pipeline Visualization in Formula Recognition Pipeline Tutorial to install the LaTeX rendering engine at first."
+            )
+            return image
+
+        rec_formula = str(self["rec_formula"])
+        image = np.array(image.convert("RGB"))
+        xywh = crop_white_area(image)
+        if xywh is not None:
+            x, y, w, h = xywh
+            image = image[y : y + h, x : x + w]
+        image = Image.fromarray(image)
+        image_width, image_height = image.size
+        box = [[0, 0], [image_width, 0], [image_width, image_height], [0, image_height]]
+        try:
+            img_formula = draw_formula_module(
+                image.size, box, rec_formula, is_debug=False
+            )
+            img_formula = Image.fromarray(img_formula)
+            render_width, render_height = img_formula.size
+            resize_height = render_height
+            resize_width = int(resize_height * image_width / image_height)
+            image = image.resize((resize_width, resize_height), Image.LANCZOS)
+
+            new_image_width = image.width + int(render_width) + 10
+            new_image = Image.new(
+                "RGB", (new_image_width, render_height), (255, 255, 255)
+            )
+            new_image.paste(image, (0, 0))
+            new_image.paste(img_formula, (image.width + 10, 0))
+            return new_image
+        except subprocess.CalledProcessError as e:
+            logging.warning("Syntax error detected in formula, rendering failed.")
+            return image
+
+
+def get_align_equation(equation):
+    is_align = False
+    equation = str(equation) + "\n"
+    begin_dict = [
+        r"begin{align}",
+        r"begin{align*}",
+    ]
+    for begin_sym in begin_dict:
+        if begin_sym in equation:
+            is_align = True
+            break
+    if not is_align:
+        equation = (
+            r"\begin{equation}"
+            + "\n"
+            + equation.strip()
+            + r"\nonumber"
+            + "\n"
+            + r"\end{equation}"
+            + "\n"
+        )
+    return equation
+
+
+def generate_tex_file(tex_file_path, equation):
+    with open(tex_file_path, "w") as fp:
+        start_template = (
+            r"\documentclass{article}" + "\n"
+            r"\usepackage{cite}" + "\n"
+            r"\usepackage{amsmath,amssymb,amsfonts}" + "\n"
+            r"\usepackage{graphicx}" + "\n"
+            r"\usepackage{textcomp}" + "\n"
+            r"\DeclareMathSizes{14}{14}{9.8}{7}" + "\n"
+            r"\pagestyle{empty}" + "\n"
+            r"\begin{document}" + "\n"
+            r"\begin{large}" + "\n"
+        )
+        fp.write(start_template)
+        equation = get_align_equation(equation)
+        fp.write(equation)
+        end_template = r"\end{large}" + "\n" r"\end{document}" + "\n"
+        fp.write(end_template)
+
+
+def generate_pdf_file(tex_path, pdf_dir, is_debug=False):
+    if os.path.exists(tex_path):
+        command = "pdflatex -halt-on-error -output-directory={} {}".format(
+            pdf_dir, tex_path
+        )
+        if is_debug:
+            subprocess.check_call(command, shell=True)
+        else:
+            devNull = open(os.devnull, "w")
+            subprocess.check_call(
+                command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True
+            )
+
+
+def crop_white_area(image):
+    image = np.array(image).astype("uint8")
+    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
+    _, thresh = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)
+    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+    if len(contours) > 0:
+        x, y, w, h = cv2.boundingRect(np.concatenate(contours))
+        return [x, y, w, h]
+    else:
+        return None
+
+
+def pdf2img(pdf_path, img_path, is_padding=False):
+    import fitz
+
+    pdfDoc = fitz.open(pdf_path)
+    if pdfDoc.page_count != 1:
+        return None
+    for pg in range(pdfDoc.page_count):
+        page = pdfDoc[pg]
+        rotate = int(0)
+        zoom_x = 2
+        zoom_y = 2
+        mat = fitz.Matrix(zoom_x, zoom_y).prerotate(rotate)
+        pix = page.get_pixmap(matrix=mat, alpha=False)
+        if not os.path.exists(img_path):
+            os.makedirs(img_path)
+
+        pix._writeIMG(img_path, 7, 100)
+        img = cv2.imread(img_path)
+        xywh = crop_white_area(img)
+
+        if xywh is not None:
+            x, y, w, h = xywh
+            img = img[y : y + h, x : x + w]
+            if is_padding:
+                img = cv2.copyMakeBorder(
+                    img, 30, 30, 30, 30, cv2.BORDER_CONSTANT, value=(255, 255, 255)
+                )
+            return img
+    return None
+
+
+def draw_formula_module(img_size, box, formula, is_debug=False):
+    """draw box formula for module"""
+    box_width, box_height = img_size
+    with tempfile.TemporaryDirectory() as td:
+        tex_file_path = os.path.join(td, "temp.tex")
+        pdf_file_path = os.path.join(td, "temp.pdf")
+        img_file_path = os.path.join(td, "temp.jpg")
+        generate_tex_file(tex_file_path, formula)
+        if os.path.exists(tex_file_path):
+            generate_pdf_file(tex_file_path, td, is_debug)
+        formula_img = None
+        if os.path.exists(pdf_file_path):
+            formula_img = pdf2img(pdf_file_path, img_file_path, is_padding=False)
+        if formula_img is not None:
+            return formula_img
+        else:
+            img_right_text = draw_box_txt_fine(
+                img_size, box, "Rendering Failed", PINGFANG_FONT_FILE_PATH
+            )
+        return img_right_text
+
+
+def env_valid():
+    with tempfile.TemporaryDirectory() as td:
+        tex_file_path = os.path.join(td, "temp.tex")
+        pdf_file_path = os.path.join(td, "temp.pdf")
+        img_file_path = os.path.join(td, "temp.jpg")
+        formula = "a+b=c"
+        is_debug = False
+        generate_tex_file(tex_file_path, formula)
+        if os.path.exists(tex_file_path):
+            generate_pdf_file(tex_file_path, td, is_debug)
+        if os.path.exists(pdf_file_path):
+            formula_img = pdf2img(pdf_file_path, img_file_path, is_padding=False)
+
+
+def draw_box_formula_fine(img_size, box, formula, is_debug=False):
+    """draw box formula for pipeline"""
+    box_height = int(
+        math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
+    )
+    box_width = int(
+        math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
+    )
+    with tempfile.TemporaryDirectory() as td:
+        tex_file_path = os.path.join(td, "temp.tex")
+        pdf_file_path = os.path.join(td, "temp.pdf")
+        img_file_path = os.path.join(td, "temp.jpg")
+        generate_tex_file(tex_file_path, formula)
+        if os.path.exists(tex_file_path):
+            generate_pdf_file(tex_file_path, td, is_debug)
+        formula_img = None
+        if os.path.exists(pdf_file_path):
+            formula_img = pdf2img(pdf_file_path, img_file_path, is_padding=False)
+        if formula_img is not None:
+            formula_h, formula_w = formula_img.shape[:-1]
+            resize_height = box_height
+            resize_width = formula_w * resize_height / formula_h
+            formula_img = cv2.resize(
+                formula_img, (int(resize_width), int(resize_height))
+            )
+            formula_h, formula_w = formula_img.shape[:-1]
+            pts1 = np.float32(
+                [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]
+            )
+            pts2 = np.array(box, dtype=np.float32)
+            M = cv2.getPerspectiveTransform(pts1, pts2)
+            formula_img = np.array(formula_img, dtype=np.uint8)
+            img_right_text = cv2.warpPerspective(
+                formula_img,
+                M,
+                img_size,
+                flags=cv2.INTER_NEAREST,
+                borderMode=cv2.BORDER_CONSTANT,
+                borderValue=(255, 255, 255),
+            )
+        else:
+            img_right_text = draw_box_txt_fine(
+                img_size, box, "Rendering Failed", PINGFANG_FONT_FILE_PATH
+            )
+        return img_right_text
+
+
+def draw_box_txt_fine(img_size, box, txt, font_path):
+    """draw box text"""
+    box_height = int(
+        math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
+    )
+    box_width = int(
+        math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
+    )
+
+    if box_height > 2 * box_width and box_height > 30:
+        img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255))
+        draw_text = ImageDraw.Draw(img_text)
+        if txt:
+            font = create_font(txt, (box_height, box_width), font_path)
+            draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
+        img_text = img_text.transpose(Image.ROTATE_270)
+    else:
+        img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255))
+        draw_text = ImageDraw.Draw(img_text)
+        if txt:
+            font = create_font(txt, (box_width, box_height), font_path)
+            draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
+
+    pts1 = np.float32(
+        [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]
+    )
+    pts2 = np.array(box, dtype=np.float32)
+    M = cv2.getPerspectiveTransform(pts1, pts2)
+
+    img_text = np.array(img_text, dtype=np.uint8)
+    img_right_text = cv2.warpPerspective(
+        img_text,
+        M,
+        img_size,
+        flags=cv2.INTER_NEAREST,
+        borderMode=cv2.BORDER_CONSTANT,
+        borderValue=(255, 255, 255),
+    )
+    return img_right_text
+
+
+def create_font(txt, sz, font_path):
+    """create font"""
+    font_size = int(sz[1] * 0.8)
+    font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+    if int(PIL.__version__.split(".")[0]) < 10:
+        length = font.getsize(txt)[0]
+    else:
+        length = font.getlength(txt)
+
+    if length > sz[0]:
+        font_size = int(font_size * sz[0] / length)
+        font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+    return font

+ 3 - 0
paddlex/inference/utils/official_models.py

@@ -248,6 +248,9 @@ PP-LCNet_x1_0_vehicle_attribute_infer.tar",
     "SLANet": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/SLANet_infer.tar",
     "SLANet_plus": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/SLANet_plus_infer.tar",
     "LaTeX_OCR_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/LaTeX_OCR_rec_infer.tar",
+    "UniMERNet": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/UniMERNet_infer.tar",
+    "PP-FormulaNet-S": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/PP-FormulaNet-S_infer.tar",
+    "PP-FormulaNet-L": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/PP-FormulaNet-L_infer.tar",
     "FasterRCNN-ResNet34-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/FasterRCNN-ResNet34-FPN_infer.tar",
     "FasterRCNN-ResNet50": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/FasterRCNN-ResNet50_infer.tar",
     "FasterRCNN-ResNet50-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/FasterRCNN-ResNet50-FPN_infer.tar",

+ 5 - 0
paddlex/modules/formula_recognition/__init__.py

@@ -11,3 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+from .dataset_checker import FormulaRecDatasetChecker
+from .trainer import FormulaRecTrainer
+from .evaluator import FormulaRecEvaluator
+from .exportor import FormulaRecExportor

+ 98 - 0
paddlex/modules/formula_recognition/dataset_checker/__init__.py

@@ -0,0 +1,98 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ...base import BaseDatasetChecker
+from .dataset_src import check, split_dataset, deep_analyse, convert
+
+from ..model_list import MODELS
+
+
+class FormulaRecDatasetChecker(BaseDatasetChecker):
+    """Dataset Checker for Text Recognition Model"""
+
+    entities = MODELS
+    sample_num = 10
+
+    def convert_dataset(self, src_dataset_dir: str) -> str:
+        """convert the dataset from other type to specified type
+
+        Args:
+            src_dataset_dir (str): the root directory of dataset.
+
+        Returns:
+            str: the root directory of converted dataset.
+        """
+        return convert(
+            self.check_dataset_config.convert.src_dataset_type, src_dataset_dir
+        )
+
+    def split_dataset(self, src_dataset_dir: str) -> str:
+        """repartition the train and validation dataset
+
+        Args:
+            src_dataset_dir (str): the root directory of dataset.
+
+        Returns:
+            str: the root directory of splited dataset.
+        """
+        return split_dataset(
+            src_dataset_dir,
+            self.check_dataset_config.split.train_percent,
+            self.check_dataset_config.split.val_percent,
+        )
+
+    def check_dataset(self, dataset_dir: str, sample_num: int = sample_num) -> dict:
+        """check if the dataset meets the specifications and get dataset summary
+
+        Args:
+            dataset_dir (str): the root directory of dataset.
+            sample_num (int): the number to be sampled.
+        Returns:
+            dict: dataset summary.
+        """
+        return check(
+            dataset_dir,
+            self.global_config.output,
+            sample_num=10,
+            dataset_type=self.get_dataset_type(),
+        )
+
+    def analyse(self, dataset_dir: str) -> dict:
+        """deep analyse dataset
+
+        Args:
+            dataset_dir (str): the root directory of dataset.
+
+        Returns:
+            dict: the deep analysis results.
+        """
+        datatype = "FormulaRecDataset"
+        return deep_analyse(dataset_dir, self.output, datatype=datatype)
+
+    def get_show_type(self) -> str:
+        """get the show type of dataset
+
+        Returns:
+            str: show type
+        """
+        return "image"
+
+    def get_dataset_type(self) -> str:
+        """return the dataset type
+
+        Returns:
+            str: dataset type
+        """
+        return "FormulaRecDataset"

+ 19 - 0
paddlex/modules/formula_recognition/dataset_checker/dataset_src/__init__.py

@@ -0,0 +1,19 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .check_dataset import check
+from .convert_dataset import convert
+from .split_dataset import split_dataset
+from .analyse_dataset import deep_analyse

+ 157 - 0
paddlex/modules/formula_recognition/dataset_checker/dataset_src/analyse_dataset.py

@@ -0,0 +1,157 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import json
+import math
+import platform
+from pathlib import Path
+
+from collections import defaultdict
+from PIL import Image
+import cv2
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.backends.backend_agg import FigureCanvasAgg
+from matplotlib import font_manager
+
+from .....utils.file_interface import custom_open
+from .....utils.logging import warning
+from .....utils.fonts import PINGFANG_FONT_FILE_PATH
+
+
+def simple_analyse(dataset_path, images_dict):
+    """
+    Analyse the dataset samples by return image path and label path
+
+    Args:
+        dataset_path (str): dataset path
+        ds_meta (dict): dataset meta
+        images_dict (dict): train, val and test image path
+
+    Returns:
+        tuple: tuple of sample number, image path and label path for train, val and text subdataset.
+
+    """
+    tags = ["train", "val", "test"]
+    sample_cnts = defaultdict(int)
+    img_paths = defaultdict(list)
+    res = [None] * 6
+
+    for tag in tags:
+        file_list = os.path.join(dataset_path, f"{tag}.txt")
+        if not os.path.exists(file_list):
+            if tag in ("train", "val"):
+                res.insert(0, "数据集不符合规范,请先通过数据校准")
+                return res
+            else:
+                continue
+        else:
+            with custom_open(file_list, "r") as f:
+                all_lines = f.readlines()
+
+            # Each line corresponds to a sample
+            sample_cnts[tag] = len(all_lines)
+            img_paths[tag] = images_dict[tag]
+
+    return (
+        "完成数据分析",
+        sample_cnts[tags[0]],
+        sample_cnts[tags[1]],
+        sample_cnts[tags[2]],
+        img_paths[tags[0]],
+        img_paths[tags[1]],
+        img_paths[tags[2]],
+    )
+
+
+def deep_analyse(dataset_path, output, datatype="FormulaRecDataset"):
+    """class analysis for dataset"""
+    tags = ["train", "val"]
+    all_instances = 0
+    labels_cnt = {}
+    x_max = []
+    classes_max = []
+    for tag in tags:
+        image_path = os.path.join(dataset_path, f"{tag}.txt")
+        str_nums = []
+        with custom_open(image_path, "r") as f:
+            lines = f.readlines()
+        for line in lines:
+            line = line.strip().split("\t")
+            if len(line) != 2:
+                warning(f"Error in {line}.")
+                continue
+            str_nums.append(len(line[1]))
+
+        max_length = min(768, max(str_nums))
+        interval = 20
+
+        start = 0
+        for i in range(1, math.ceil((max_length / interval))):
+            stop = i * interval
+            num_str = sum(start < i <= stop for i in str_nums)
+            labels_cnt[f"{start}-{stop}"] = num_str
+            start = stop
+        if sum(max_length < i for i in str_nums) != 0:
+            labels_cnt[f"> {max_length}"] = sum(max_length < i for i in str_nums)
+        if tag == "train":
+            cnts_train = [cat_ids for cat_name, cat_ids in labels_cnt.items()]
+            x_train = np.arange(len(cnts_train))
+            if len(x_train) > len(x_max):
+                x_max = x_train
+                classes_max = [cat_name for cat_name, cat_ids in labels_cnt.items()]
+        elif tag == "val":
+            cnts_val = [cat_ids for cat_name, cat_ids in labels_cnt.items()]
+            x_val = np.arange(len(cnts_val))
+            if len(x_val) > len(x_max):
+                x_max = x_val
+                classes_max = [cat_name for cat_name, cat_ids in labels_cnt.items()]
+
+    width = 0.3
+
+    # bar
+    os_system = platform.system().lower()
+    if os_system == "windows":
+        plt.rcParams["font.sans-serif"] = "FangSong"
+    else:
+        font = font_manager.FontProperties(fname=PINGFANG_FONT_FILE_PATH, size=15)
+
+    fig, ax = plt.subplots(figsize=(15, 9), dpi=120)
+    xlabel_name = "公式长度区间"
+
+    ax.bar(x_train, cnts_train, width=0.3, label="train")
+    ax.bar(x_val + width, cnts_val, width=0.3, label="val")
+    plt.xticks(x_max + width / 2, classes_max, rotation=90)
+    plt.legend(prop={"size": 18})
+    ax.set_xlabel(
+        xlabel_name,
+        fontproperties=None if os_system == "windows" else font,
+        fontsize=12,
+    )
+    ax.set_ylabel(
+        "图片数量", fontproperties=None if os_system == "windows" else font, fontsize=12
+    )
+
+    canvas = FigureCanvasAgg(fig)
+    canvas.draw()
+    width, height = fig.get_size_inches() * fig.get_dpi()
+    pie_array = np.frombuffer(canvas.tostring_rgb(), dtype="uint8").reshape(
+        int(height), int(width), 3
+    )
+    fig1_path = os.path.join(output, "histogram.png")
+    cv2.imwrite(fig1_path, pie_array)
+
+    return {"histogram": os.path.join("check_dataset", "histogram.png")}

+ 80 - 0
paddlex/modules/formula_recognition/dataset_checker/dataset_src/check_dataset.py

@@ -0,0 +1,80 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import os.path as osp
+from collections import defaultdict
+
+from PIL import Image
+import json
+import numpy as np
+
+from .....utils.errors import DatasetFileNotFoundError, CheckFailedError
+
+
+def check(
+    dataset_dir, output, dataset_type=" FormulaRecDataset", mode="fast", sample_num=10
+):
+    """check dataset"""
+    if dataset_type == "FormulaRecDataset":
+        # Custom dataset
+        if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
+            raise DatasetFileNotFoundError(file_path=dataset_dir)
+        tags = ["train", "val"]
+        delim = "\t"
+        valid_num_parts = 2
+        max_recorded_sample_cnts = 50
+        sample_cnts = dict()
+        sample_paths = defaultdict(list)
+        for tag in tags:
+            file_list = osp.join(dataset_dir, f"{tag}.txt")
+            if not osp.exists(file_list):
+                if tag in ("train", "val"):
+                    # train and val file lists must exist
+                    raise DatasetFileNotFoundError(
+                        file_path=file_list,
+                        solution=f"Ensure that both `train.txt` and `val.txt` exist in {dataset_dir}",
+                    )
+                else:
+                    continue
+            else:
+                with open(file_list, "r", encoding="utf-8") as f:
+                    all_lines = f.readlines()
+                    sample_cnts[tag] = len(all_lines)
+                    for line in all_lines:
+                        substr = line.strip("\n").split(delim)
+                        if len(line.strip("\n")) < 1:
+                            continue
+                        if len(substr) != valid_num_parts and len(line.strip("\n")) > 1:
+                            raise CheckFailedError(
+                                f"Error in {line}, The number of delimiter-separated items in each row "
+                                "in {file_list} should be {valid_num_parts} (current delimiter is '{delim}')."
+                            )
+                        file_name = substr[0]
+                        img_path = osp.join(dataset_dir, file_name)
+                        if len(sample_paths[tag]) < max_recorded_sample_cnts:
+                            sample_paths[tag].append(os.path.relpath(img_path, output))
+
+                        if not os.path.exists(img_path):
+                            raise DatasetFileNotFoundError(file_path=img_path)
+
+        meta = {}
+        meta["train_samples"] = sample_cnts["train"]
+        meta["train_sample_paths"] = sample_paths["train"][:sample_num]
+
+        meta["val_samples"] = sample_cnts["val"]
+        meta["val_sample_paths"] = sample_paths["val"][:sample_num]
+
+        return meta

+ 94 - 0
paddlex/modules/formula_recognition/dataset_checker/dataset_src/convert_dataset.py

@@ -0,0 +1,94 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import shutil
+import json
+import random
+import math
+import pickle
+from tqdm import tqdm
+from collections import defaultdict
+import imagesize
+from .....utils.errors import ConvertFailedError
+from .....utils.logging import info, warning
+
+
+def check_src_dataset(root_dir, dataset_type):
+    """check src dataset format validity"""
+    if dataset_type in ("FormulaRecDataset"):
+        anno_suffix = ".txt"
+    else:
+        raise ConvertFailedError(
+            message=f"数据格式转换失败!不支持{dataset_type}格式数据集。当前仅支持 FormulaRecDataset 格式。"
+        )
+
+    err_msg_prefix = f"数据格式转换失败!请参考上述`{dataset_type}格式数据集示例`检查待转换数据集格式。"
+
+    for anno in ["train.txt", "val.txt"]:
+        src_anno_path = os.path.join(root_dir, anno)
+        if not os.path.exists(src_anno_path):
+            raise ConvertFailedError(
+                message=f"{err_msg_prefix}保证{src_anno_path}文件存在。"
+            )
+    return None
+
+
+def convert(dataset_type, input_dir):
+    """convert dataset to pkl format"""
+    # check format validity
+    check_src_dataset(input_dir, dataset_type)
+    if dataset_type in ("FormulaRecDataset"):
+        convert_pkl_dataset(input_dir)
+    else:
+        raise ConvertFailedError(
+            message=f"数据格式转换失败!不支持{dataset_type}格式数据集。当前仅支持 FormulaRecDataset 格式。"
+        )
+
+
+def convert_pkl_dataset(root_dir):
+    for anno in ["train.txt", "val.txt"]:
+        src_img_dir = root_dir
+        src_anno_path = os.path.join(root_dir, anno)
+        txt2pickle(src_img_dir, src_anno_path, root_dir)
+
+
+def txt2pickle(images, equations, save_dir):
+    phase = os.path.basename(equations).replace(".txt", "")
+    save_p = os.path.join(save_dir, "latexocr_{}.pkl".format(phase))
+    min_dimensions = (32, 32)
+    max_dimensions = (672, 192)
+    max_length = 512
+    data = defaultdict(lambda: [])
+    pic_num = 0
+    if images is not None and equations is not None:
+        with open(equations, "r") as f:
+            lines = f.readlines()
+            for l in tqdm(lines, total=len(lines)):
+                l = l.strip()
+                img_name, equation = l.split("\t")
+                img_path = os.path.join(images, img_name)
+                width, height = imagesize.get(img_path)
+                if (
+                    min_dimensions[0] <= width <= max_dimensions[0]
+                    and min_dimensions[1] <= height <= max_dimensions[1]
+                ):
+                    divide_h = math.ceil(height / 16) * 16
+                    divide_w = math.ceil(width / 16) * 16
+                    data[(divide_w, divide_h)].append((equation, img_name))
+                    pic_num += 1
+        data = dict(data)
+        with open(save_p, "wb") as file:
+            pickle.dump(data, file)

+ 81 - 0
paddlex/modules/formula_recognition/dataset_checker/dataset_src/split_dataset.py

@@ -0,0 +1,81 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import shutil
+from random import shuffle
+
+from .....utils.file_interface import custom_open
+
+
+def split_dataset(dataset_root, train_rate, val_rate):
+    """
+    将图像数据集按照比例分成训练集、验证集和测试集,并生成对应的.txt文件。
+
+    Args:
+        dataset_root (str): 数据集根目录路径。
+        train_rate (int): 训练集占总数据集的比例(%)。
+        val_rate (int): 验证集占总数据集的比例(%)。
+
+    Returns:
+        str: 数据划分结果信息。
+    """
+    sum_rate = train_rate + val_rate
+    if sum_rate != 100:
+        return "训练集、验证集比例之和需要等于100,请修改后重试"
+    tags = ["train", "val"]
+
+    valid_path = False
+    image_files = []
+    for tag in tags:
+        split_image_list = os.path.abspath(os.path.join(dataset_root, f"{tag}.txt"))
+        rename_image_list = os.path.abspath(
+            os.path.join(dataset_root, f"{tag}.txt.bak")
+        )
+        if os.path.exists(split_image_list):
+            with custom_open(split_image_list, "r") as f:
+                lines = f.readlines()
+            image_files = image_files + lines
+            valid_path = True
+            if not os.path.exists(rename_image_list):
+                os.rename(split_image_list, rename_image_list)
+
+    if not valid_path:
+        return f"数据集目录下保存待划分文件{tags[0]}.txt或{tags[1]}.txt不存在,请检查后重试"
+
+    shuffle(image_files)
+    start = 0
+    image_num = len(image_files)
+    rate_list = [train_rate, val_rate]
+    for i, tag in enumerate(tags):
+
+        rate = rate_list[i]
+        if rate == 0:
+            continue
+        if rate > 100 or rate < 0:
+            return f"{tag} 数据集的比例应该在0~100之间."
+
+        end = start + round(image_num * rate / 100)
+        if sum(rate_list[i + 1 :]) == 0:
+            end = image_num
+
+        txt_file = os.path.abspath(os.path.join(dataset_root, tag + ".txt"))
+        with custom_open(txt_file, "w") as f:
+            m = 0
+            for id in range(start, end):
+                m += 1
+                f.write(image_files[id])
+        start = end
+    return dataset_root

+ 64 - 0
paddlex/modules/formula_recognition/evaluator.py

@@ -0,0 +1,64 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from pathlib import Path
+
+from ..base import BaseEvaluator
+from .model_list import MODELS
+
+
+class FormulaRecEvaluator(BaseEvaluator):
+    """Text Recognition Model Evaluator"""
+
+    entities = MODELS
+
+    def update_config(self):
+        """update evalution config"""
+        if self.eval_config.log_interval:
+            self.pdx_config.update_log_interval(self.eval_config.log_interval)
+        if self.global_config["model"] == "LaTeX_OCR_rec":
+            self.pdx_config.update_dataset(
+                self.global_config.dataset_dir, "LaTeXOCRDataSet"
+            )
+        elif self.global_config["model"] in (
+            "UniMERNet",
+            "PP-FormulaNet-L",
+            "PP-FormulaNet-S",
+        ):
+            self.pdx_config.update_dataset(
+                self.global_config.dataset_dir, "SimpleDataSet"
+            )
+        label_dict_path = None
+        if self.eval_config.get("label_dict_path"):
+            label_dict_path = self.eval_config.label_dict_path
+        else:
+            label_dict_path = (
+                Path(self.eval_config.weight_path).parent / "label_dict.txt"
+            )
+            if not label_dict_path.exists():
+                label_dict_path = None
+        if label_dict_path is not None:
+            self.pdx_config.update_label_dict_path(label_dict_path)
+
+    def get_eval_kwargs(self) -> dict:
+        """get key-value arguments of model evalution function
+
+        Returns:
+            dict: the arguments of evaluation function.
+        """
+        return {
+            "weight_path": self.eval_config.weight_path,
+            "device": self.get_device(),
+        }

+ 22 - 0
paddlex/modules/formula_recognition/exportor.py

@@ -0,0 +1,22 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseExportor
+from .model_list import MODELS
+
+
+class FormulaRecExportor(BaseExportor):
+    """Text Recognition Model Exportor"""
+
+    entities = MODELS

+ 3 - 0
paddlex/modules/formula_recognition/model_list.py

@@ -14,4 +14,7 @@
 
 MODELS = [
     "LaTeX_OCR_rec",
+    "UniMERNet",
+    "PP-FormulaNet-S",
+    "PP-FormulaNet-L",
 ]

+ 111 - 0
paddlex/modules/formula_recognition/trainer.py

@@ -0,0 +1,111 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import shutil
+from pathlib import Path
+
+from ..base import BaseTrainer
+from ...utils.config import AttrDict
+from .model_list import MODELS
+
+
+class FormulaRecTrainer(BaseTrainer):
+    """Text Recognition Model Trainer"""
+
+    entities = MODELS
+
+    def dump_label_dict(self, src_label_dict_path: str):
+        """dump label dict config
+
+        Args:
+            src_label_dict_path (str): path to label dict file to be saved.
+        """
+        dst_label_dict_path = Path(self.global_config.output).joinpath("label_dict.txt")
+        shutil.copyfile(src_label_dict_path, dst_label_dict_path)
+
+    def update_config(self):
+        """update training config"""
+        if self.train_config.log_interval:
+            self.pdx_config.update_log_interval(self.train_config.log_interval)
+        if self.train_config.eval_interval:
+            self.pdx_config._update_eval_interval_by_epoch(
+                self.train_config.eval_interval
+            )
+        if self.train_config.save_interval:
+            self.pdx_config.update_save_interval(self.train_config.save_interval)
+
+        if self.global_config["model"] == "LaTeX_OCR_rec":
+            self.pdx_config.update_dataset(
+                self.global_config.dataset_dir, "LaTeXOCRDataSet"
+            )
+        elif self.global_config["model"] in (
+            "UniMERNet",
+            "PP-FormulaNet-L",
+            "PP-FormulaNet-S",
+        ):
+            self.pdx_config.update_dataset(
+                self.global_config.dataset_dir, "SimpleDataSet"
+            )
+
+        label_dict_path = Path(self.global_config.dataset_dir).joinpath("dict.txt")
+        if label_dict_path.exists():
+            self.pdx_config.update_label_dict_path(label_dict_path)
+            self.dump_label_dict(label_dict_path)
+
+        if self.train_config.pretrain_weight_path:
+            self.pdx_config.update_pretrained_weights(
+                self.train_config.pretrain_weight_path
+            )
+
+        if self.global_config["model"] == "LaTeX_OCR_rec":
+            if (
+                self.train_config.batch_size_train is not None
+                and self.train_config.batch_size_val is not None
+            ):
+                self.pdx_config.update_batch_size_pair(
+                    self.train_config.batch_size_train, self.train_config.batch_size_val
+                )
+        else:
+            if (
+                self.train_config.batch_size_train is not None
+                and self.train_config.batch_size_val is not None
+            ):
+                self.pdx_config.update_batch_size(
+                    self.train_config.batch_size_train, self.train_config.batch_size_val
+                )
+
+        if self.train_config.learning_rate is not None:
+            self.pdx_config.update_learning_rate(self.train_config.learning_rate)
+        if self.train_config.epochs_iters is not None:
+            self.pdx_config._update_epochs(self.train_config.epochs_iters)
+        if (
+            self.train_config.resume_path is not None
+            and self.train_config.resume_path != ""
+        ):
+            self.pdx_config._update_checkpoints(self.train_config.resume_path)
+        if self.global_config.output is not None:
+            self.pdx_config._update_output_dir(self.global_config.output)
+
+    def get_train_kwargs(self) -> dict:
+        """get key-value arguments of model training function
+
+        Returns:
+            dict: the arguments of training function.
+        """
+        return {
+            "device": self.get_device(),
+            "dy2st": self.train_config.get("dy2st", False),
+        }

+ 0 - 3
paddlex/modules/text_recognition/dataset_checker/__init__.py

@@ -25,9 +25,6 @@ from ...base import BaseDatasetChecker
 from .dataset_src import check, split_dataset, deep_analyse, convert
 
 from ..model_list import MODELS
-from ...formula_recognition.model_list import MODELS as MODELS_LaTeX
-
-MODELS = MODELS + MODELS_LaTeX
 
 
 class TextRecDatasetChecker(BaseDatasetChecker):

+ 0 - 3
paddlex/modules/text_recognition/evaluator.py

@@ -17,9 +17,6 @@ from pathlib import Path
 
 from ..base import BaseEvaluator
 from .model_list import MODELS
-from ..formula_recognition.model_list import MODELS as MODELS_LaTeX
-
-MODELS = MODELS + MODELS_LaTeX
 
 
 class TextRecEvaluator(BaseEvaluator):

+ 0 - 3
paddlex/modules/text_recognition/exportor.py

@@ -14,9 +14,6 @@
 
 from ..base import BaseExportor
 from .model_list import MODELS
-from ..formula_recognition.model_list import MODELS as MODELS_LaTeX
-
-MODELS = MODELS + MODELS_LaTeX
 
 
 class TextRecExportor(BaseExportor):

+ 0 - 3
paddlex/modules/text_recognition/trainer.py

@@ -20,9 +20,6 @@ from pathlib import Path
 from ..base import BaseTrainer
 from ...utils.config import AttrDict
 from .model_list import MODELS
-from ..formula_recognition.model_list import MODELS as MODELS_LaTeX
-
-MODELS = MODELS + MODELS_LaTeX
 
 
 class TextRecTrainer(BaseTrainer):

+ 1 - 0
paddlex/repo_apis/PaddleOCR_api/__init__.py

@@ -17,4 +17,5 @@
 # task
 from .text_det import register
 from .text_rec import register
+from .formula_rec import register
 from .table_rec import register

+ 1 - 0
paddlex/repo_apis/PaddleOCR_api/configs/LaTeX_OCR_rec.yml

@@ -107,6 +107,7 @@ Eval:
     keep_smaller_batches: True
     transforms:
       - DecodeImage:
+          img_mode: RGB
           channel_first: False
       - MinMaxResize:
           min_dimensions: [32, 32]

+ 117 - 0
paddlex/repo_apis/PaddleOCR_api/configs/PP-FormulaNet-L.yaml

@@ -0,0 +1,117 @@
+Global:
+  use_gpu: True
+  epoch_num: 10
+  log_smooth_window: 10
+  print_batch_step: 10
+  save_model_dir: ./output/rec/pp_formulanet_l/
+  save_epoch_step: 2
+  # evaluation is run every  417  iterations (1 epoch)(batch_size = 24)   # max_seq_len: 1024
+  eval_batch_step: [0,  417 ]
+  cal_metric_during_train: True
+  pretrained_model:
+  checkpoints:
+  save_inference_dir:
+  use_visualdl: False
+  infer_img: doc/datasets/pme_demo/0000013.png
+  infer_mode: False
+  use_space_char: False
+  rec_char_dict_path: &rec_char_dict_path ppocr/utils/dict/unimernet_tokenizer
+  max_new_tokens: &max_new_tokens 1024
+  input_size: &input_size [768, 768]
+  save_res_path: ./output/rec/predicts_unimernet_latexocr.txt
+  allow_resize_largeImg: False
+  start_ema: True
+
+Optimizer:
+  name: AdamW
+  beta1: 0.9
+  beta2: 0.999
+  weight_decay: 0.05
+  lr:
+    name: LinearWarmupCosine
+    learning_rate: 0.0001
+
+Architecture:
+  model_type: rec
+  algorithm: PP-FormulaNet-L
+  in_channels: 3
+  Transform:
+  Backbone:
+    name: Vary_VIT_B_Formula
+    image_size: 768 
+    encoder_embed_dim: 768
+    encoder_depth: 12
+    encoder_num_heads: 12
+    encoder_global_attn_indexes: [2, 5, 8, 11]
+  Head:
+    name: PPFormulaNet_Head
+    max_new_tokens: *max_new_tokens
+    decoder_start_token_id: 0
+    decoder_ffn_dim: 2048
+    decoder_hidden_size: 512
+    decoder_layers: 8
+    temperature: 0.2
+    do_sample: False
+    top_p: 0.95 
+    encoder_hidden_size: 1024
+    is_export: False
+    length_aware: False 
+    use_parallel: False
+    parallel_step: 0
+
+Loss:
+  name: PPFormulaNet_L_Loss
+
+PostProcess:
+  name:  UniMERNetDecode
+  rec_char_dict_path:  *rec_char_dict_path
+
+Metric:
+  name: LaTeXOCRMetric
+  main_indicator:  exp_rate
+  cal_blue_score: False
+
+Train:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./ocr_rec_latexocr_dataset_example
+    label_file_list: ["./ocr_rec_latexocr_dataset_example/train.txt"]
+    transforms:
+      - UniMERNetImgDecode:
+          input_size: *input_size
+      - UniMERNetTrainTransform: 
+      - LatexImageFormat:
+      - UniMERNetLabelEncode:
+          rec_char_dict_path: *rec_char_dict_path
+          max_seq_len:  *max_new_tokens
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'attention_mask']
+
+  loader:
+    shuffle: False
+    drop_last: False
+    batch_size_per_card: 6
+    num_workers: 0
+    collate_fn: UniMERNetCollator
+
+Eval:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./ocr_rec_latexocr_dataset_example
+    label_file_list: ["./ocr_rec_latexocr_dataset_example/val.txt"]
+    transforms:
+      - UniMERNetImgDecode:
+          input_size: *input_size
+      - UniMERNetTestTransform:
+      - LatexImageFormat:
+      - UniMERNetLabelEncode:
+          max_seq_len:  *max_new_tokens
+          rec_char_dict_path: *rec_char_dict_path
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'attention_mask', 'filename']
+  loader:
+    shuffle: False
+    drop_last: False
+    batch_size_per_card: 10
+    num_workers: 0
+    collate_fn: UniMERNetCollator

+ 115 - 0
paddlex/repo_apis/PaddleOCR_api/configs/PP-FormulaNet-S.yaml

@@ -0,0 +1,115 @@
+Global:
+  use_gpu: True
+  epoch_num: 20
+  log_smooth_window: 10
+  print_batch_step: 10
+  save_model_dir: ./output/rec/pp_formulanet_s/
+  save_epoch_step: 2
+  # evaluation is run every 179 iterations (1 epoch)(batch_size = 56)   # max_seq_len: 1024
+  eval_batch_step: [0, 179]
+  cal_metric_during_train: True
+  pretrained_model:
+  checkpoints:
+  save_inference_dir:
+  use_visualdl: False
+  infer_img: doc/datasets/pme_demo/0000013.png
+  infer_mode: False
+  use_space_char: False
+  rec_char_dict_path: &rec_char_dict_path  ppocr/utils/dict/unimernet_tokenizer
+  max_new_tokens: &max_new_tokens 1024
+  input_size: &input_size [384, 384]
+  save_res_path: ./output/rec/predicts_unimernet_latexocr.txt
+  allow_resize_largeImg: False
+  start_ema: True
+
+Optimizer:
+  name: AdamW
+  beta1: 0.9
+  beta2: 0.999
+  weight_decay: 0.05
+  lr:
+    name: LinearWarmupCosine
+    learning_rate: 0.0001
+
+Architecture:
+  model_type: rec
+  algorithm: PP-FormulaNet-S
+  in_channels: 3
+  Transform:
+  Backbone:
+    name: PPHGNetV2_B4
+    class_num: 1024
+
+  Head:
+    name: PPFormulaNet_Head
+    max_new_tokens:  *max_new_tokens
+    decoder_start_token_id: 0
+    decoder_ffn_dim: 1536
+    decoder_hidden_size: 384
+    decoder_layers: 2
+    temperature: 0.2
+    do_sample: False
+    top_p: 0.95 
+    encoder_hidden_size: 2048
+    is_export: False
+    length_aware: True 
+    use_parallel: True,
+    parallel_step: 3
+
+Loss:
+  name: PPFormulaNet_S_Loss
+  parallel_step: 3
+
+PostProcess:
+  name:  UniMERNetDecode
+  rec_char_dict_path: *rec_char_dict_path
+
+Metric:
+  name: LaTeXOCRMetric
+  main_indicator:  exp_rate
+  cal_blue_score: False
+
+Train:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./ocr_rec_latexocr_dataset_example
+    label_file_list: ["./ocr_rec_latexocr_dataset_example/train.txt"]
+    transforms:
+      - UniMERNetImgDecode:
+          input_size: *input_size
+      - UniMERNetTrainTransform: 
+      - LatexImageFormat:
+      - UniMERNetLabelEncode:
+          rec_char_dict_path: *rec_char_dict_path
+          max_seq_len: *max_new_tokens
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'attention_mask']
+
+  loader:
+    shuffle: False
+    drop_last: False
+    batch_size_per_card: 14
+    num_workers: 0
+    collate_fn: UniMERNetCollator
+
+Eval:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./ocr_rec_latexocr_dataset_example
+    label_file_list: ["./ocr_rec_latexocr_dataset_example/val.txt"]
+    transforms:
+      - UniMERNetImgDecode:
+          input_size:  *input_size
+      - UniMERNetTestTransform:
+      - LatexImageFormat:
+      - UniMERNetLabelEncode:
+          max_seq_len: *max_new_tokens
+          rec_char_dict_path: *rec_char_dict_path
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'attention_mask', 'filename']
+  loader:
+    shuffle: False
+    drop_last: False
+    batch_size_per_card: 30
+    num_workers: 0
+    collate_fn: UniMERNetCollator

+ 113 - 0
paddlex/repo_apis/PaddleOCR_api/configs/UniMERNet.yaml

@@ -0,0 +1,113 @@
+Global:
+  use_gpu: True
+  epoch_num: 40
+  log_smooth_window: 10
+  print_batch_step: 10
+  save_model_dir: ./output/rec/unimernet/
+  save_epoch_step: 5
+  # evaluation is run every 37880 iterations after the 0th iteration
+  eval_batch_step: [0, 10]
+  cal_metric_during_train: True
+  pretrained_model:
+  checkpoints:
+  save_inference_dir:
+  use_visualdl: False
+  infer_img: doc/datasets/pme_demo/0000013.png
+  infer_mode: False
+  use_space_char: False
+  rec_char_dict_path: &rec_char_dict_path ppocr/utils/dict/unimernet_tokenizer
+  input_size: &input_size [192, 672]
+  max_seq_len: &max_seq_len 1024
+  save_res_path: ./output/rec/predicts_unimernet_plus_config_latexocr.txt
+  allow_resize_largeImg: False
+
+Optimizer:
+  name: AdamW
+  beta1: 0.9
+  beta2: 0.999
+  weight_decay: 0.05
+  lr:
+    name: LinearWarmupCosine
+    learning_rate: 1e-4
+    start_lr: 1e-5
+    min_lr: 1e-8
+    warmup_steps: 5000
+
+Architecture:
+  model_type: rec
+  algorithm: UniMERNet
+  in_channels: 3
+  Transform:
+  Backbone:
+    name: DonutSwinModel
+    hidden_size : 1024
+    num_layers: 4
+    num_heads: [4, 8, 16, 32]
+    add_pooling_layer: True
+    use_mask_token: False
+  Head:
+    name: UniMERNetHead
+    max_new_tokens: 1536
+    decoder_start_token_id: 0
+    temperature: 0.2
+    do_sample: False
+    top_p: 0.95 
+    encoder_hidden_size: 1024
+    is_export: False
+    length_aware: True 
+
+Loss:
+  name: UniMERNetLoss
+
+PostProcess:
+  name:  UniMERNetDecode
+  rec_char_dict_path: *rec_char_dict_path
+
+Metric:
+  name: LaTeXOCRMetric
+  main_indicator:  exp_rate
+  cal_blue_score: False
+
+Train:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./train_data/UniMERNet/
+    label_file_list: ["./train_data/UniMERNet/train_unimernet_1M.txt"]
+    transforms:
+      - UniMERNetImgDecode:
+          input_size: *input_size
+      - UniMERNetTrainTransform: 
+      - UniMERNetImageFormat:
+      - UniMERNetLabelEncode:
+          rec_char_dict_path: *rec_char_dict_path
+          max_seq_len: *max_seq_len
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'attention_mask']
+  loader:
+    shuffle: False
+    drop_last: False
+    batch_size_per_card: 7
+    num_workers: 0
+    collate_fn: UniMERNetCollator
+
+Eval:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./train_data/UniMERNet/UniMER-Test/cpe
+    label_file_list: ["./train_data/UniMERNet/test_unimernet_cpe.txt"]
+    transforms:
+      - UniMERNetImgDecode:
+          input_size: *input_size
+      - UniMERNetTestTransform:
+      - UniMERNetImageFormat:
+      - UniMERNetLabelEncode:
+          max_seq_len: *max_seq_len
+          rec_char_dict_path: *rec_char_dict_path
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'attention_mask']
+  loader:
+    shuffle: False
+    drop_last: False
+    batch_size_per_card: 30
+    num_workers: 0
+    collate_fn: UniMERNetCollator

+ 16 - 0
paddlex/repo_apis/PaddleOCR_api/formula_rec/__init__.py

@@ -0,0 +1,16 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from . import register

+ 544 - 0
paddlex/repo_apis/PaddleOCR_api/formula_rec/config.py

@@ -0,0 +1,544 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import yaml
+from typing import Union
+from ...base import BaseConfig
+from ....utils.misc import abspath
+from ..config_utils import load_config, merge_config
+
+
+class FormulaRecConfig(BaseConfig):
+    """Formula Recognition Config"""
+
+    def update(self, dict_like_obj: list):
+        """update self
+
+        Args:
+            dict_like_obj (dict): dict of pairs(key0.key1.idx.key2=value).
+        """
+        dict_ = merge_config(self.dict, dict_like_obj)
+        self.reset_from_dict(dict_)
+
+    def load(self, config_file_path: str):
+        """load config from yaml file
+
+        Args:
+            config_file_path (str): the path of yaml file.
+
+        Raises:
+            TypeError: the content of yaml file `config_file_path` error.
+        """
+        dict_ = load_config(config_file_path)
+        if not isinstance(dict_, dict):
+            raise TypeError
+        self.reset_from_dict(dict_)
+
+    def dump(self, config_file_path: str):
+        """dump self to yaml file
+
+        Args:
+            config_file_path (str): the path to save self as yaml file.
+        """
+        with open(config_file_path, "w", encoding="utf-8") as f:
+            yaml.dump(self.dict, f, default_flow_style=False, sort_keys=False)
+
+    def update_dataset(
+        self,
+        dataset_path: str,
+        dataset_type: str = None,
+        *,
+        train_list_path: str = None,
+    ):
+        """update dataset settings
+
+        Args:
+            dataset_path (str): the root path of dataset.
+            dataset_type (str, optional): dataset type. Defaults to None.
+            train_list_path (str, optional): the path of train dataset annotation file . Defaults to None.
+
+        Raises:
+            ValueError: the dataset_type error.
+        """
+        dataset_path = abspath(dataset_path)
+        if dataset_type is None:
+            dataset_type = "SimpleDataSet"
+        if train_list_path:
+            train_list_path = f"{train_list_path}"
+        else:
+            train_list_path = os.path.join(dataset_path, "train.txt")
+
+        if dataset_type == "SimpleDataSet":
+            _cfg = {
+                "Train.dataset.name": dataset_type,
+                "Train.dataset.data_dir": dataset_path,
+                "Train.dataset.label_file_list": [train_list_path],
+                "Eval.dataset.name": "SimpleDataSet",
+                "Eval.dataset.data_dir": dataset_path,
+                "Eval.dataset.label_file_list": [os.path.join(dataset_path, "val.txt")],
+            }
+            self.update(_cfg)
+        elif dataset_type == "LaTeXOCRDataSet":
+            _cfg = {
+                "Train.dataset.name": dataset_type,
+                "Train.dataset.data_dir": dataset_path,
+                "Train.dataset.data": os.path.join(dataset_path, "latexocr_train.pkl"),
+                "Train.dataset.label_file_list": [train_list_path],
+                "Eval.dataset.name": dataset_type,
+                "Eval.dataset.data_dir": dataset_path,
+                "Eval.dataset.data": os.path.join(dataset_path, "latexocr_val.pkl"),
+                "Eval.dataset.label_file_list": [os.path.join(dataset_path, "val.txt")],
+                "Global.character_dict_path": os.path.join(dataset_path, "dict.txt"),
+            }
+            self.update(_cfg)
+        else:
+            raise ValueError(f"{repr(dataset_type)} is not supported.")
+
+    def update_batch_size(
+        self, batch_size_train: int, batch_size_val: int, mode: str = "train"
+    ):
+        """update batch size setting
+
+        Args:
+            batch_size (int): the batch size number to set.
+            mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
+                Defaults to 'train'.
+
+        Raises:
+            ValueError: mode error.
+        """
+
+        _cfg = {
+            "Train.loader.batch_size_per_card": batch_size_train,
+            "Eval.loader.batch_size_per_card": batch_size_val,
+        }
+        self.update(_cfg)
+
+    def update_batch_size_pair(
+        self, batch_size_train: int, batch_size_val: int, mode: str = "train"
+    ):
+        """update batch size setting
+        Args:
+            batch_size (int): the batch size number to set.
+            mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
+                Defaults to 'train'.
+        Raises:
+            ValueError: mode error.
+        """
+        _cfg = {
+            "Train.dataset.batch_size_per_pair": batch_size_train,
+            "Eval.dataset.batch_size_per_pair": batch_size_val,
+        }
+
+        self.update(_cfg)
+
+    def update_learning_rate(self, learning_rate: float):
+        """update learning rate
+
+        Args:
+            learning_rate (float): the learning rate value to set.
+        """
+        _cfg = {
+            "Optimizer.lr.learning_rate": learning_rate,
+        }
+        self.update(_cfg)
+
+    def update_label_dict_path(self, dict_path: str):
+        """update label dict file path
+
+        Args:
+            dict_path (str): the path to label dict file.
+        """
+        _cfg = {
+            "Global.character_dict_path": abspath(dict_path),
+        }
+        self.update(_cfg)
+
+    def update_warmup_epochs(self, warmup_epochs: int):
+        """update warmup epochs
+
+        Args:
+            warmup_epochs (int): the warmup epochs value to set.
+        """
+        _cfg = {"Optimizer.lr.warmup_epoch": warmup_epochs}
+        self.update(_cfg)
+
+    def update_pretrained_weights(self, pretrained_model: str):
+        """update pretrained weight path
+
+        Args:
+            pretrained_model (str): the local path or url of pretrained weight file to set.
+        """
+        if pretrained_model:
+            if not pretrained_model.startswith(
+                "http://"
+            ) and not pretrained_model.startswith("https://"):
+                pretrained_model = abspath(pretrained_model)
+        self.update(
+            {"Global.pretrained_model": pretrained_model, "Global.checkpoints": ""}
+        )
+
+    # TODO
+    def update_class_path(self, class_path: str):
+        """_summary_
+
+        Args:
+            class_path (str): _description_
+        """
+        self.update(
+            {
+                "PostProcess.class_path": class_path,
+            }
+        )
+
+    def _update_amp(self, amp: Union[None, str]):
+        """update AMP settings
+
+        Args:
+            amp (None | str): the AMP level if it is not None or `OFF`.
+        """
+        _cfg = {
+            "Global.use_amp": amp is not None and amp != "OFF",
+            "Global.amp_level": amp,
+        }
+        self.update(_cfg)
+
+    def update_device(self, device: str):
+        """update device setting
+
+        Args:
+            device (str): the running device to set
+        """
+        device = device.split(":")[0]
+        default_cfg = {
+            "Global.use_gpu": False,
+            "Global.use_xpu": False,
+            "Global.use_npu": False,
+            "Global.use_mlu": False,
+            "Global.use_gcu": False,
+        }
+
+        device_cfg = {
+            "cpu": {},
+            "gpu": {"Global.use_gpu": True},
+            "xpu": {"Global.use_xpu": True},
+            "mlu": {"Global.use_mlu": True},
+            "npu": {"Global.use_npu": True},
+            "gcu": {"Global.use_gcu": True},
+        }
+        default_cfg.update(device_cfg[device])
+        self.update(default_cfg)
+
+    def _update_epochs(self, epochs: int):
+        """update epochs setting
+
+        Args:
+            epochs (int): the epochs number value to set
+        """
+        self.update({"Global.epoch_num": epochs})
+
+    def _update_checkpoints(self, resume_path: Union[None, str]):
+        """update checkpoint setting
+
+        Args:
+            resume_path (None | str): the resume training setting. if is `None`, train from scratch, otherwise,
+                train from checkpoint file that path is `.pdparams` file.
+        """
+        self.update(
+            {"Global.checkpoints": abspath(resume_path), "Global.pretrained_model": ""}
+        )
+
+    def _update_to_static(self, dy2st: bool):
+        """update config to set dynamic to static mode
+
+        Args:
+            dy2st (bool): whether or not to use the dynamic to static mode.
+        """
+        self.update({"Global.to_static": dy2st})
+
+    def _update_use_vdl(self, use_vdl: bool):
+        """update config to set VisualDL
+
+        Args:
+            use_vdl (bool): whether or not to use VisualDL.
+        """
+        self.update({"Global.use_visualdl": use_vdl})
+
+    def _update_output_dir(self, save_dir: str):
+        """update output directory
+
+        Args:
+            save_dir (str): the path to save output.
+        """
+        self.update({"Global.save_model_dir": abspath(save_dir)})
+
+    # TODO
+    # def _update_log_interval(self, log_interval):
+    #     self.update({'Global.print_batch_step': log_interval})
+
+    def update_log_interval(self, log_interval: int):
+        """update log interval(by steps)
+
+        Args:
+            log_interval (int): the log interval value to set.
+        """
+        self.update({"Global.print_batch_step": log_interval})
+
+    # def _update_eval_interval(self, eval_start_step, eval_interval):
+    #     self.update({
+    #         'Global.eval_batch_step': [eval_start_step, eval_interval]
+    #     })
+
+    def update_log_ranks(self, device):
+        """update log ranks
+
+        Args:
+            device (str): the running device to set
+        """
+        log_ranks = device.split(":")[1]
+        self.update({"Global.log_ranks": log_ranks})
+
+    def update_print_mem_info(self, print_mem_info: bool):
+        """setting print memory info"""
+        assert isinstance(print_mem_info, bool), "print_mem_info should be a bool"
+        self.update({"Global.print_mem_info": f"{print_mem_info}"})
+
+    def update_shared_memory(self, shared_memeory: bool):
+        """update shared memory setting of train and eval dataloader
+
+        Args:
+            shared_memeory (bool): whether or not to use shared memory
+        """
+        assert isinstance(shared_memeory, bool), "shared_memeory should be a bool"
+        _cfg = {
+            "Train.loader.use_shared_memory": f"{shared_memeory}",
+            "Train.loader.use_shared_memory": f"{shared_memeory}",
+        }
+        self.update(_cfg)
+
+    def update_shuffle(self, shuffle: bool):
+        """update shuffle setting of train and eval dataloader
+
+        Args:
+            shuffle (bool): whether or not to shuffle the data
+        """
+        assert isinstance(shuffle, bool), "shuffle should be a bool"
+        _cfg = {
+            f"Train.loader.shuffle": shuffle,
+            f"Train.loader.shuffle": shuffle,
+        }
+        self.update(_cfg)
+
+    def update_cal_metrics(self, cal_metrics: bool):
+        """update calculate metrics setting
+        Args:
+            cal_metrics (bool): whether or not to calculate metrics during train
+        """
+        assert isinstance(cal_metrics, bool), "cal_metrics should be a bool"
+        self.update({"Global.cal_metric_during_train": cal_metrics})
+
+    def update_seed(self, seed: int):
+        """update seed
+
+        Args:
+            seed (int): the random seed value to set
+        """
+        assert isinstance(seed, int), "seed should be an int"
+        self.update({"Global.seed": seed})
+
+    def _update_eval_interval_by_epoch(self, eval_interval):
+        """update eval interval(by epoch)
+
+        Args:
+            eval_interval (int): the eval interval value to set.
+        """
+        self.update({"Global.eval_batch_epoch": eval_interval})
+
+    def update_eval_interval(self, eval_interval: int, eval_start_step: int = 0):
+        """update eval interval(by steps)
+
+        Args:
+            eval_interval (int): the eval interval value to set.
+            eval_start_step (int, optional): step number from which the evaluation is enabled. Defaults to 0.
+        """
+        self._update_eval_interval(eval_start_step, eval_interval)
+
+    def _update_save_interval(self, save_interval: int):
+        """update save interval(by steps)
+
+        Args:
+            save_interval (int): the save interval value to set.
+        """
+        self.update({"Global.save_epoch_step": save_interval})
+
+    def update_save_interval(self, save_interval: int):
+        """update save interval(by steps)
+
+        Args:
+            save_interval (int): the save interval value to set.
+        """
+        self._update_save_interval(save_interval)
+
+    def _update_infer_img(self, infer_img: str, infer_list: str = None):
+        """update image list to be infered
+
+        Args:
+            infer_img (str): path to the image file to be infered. It would be ignored when `infer_list` is be set.
+            infer_list (str, optional): path to the .txt file containing the paths to image to be infered.
+                Defaults to None.
+        """
+        if infer_list:
+            self.update({"Global.infer_list": infer_list})
+        self.update({"Global.infer_img": infer_img})
+
+    def _update_save_inference_dir(self, save_inference_dir: str):
+        """update the directory saving infer outputs
+
+        Args:
+            save_inference_dir (str): the directory saving infer outputs.
+        """
+        self.update({"Global.save_inference_dir": abspath(save_inference_dir)})
+
+    def _update_save_res_path(self, save_res_path: str):
+        """update the .txt file path saving OCR model inference result
+
+        Args:
+            save_res_path (str): the .txt file path saving OCR model inference result.
+        """
+        self.update({"Global.save_res_path": abspath(save_res_path)})
+
+    def update_num_workers(
+        self, num_workers: int, modes: Union[str, list] = ["train", "eval"]
+    ):
+        """update workers number of train or eval dataloader
+
+        Args:
+            num_workers (int): the value of train and eval dataloader workers number to set.
+            modes (str | [list], optional): mode. Defaults to ['train', 'eval'].
+
+        Raises:
+            ValueError: mode error. The `mode` should be `train`, `eval` or `['train', 'eval']`.
+        """
+        if not isinstance(modes, list):
+            modes = [modes]
+        for mode in modes:
+            if not mode in ("train", "eval"):
+                raise ValueError
+            if mode == "train":
+                self["Train"]["loader"]["num_workers"] = num_workers
+            else:
+                self["Eval"]["loader"]["num_workers"] = num_workers
+
+    def _get_model_type(self) -> str:
+        """get model type
+
+        Returns:
+            str: model type, i.e. `Architecture.algorithm` or `Architecture.Models.Student.algorithm`.
+        """
+        if "Models" in self.dict["Architecture"]:
+            return self.dict["Architecture"]["Models"]["Student"]["algorithm"]
+
+        return self.dict["Architecture"]["algorithm"]
+
+    def get_epochs_iters(self) -> int:
+        """get epochs
+
+        Returns:
+            int: the epochs value, i.e., `Global.epochs` in config.
+        """
+        return self.dict["Global"]["epoch_num"]
+
+    def get_learning_rate(self) -> float:
+        """get learning rate
+
+        Returns:
+            float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
+        """
+        return self.dict["Optimizer"]["lr"]["learning_rate"]
+
+    def get_batch_size(self, mode="train") -> int:
+        """get batch size
+
+        Args:
+            mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
+                Defaults to 'train'.
+
+        Returns:
+            int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
+        """
+        return self.dict["Train"]["loader"]["batch_size_per_card"]
+
+    def get_qat_epochs_iters(self) -> int:
+        """get qat epochs
+
+        Returns:
+            int: the epochs value.
+        """
+        return self.get_epochs_iters()
+
+    def get_qat_learning_rate(self) -> float:
+        """get qat learning rate
+
+        Returns:
+            float: the learning rate value.
+        """
+        return self.get_learning_rate()
+
+    def get_label_dict_path(self) -> str:
+        """get label dict file path
+
+        Returns:
+            str: the label dict file path, i.e., `Global.character_dict_path` in config.
+        """
+        return self.dict["Global"]["character_dict_path"]
+
+    def _get_dataset_root(self) -> str:
+        """get root directory of dataset, i.e. `DataLoader.Train.dataset.data_dir`
+
+        Returns:
+            str: the root directory of dataset
+        """
+        return self.dict["Train"]["dataset"]["data_dir"]
+
+    def _get_infer_shape(self) -> str:
+        """get resize scale of ResizeImg operation in the evaluation
+
+        Returns:
+            str: resize scale, i.e. `Eval.dataset.transforms.ResizeImg.image_shape`
+        """
+        size = None
+        transforms = self.dict["Eval"]["dataset"]["transforms"]
+        for op in transforms:
+            op_name = list(op)[0]
+            if "ResizeImg" in op_name:
+                size = op[op_name]["image_shape"]
+        return ",".join([str(x) for x in size])
+
+    def get_train_save_dir(self) -> str:
+        """get the directory to save output
+
+        Returns:
+            str: the directory to save output
+        """
+        return self["Global"]["save_model_dir"]
+
+    def get_predict_save_dir(self) -> str:
+        """get the directory to save output in predicting
+
+        Returns:
+            str: the directory to save output
+        """
+        return os.path.dirname(self["Global"]["save_res_path"])

+ 396 - 0
paddlex/repo_apis/PaddleOCR_api/formula_rec/model.py

@@ -0,0 +1,396 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from ...base import BaseModel
+from ...base.utils.arg import CLIArgument
+from ...base.utils.subprocess import CompletedProcess
+from ....utils.device import parse_device
+from ....utils.misc import abspath
+from ....utils import logging
+
+
+class FormulaRecModel(BaseModel):
+    """Formula Recognition Model"""
+
+    METRICS = [
+        "acc",
+        "norm_edit_dis",
+        "Teacher_acc",
+        "Teacher_norm_edit_dis",
+        "precision",
+        "recall",
+        "hmean",
+    ]
+
+    def train(
+        self,
+        batch_size: int = None,
+        learning_rate: float = None,
+        epochs_iters: int = None,
+        ips: str = None,
+        device: str = "gpu",
+        resume_path: str = None,
+        dy2st: bool = False,
+        amp: str = "OFF",
+        num_workers: int = None,
+        use_vdl: bool = True,
+        save_dir: str = None,
+        **kwargs,
+    ) -> CompletedProcess:
+        """train self
+
+        Args:
+            batch_size (int, optional): the train batch size value. Defaults to None.
+            learning_rate (float, optional): the train learning rate value. Defaults to None.
+            epochs_iters (int, optional): the train epochs value. Defaults to None.
+            ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
+            device (str, optional): the running device. Defaults to 'gpu'.
+            resume_path (str, optional): the checkpoint file path to resume training. Train from scratch if it is set
+                to None. Defaults to None.
+            dy2st (bool, optional): Enable dynamic to static. Defaults to False.
+            amp (str, optional): the amp settings. Defaults to 'OFF'.
+            num_workers (int, optional): the workers number. Defaults to None.
+            use_vdl (bool, optional): enable VisualDL. Defaults to True.
+            save_dir (str, optional): the directory path to save train output. Defaults to None.
+
+        Returns:
+           CompletedProcess: the result of training subprocess execution.
+        """
+        config = self.config.copy()
+
+        if batch_size is not None:
+            config.update_batch_size(batch_size)
+
+        if learning_rate is not None:
+            config.update_learning_rate(learning_rate)
+
+        if epochs_iters is not None:
+            config._update_epochs(epochs_iters)
+
+        # No need to handle `ips`
+
+        config.update_device(device)
+
+        if resume_path is not None:
+            resume_path = abspath(resume_path)
+            config._update_checkpoints(resume_path)
+
+        config._update_to_static(dy2st)
+
+        config._update_amp(amp)
+
+        if num_workers is not None:
+            config.update_num_workers(num_workers, "train")
+
+        config._update_use_vdl(use_vdl)
+
+        if save_dir is not None:
+            save_dir = abspath(save_dir)
+        else:
+            save_dir = abspath(config.get_train_save_dir())
+        config._update_output_dir(save_dir)
+
+        cli_args = []
+
+        do_eval = kwargs.pop("do_eval", True)
+
+        profile = kwargs.pop("profile", None)
+        if profile is not None:
+            cli_args.append(CLIArgument("--profiler_options", profile))
+
+        # Benchmarking mode settings
+        benchmark = kwargs.pop("benchmark", None)
+        if benchmark is not None:
+            envs = benchmark.get("env", None)
+            seed = benchmark.get("seed", None)
+            do_eval = benchmark.get("do_eval", False)
+            num_workers = benchmark.get("num_workers", None)
+            config.update_log_ranks(device)
+            config._update_amp(benchmark.get("amp", None))
+            config.update_shuffle(benchmark.get("shuffle", False))
+            config.update_cal_metrics(benchmark.get("cal_metrics", True))
+            config.update_shared_memory(benchmark.get("shared_memory", True))
+            config.update_print_mem_info(benchmark.get("print_mem_info", True))
+            if num_workers is not None:
+                config.update_num_workers(num_workers)
+            if seed is not None:
+                config.update_seed(seed)
+            if envs is not None:
+                for env_name, env_value in envs.items():
+                    os.environ[env_name] = str(env_value)
+
+        # PDX related settings
+        device_type = device.split(":")[0]
+        uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+        config.update({"Global.uniform_output_enabled": uniform_output_enabled})
+        config.update({"Global.pdx_model_name": self.name})
+
+        self._assert_empty_kwargs(kwargs)
+
+        with self._create_new_config_file() as config_path:
+            config.dump(config_path)
+
+            return self.runner.train(
+                config_path, cli_args, device, ips, save_dir, do_eval=do_eval
+            )
+
+    def evaluate(
+        self,
+        weight_path: str,
+        batch_size: int = None,
+        ips: str = None,
+        device: str = "gpu",
+        amp: str = "OFF",
+        num_workers: int = None,
+        **kwargs,
+    ) -> CompletedProcess:
+        """evaluate self using specified weight
+
+        Args:
+            weight_path (str): the path of model weight file to be evaluated.
+            batch_size (int, optional): the batch size value in evaluating. Defaults to None.
+            ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
+            device (str, optional): the running device. Defaults to 'gpu'.
+            amp (str, optional): the AMP setting. Defaults to 'OFF'.
+            num_workers (int, optional): the workers number in evaluating. Defaults to None.
+
+        Returns:
+            CompletedProcess: the result of evaluating subprocess execution.
+        """
+        config = self.config.copy()
+
+        weight_path = abspath(weight_path)
+        config._update_checkpoints(weight_path)
+
+        if batch_size is not None:
+            config.update_batch_size(batch_size)
+
+        # No need to handle `ips`
+
+        config.update_device(device)
+
+        config._update_amp(amp)
+
+        if num_workers is not None:
+            config.update_num_workers(num_workers, "eval")
+
+        self._assert_empty_kwargs(kwargs)
+
+        with self._create_new_config_file() as config_path:
+            config.dump(config_path)
+            cp = self.runner.evaluate(config_path, [], device, ips)
+            return cp
+
+    def predict(
+        self,
+        weight_path: str,
+        input_path: str,
+        device: str = "gpu",
+        save_dir: str = None,
+        **kwargs,
+    ) -> CompletedProcess:
+        """predict using specified weight
+
+        Args:
+            weight_path (str): the path of model weight file used to predict.
+            input_path (str): the path of image file to be predicted.
+            device (str, optional): the running device. Defaults to 'gpu'.
+            save_dir (str, optional): the directory path to save predict output. Defaults to None.
+
+        Returns:
+            CompletedProcess: the result of predicting subprocess execution.
+        """
+        config = self.config.copy()
+
+        weight_path = abspath(weight_path)
+        config.update_pretrained_weights(weight_path)
+
+        input_path = abspath(input_path)
+        config._update_infer_img(
+            input_path, infer_list=kwargs.pop("input_list_path", None)
+        )
+
+        config.update_device(device)
+
+        # TODO: Handle `device`
+        logging.warning("`device` will not be used.")
+
+        if save_dir is not None:
+            save_dir = abspath(save_dir)
+        else:
+            save_dir = abspath(config.get_predict_save_dir())
+        config._update_save_res_path(os.path.join(save_dir, "res.txt"))
+
+        self._assert_empty_kwargs(kwargs)
+
+        with self._create_new_config_file() as config_path:
+            config.dump(config_path)
+            return self.runner.predict(config_path, [], device)
+
+    def export(self, weight_path: str, save_dir: str, **kwargs) -> CompletedProcess:
+        """export the dynamic model to static model
+
+        Args:
+            weight_path (str): the model weight file path that used to export.
+            save_dir (str): the directory path to save export output.
+
+        Returns:
+            CompletedProcess: the result of exporting subprocess execution.
+        """
+        config = self.config.copy()
+
+        device = kwargs.pop("device", None)
+        if device:
+            config.update_device(device)
+
+        if not weight_path.startswith("http"):
+            weight_path = abspath(weight_path)
+        config.update_pretrained_weights(weight_path)
+
+        save_dir = abspath(save_dir)
+        config._update_save_inference_dir(save_dir)
+
+        class_path = kwargs.pop("class_path", None)
+        if class_path is not None:
+            config.update_class_path(class_path)
+
+        # PDX related settings
+        uniform_output_enabled = kwargs.pop("uniform_output_enabled", True)
+        config.update({"Global.uniform_output_enabled": uniform_output_enabled})
+        config.update({"Global.pdx_model_name": self.name})
+
+        self._assert_empty_kwargs(kwargs)
+
+        with self._create_new_config_file() as config_path:
+            config.dump(config_path)
+            return self.runner.export(config_path, [], None, save_dir)
+
+    def infer(
+        self,
+        model_dir: str,
+        input_path: str,
+        device: str = "gpu",
+        save_dir: str = None,
+        **kwargs,
+    ) -> CompletedProcess:
+        """predict image using infernece model
+
+        Args:
+            model_dir (str): the directory path of inference model files that would use to predict.
+            input_path (str): the path of image that would be predict.
+            device (str, optional): the running device. Defaults to 'gpu'.
+            save_dir (str, optional): the directory path to save output. Defaults to None.
+
+        Returns:
+            CompletedProcess: the result of infering subprocess execution.
+        """
+        config = self.config.copy()
+        cli_args = []
+
+        model_dir = abspath(model_dir)
+        cli_args.append(CLIArgument("--rec_model_dir", model_dir))
+
+        input_path = abspath(input_path)
+        cli_args.append(CLIArgument("--image_dir", input_path))
+
+        device_type, _ = parse_device(device)
+        cli_args.append(CLIArgument("--use_gpu", str(device_type == "gpu")))
+
+        if save_dir is not None:
+            logging.warning("`save_dir` will not be used.")
+
+        dict_path = kwargs.pop("dict_path", None)
+        if dict_path is not None:
+            dict_path = abspath(dict_path)
+        else:
+            dict_path = config.get_label_dict_path()
+        cli_args.append(CLIArgument("--rec_char_dict_path", dict_path))
+
+        model_type = config._get_model_type()
+        cli_args.append(CLIArgument("--rec_algorithm", model_type))
+        infer_shape = config._get_infer_shape()
+        if infer_shape is not None:
+            cli_args.append(CLIArgument("--rec_image_shape", infer_shape))
+
+        self._assert_empty_kwargs(kwargs)
+
+        with self._create_new_config_file() as config_path:
+            config.dump(config_path)
+            return self.runner.infer(config_path, cli_args, device)
+
+    def compression(
+        self,
+        weight_path: str,
+        batch_size: int = None,
+        learning_rate: float = None,
+        epochs_iters: int = None,
+        device: str = "gpu",
+        use_vdl: bool = True,
+        save_dir: str = None,
+        **kwargs,
+    ) -> CompletedProcess:
+        """compression model
+
+        Args:
+            weight_path (str): the path to weight file of model.
+            batch_size (int, optional): the batch size value of compression training. Defaults to None.
+            learning_rate (float, optional): the learning rate value of compression training. Defaults to None.
+            epochs_iters (int, optional): the epochs or iters of compression training. Defaults to None.
+            device (str, optional): the device to run compression training. Defaults to 'gpu'.
+            use_vdl (bool, optional): whether or not to use VisualDL. Defaults to True.
+            save_dir (str, optional): the directory to save output. Defaults to None.
+
+        Returns:
+            CompletedProcess: the result of compression subprocess execution.
+        """
+        config = self.config.copy()
+        export_cli_args = []
+
+        weight_path = abspath(weight_path)
+        config.update_pretrained_weights(weight_path)
+
+        if batch_size is not None:
+            config.update_batch_size(batch_size)
+
+        if learning_rate is not None:
+            config.update_learning_rate(learning_rate)
+
+        if epochs_iters is not None:
+            config._update_epochs(epochs_iters)
+
+        config.update_device(device)
+
+        config._update_use_vdl(use_vdl)
+
+        if save_dir is not None:
+            save_dir = abspath(save_dir)
+        else:
+            save_dir = abspath(config.get_train_save_dir())
+        config._update_output_dir(save_dir)
+        export_cli_args.append(
+            CLIArgument(
+                "-o", f"Global.save_inference_dir={os.path.join(save_dir, 'export')}"
+            )
+        )
+
+        self._assert_empty_kwargs(kwargs)
+
+        with self._create_new_config_file() as config_path:
+            config.dump(config_path)
+
+            return self.runner.compression(
+                config_path, [], export_cli_args, device, save_dir
+            )

+ 73 - 0
paddlex/repo_apis/PaddleOCR_api/formula_rec/register.py

@@ -0,0 +1,73 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+from pathlib import Path
+
+from ...base.register import register_model_info, register_suite_info
+from .model import FormulaRecModel
+from .runner import FormulaRecRunner
+from .config import FormulaRecConfig
+
+REPO_ROOT_PATH = os.environ.get("PADDLE_PDX_PADDLEOCR_PATH")
+PDX_CONFIG_DIR = osp.abspath(osp.join(osp.dirname(__file__), "..", "configs"))
+
+register_suite_info(
+    {
+        "suite_name": "FormulaRec",
+        "model": FormulaRecModel,
+        "runner": FormulaRecRunner,
+        "config": FormulaRecConfig,
+        "runner_root_path": REPO_ROOT_PATH,
+    }
+)
+
+
+register_model_info(
+    {
+        "model_name": "LaTeX_OCR_rec",
+        "suite": "FormulaRec",
+        "config_path": osp.join(PDX_CONFIG_DIR, "LaTeX_OCR_rec.yml"),
+        "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+    }
+)
+
+register_model_info(
+    {
+        "model_name": "UniMERNet",
+        "suite": "FormulaRec",
+        "config_path": osp.join(PDX_CONFIG_DIR, "UniMERNet.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+    }
+)
+
+
+register_model_info(
+    {
+        "model_name": "PP-FormulaNet-S",
+        "suite": "FormulaRec",
+        "config_path": osp.join(PDX_CONFIG_DIR, "PP-FormulaNet-S.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+    }
+)
+
+register_model_info(
+    {
+        "model_name": "PP-FormulaNet-L",
+        "suite": "FormulaRec",
+        "config_path": osp.join(PDX_CONFIG_DIR, "PP-FormulaNet-L.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+    }
+)

+ 240 - 0
paddlex/repo_apis/PaddleOCR_api/formula_rec/runner.py

@@ -0,0 +1,240 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import tempfile
+
+from ...base import BaseRunner
+from ...base.utils.subprocess import CompletedProcess
+
+
+class FormulaRecRunner(BaseRunner):
+    """Formula Recognition Runner"""
+
+    def train(
+        self,
+        config_path: str,
+        cli_args: list,
+        device: str,
+        ips: str,
+        save_dir: str,
+        do_eval=True,
+    ) -> CompletedProcess:
+        """train model
+
+        Args:
+            config_path (str): the config file path used to train.
+            cli_args (list): the additional parameters.
+            device (str): the training device.
+            ips (str): the ip addresses of nodes when using distribution.
+            save_dir (str): the directory path to save training output.
+            do_eval (bool, optional): whether or not to evaluate model during training. Defaults to True.
+
+        Returns:
+            CompletedProcess: the result of training subprocess execution.
+        """
+        args, env = self.distributed(device, ips, log_dir=save_dir)
+        cmd = [*args, "tools/train.py", "-c", config_path, *cli_args]
+        if do_eval:
+            # We simply pass here because in PaddleOCR periodic evaluation cannot be switched off
+            pass
+        else:
+            inf = int(1.0e11)
+            cmd.extend(["-o", f"Global.eval_batch_step={inf}"])
+        return self.run_cmd(
+            cmd,
+            env=env,
+            switch_wdir=True,
+            echo=True,
+            silent=False,
+            capture_output=True,
+            log_path=self._get_train_log_path(save_dir),
+        )
+
+    def evaluate(
+        self, config_path: str, cli_args: list, device: str, ips: str
+    ) -> CompletedProcess:
+        """run model evaluating
+
+        Args:
+            config_path (str): the config file path used to evaluate.
+            cli_args (list): the additional parameters.
+            device (str): the evaluating device.
+            ips (str): the ip addresses of nodes when using distribution.
+
+        Returns:
+            CompletedProcess: the result of evaluating subprocess execution.
+        """
+        args, env = self.distributed(device, ips)
+        cmd = [*args, "tools/eval.py", "-c", config_path]
+
+        cp = self.run_cmd(
+            cmd, env=env, switch_wdir=True, echo=True, silent=False, capture_output=True
+        )
+        if cp.returncode == 0:
+            metric_dict = _extract_eval_metrics(cp.stdout)
+            cp.metrics = metric_dict
+        return cp
+
+    def predict(
+        self, config_path: str, cli_args: list, device: str
+    ) -> CompletedProcess:
+        """run predicting using dynamic mode
+
+        Args:
+            config_path (str): the config file path used to predict.
+            cli_args (list): the additional parameters.
+            device (str): unused.
+
+        Returns:
+            CompletedProcess: the result of predicting subprocess execution.
+        """
+        cmd = [self.python, "tools/infer_rec.py", "-c", config_path]
+        return self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+
+    def export(
+        self, config_path: str, cli_args: list, device: str, save_dir: str = None
+    ) -> CompletedProcess:
+        """run exporting
+
+        Args:
+            config_path (str): the path of config file used to export.
+            cli_args (list): the additional parameters.
+            device (str): unused.
+            save_dir (str, optional): the directory path to save exporting output. Defaults to None.
+
+        Returns:
+            CompletedProcess: the result of exporting subprocess execution.
+        """
+        # `device` unused
+        cmd = [self.python, "tools/export_model.py", "-c", config_path]
+        cp = self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+        return cp
+
+    def infer(self, config_path: str, cli_args: list, device: str) -> CompletedProcess:
+        """run predicting using inference model
+
+        Args:
+            config_path (str): the path of config file used to predict.
+            cli_args (list): the additional parameters.
+            device (str): unused.
+
+        Returns:
+            CompletedProcess: the result of infering subprocess execution.
+        """
+        cmd = [self.python, "tools/infer/predict_rec.py", *cli_args]
+        return self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+
+    def compression(
+        self,
+        config_path: str,
+        train_cli_args: list,
+        export_cli_args: list,
+        device: str,
+        train_save_dir: str,
+    ) -> CompletedProcess:
+        """run compression model
+
+        Args:
+            config_path (str): the path of config file used to predict.
+            train_cli_args (list): the additional training parameters.
+            export_cli_args (list): the additional exporting parameters.
+            device (str): the running device.
+            train_save_dir (str): the directory path to save output.
+
+        Returns:
+            CompletedProcess: the result of compression subprocess execution.
+        """
+        # Step 1: Train model
+        args, env = self.distributed(device, log_dir=train_save_dir)
+        cmd = [*args, "deploy/slim/quantization/quant.py", "-c", config_path]
+        cp_train = self.run_cmd(
+            cmd,
+            env=env,
+            switch_wdir=True,
+            echo=True,
+            silent=False,
+            capture_output=True,
+            log_path=self._get_train_log_path(train_save_dir),
+        )
+
+        # Step 2: Export model
+        export_cli_args = [
+            *export_cli_args,
+            "-o",
+            f"Global.checkpoints={train_save_dir}/latest",
+        ]
+        cmd = [
+            self.python,
+            "deploy/slim/quantization/export_model.py",
+            "-c",
+            config_path,
+            *export_cli_args,
+        ]
+        cp_export = self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+
+        return cp_train, cp_export
+
+
+def _extract_eval_metrics(stdout: str) -> dict:
+    """extract evaluation metrics from training log
+
+    Args:
+        stdout (str): the training log
+
+    Returns:
+        dict: the training metric
+    """
+    import re
+
+    def _lazy_split_lines(s):
+        prev_idx = 0
+        while True:
+            curr_idx = s.find(os.linesep, prev_idx)
+            if curr_idx == -1:
+                curr_idx = len(s)
+            yield s[prev_idx:curr_idx]
+            prev_idx = curr_idx + len(os.linesep)
+            if prev_idx >= len(s):
+                break
+
+    _DP = r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?"
+    pattern_key_pairs = [
+        (re.compile(r"acc:(_dp)$".replace("_dp", _DP)), "acc"),
+        (re.compile(r"norm_edit_dis:(_dp)$".replace("_dp", _DP)), "norm_edit_dis"),
+        (re.compile(r"Teacher_acc:(_dp)$".replace("_dp", _DP)), "teacher_acc"),
+        (
+            re.compile(r"Teacher_norm_edit_dis:(_dp)$".replace("_dp", _DP)),
+            "teacher_norm_edit_dis",
+        ),
+        (re.compile(r"precision:(_dp)$".replace("_dp", _DP)), "precision"),
+        (re.compile(r"recall:(_dp)$".replace("_dp", _DP)), "recall"),
+        (re.compile(r"hmean:(_dp)$".replace("_dp", _DP)), "hmean"),
+        (re.compile(r"exp_rate:(_dp)$".replace("_dp", _DP)), "exp_rate"),
+    ]
+
+    metric_dict = dict()
+    start_match = False
+    for line in _lazy_split_lines(stdout):
+        if "metric eval" in line:
+            start_match = True
+        if start_match:
+            for pattern, key in pattern_key_pairs:
+                match = pattern.search(line)
+                if match:
+                    assert len(match.groups()) == 1
+                    # Newer overwrites older
+                    metric_dict[key] = float(match.group(1))
+    return metric_dict

+ 0 - 9
paddlex/repo_apis/PaddleOCR_api/text_rec/register.py

@@ -197,12 +197,3 @@ register_model_info(
         "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
     }
 )
-
-register_model_info(
-    {
-        "model_name": "LaTeX_OCR_rec",
-        "suite": "TextRec",
-        "config_path": osp.join(PDX_CONFIG_DIR, "LaTeX_OCR_rec.yml"),
-        "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
-    }
-)