فهرست منبع

Add table pipeline v2 (#2772)

* Update paddlepaddle_install.md

* Update paddlepaddle_install_en.md

* add table pipev2

* Update test_table_recognition_v2.py

* Update test_table_recognition_v2.py
Liu Jiaxuan 10 ماه پیش
والد
کامیت
534325ed1e

+ 26 - 0
api_examples/pipelines/test_table_recognition_v2.py

@@ -0,0 +1,26 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddlex import create_pipeline
+
+pipeline = create_pipeline(pipeline="table_recognition_v2")
+
+output = pipeline.predict("./test_samples/table_recognition.jpg")
+
+for res in output:
+    res.print()
+    res.save_to_img("./output")
+    res.save_to_json("./output")
+    res.save_to_xlsx("./output")
+    res.save_to_html("./output")

+ 2 - 2
paddlex/configs/pipelines/table_recognition.yaml

@@ -1,14 +1,14 @@
 
 pipeline_name: table_recognition
 
-use_layout_detection: True
 use_doc_preprocessor: True
+use_layout_detection: True
 use_ocr_model: True
 
 SubModules:
   LayoutDetection:
     module_name: layout_detection
-    model_name: RT-DETR-H_layout_3cls
+    model_name: RT-DETR-H_layout_17cls
     model_dir: null
 
   TableStructureRecognition:

+ 76 - 0
paddlex/configs/pipelines/table_recognition_v2.yaml

@@ -0,0 +1,76 @@
+
+pipeline_name: table_recognition_v2
+
+use_doc_preprocessor: True
+use_layout_detection: True
+use_ocr_model: True
+
+SubModules:
+  LayoutDetection:
+    module_name: layout_detection
+    model_name: RT-DETR-H_layout_17cls
+    model_dir: null
+  
+  TableClassification:
+    module_name: table_classification
+    model_name: PP-LCNet_x1_0_table_cls
+    model_dir: null
+
+  WiredTableStructureRecognition:
+    module_name: table_structure_recognition
+    model_name: SLANeXt_wired
+    model_dir: null
+  
+  WirelessTableStructureRecognition:
+    module_name: table_structure_recognition
+    model_name: SLANeXt_wireless
+    model_dir: null
+  
+  WiredTableCellsDetection:
+    module_name: table_cells_detection
+    model_name: RT-DETR-L_wired_table_cell_det
+    model_dir: null
+  
+  WirelessTableCellsDetection:
+    module_name: table_cells_detection
+    model_name: RT-DETR-L_wireless_table_cell_det
+    model_dir: null
+
+SubPipelines:
+  DocPreprocessor:
+    pipeline_name: doc_preprocessor
+    use_doc_orientation_classify: True
+    use_doc_unwarping: True
+    SubModules:
+      DocOrientationClassify:
+        module_name: doc_text_orientation
+        model_name: PP-LCNet_x1_0_doc_ori
+        model_dir: null
+
+      DocUnwarping:
+        module_name: image_unwarping
+        model_name: UVDoc
+        model_dir: null
+
+  GeneralOCR:
+    pipeline_name: OCR
+    text_type: general
+    use_doc_preprocessor: False
+    use_textline_orientation: False
+    SubModules:
+      TextDetection:
+        module_name: text_detection
+        model_name: PP-OCRv4_server_det
+        model_dir: null
+        limit_side_len: 960
+        limit_type: max
+        thresh: 0.3
+        box_thresh: 0.6
+        unclip_ratio: 2.0
+        
+      TextRecognition:
+        module_name: text_recognition
+        model_name: PP-OCRv4_server_rec
+        model_dir: null
+        batch_size: 1
+        score_thresh: 0

+ 1 - 0
paddlex/inference/pipelines_new/__init__.py

@@ -26,6 +26,7 @@ from .image_classification import ImageClassificationPipeline
 from .object_detection import ObjectDetectionPipeline
 from .seal_recognition import SealRecognitionPipeline
 from .table_recognition import TableRecognitionPipeline
+from .table_recognition import TableRecognitionPipelineV2
 from .multilingual_speech_recognition import MultilingualSpeechRecognitionPipeline
 from .formula_recognition import FormulaRecognitionPipeline
 from .image_multilabel_classification import ImageMultiLabelClassificationPipeline

+ 1 - 0
paddlex/inference/pipelines_new/table_recognition/__init__.py

@@ -13,3 +13,4 @@
 # limitations under the License.
 
 from .pipeline import TableRecognitionPipeline
+from .pipeline_v2 import TableRecognitionPipelineV2

+ 434 - 0
paddlex/inference/pipelines_new/table_recognition/pipeline_v2.py

@@ -0,0 +1,434 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os, sys
+from typing import Any, Dict, Optional
+import numpy as np
+import cv2
+from ..base import BasePipeline
+from ..components import CropByBoxes
+from .utils import get_neighbor_boxes_idx
+from .table_recognition_post_processing_v2 import get_table_recognition_res
+from .result import SingleTableRecognitionResult, TableRecognitionResult
+from ....utils import logging
+from ...utils.pp_option import PaddlePredictorOption
+from ...common.reader import ReadImage
+from ...common.batch_sampler import ImageBatchSampler
+from ..ocr.result import OCRResult
+from ..doc_preprocessor.result import DocPreprocessorResult
+
+# [TODO] 待更新models_new到models
+from ...models_new.object_detection.result import DetResult
+
+
+class TableRecognitionPipelineV2(BasePipeline):
+    """Table Recognition Pipeline"""
+
+    entities = ["table_recognition_v2"]
+
+    def __init__(
+        self,
+        config: Dict,
+        device: str = None,
+        pp_option: PaddlePredictorOption = None,
+        use_hpip: bool = False,
+        hpi_params: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        """Initializes the layout parsing pipeline.
+
+        Args:
+            config (Dict): Configuration dictionary containing various settings.
+            device (str, optional): Device to run the predictions on. Defaults to None.
+            pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
+            use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
+            hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
+        """
+
+        super().__init__(
+            device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
+        )
+
+        self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
+        if self.use_doc_preprocessor:
+            doc_preprocessor_config = config.get("SubPipelines", {}).get(
+                "DocPreprocessor",
+                {
+                    "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
+                },
+            )
+            self.doc_preprocessor_pipeline = self.create_pipeline(
+                doc_preprocessor_config
+            )
+
+        self.use_layout_detection = config.get("use_layout_detection", True)
+        if self.use_layout_detection:
+            layout_det_config = config.get("SubModules", {}).get(
+                "LayoutDetection",
+                {"model_config_error": "config error for layout_det_model!"},
+            )
+            self.layout_det_model = self.create_model(layout_det_config)
+
+        table_cls_config = config.get("SubModules", {}).get(
+            "TableClassification",
+            {"model_config_error": "config error for table_classification_model!"},
+        )
+        self.table_cls_model = self.create_model(table_cls_config)
+
+        wired_table_rec_config = config.get("SubModules", {}).get(
+            "WiredTableStructureRecognition",
+            {"model_config_error": "config error for wired_table_structure_model!"},
+        )
+        self.wired_table_rec_model = self.create_model(wired_table_rec_config)
+
+        wireless_table_rec_config = config.get("SubModules", {}).get(
+            "WirelessTableStructureRecognition",
+            {"model_config_error": "config error for wireless_table_structure_model!"},
+        )
+        self.wireless_table_rec_model = self.create_model(wireless_table_rec_config)
+        
+        wired_table_cells_det_config = config.get("SubModules", {}).get(
+            "WiredTableCellsDetection",
+            {"model_config_error": "config error for wired_table_cells_detection_model!"},
+        )
+        self.wired_table_cells_detection_model = self.create_model(wired_table_cells_det_config)
+
+        wireless_table_cells_det_config = config.get("SubModules", {}).get(
+            "WirelessTableCellsDetection",
+            {"model_config_error": "config error for wireless_table_cells_detection_model!"},
+        )
+        self.wireless_table_cells_detection_model = self.create_model(wireless_table_cells_det_config)
+    
+        self.use_ocr_model = config.get("use_ocr_model", True)
+        if self.use_ocr_model:
+            general_ocr_config = config.get("SubPipelines", {}).get(
+                "GeneralOCR",
+                {"pipeline_config_error": "config error for general_ocr_pipeline!"},
+            )
+            self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
+
+        self._crop_by_boxes = CropByBoxes()
+
+        self.batch_sampler = ImageBatchSampler(batch_size=1)
+        self.img_reader = ReadImage(format="BGR")
+
+    def get_model_settings(
+        self,
+        use_doc_orientation_classify: Optional[bool],
+        use_doc_unwarping: Optional[bool],
+        use_layout_detection: Optional[bool],
+        use_ocr_model: Optional[bool],
+    ) -> dict:
+        """
+        Get the model settings based on the provided parameters or default values.
+
+        Args:
+            use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
+            use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
+            use_layout_detection (Optional[bool]): Whether to use layout detection.
+            use_ocr_model (Optional[bool]): Whether to use OCR model.
+
+        Returns:
+            dict: A dictionary containing the model settings.
+        """
+        if use_doc_orientation_classify is None and use_doc_unwarping is None:
+            use_doc_preprocessor = self.use_doc_preprocessor
+        else:
+            if use_doc_orientation_classify is True or use_doc_unwarping is True:
+                use_doc_preprocessor = True
+            else:
+                use_doc_preprocessor = False
+
+        if use_layout_detection is None:
+            use_layout_detection = self.use_layout_detection
+
+        if use_ocr_model is None:
+            use_ocr_model = self.use_ocr_model
+
+        return dict(
+            use_doc_preprocessor=use_doc_preprocessor,
+            use_layout_detection=use_layout_detection,
+            use_ocr_model=use_ocr_model,
+        )
+
+    def check_model_settings_valid(
+        self,
+        model_settings: Dict,
+        overall_ocr_res: OCRResult,
+        layout_det_res: DetResult,
+    ) -> bool:
+        """
+        Check if the input parameters are valid based on the initialized models.
+
+        Args:
+            model_settings (Dict): A dictionary containing input parameters.
+            overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
+                The overall OCR result with convert_points_to_boxes information.
+            layout_det_res (DetResult): The layout detection result.
+        Returns:
+            bool: True if all required models are initialized according to input parameters, False otherwise.
+        """
+
+        if model_settings["use_doc_preprocessor"] and not self.use_doc_preprocessor:
+            logging.error(
+                "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
+            )
+            return False
+
+        if model_settings["use_layout_detection"]:
+            if layout_det_res is not None:
+                logging.error(
+                    "The layout detection model has already been initialized, please set use_layout_detection=False"
+                )
+                return False
+
+            if not self.use_layout_detection:
+                logging.error(
+                    "Set use_layout_detection, but the models for layout detection are not initialized."
+                )
+                return False
+
+        if model_settings["use_ocr_model"]:
+            if overall_ocr_res is not None:
+                logging.error(
+                    "The OCR models have already been initialized, please set use_ocr_model=False"
+                )
+                return False
+
+            if not self.use_ocr_model:
+                logging.error(
+                    "Set use_ocr_model, but the models for OCR are not initialized."
+                )
+                return False
+        else:
+            if overall_ocr_res is None:
+                logging.error("Set use_ocr_model=False, but no OCR results were found.")
+                return False
+        return True
+
+    def predict_doc_preprocessor_res(
+        self, image_array: np.ndarray, input_params: dict
+    ) -> tuple[DocPreprocessorResult, np.ndarray]:
+        """
+        Preprocess the document image based on input parameters.
+
+        Args:
+            image_array (np.ndarray): The input image array.
+            input_params (dict): Dictionary containing preprocessing parameters.
+
+        Returns:
+            tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
+                                              result dictionary and the processed image array.
+        """
+        if input_params["use_doc_preprocessor"]:
+            use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
+            use_doc_unwarping = input_params["use_doc_unwarping"]
+            doc_preprocessor_res = next(
+                self.doc_preprocessor_pipeline(
+                    image_array,
+                    use_doc_orientation_classify=use_doc_orientation_classify,
+                    use_doc_unwarping=use_doc_unwarping,
+                )
+            )
+            doc_preprocessor_image = doc_preprocessor_res["output_img"]
+        else:
+            doc_preprocessor_res = {}
+            doc_preprocessor_image = image_array
+        return doc_preprocessor_res, doc_preprocessor_image
+
+    def extract_results(self, pred, task):
+        if task == "cls":
+            return pred['label_names'][np.argmax(pred['scores'])]
+        elif task == "det":
+            threshold = 0.0
+            result = []
+            if 'boxes' in pred and isinstance(pred['boxes'], list):
+                for box in pred['boxes']:
+                    if isinstance(box, dict) and 'score' in box and 'coordinate' in box:
+                        score = box['score']
+                        coordinate = box['coordinate']
+                        if isinstance(score, float) and score > threshold:
+                            result.append(coordinate)
+            return result
+        elif task == "table_stru":
+            return pred["structure"]
+        else:
+            return None
+
+    def predict_single_table_recognition_res(
+        self,
+        image_array: np.ndarray,
+        overall_ocr_res: OCRResult,
+        table_box: list,
+        flag_find_nei_text: bool = True,
+    ) -> SingleTableRecognitionResult:
+        """
+        Predict table recognition results from an image array, layout detection results, and OCR results.
+
+        Args:
+            image_array (np.ndarray): The input image represented as a numpy array.
+            overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
+                The overall OCR results containing text recognition information.
+            table_box (list): The table box coordinates.
+            flag_find_nei_text (bool): Whether to find neighboring text.
+        Returns:
+            SingleTableRecognitionResult: single table recognition result.
+        """
+        table_cls_pred = next(self.table_cls_model(image_array))
+        table_cls_result = self.extract_results(table_cls_pred, "cls")
+        if table_cls_result == "wired_table":
+            table_structure_pred = next(self.wired_table_rec_model(image_array))
+            table_cells_pred = next(self.wired_table_cells_detection_model(image_array))
+        elif table_cls_result == "wireless_table":
+            table_structure_pred = next(self.wireless_table_rec_model(image_array))
+            table_cells_pred = next(self.wireless_table_cells_detection_model(image_array))
+        table_structure_result = self.extract_results(table_structure_pred, "table_stru")
+        table_cells_result = self.extract_results(table_cells_pred, "det")
+        single_table_recognition_res = get_table_recognition_res(
+            table_box, table_structure_result, table_cells_result, overall_ocr_res
+        )
+        neighbor_text = ""
+        if flag_find_nei_text:
+            match_idx_list = get_neighbor_boxes_idx(
+                overall_ocr_res["rec_boxes"], table_box
+            )
+            if len(match_idx_list) > 0:
+                for idx in match_idx_list:
+                    neighbor_text += overall_ocr_res["rec_texts"][idx] + "; "
+        single_table_recognition_res["neighbor_text"] = neighbor_text
+        return single_table_recognition_res
+
+    def predict(
+        self,
+        input: str | list[str] | np.ndarray | list[np.ndarray],
+        use_doc_orientation_classify: Optional[bool] = None,
+        use_doc_unwarping: Optional[bool] = None,
+        use_layout_detection: Optional[bool] = None,
+        use_ocr_model: Optional[bool] = None,
+        overall_ocr_res: Optional[OCRResult] = None,
+        layout_det_res: Optional[DetResult] = None,
+        text_det_limit_side_len: Optional[int] = None,
+        text_det_limit_type: Optional[str] = None,
+        text_det_thresh: Optional[float] = None,
+        text_det_box_thresh: Optional[float] = None,
+        text_det_unclip_ratio: Optional[float] = None,
+        text_rec_score_thresh: Optional[float] = None,
+        **kwargs,
+    ) -> TableRecognitionResult:
+        """
+        This function predicts the layout parsing result for the given input.
+
+        Args:
+            input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) of pdf(s) to be processed.
+            use_layout_detection (bool): Whether to use layout detection.
+            use_doc_orientation_classify (bool): Whether to use document orientation classification.
+            use_doc_unwarping (bool): Whether to use document unwarping.
+            overall_ocr_res (OCRResult): The overall OCR result with convert_points_to_boxes information.
+                It will be used if it is not None and use_ocr_model is False.
+            layout_det_res (DetResult): The layout detection result.
+                It will be used if it is not None and use_layout_detection is False.
+            **kwargs: Additional keyword arguments.
+
+        Returns:
+            TableRecognitionResult: The predicted table recognition result.
+        """
+
+        model_settings = self.get_model_settings(
+            use_doc_orientation_classify,
+            use_doc_unwarping,
+            use_layout_detection,
+            use_ocr_model,
+        )
+
+        if not self.check_model_settings_valid(
+            model_settings, overall_ocr_res, layout_det_res
+        ):
+            yield {"error": "the input params for model settings are invalid!"}
+
+        for img_id, batch_data in enumerate(self.batch_sampler(input)):
+            if not isinstance(batch_data[0], str):
+                # TODO: add support input_pth for ndarray and pdf
+                input_path = f"{img_id}"
+            else:
+                input_path = batch_data[0]
+
+            image_array = self.img_reader(batch_data)[0]
+
+            if model_settings["use_doc_preprocessor"]:
+                doc_preprocessor_res = next(
+                    self.doc_preprocessor_pipeline(
+                        image_array,
+                        use_doc_orientation_classify=use_doc_orientation_classify,
+                        use_doc_unwarping=use_doc_unwarping,
+                    )
+                )
+            else:
+                doc_preprocessor_res = {"output_img": image_array}
+
+            doc_preprocessor_image = doc_preprocessor_res["output_img"]
+
+            if model_settings["use_ocr_model"]:
+                overall_ocr_res = next(
+                    self.general_ocr_pipeline(
+                        doc_preprocessor_image,
+                        text_det_limit_side_len=text_det_limit_side_len,
+                        text_det_limit_type=text_det_limit_type,
+                        text_det_thresh=text_det_thresh,
+                        text_det_box_thresh=text_det_box_thresh,
+                        text_det_unclip_ratio=text_det_unclip_ratio,
+                        text_rec_score_thresh=text_rec_score_thresh,
+                    )
+                )
+
+            table_res_list = []
+            table_region_id = 1
+            if not model_settings["use_layout_detection"] and layout_det_res is None:
+                layout_det_res = {}
+                img_height, img_width = doc_preprocessor_image.shape[:2]
+                table_box = [0, 0, img_width - 1, img_height - 1]
+                single_table_rec_res = self.predict_single_table_recognition_res(
+                    doc_preprocessor_image,
+                    overall_ocr_res,
+                    table_box,
+                    flag_find_nei_text=False,
+                )
+                single_table_rec_res["table_region_id"] = table_region_id
+                table_res_list.append(single_table_rec_res)
+                table_region_id += 1
+            else:
+                if model_settings["use_layout_detection"]:
+                    layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
+
+                for box_info in layout_det_res["boxes"]:
+                    if box_info["label"].lower() in ["table"]:
+                        crop_img_info = self._crop_by_boxes(image_array, [box_info])
+                        crop_img_info = crop_img_info[0]
+                        table_box = crop_img_info["box"]
+                        single_table_rec_res = (
+                            self.predict_single_table_recognition_res(
+                                crop_img_info["img"], overall_ocr_res, table_box
+                            )
+                        )
+                        single_table_rec_res["table_region_id"] = table_region_id
+                        table_res_list.append(single_table_rec_res)
+                        table_region_id += 1
+
+            single_img_res = {
+                "input_path": input_path,
+                "doc_preprocessor_res": doc_preprocessor_res,
+                "layout_det_res": layout_det_res,
+                "overall_ocr_res": overall_ocr_res,
+                "table_res_list": table_res_list,
+                "model_settings": model_settings,
+            }
+            yield TableRecognitionResult(single_img_res)

+ 274 - 0
paddlex/inference/pipelines_new/table_recognition/table_recognition_post_processing_v2.py

@@ -0,0 +1,274 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Any, Dict, Optional
+import numpy as np
+from ..layout_parsing.utils import get_sub_regions_ocr_res
+from ..components import convert_points_to_boxes
+from .result import SingleTableRecognitionResult
+from ..ocr.result import OCRResult
+
+
+def get_ori_image_coordinate(x: int, y: int, box_list: list) -> list:
+    """
+    get the original coordinate from Cropped image to Original image.
+    Args:
+        x (int): x coordinate of cropped image
+        y (int): y coordinate of cropped image
+        box_list (list): list of table bounding boxes, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
+    Returns:
+        list: list of original coordinates, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
+    """
+    if not box_list:
+        return box_list
+    offset = np.array([x, y] * 4)
+    box_list = np.array(box_list)
+    if box_list.shape[-1] == 2:
+        offset = offset.reshape(4, 2)
+    ori_box_list = offset + box_list
+    return ori_box_list
+
+
+def distance(box_1: list, box_2: list) -> float:
+    """
+    compute the distance between two boxes
+
+    Args:
+        box_1 (list): first rectangle box,eg.(x1, y1, x2, y2)
+        box_2 (list): second rectangle box,eg.(x1, y1, x2, y2)
+
+    Returns:
+        float: the distance between two boxes
+    """
+    x1, y1, x2, y2 = box_1
+    x3, y3, x4, y4 = box_2
+    center1_x = (x1 + x2) / 2
+    center1_y = (y1 + y2) / 2
+    center2_x = (x3 + x4) / 2
+    center2_y = (y3 + y4) / 2
+    dis = math.sqrt((center2_x - center1_x) ** 2 + (center2_y - center1_y) ** 2)
+    dis_2 = abs(x3 - x1) + abs(y3 - y1)
+    dis_3 = abs(x4 - x2) + abs(y4 - y2)
+    return dis + min(dis_2, dis_3)
+
+
+def compute_iou(rec1: list, rec2: list) -> float:
+    """
+    computing IoU
+    Args:
+        rec1 (list): (x1, y1, x2, y2)
+        rec2 (list): (x1, y1, x2, y2)
+    Returns:
+        float: Intersection over Union
+    """
+    # computing area of each rectangles
+    S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
+    S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
+
+    # computing the sum_area
+    sum_area = S_rec1 + S_rec2
+
+    # find the each edge of intersect rectangle
+    left_line = max(rec1[0], rec2[0])
+    right_line = min(rec1[2], rec2[2])
+    top_line = max(rec1[1], rec2[1])
+    bottom_line = min(rec1[3], rec2[3])
+
+    # judge if there is an intersect
+    if left_line >= right_line or top_line >= bottom_line:
+        return 0.0
+    else:
+        intersect = (right_line - left_line) * (bottom_line - top_line)
+        return (intersect / (sum_area - intersect)) * 1.0
+
+
+def match_table_and_ocr(cell_box_list: list, ocr_dt_boxes: list) -> dict:
+    """
+    match table and ocr
+
+    Args:
+        cell_box_list (list): bbox for table cell, 2 points, [left, top, right, bottom]
+        ocr_dt_boxes (list): bbox for ocr, 2 points, [left, top, right, bottom]
+
+    Returns:
+        dict: matched dict, key is table index, value is ocr index
+    """
+    matched = {}
+    for i, ocr_box in enumerate(np.array(ocr_dt_boxes)):
+        ocr_box = ocr_box.astype(np.float32)
+        distances = []
+        for j, table_box in enumerate(cell_box_list):
+            if len(table_box) == 8:
+                    table_box = [
+                        np.min(table_box[0::2]),
+                        np.min(table_box[1::2]),
+                        np.max(table_box[0::2]),
+                        np.max(table_box[1::2]),
+                    ]
+            distances.append(
+                (distance(table_box, ocr_box), 1.0 - compute_iou(table_box, ocr_box))
+            )  # compute iou and l1 distance
+        sorted_distances = distances.copy()
+        # select det box by iou and l1 distance
+        sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0]))
+        if distances.index(sorted_distances[0]) not in matched.keys():
+            matched[distances.index(sorted_distances[0])] = [i]
+        else:
+            matched[distances.index(sorted_distances[0])].append(i)
+    return matched
+
+
+def get_html_result(
+    matched_index: dict, ocr_contents: dict, pred_structures: list
+) -> str:
+    """
+    Generates HTML content based on the matched index, OCR contents, and predicted structures.
+
+    Args:
+        matched_index (dict): A dictionary containing matched indices.
+        ocr_contents (dict): A dictionary of OCR contents.
+        pred_structures (list): A list of predicted HTML structures.
+
+    Returns:
+        str: Generated HTML content as a string.
+    """
+    pred_html = []
+    td_index = 0
+    head_structure = pred_structures[0:3]
+    html = "".join(head_structure)
+    table_structure = pred_structures[3:-3]
+    for tag in table_structure:
+        if "</td>" in tag:
+            if "<td></td>" == tag:
+                pred_html.extend("<td>")
+            if td_index in matched_index.keys():
+                b_with = False
+                if (
+                    "<b>" in ocr_contents[matched_index[td_index][0]]
+                    and len(matched_index[td_index]) > 1
+                ):
+                    b_with = True
+                    pred_html.extend("<b>")
+                for i, td_index_index in enumerate(matched_index[td_index]):
+                    content = ocr_contents[td_index_index]
+                    if len(matched_index[td_index]) > 1:
+                        if len(content) == 0:
+                            continue
+                        if content[0] == " ":
+                            content = content[1:]
+                        if "<b>" in content:
+                            content = content[3:]
+                        if "</b>" in content:
+                            content = content[:-4]
+                        if len(content) == 0:
+                            continue
+                        if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
+                            content += " "
+                    pred_html.extend(content)
+                if b_with:
+                    pred_html.extend("</b>")
+            if "<td></td>" == tag:
+                pred_html.append("</td>")
+            else:
+                pred_html.append(tag)
+            td_index += 1
+        else:
+            pred_html.append(tag)
+    html += "".join(pred_html)
+    end_structure = pred_structures[-3:]
+    html += "".join(end_structure)
+    return html
+
+
+def sort_table_cells_boxes(boxes):
+    '''
+    Sort the input list of bounding boxes by using the DBSCAN algorithm to cluster based on the top-left y-coordinate (y1), and then sort within each line from left to right based on the top-left x-coordinate (x1).
+
+    Args:
+        boxes (list of lists): The input list of bounding boxes, where each bounding box is formatted as [x1, y1, x2, y2].
+
+    Returns:
+        sorted_boxes (list of lists): The list of bounding boxes sorted.
+    '''
+    import numpy as np
+    from sklearn.cluster import DBSCAN
+
+    # Extract the top-left y-coordinates (y1)
+    y1_coords = np.array([box[1] for box in boxes])
+    y1_coords = y1_coords.reshape(-1, 1)  # Convert to a 2D array
+
+    # Choose an appropriate eps parameter based on the range of y-values
+    y_range = y1_coords.max() - y1_coords.min()
+    eps = y_range / 50  # Adjust the denominator as needed
+
+    # Perform clustering using DBSCAN
+    db = DBSCAN(eps=eps, min_samples=1).fit(y1_coords)
+    labels = db.labels_
+
+    # Group bounding boxes by their labels
+    clusters = {}
+    for label, box in zip(labels, boxes):
+        if label not in clusters:
+            clusters[label] = []
+        clusters[label].append(box)
+
+    # Sort rows based on y-coordinates
+    # Compute the average y1 value for each row and sort from top to bottom
+    sorted_rows = sorted(clusters.items(), key=lambda item: np.mean([box[1] for box in item[1]]))
+
+    # Within each row, sort by x1 coordinate
+    sorted_boxes = []
+    for label, row in sorted_rows:
+        row_sorted = sorted(row, key=lambda x: x[0])
+        sorted_boxes.extend(row_sorted)
+
+    return sorted_boxes
+
+
+def get_table_recognition_res(
+    table_box: list, table_structure_result: list, table_cells_result: list, overall_ocr_res: OCRResult
+) -> SingleTableRecognitionResult:
+    """
+    Retrieve table recognition result from cropped image info, table structure prediction, and overall OCR result.
+
+    Args:
+        table_box (list): Information about the location of cropped image, including the bounding box.
+        table_structure_pred (dict): Predicted table structure.
+        overall_ocr_res (OCRResult): Overall OCR result from the input image.
+
+    Returns:
+        SingleTableRecognitionResult: An object containing the single table recognition result.
+    """
+    table_box = np.array([table_box])
+    table_ocr_pred = get_sub_regions_ocr_res(overall_ocr_res, table_box)
+
+    crop_start_point = [table_box[0][0], table_box[0][1]]
+    img_shape = overall_ocr_res["doc_preprocessor_res"]["output_img"].shape[0:2]
+
+    ocr_dt_boxes = table_ocr_pred["rec_boxes"]
+    ocr_texts_res = table_ocr_pred["rec_texts"]
+
+    table_cells_result = sort_table_cells_boxes(table_cells_result)
+    ocr_dt_boxes = sort_table_cells_boxes(ocr_dt_boxes)
+
+    matched_index = match_table_and_ocr(table_cells_result, ocr_dt_boxes)
+    pred_html = get_html_result(matched_index, ocr_texts_res, table_structure_result)
+
+    single_img_res = {
+        "cell_box_list": table_cells_result,
+        "table_ocr_pred": table_ocr_pred,
+        "pred_html": pred_html,
+    }
+
+    return SingleTableRecognitionResult(single_img_res)

+ 5 - 0
paddlex/inference/utils/official_models.py

@@ -318,6 +318,11 @@ PP-LCNet_x1_0_vehicle_attribute_infer.tar",
     "PP-TSMv2-LCNetV2_16frames_uniform": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/PP-TSMv2-LCNetV2_16frames_uniform_infer.tar",
     "MaskFormer_tiny": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/MaskFormer_tiny_infer.tar",
     "MaskFormer_small": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/MaskFormer_small_infer.tar",
+    "PP-LCNet_x1_0_table_cls": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/PP-LCNet_x1_0_table_cls_infer.tar",
+    "SLANeXt_wired": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/SLANeXt_wired_infer.tar",
+    "SLANeXt_wireless": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/SLANeXt_wireless_infer.tar",
+    "RT-DETR-L_wired_table_cell_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/RT-DETR-L_wired_table_cell_det_infer.tar",
+    "RT-DETR-L_wireless_table_cell_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/RT-DETR-L_wireless_table_cell_det_infer.tar",
     "YOWO": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/YOWO_infer.tar",
     "PP-TinyPose_128x96": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/PP-TinyPose_128x96_infer.tar",
     "PP-TinyPose_256x192": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0rc0/PP-TinyPose_256x192_infer.tar",