Browse Source

refine table related method (#3600)

Liu Jiaxuan 8 months ago
parent
commit
0c795513f8

+ 7 - 0
docs/pipeline_usage/tutorials/ocr_pipelines/table_recognition_v2.en.md

@@ -893,6 +893,13 @@ In the above Python script, the following steps are executed:
 </td>
 <td><code>None</code></td>
 </tr>
+<td><code>use_table_cells_ocr_results</code></td>
+<td>Whether to enable Table-Cells-OCR mode, when not enabled, use global OCR result to fill to html table, when enabled, do OCR cell by cell and fill to html table. Both of them perform differently in different scenarios, please choose according to the actual situation.</td>
+<td><code>bool|False</code></td>
+<td>
+<ul>
+<li><b>bool</b>:<code>True</code> or <code>False</code>
+<td><code>False</code></td>
 </table>
 
 (3) Process the prediction results, where each sample's prediction result is represented as a corresponding Result object, and supports operations such as printing, saving as an image, saving as an `xlsx` file, saving as an `HTML` file, and saving as a `json` file:

+ 8 - 0
docs/pipeline_usage/tutorials/ocr_pipelines/table_recognition_v2.md

@@ -895,6 +895,14 @@ for res in output:
 <li><b>float</b>:大于 <code>0</code> 的任意浮点数
     <li><b>None</b>:如果设置为 <code>None</code>, 将默认使用产线初始化的该参数值 <code>0.0</code>。即不设阈值</li></li></ul></td>
 <td><code>None</code></td>
+</tr>
+<td><code>use_table_cells_ocr_results</code></td>
+<td>是否启用单元格OCR模式,不启用时采用全局OCR结果填充至html表格,启用时逐个单元格做OCR并填充至html表格。二者在不同场景下表现不同,请根据实际情况选择。</td>
+<td><code>bool|False</code></td>
+<td>
+<ul>
+<li><b>bool</b>:<code>True</code> 或者 <code>False</code>
+<td><code>False</code></td>
 
 </tr></table>
 

+ 38 - 2
paddlex/inference/pipelines/table_recognition/pipeline_v2.py

@@ -15,6 +15,7 @@
 import os, sys
 from typing import Any, Dict, Optional, Union, List, Tuple
 import numpy as np
+import math
 import cv2
 from sklearn.cluster import KMeans
 from ..base import BasePipeline
@@ -497,12 +498,40 @@ class TableRecognitionPipelineV2(BasePipeline):
         if len(final_results) <= 0.6*html_pred_boxes_nums:
             final_results = combine_rectangles(ocr_det_results, html_pred_boxes_nums)
         return final_results
+    
+    def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes):
+        """
+        Splits OCR bounding boxes by table cells and retrieves text.
+
+        Args:
+            ori_img (ndarray): The original image from which text regions will be extracted.
+            cells_bboxes (list or ndarray): Detected cell bounding boxes to extract text from.
+
+        Returns:
+            list: A list containing the recognized texts from each cell.
+        """
+
+        # Check if cells_bboxes is a list and convert it if not.
+        if not isinstance(cells_bboxes, list):
+            cells_bboxes = cells_bboxes.tolist()
+        texts_list = []  # Initialize a list to store the recognized texts.
+        # Process each bounding box provided in cells_bboxes.
+        for i in range(len(cells_bboxes)):
+            # Extract and round up the coordinates of the bounding box.
+            x1, y1, x2, y2 = [math.ceil(k) for k in cells_bboxes[i]]
+            # Perform OCR on the defined region of the image and get the recognized text.
+            rec_te = next(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))
+            # Concatenate the texts and append them to the texts_list.
+            texts_list.append(''.join(rec_te["rec_texts"]))
+        # Return the list of recognized texts from each cell.
+        return texts_list
 
     def predict_single_table_recognition_res(
         self,
         image_array: np.ndarray,
         overall_ocr_res: OCRResult,
         table_box: list,
+        use_table_cells_ocr_results: bool = False,
         flag_find_nei_text: bool = True,
     ) -> SingleTableRecognitionResult:
         """
@@ -517,6 +546,7 @@ class TableRecognitionPipelineV2(BasePipeline):
         Returns:
             SingleTableRecognitionResult: single table recognition result.
         """
+
         table_cls_pred = next(self.table_cls_model(image_array))
         table_cls_result = self.extract_results(table_cls_pred, "cls")
         if table_cls_result == "wired_table":
@@ -538,8 +568,12 @@ class TableRecognitionPipelineV2(BasePipeline):
         table_cells_result = self.cells_det_results_reprocessing(
             table_cells_result, table_cells_score, ocr_det_boxes, len(table_structure_pred['bbox'])
         )
+        if use_table_cells_ocr_results == True:
+            cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result)
+        else:
+            cells_texts_list = []
         single_table_recognition_res = get_table_recognition_res(
-            table_box, table_structure_result, table_cells_result, overall_ocr_res
+            table_box, table_structure_result, table_cells_result, overall_ocr_res, cells_texts_list, use_table_cells_ocr_results
         )
         neighbor_text = ""
         if flag_find_nei_text:
