Bläddra i källkod

Refine table pipes for layout (#3647)

Liu Jiaxuan 8 månader sedan
förälder
incheckning
f4b4377443

+ 27 - 0
paddlex/configs/pipelines/PP-StructureV3.yaml

@@ -96,6 +96,33 @@ SubPipelines:
         module_name: table_cells_detection
         model_name: RT-DETR-L_wireless_table_cell_det
         model_dir: null
+    SubPipelines:
+      GeneralOCR:
+        pipeline_name: OCR
+        text_type: general
+        use_doc_preprocessor: False
+        use_textline_orientation: True
+        SubModules:
+          TextDetection:
+            module_name: text_detection
+            model_name: PP-OCRv4_server_det
+            model_dir: null
+            limit_side_len: 1200
+            limit_type: max
+            thresh: 0.3
+            box_thresh: 0.4
+            unclip_ratio: 2.0
+          TextLineOrientation:
+            module_name: textline_orientation
+            model_name: PP-LCNet_x0_25_textline_ori
+            model_dir: null
+            batch_size: 1 
+          TextRecognition:
+            module_name: text_recognition
+            model_name: PP-OCRv4_server_rec_doc
+            model_dir: null
+            batch_size: 6
+        score_thresh: 0.0
 
   SealRecognition:
     pipeline_name: seal_recognition

+ 38 - 9
paddlex/inference/pipelines/table_recognition/pipeline.py

@@ -88,6 +88,11 @@ class TableRecognitionPipeline(BasePipeline):
                 {"pipeline_config_error": "config error for general_ocr_pipeline!"},
             )
             self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
+        else:
+            self.general_ocr_config_bak = config.get("SubPipelines", {}).get(
+                "GeneralOCR",
+                None
+            )
 
         self._crop_by_boxes = CropByBoxes()
 
@@ -217,6 +222,33 @@ class TableRecognitionPipeline(BasePipeline):
             doc_preprocessor_res = {}
             doc_preprocessor_image = image_array
         return doc_preprocessor_res, doc_preprocessor_image
+    
+    def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes):
+        """
+        Splits OCR bounding boxes by table cells and retrieves text.
+
+        Args:
+            ori_img (ndarray): The original image from which text regions will be extracted.
+            cells_bboxes (list or ndarray): Detected cell bounding boxes to extract text from.
+
+        Returns:
+            list: A list containing the recognized texts from each cell.
+        """
+
+        # Check if cells_bboxes is a list and convert it if not.
+        if not isinstance(cells_bboxes, list):
+            cells_bboxes = cells_bboxes.tolist()
+        texts_list = []  # Initialize a list to store the recognized texts.
+        # Process each bounding box provided in cells_bboxes.
+        for i in range(len(cells_bboxes)):
+            # Extract and round up the coordinates of the bounding box.
+            x1, y1, x2, y2 = [math.ceil(k) for k in cells_bboxes[i]]
+            # Perform OCR on the defined region of the image and get the recognized text.
+            rec_te = next(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))
+            # Concatenate the texts and append them to the texts_list.
+            texts_list.append(''.join(rec_te["rec_texts"]))
+        # Return the list of recognized texts from each cell.
+        return texts_list
 
     def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes):
         """
@@ -270,15 +302,9 @@ class TableRecognitionPipeline(BasePipeline):
         """
         table_structure_pred = next(self.table_structure_model(image_array))
         if use_table_cells_ocr_results == True:
-            table_cells_result = list(
-                map(lambda arr: arr.tolist(), table_structure_pred["bbox"])
-            )
-            table_cells_result = [
-                [rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result
-            ]
-            cells_texts_list = self.split_ocr_bboxes_by_table_cells(
-                image_array, table_cells_result
-            )
+            table_cells_result = list(map(lambda arr: arr.tolist(), table_structure_pred["bbox"]))
+            table_cells_result = [[rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result]
+            cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result)
         else:
             cells_texts_list = []
         single_table_recognition_res = get_table_recognition_res(
@@ -381,6 +407,9 @@ class TableRecognitionPipeline(BasePipeline):
                         text_rec_score_thresh=text_rec_score_thresh,
                     )
                 )
