Parcourir la source

add formula_rec pipeline (#2779)

* add formula_rec pipeline

* add comment for formula pipeline

* remove print

* adjust import order
liuhongen1234567 il y a 10 mois
Parent
commit
a7ca025449

+ 43 - 0
api_examples/pipelines/test_formula_recognition.py

@@ -0,0 +1,43 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="formula_recognition")
+
+output = pipeline.predict(
+    "./test_samples/general_formula_recognition01.png", use_layout_detection=True
+)
+
+# output = pipeline.predict(
+#     "./test_samples/general_formula_recognition01.pdf",
+#     use_layout_detection=True,
+# )
+
+# output = pipeline.predict(
+#     "./test_samples/general_formula_recognition02.png",
+#     use_layout_detection=False,
+# )
+
+# img_list = [ "./test_samples/general_formula_recognition03.png", \
+#     "./test_samples/general_formula_recognition04.png", \
+#         "./test_samples/general_formula_recognition05.png",]
+# output = pipeline.predict(
+#     img_list,
+#     use_layout_detection=True,
+# )
+
+for res in output:
+    # res.save_to_img("./output/")
+    res.save_results("./output")

+ 35 - 0
paddlex/configs/pipelines/formula_recognition.yaml

@@ -0,0 +1,35 @@
+
+pipeline_name: formula_recognition
+
+use_layout_detection: True
+use_doc_preprocessor: True
+
+SubModules:
+  LayoutDetection:
+    module_name: layout_detection
+    model_name: RT-DETR-H_layout_17cls
+    model_dir: null
+    batch_size: 1
+
+  FormulaRecognition:
+    module_name: formula_recognition
+    model_name: PP-FormulaNet-L
+    model_dir: null
+    batch_size: 5
+
+SubPipelines:
+  DocPreprocessor:
+    pipeline_name: doc_preprocessor
+    use_doc_orientation_classify: True
+    use_doc_unwarping: True
+    SubModules:
+      DocOrientationClassify:
+        module_name: doc_text_orientation
+        model_name: PP-LCNet_x1_0_doc_ori
+        model_dir: null
+        batch_size: 1
+      DocUnwarping:
+        module_name: image_unwarping
+        model_name: UVDoc
+        model_dir: null
+        batch_size: 1

+ 137 - 69
paddlex/inference/models_new/formula_recognition/result.py

@@ -12,10 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import os
-import sys
+import os, sys
+from typing import Any, Dict, Optional, List
 import cv2
 import PIL
+import fitz
 import math
 import random
 import tempfile
@@ -27,6 +28,7 @@ from PIL import Image, ImageDraw, ImageFont
 from ...common.result import BaseCVResult
 from ....utils import logging
 from ....utils.fonts import PINGFANG_FONT_FILE_PATH
+from ....utils.file_interface import custom_open
 
 
 class FormulaRecResult(BaseCVResult):
@@ -35,8 +37,18 @@ class FormulaRecResult(BaseCVResult):
 
     def _to_img(
         self,
-    ):
-        """Draw formula on image"""
+    ) -> Image.Image:
+        """
+        Draws a recognized formula on an image.
+
+        This method processes an input image to recognize and render a LaTeX formula.
+        It overlays the rendered formula onto the input image and returns the combined image.
+        If the LaTeX rendering engine is not installed or a syntax error is detected,
+        it logs a warning and returns the original image.
+
+        Returns:
+            Image.Image: An image with the recognized formula rendered alongside the original image.
+        """
         image = Image.fromarray(self["input_img"])
         try:
             env_valid()
@@ -77,7 +89,19 @@ class FormulaRecResult(BaseCVResult):
             return image
 
 