@@ -567,6 +601,7 @@ class TableRecognitionPipelineV2(BasePipeline):
         text_det_box_thresh: Optional[float] = None,
         text_det_unclip_ratio: Optional[float] = None,
         text_rec_score_thresh: Optional[float] = None,
+        use_table_cells_ocr_results: Optional[bool] = False,
         **kwargs,
     ) -> TableRecognitionResult:
         """
@@ -638,6 +673,7 @@ class TableRecognitionPipelineV2(BasePipeline):
                     doc_preprocessor_image,
                     overall_ocr_res,
                     table_box,
+                    use_table_cells_ocr_results,
                     flag_find_nei_text=False,
                 )
                 single_table_rec_res["table_region_id"] = table_region_id
@@ -654,7 +690,7 @@ class TableRecognitionPipelineV2(BasePipeline):
                         table_box = crop_img_info["box"]
                         single_table_rec_res = (
                             self.predict_single_table_recognition_res(
-                                crop_img_info["img"], overall_ocr_res, table_box
+                                crop_img_info["img"], overall_ocr_res, table_box, use_table_cells_ocr_results
                             )
                         )
                         single_table_rec_res["table_region_id"] = table_region_id

+ 23 - 2
paddlex/inference/pipelines/table_recognition/table_recognition_post_processing_v2.py

@@ -403,6 +403,8 @@ def get_table_recognition_res(
     table_structure_result: list,
     table_cells_result: list,
     overall_ocr_res: OCRResult,
+    cells_texts_list: list,
+    use_table_cells_ocr_results: bool,
 ) -> SingleTableRecognitionResult:
     """
     Retrieve table recognition result from cropped image info, table structure prediction, and overall OCR result.
@@ -412,6 +414,8 @@ def get_table_recognition_res(
         table_structure_result (list): Predicted table structure.
         table_cells_result (list): Predicted table cells.
         overall_ocr_res (OCRResult): Overall OCR result from the input image.
+        cells_texts_list (list): OCR results with cells.
+        use_table_cells_ocr_results (bool): whether to use OCR results with cells.
 
     Returns:
         SingleTableRecognitionResult: An object containing the single table recognition result.
@@ -425,12 +429,29 @@ def get_table_recognition_res(
     crop_start_point = [table_box[0][0], table_box[0][1]]
     img_shape = overall_ocr_res["doc_preprocessor_res"]["output_img"].shape[0:2]
 
+    if len(table_cells_result) == 0 or len(table_ocr_pred["rec_boxes"]) == 0:
+        pred_html = ' '.join(table_structure_result)
+        if len(table_cells_result) != 0:
+            table_cells_result = convert_table_structure_pred_bbox(
+                table_cells_result, crop_start_point, img_shape
+            )
+        single_img_res = {
+            "cell_box_list": table_cells_result,
+            "table_ocr_pred": table_ocr_pred,
+            "pred_html": pred_html,
+        }
+        return SingleTableRecognitionResult(single_img_res)
+
     table_cells_result = convert_table_structure_pred_bbox(
         table_cells_result, crop_start_point, img_shape
     )
 
-    ocr_dt_boxes = table_ocr_pred["rec_boxes"]
-    ocr_texts_res = table_ocr_pred["rec_texts"]
+    if use_table_cells_ocr_results == False:
+        ocr_dt_boxes = table_ocr_pred["rec_boxes"]
+        ocr_texts_res = table_ocr_pred["rec_texts"]
+    else:
+        ocr_dt_boxes = table_cells_result
+        ocr_texts_res = cells_texts_list
 
     table_cells_result, table_cells_flag = sort_table_cells_boxes(table_cells_result)
     row_start_index = find_row_start_index(table_structure_result)

+ 0 - 12
paddlex/modules/table_recognition/dataset_checker/dataset_src/check_dataset.py

@@ -75,18 +75,6 @@ def check(dataset_dir, output, dataset_type="PubTabTableRecDataset", sample_num=
                             )
                             sample_paths[tag].append(sample_path)
 
-                        boxes_num = len(cells)
-                        tokens_num = sum(
-                            [
-                                structure.count(x)
-                                for x in ["<td>", "<td", "<eb></eb>", "<td></td>"]
-                            ]
-                        )
-                        if boxes_num != tokens_num:
-                            raise CheckFailedError(
-                                f"The number of cells needs to be consistent with the number of tokens "
-                                "but the number of cells is {boxes_num}, and the number of tokens is {tokens_num}."
-                            )
         meta = {}
 
         meta["train_samples"] = sample_cnts["train"]

+ 5 - 0
paddlex/utils/pipeline_arguments.py

@@ -196,6 +196,11 @@ PIPELINE_ARGUMENTS = {
     ],
     "table_recognition_v2": [
         {
+            "name": "--use_table_cells_ocr_results",
+            "type": bool,
+            "help": "Determines whether to use cells OCR results",
+        },
+        {
             "name": "--use_doc_orientation_classify",
             "type": bool,
             "help": "Determines whether to use document preprocessing",