+            elif use_table_cells_ocr_results == True:
+                assert self.general_ocr_config_bak != None
+                self.general_ocr_pipeline = self.create_pipeline(self.general_ocr_config_bak)
 
             table_res_list = []
             table_region_id = 1

+ 17 - 28
paddlex/inference/pipelines/table_recognition/pipeline_v2.py

@@ -128,6 +128,11 @@ class TableRecognitionPipelineV2(BasePipeline):
                 {"pipeline_config_error": "config error for general_ocr_pipeline!"},
             )
             self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
+        else:
+            self.general_ocr_config_bak = config.get("SubPipelines", {}).get(
+                "GeneralOCR",
+                None
+            )
 
         self._crop_by_boxes = CropByBoxes()
 
@@ -595,25 +600,15 @@ class TableRecognitionPipelineV2(BasePipeline):
                 use_e2e_model = True
             else:
                 table_cells_pred = next(
-                    self.wireless_table_cells_detection_model(
-                        image_array, threshold=0.3
-                    )
+                    self.wireless_table_cells_detection_model(image_array, threshold=0.3)
                 )  # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
                 # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
 
         if use_e2e_model == False:
-            table_structure_result = self.extract_results(
-                table_structure_pred, "table_stru"
-            )
-            table_cells_result, table_cells_score = self.extract_results(
-                table_cells_pred, "det"
-            )
-            table_cells_result, table_cells_score = self.cells_det_results_nms(
-                table_cells_result, table_cells_score
-            )
-            ocr_det_boxes = self.get_region_ocr_det_boxes(
-                overall_ocr_res["rec_boxes"].tolist(), table_box
-            )
+            table_structure_result = self.extract_results(table_structure_pred, "table_stru")
+            table_cells_result, table_cells_score = self.extract_results(table_cells_pred, "det")
+            table_cells_result, table_cells_score = self.cells_det_results_nms(table_cells_result, table_cells_score)
+            ocr_det_boxes = self.get_region_ocr_det_boxes(overall_ocr_res["rec_boxes"].tolist(), table_box)
             table_cells_result = self.cells_det_results_reprocessing(
                 table_cells_result,
                 table_cells_score,
@@ -621,9 +616,7 @@ class TableRecognitionPipelineV2(BasePipeline):
                 len(table_structure_pred["bbox"]),
             )
             if use_table_cells_ocr_results == True:
-                cells_texts_list = self.split_ocr_bboxes_by_table_cells(
-                    image_array, table_cells_result
-                )
+                cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result)
             else:
                 cells_texts_list = []
             single_table_recognition_res = get_table_recognition_res(
@@ -636,16 +629,9 @@ class TableRecognitionPipelineV2(BasePipeline):
             )
         else:
             if use_table_cells_ocr_results == True:
-                table_cells_result_e2e = list(
-                    map(lambda arr: arr.tolist(), table_structure_pred["bbox"])
-                )
-                table_cells_result_e2e = [
-                    [rect[0], rect[1], rect[4], rect[5]]
-                    for rect in table_cells_result_e2e
-                ]
-                cells_texts_list = self.split_ocr_bboxes_by_table_cells(
-                    image_array, table_cells_result_e2e
-                )
+                table_cells_result_e2e = list(map(lambda arr: arr.tolist(), table_structure_pred["bbox"]))
+                table_cells_result_e2e = [[rect[0], rect[1], rect[4], rect[5]]for rect in table_cells_result_e2e]
+                cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result_e2e)
             else:
                 cells_texts_list = []
             single_table_recognition_res = get_table_recognition_res_e2e(
@@ -749,6 +735,9 @@ class TableRecognitionPipelineV2(BasePipeline):
                         text_rec_score_thresh=text_rec_score_thresh,
                     )
                 )
+            elif use_table_cells_ocr_results == True:
+                assert self.general_ocr_config_bak != None
+                self.general_ocr_pipeline = self.create_pipeline(self.general_ocr_config_bak)
 
             table_res_list = []
             table_region_id = 1