-def get_align_equation(equation):
+def get_align_equation(equation: str) -> str:
+    """
+    Wraps an equation in LaTeX environment tags if not already aligned.
+
+    This function checks if a given LaTeX equation contains any alignment tags (`align` or `align*`).
+    If the equation does not contain these tags, it wraps the equation in `equation` and `nonumber` tags.
+
+    Args:
+        equation (str): The LaTeX equation to be checked and potentially modified.
+
+    Returns:
+        str: The modified equation with appropriate LaTeX tags for alignment.
+    """
     is_align = False
     equation = str(equation) + "\n"
     begin_dict = [
@@ -101,8 +125,19 @@ def get_align_equation(equation):
     return equation
 
 
-def generate_tex_file(tex_file_path, equation):
-    with open(tex_file_path, "w") as fp:
+def generate_tex_file(tex_file_path: str, equation: str) -> None:
+    """
+    Generates a LaTeX file containing a specific equation.
+
+    This function creates a LaTeX file at the specified file path, writing the necessary
+    LaTeX preamble and wrapping the provided equation in a document structure. The equation
+    is processed to ensure it includes alignment tags if necessary.
+
+    Args:
+        tex_file_path (str): The file path where the LaTeX file will be saved.
+        equation (str): The LaTeX equation to be written into the file.
+    """
+    with custom_open(tex_file_path, "w") as fp:
         start_template = (
             r"\documentclass{article}" + "\n"
             r"\usepackage{cite}" + "\n"
@@ -121,7 +156,24 @@ def generate_tex_file(tex_file_path, equation):
         fp.write(end_template)
 
 
-def generate_pdf_file(tex_path, pdf_dir, is_debug=False):
+def generate_pdf_file(
+    tex_path: str, pdf_dir: str, is_debug: bool = False
+) -> Optional[bool]:
+    """
+    Generates a PDF file from a LaTeX file using pdflatex.
+
+    This function checks if the specified LaTeX file exists, and then runs pdflatex to generate a PDF file
+    in the specified directory. It can run in debug mode to show detailed output or in silent mode.
+
+    Args:
+        tex_path (str): The path to the LaTeX file.
+        pdf_dir (str): The directory where the PDF file will be saved.
+        is_debug (bool, optional): If True, runs pdflatex with detailed output. Defaults to False.
+
+    Returns:
+        Optional[bool]: Returns True if the PDF was generated successfully, False if the LaTeX file does not exist,
+                        and None if an error occurred during the pdflatex execution.
+    """
     if os.path.exists(tex_path):
         command = "pdflatex -halt-on-error -output-directory={} {}".format(
             pdf_dir, tex_path
@@ -129,13 +181,27 @@ def generate_pdf_file(tex_path, pdf_dir, is_debug=False):
         if is_debug:
             subprocess.check_call(command, shell=True)
         else:
-            devNull = open(os.devnull, "w")
+            devNull = custom_open(os.devnull, "w")
             subprocess.check_call(
                 command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True
             )
 
 
-def crop_white_area(image):
+def crop_white_area(image: np.ndarray) -> Optional[List[int]]:
+    """
+    Finds and returns the bounding box of the non-white area in an image.
+
+    This function converts an image to grayscale and uses binary thresholding to
+    find contours. It then calculates the bounding rectangle around the non-white
+    areas of the image.
+
+    Args:
+        image (np.ndarray): The input image as a NumPy array.
+
+    Returns:
+        Optional[List[int]]: A list [x, y, w, h] representing the bounding box of
+                             the non-white area, or None if no such area is found.
+    """
     image = np.array(image).astype("uint8")
     gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
     _, thresh = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)
@@ -147,8 +213,18 @@ def crop_white_area(image):
         return None
 
 
-def pdf2img(pdf_path, img_path, is_padding=False):
-    import fitz
+def pdf2img(pdf_path: str, img_path: str, is_padding: bool = False):
+    """
+    Converts a single-page PDF to an image, optionally cropping white areas and adding padding.
+
+    Args:
+        pdf_path (str): The path to the PDF file.
+        img_path (str): The path where the image will be saved.
+        is_padding (bool): If True, adds a 30-pixel white padding around the image.
+
+    Returns:
+        np.ndarray: The resulting image as a NumPy array, or None if the PDF is not single-page.
+    """
 
     pdfDoc = fitz.open(pdf_path)
     if pdfDoc.page_count != 1:
@@ -160,11 +236,10 @@ def pdf2img(pdf_path, img_path, is_padding=False):
         zoom_y = 2
         mat = fitz.Matrix(zoom_x, zoom_y).prerotate(rotate)
         pix = page.get_pixmap(matrix=mat, alpha=False)
-        if not os.path.exists(img_path):
-            os.makedirs(img_path)
-
-        pix._writeIMG(img_path, 7, 100)
-        img = cv2.imread(img_path)
+        getpngdata = pix.tobytes(output="png")
+        # decode as np.uint8
+        image_array = np.frombuffer(getpngdata, dtype=np.uint8)
+        img = cv2.imdecode(image_array, cv2.IMREAD_ANYCOLOR)
         xywh = crop_white_area(img)
 
         if xywh is not None:
@@ -178,8 +253,21 @@ def pdf2img(pdf_path, img_path, is_padding=False):
     return None
 
 
-def draw_formula_module(img_size, box, formula, is_debug=False):
-    """draw box formula for module"""
+def draw_formula_module(
+    img_size: tuple, box: list, formula: str, is_debug: bool = False
+):
+    """
+    Draw box formula for module.
+
+    Args:
+        img_size (tuple): The size of the image as (width, height).
+        box (list): The coordinates for the bounding box.
+        formula (str): The LaTeX formula to render.
+        is_debug (bool): If True, retains intermediate files for debugging purposes.
+
+    Returns:
+        np.ndarray: The resulting image with the formula or an error message.
+    """
     box_width, box_height = img_size
     with tempfile.TemporaryDirectory() as td:
         tex_file_path = os.path.join(td, "temp.tex")
@@ -200,7 +288,13 @@ def draw_formula_module(img_size, box, formula, is_debug=False):
         return img_right_text
 
 
-def env_valid():
+def env_valid() -> bool:
+    """
+    Validates if the environment is correctly set up to convert LaTeX formulas to images.
+
+    Returns:
+        bool: True if the environment is valid and the conversion is successful, False otherwise.
+    """
     with tempfile.TemporaryDirectory() as td:
         tex_file_path = os.path.join(td, "temp.tex")
         pdf_file_path = os.path.join(td, "temp.pdf")
@@ -214,55 +308,19 @@ def env_valid():
             formula_img = pdf2img(pdf_file_path, img_file_path, is_padding=False)
 
 
-def draw_box_formula_fine(img_size, box, formula, is_debug=False):
-    """draw box formula for pipeline"""
-    box_height = int(
-        math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
-    )
-    box_width = int(
-        math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
-    )
-    with tempfile.TemporaryDirectory() as td:
-        tex_file_path = os.path.join(td, "temp.tex")
-        pdf_file_path = os.path.join(td, "temp.pdf")
-        img_file_path = os.path.join(td, "temp.jpg")
-        generate_tex_file(tex_file_path, formula)
-        if os.path.exists(tex_file_path):
-            generate_pdf_file(tex_file_path, td, is_debug)
-        formula_img = None
-        if os.path.exists(pdf_file_path):
-            formula_img = pdf2img(pdf_file_path, img_file_path, is_padding=False)
-        if formula_img is not None:
-            formula_h, formula_w = formula_img.shape[:-1]
-            resize_height = box_height
-            resize_width = formula_w * resize_height / formula_h
-            formula_img = cv2.resize(
-                formula_img, (int(resize_width), int(resize_height))
-            )
-            formula_h, formula_w = formula_img.shape[:-1]
-            pts1 = np.float32(
-                [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]
-            )
-            pts2 = np.array(box, dtype=np.float32)
-            M = cv2.getPerspectiveTransform(pts1, pts2)
-            formula_img = np.array(formula_img, dtype=np.uint8)
-            img_right_text = cv2.warpPerspective(
-                formula_img,
-                M,
-                img_size,
-                flags=cv2.INTER_NEAREST,
-                borderMode=cv2.BORDER_CONSTANT,
-                borderValue=(255, 255, 255),
-            )
-        else:
-            img_right_text = draw_box_txt_fine(
-                img_size, box, "Rendering Failed", PINGFANG_FONT_FILE_PATH
-            )
-        return img_right_text
+def draw_box_txt_fine(img_size: tuple, box: list, txt: str, font_path: str):
+    """
+    Draw box text.
 
+    Args:
+        img_size (tuple): Size of the image as (width, height).
+        box (list): List of four points defining the box, each point is a tuple (x, y).
+        txt (str): The text to draw inside the box.
+        font_path (str): Path to the font file to be used for drawing text.
 
-def draw_box_txt_fine(img_size, box, txt, font_path):
-    """draw box text"""
+    Returns:
+        np.ndarray: Image array with the text drawn and transformed to fit the box.
+    """
     box_height = int(
         math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
     )
@@ -302,8 +360,18 @@ def draw_box_txt_fine(img_size, box, txt, font_path):
     return img_right_text
 
 
-def create_font(txt, sz, font_path):
-    """create font"""
+def create_font(txt: str, sz: tuple, font_path: str) -> ImageFont.FreeTypeFont:
+    """
+    Creates a font object with a size that ensures the text fits within the specified dimensions.
+
+    Args:
+        txt (str): The text to fit.
+        sz (tuple): The target size as (width, height).
+        font_path (str): The path to the font file.
+
+    Returns:
+        ImageFont.FreeTypeFont: A PIL font object at the appropriate size.
+    """
     font_size = int(sz[1] * 0.8)
     font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
     if int(PIL.__version__.split(".")[0]) < 10:

+ 1 - 0
paddlex/inference/pipelines_new/__init__.py

@@ -25,6 +25,7 @@ from .pp_chatocr import PP_ChatOCRv3_Pipeline, PP_ChatOCRv4_Pipeline
 from .image_classification import ImageClassificationPipeline
 from .seal_recognition import SealRecognitionPipeline
 from .table_recognition import TableRecognitionPipeline
+from .formula_recognition import FormulaRecognitionPipeline
 from .video_classification import VideoClassificationPipeline
 from .anomaly_detection import AnomalyDetectionPipeline
 from .ts_forecasting import TSFcPipeline

+ 15 - 0
paddlex/inference/pipelines_new/formula_recognition/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pipeline import FormulaRecognitionPipeline

+ 259 - 0
paddlex/inference/pipelines_new/formula_recognition/pipeline.py

@@ -0,0 +1,259 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os, sys
+from typing import Any, Dict, Optional
+import numpy as np
+import cv2
+from ..base import BasePipeline
+from ..components import CropByBoxes
+from ..layout_parsing.utils import convert_points_to_boxes
+
+from .result import FormulaRecognitionResult
+from ...models_new.formula_recognition.result import (
+    FormulaRecResult as SingleFormulaRecognitionResult,
+)
+from ....utils import logging
+from ...utils.pp_option import PaddlePredictorOption
+from ...common.reader import ReadImage
+from ...common.batch_sampler import ImageBatchSampler
+from ..ocr.result import OCRResult
+from ..doc_preprocessor.result import DocPreprocessorResult
+
+# [TODO] 待更新models_new到models
+from ...models_new.object_detection.result import DetResult
+
+
+class FormulaRecognitionPipeline(BasePipeline):
+    """Formula Recognition Pipeline"""
+
+    entities = ["formula_recognition"]
+
+    def __init__(
+        self,
+        config: Dict,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        """Initializes the layout parsing pipeline.
+
+        Args:
+            config (Dict): Configuration dictionary containing various settings.
+            device (str, optional): Device to run the predictions on. Defaults to None.
+            pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
+            use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+            hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
+        """
+
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
+        self.use_doc_preprocessor = False
+        if "use_doc_preprocessor" in config:
+            self.use_doc_preprocessor = config["use_doc_preprocessor"]
+
+        if self.use_doc_preprocessor:
+            doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
+            self.doc_preprocessor_pipeline = self.create_pipeline(
+                doc_preprocessor_config
+            )
+
+        self.use_layout_detection = True
+        if "use_layout_detection" in config:
+            self.use_layout_detection = config["use_layout_detection"]
+        if self.use_layout_detection:
+            layout_det_config = config["SubModules"]["LayoutDetection"]
+            self.layout_det_model = self.create_model(layout_det_config)
+
+        formula_recognition_config = config["SubModules"]["FormulaRecognition"]
+        self.formula_recognition_model = self.create_model(formula_recognition_config)
+
+        self._crop_by_boxes = CropByBoxes()
+
+        self.batch_sampler = ImageBatchSampler(batch_size=1)
+        self.img_reader = ReadImage(format="BGR")
+
+    def check_input_params_valid(
+        self, input_params: Dict, layout_det_res: DetResult
+    ) -> bool:
+        """
+        Check if the input parameters are valid based on the initialized models.
+
+        Args:
+            input_params (Dict): A dictionary containing input parameters.
+            layout_det_res (DetResult): The layout detection result.
+        Returns:
+            bool: True if all required models are initialized according to input parameters, False otherwise.
+        """
+
+        if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
+            logging.error(
+                "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
+            )
+            return False
+
+        if input_params["use_layout_detection"]:
+            if layout_det_res is not None:
+                logging.error(
+                    "The layout detection model has already been initialized, please set use_layout_detection=False"
+                )
+                return False
+
+            if not self.use_layout_detection:
+                logging.error(
+                    "Set use_layout_detection, but the models for layout detection are not initialized."
+                )
+                return False
+
+        return True
+
+    def predict_doc_preprocessor_res(
+        self, image_array: np.ndarray, input_params: dict
+    ) -> tuple[DocPreprocessorResult, np.ndarray]:
+        """
+        Preprocess the document image based on input parameters.
+
+        Args:
+            image_array (np.ndarray): The input image array.
+            input_params (dict): Dictionary containing preprocessing parameters.
+
+        Returns:
+            tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
+                                              result dictionary and the processed image array.
+        """
+        if input_params["use_doc_preprocessor"]:
+            use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
+            use_doc_unwarping = input_params["use_doc_unwarping"]
+            doc_preprocessor_res = next(
+                self.doc_preprocessor_pipeline(
+                    image_array,
+                    use_doc_orientation_classify=use_doc_orientation_classify,
+                    use_doc_unwarping=use_doc_unwarping,
+                )
+            )
+            doc_preprocessor_image = doc_preprocessor_res["output_img"]
+        else:
+            doc_preprocessor_res = {}
+            doc_preprocessor_image = image_array
+        return doc_preprocessor_res, doc_preprocessor_image
+
+    def predict_single_formula_recognition_res(
+        self,
+        image_array: np.ndarray,
+    ) -> SingleFormulaRecognitionResult:
+        """
+        Predict formula recognition results from an image array, layout detection results.
+
+        Args:
+            image_array (np.ndarray): The input image represented as a numpy array.
+            formula_box (list): The formula box coordinates.
+            flag_find_nei_text (bool): Whether to find neighboring text.
+        Returns:
+            SingleFormulaRecognitionResult: single formula recognition result.
+        """
+
+        formula_recognition_pred = next(self.formula_recognition_model(image_array))
+
+        return formula_recognition_pred
+
+    def predict(
+        self,
+        input: str | list[str] | np.ndarray | list[np.ndarray],
+        use_layout_detection: bool = True,
+        use_doc_orientation_classify: bool = False,
+        use_doc_unwarping: bool = False,
+        layout_det_res: DetResult = None,
+        **kwargs
+    ) -> FormulaRecognitionResult:
+        """
+        This function predicts the layout parsing result for the given input.
+
+        Args:
+            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) of pdf(s) to be processed.
+            use_layout_detection (bool): Whether to use layout detection.
+            use_doc_orientation_classify (bool): Whether to use document orientation classification.
+            use_doc_unwarping (bool): Whether to use document unwarping.
+            layout_det_res (DetResult): The layout detection result.
+                It will be used if it is not None and use_layout_detection is False.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            formulaRecognitionResult: The predicted formula recognition result.
+        """
+
+        input_params = {
+            "use_layout_detection": use_layout_detection,
+            "use_doc_preprocessor": self.use_doc_preprocessor,
+            "use_doc_orientation_classify": use_doc_orientation_classify,
+            "use_doc_unwarping": use_doc_unwarping,
+        }
+
+        if use_doc_orientation_classify or use_doc_unwarping:
+            input_params["use_doc_preprocessor"] = True
+        else:
+            input_params["use_doc_preprocessor"] = False
+
+        if not self.check_input_params_valid(input_params, layout_det_res):
+            yield None
+
+        for img_id, batch_data in enumerate(self.batch_sampler(input)):
+            image_array = self.img_reader(batch_data)[0]
+            input_path = batch_data[0]
+            img_id += 1
+
+            doc_preprocessor_res, doc_preprocessor_image = (
+                self.predict_doc_preprocessor_res(image_array, input_params)
+            )
+
+            formula_res_list = []
+            formula_region_id = 1
+
+            if not input_params["use_layout_detection"] and layout_det_res is None:
+                layout_det_res = {}
+                img_height, img_width = doc_preprocessor_image.shape[:2]
+                single_formula_rec_res = self.predict_single_formula_recognition_res(
+                    doc_preprocessor_image,
+                )
+                single_formula_rec_res["formula_region_id"] = formula_region_id
+                formula_res_list.append(single_formula_rec_res)
+                formula_region_id += 1
+            else:
+                if input_params["use_layout_detection"]:
+                    layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
+                for box_info in layout_det_res["boxes"]:
+                    if box_info["label"].lower() in ["formula"]:
+                        crop_img_info = self._crop_by_boxes(image_array, [box_info])
+                        crop_img_info = crop_img_info[0]
+                        single_formula_rec_res = (
+                            self.predict_single_formula_recognition_res(
+                                crop_img_info["img"]
+                            )
+                        )
+                        single_formula_rec_res["formula_region_id"] = formula_region_id
+                        single_formula_rec_res["dt_polys"] = box_info["coordinate"]
+                        formula_res_list.append(single_formula_rec_res)
+                        formula_region_id += 1
+
+            single_img_res = {
+                "layout_det_res": layout_det_res,
+                "doc_preprocessor_res": doc_preprocessor_res,
+                "formula_res_list": formula_res_list,
+                "input_params": input_params,
+                "img_id": img_id,
+                "img_name": input_path,
+            }
+            yield FormulaRecognitionResult(single_img_res)

+ 216 - 0
paddlex/inference/pipelines_new/formula_recognition/result.py

@@ -0,0 +1,216 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os, sys
+from typing import Tuple
+import cv2
+import PIL
+import math
+import random
+import tempfile
+import subprocess
+import numpy as np
+from pathlib import Path
+from PIL import Image, ImageDraw, ImageFont
+
+from ...common.result import BaseCVResult
+from ....utils import logging
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH
+from ...models_new.formula_recognition.result import (
+    get_align_equation,
+    generate_tex_file,
+    generate_pdf_file,
+    env_valid,
+    pdf2img,
+    create_font,
+    crop_white_area,
+    draw_box_txt_fine,
+)
+
+
+class FormulaRecognitionResult(dict):
+    """Layout Parsing Result"""
+
+    def __init__(self, data) -> None:
+        """Initializes a new instance of the class with the specified data."""
+        super().__init__(data)
+
+    def save_to_img(self, save_path: str) -> None:
+        """
+        Saves an image with overlaid formula recognition results.
+
+        This function attempts to save an image with recognized formulas highlighted
+        and annotated. It verifies the environment setup before proceeding and logs
+        a warning if the necessary rendering engine is not installed. The output image
+        consists of two halves: the left side shows the original image with bounding
+        boxes, and the right side shows the recognized formulas.
+
+        Args:
+            save_path (str): The directory path where the output image will be saved.
+
+        Returns:
+            None
+        """
+        try:
+            env_valid()
+        except subprocess.CalledProcessError as e:
+            logging.warning(
+                "Please refer to 2.3 Formula Recognition Pipeline Visualization in Formula Recognition Pipeline Tutorial to install the LaTeX rendering engine at first."
+            )
+            return None
+        if not os.path.exists(save_path):
+            os.makedirs(save_path)
+        img_id = self["img_id"]
+        img_name = self["img_name"]
+        if len(self["layout_det_res"]) <= 0:
+            return
+        image = Image.fromarray(self["layout_det_res"]["input_img"])
+        h, w = image.height, image.width
+        img_left = image.copy()
+        img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
+        random.seed(0)
+        draw_left = ImageDraw.Draw(img_left)
+
+        formula_save_path = os.path.join(save_path, "formula_img_{}.jpg".format(img_id))
+        formula_res_list = self["formula_res_list"]
+        for tno in range(len(self["formula_res_list"])):
+            formula_res = self["formula_res_list"][tno]
+            formula_region_id = formula_res["formula_region_id"]
+            formula = str(formula_res["rec_formula"])
+            dt_polys = formula_res["dt_polys"]
+            x1, y1, x2, y2 = list(dt_polys)
+            try:
+                color = (
+                    random.randint(0, 255),
+                    random.randint(0, 255),
+                    random.randint(0, 255),
+                )
+                box = [x1, y1, x2, y1, x2, y2, x1, y2]
+                box = np.array(box).reshape([-1, 2])
+                pts = [(x, y) for x, y in box.tolist()]
+                draw_left.polygon(pts, outline=color, width=8)
+                draw_left.polygon(box, fill=color)
+                img_right_text = draw_box_formula_fine(
+                    (w, h),
+                    box,
+                    formula,
+                    is_debug=False,
+                )
+                pts = np.array(box, np.int32).reshape((-1, 1, 2))
+                cv2.polylines(img_right_text, [pts], True, color, 1)
+                img_right = cv2.bitwise_and(img_right, img_right_text)
+            except subprocess.CalledProcessError as e:
+                logging.warning("Syntax error detected in formula, rendering failed.")
+                continue
+        img_left = Image.blend(image, img_left, 0.5)
+        img_show = Image.new("RGB", (int(w * 2), h), (255, 255, 255))
+        img_show.paste(img_left, (0, 0, w, h))
+        img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
+        img_show.save(formula_save_path)
+
+    def save_results(self, save_path: str) -> None:
+        """Save the formula recognition results to the specified directory.
+
+        Args:
+            save_path (str): The directory path to save the results.
+        """
+        if not os.path.exists(save_path):
+            os.makedirs(save_path)
+        if not os.path.isdir(save_path):
+            return
+
+        img_id = self["img_id"]
+        layout_det_res = self["layout_det_res"]
+        if len(layout_det_res) > 0:
+            save_img_path = Path(save_path) / f"layout_det_result_img{img_id}.jpg"
+            layout_det_res.save_to_img(save_img_path)
+        self.save_to_img(save_path)
+        input_params = self["input_params"]
+        if input_params["use_doc_preprocessor"]:
+            save_img_path = Path(save_path) / f"doc_preprocessor_result_img{img_id}.jpg"
+            self["doc_preprocessor_res"].save_to_img(save_img_path)
+        for tno in range(len(self["formula_res_list"])):
+            formula_res = self["formula_res_list"][tno]
+            formula_region_id = formula_res["formula_region_id"]
+            save_img_path = (
+                Path(save_path)
+                / f"formula_res_img{img_id}_region{formula_region_id}.jpg"
+            )
+            formula_res.save_to_img(save_img_path)
+        return
+
+
+def draw_box_formula_fine(
+    img_size: Tuple[int, int], box: np.ndarray, formula: str, is_debug: bool = False
+) -> np.ndarray:
+    """draw box formula for pipeline"""
+    """
+    Draw box formula for pipeline.
+
+    This function generates a LaTeX formula image and transforms it to fit
+    within a specified bounding box on a larger image. If the rendering fails,
+    it will write "Rendering Failed" inside the box.
+
+    Args:
+        img_size (Tuple[int, int]): The size of the image (width, height).
+        box (np.ndarray): A numpy array representing the four corners of the bounding box.
+        formula (str): The LaTeX formula to render.
+        is_debug (bool, optional): If True, enables debug mode. Defaults to False.
+
+    Returns:
+        np.ndarray: An image array with the rendered formula inside the specified box.
+    """
+    box_height = int(
+        math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
+    )
+    box_width = int(
+        math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
+    )
+    with tempfile.TemporaryDirectory() as td:
+        tex_file_path = os.path.join(td, "temp.tex")
+        pdf_file_path = os.path.join(td, "temp.pdf")
+        img_file_path = os.path.join(td, "temp.jpg")
+        generate_tex_file(tex_file_path, formula)
+        if os.path.exists(tex_file_path):
+            generate_pdf_file(tex_file_path, td, is_debug)
+        formula_img = None
+        if os.path.exists(pdf_file_path):
+            formula_img = pdf2img(pdf_file_path, img_file_path, is_padding=False)
+        if formula_img is not None:
+            formula_h, formula_w = formula_img.shape[:-1]
+            resize_height = box_height
+            resize_width = formula_w * resize_height / formula_h
+            formula_img = cv2.resize(
+                formula_img, (int(resize_width), int(resize_height))
+            )
+            formula_h, formula_w = formula_img.shape[:-1]
+            pts1 = np.float32(
+                [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]
+            )
+            pts2 = np.array(box, dtype=np.float32)
+            M = cv2.getPerspectiveTransform(pts1, pts2)
+            formula_img = np.array(formula_img, dtype=np.uint8)
+            img_right_text = cv2.warpPerspective(
+                formula_img,
+                M,
+                img_size,
+                flags=cv2.INTER_NEAREST,
+                borderMode=cv2.BORDER_CONSTANT,
+                borderValue=(255, 255, 255),
+            )
+        else:
+            img_right_text = draw_box_txt_fine(
+                img_size, box, "Rendering Failed", PINGFANG_FONT_FILE_PATH
+            )
+        return img_right_text