Эх сурвалжийг харах

Update table_rec_v2 interface (#3608)

* Update table_rec_v2 interface

* Update
Lin Manhui 8 сар өмнө
parent
commit
46345a1740

+ 7 - 1
docs/pipeline_usage/tutorials/ocr_pipelines/table_recognition.en.md

@@ -870,7 +870,7 @@ In the above Python script, the following steps are executed:
 </tr>
 <td><code>use_table_cells_ocr_results</code></td>
 <td>Whether to enable Table-Cells-OCR mode, when not enabled, use global OCR result to fill to HTML table, when enabled, do OCR cell by cell and fill to HTML table (it will increase the time consuming). Both of them perform differently in different scenarios, please choose according to the actual situation.</td>
-<td><code>bool|False</code></td>
+<td><code>bool</code></td>
 <td>
 <ul>
 <li><b>bool</b>:<code>True</code> or <code>False</code>
@@ -1249,6 +1249,12 @@ Below are the API references for basic serving deployment and multi-language ser
 <td>Please refer to the description of the <code>text_rec_score_thresh</code> parameter of the pipeline object's <code>predict</code> method.</td>
 <td>No</td>
 </tr>
+<tr>
+<td><code>useTableCellsOcrResults</code></td>
+<td><code>boolean</code></td>
+<td>Please refer to the description of the <code>use_table_cells_ocr_results</code> parameter of the pipeline object's <code>predict</code> method.</td>
+<td>No</td>
+</tr>
 </tbody>
 </table>
 

+ 7 - 1
docs/pipeline_usage/tutorials/ocr_pipelines/table_recognition.md

@@ -815,7 +815,7 @@ for res in output:
 </tr>
 <td><code>use_table_cells_ocr_results</code></td>
 <td>是否启用单元格OCR模式,不启用时采用全局OCR结果填充至HTML表格,启用时逐个单元格做OCR并填充至HTML表格(会增加耗时)。二者在不同场景下性能不同,请根据实际情况选择。</td>
-<td><code>bool|False</code></td>
+<td><code>bool</code></td>
 <td>
 <ul>
 <li><b>bool</b>:<code>True</code> 或者 <code>False</code>
@@ -1194,6 +1194,12 @@ for res in output:
 <td>请参阅产线对象中 <code>predict</code> 方法的 <code>text_rec_score_thresh</code> 参数相关说明。</td>
 <td>否</td>
 </tr>
+<tr>
+<td><code>useTableCellsOcrResults</code></td>
+<td><code>boolean</code></td>
+<td>请参阅产线对象中 <code>predict</code> 方法的 <code>use_table_cells_ocr_results</code> 参数相关说明。</td>
+<td>否</td>
+</tr>
 </tbody>
 </table>
 <ul>

+ 21 - 3
docs/pipeline_usage/tutorials/ocr_pipelines/table_recognition_v2.en.md

@@ -893,7 +893,7 @@ In the above Python script, the following steps are executed:
 </tr>
 <td><code>use_table_cells_ocr_results</code></td>
 <td>Whether to enable Table-Cells-OCR mode, when not enabled, use global OCR result to fill to HTML table, when enabled, do OCR cell by cell and fill to HTML table (it will increase the time consuming). Both of them perform differently in different scenarios, please choose according to the actual situation.</td>
-<td><code>bool|False</code></td>
+<td><code>bool</code></td>
 <td>
 <ul>
 <li><b>bool</b>:<code>True</code> or <code>False</code>
@@ -901,7 +901,7 @@ In the above Python script, the following steps are executed:
 </tr>
 <td><code>use_e2e_wired_table_rec_model</code></td>
 <td>Whether to enable the wired table end-to-end prediction mode, when not enabled, using the table cells detection model prediction results filled to the HTML table, when enabled, using the end-to-end table structure recognition model cell prediction results filled to the HTML table. Both of them have different performance in different scenarios, please choose according to the actual situation.</td>
-<td><code>bool|False</code></td>
+<td><code>bool</code></td>
 <td>
 <ul>
 <li><b>bool</b>:<code>True</code> or <code>False</code>
@@ -909,7 +909,7 @@ In the above Python script, the following steps are executed:
 </tr>
 <td><code>use_e2e_wireless_table_rec_model</code></td>
 <td>Whether to enable the wireless table end-to-end prediction mode, when not enabled, using the table cells detection model prediction results filled to the HTML table, when enabled, using the end-to-end table structure recognition model cell prediction results filled to the HTML table. Both of them have different performance in different scenarios, please choose according to the actual situation.</td>
-<td><code>bool|False</code></td>
+<td><code>bool</code></td>
 <td>
 <ul>
 <li><b>bool</b>:<code>True</code> or <code>False</code>
@@ -1322,6 +1322,24 @@ Below are the API references for basic serving deployment and multi-language ser
 <td>Please refer to the description of the <code>text_rec_score_thresh</code> parameter of the pipeline object's <code>predict</code> method.</td>
 <td>No</td>
 </tr>
+<tr>
+<td><code>useTableCellsOcrResults</code></td>
+<td><code>boolean</code></td>
+<td>Please refer to the description of the <code>use_table_cells_ocr_results</code> parameter of the pipeline object's <code>predict</code> method.</td>
+<td>No</td>
+</tr>
+<tr>
+<td><code>useE2eWiredTableRecModel</code></td>
+<td><code>boolean</code></td>
+<td>Please refer to the description of the <code>use_e2e_wired_table_rec_model</code> parameter of the pipeline object's <code>predict</code> method.</td>
+<td>No</td>
+</tr>
+<tr>
+<td><code>useE2eWirelessTableRecModel</code></td>
+<td><code>boolean</code></td>
+<td>Please refer to the description of the <code>use_e2e_wireless_table_rec_model</code> parameter of the pipeline object's <code>predict</code> method.</td>
+<td>No</td>
+</tr>
 </tbody>
 </table>
 <p>Each element in <code>tableRecResults</code> is an <code>object</code> with the following properties:</p>

+ 21 - 3
docs/pipeline_usage/tutorials/ocr_pipelines/table_recognition_v2.md

@@ -896,7 +896,7 @@ for res in output:
 </tr>
 <td><code>use_table_cells_ocr_results</code></td>
 <td>是否启用单元格OCR模式,不启用时采用全局OCR结果填充至HTML表格,启用时逐个单元格做OCR并填充至HTML表格(会增加耗时)。二者在不同场景下性能不同,请根据实际情况选择。</td>
-<td><code>bool|False</code></td>
+<td><code>bool</code></td>
 <td>
 <ul>
 <li><b>bool</b>:<code>True</code> 或者 <code>False</code>
@@ -904,7 +904,7 @@ for res in output:
 </tr>
 <td><code>use_e2e_wired_table_rec_model</code></td>
 <td>是否启用有线表格端到端预测模式,不启用时采用表格单元格检测模型预测结果填充至HTML表格,启用时采用端到端表格结构识别模型的单元格预测结果填充至HTML表格。二者在不同场景下性能不同,请根据实际情况选择。</td>
-<td><code>bool|False</code></td>
+<td><code>bool</code></td>
 <td>
 <ul>
 <li><b>bool</b>:<code>True</code> 或者 <code>False</code>
@@ -912,7 +912,7 @@ for res in output:
 </tr>
 <td><code>use_e2e_wireless_table_rec_model</code></td>
 <td>是否启用无线表格端到端预测模式,不启用时采用表格单元格检测模型预测结果填充至HTML表格,启用时采用端到端表格结构识别模型的单元格预测结果填充至HTML表格。二者在不同场景下性能不同,请根据实际情况选择。</td>
-<td><code>bool|False</code></td>
+<td><code>bool</code></td>
 <td>
 <ul>
 <li><b>bool</b>:<code>True</code> 或者 <code>False</code>
@@ -1326,6 +1326,24 @@ for res in output:
 <td>请参阅产线对象中 <code>predict</code> 方法的 <code>text_rec_score_thresh</code> 参数相关说明。</td>
 <td>否</td>
 </tr>
+<tr>
+<td><code>useTableCellsOcrResults</code></td>
+<td><code>boolean</code></td>
+<td>请参阅产线对象中 <code>predict</code> 方法的 <code>use_table_cells_ocr_results</code> 参数相关说明。</td>
+<td>否</td>
+</tr>
+<tr>
+<td><code>useE2eWiredTableRecModel</code></td>
+<td><code>boolean</code></td>
+<td>请参阅产线对象中 <code>predict</code> 方法的 <code>use_e2e_wired_table_rec_model</code> 参数相关说明。</td>
+<td>否</td>
+</tr>
+<tr>
+<td><code>useE2eWirelessTableRecModel</code></td>
+<td><code>boolean</code></td>
+<td>请参阅产线对象中 <code>predict</code> 方法的 <code>use_e2e_wireless_table_rec_model</code> 参数相关说明。</td>
+<td>否</td>
+</tr>
 </tbody>
 </table>
 <p><code>tableRecResults</code>中的每个元素为一个<code>object</code>,具有如下属性:</p>

+ 12 - 6
paddlex/inference/pipelines/table_recognition/pipeline.py

@@ -217,7 +217,7 @@ class TableRecognitionPipeline(BasePipeline):
             doc_preprocessor_res = {}
             doc_preprocessor_image = image_array
         return doc_preprocessor_res, doc_preprocessor_image
-    
+
     def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes):
         """
         Splits OCR bounding boxes by table cells and retrieves text.
@@ -241,7 +241,7 @@ class TableRecognitionPipeline(BasePipeline):
             # Perform OCR on the defined region of the image and get the recognized text.
             rec_te = next(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))
             # Concatenate the texts and append them to the texts_list.
-            texts_list.append(''.join(rec_te["rec_texts"]))
+            texts_list.append("".join(rec_te["rec_texts"]))
         # Return the list of recognized texts from each cell.
         return texts_list
 
@@ -270,9 +270,15 @@ class TableRecognitionPipeline(BasePipeline):
         """
         table_structure_pred = next(self.table_structure_model(image_array))
         if use_table_cells_ocr_results == True:
-            table_cells_result = list(map(lambda arr: arr.tolist(), table_structure_pred['bbox']))
-            table_cells_result = [[rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result]
-            cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result)
+            table_cells_result = list(
+                map(lambda arr: arr.tolist(), table_structure_pred["bbox"])
+            )
+            table_cells_result = [
+                [rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result
+            ]
+            cells_texts_list = self.split_ocr_bboxes_by_table_cells(
+                image_array, table_cells_result
+            )
         else:
             cells_texts_list = []
         single_table_recognition_res = get_table_recognition_res(
@@ -309,7 +315,7 @@ class TableRecognitionPipeline(BasePipeline):
         text_det_box_thresh: Optional[float] = None,
         text_det_unclip_ratio: Optional[float] = None,
         text_rec_score_thresh: Optional[float] = None,
-        use_table_cells_ocr_results: Optional[bool] = False,
+        use_table_cells_ocr_results: bool = False,
         cell_sort_by_y_projection: Optional[bool] = None,
         **kwargs,
     ) -> TableRecognitionResult:

+ 97 - 43
paddlex/inference/pipelines/table_recognition/pipeline_v2.py

@@ -22,7 +22,9 @@ from ..base import BasePipeline
 from ..components import CropByBoxes
 from .utils import get_neighbor_boxes_idx
 from .table_recognition_post_processing_v2 import get_table_recognition_res
-from .table_recognition_post_processing import get_table_recognition_res as get_table_recognition_res_e2e
+from .table_recognition_post_processing import (
+    get_table_recognition_res as get_table_recognition_res_e2e,
+)
 from .result import SingleTableRecognitionResult, TableRecognitionResult
 from ....utils import logging
 from ...utils.pp_option import PaddlePredictorOption
@@ -276,8 +278,10 @@ class TableRecognitionPipelineV2(BasePipeline):
             return pred["structure"]
         else:
             return None
-    
-    def cells_det_results_nms(self, cells_det_results, cells_det_scores, cells_det_threshold=0.3):
+
+    def cells_det_results_nms(
+        self, cells_det_results, cells_det_scores, cells_det_threshold=0.3
+    ):
         """
         Apply Non-Maximum Suppression (NMS) on detection results to remove redundant overlapping bounding boxes.
 
@@ -324,12 +328,14 @@ class TableRecognitionPipelineV2(BasePipeline):
             # Indices of boxes with IoU less than threshold
             inds = np.where(ovr <= cells_det_threshold)[0]
             # Update order, only keep boxes with IoU less than threshold
-            order = order[inds + 1]  # inds shifted by 1 because order[0] is the current box
+            order = order[
+                inds + 1
+            ]  # inds shifted by 1 because order[0] is the current box
         # Select the boxes and scores based on picked indices
         final_boxes = boxes[picked_indices].tolist()
         final_scores = scores[picked_indices].tolist()
         return final_boxes, final_scores
-    
+
     def get_region_ocr_det_boxes(self, ocr_det_boxes, table_box):
         """Adjust the coordinates of ocr_det_boxes that are fully inside table_box relative to table_box.
 
@@ -340,27 +346,33 @@ class TableRecognitionPipelineV2(BasePipeline):
         Returns:
             list of list: List of adjusted bounding boxes relative to table_box, for boxes fully inside table_box.
         """
-        tol=0
+        tol = 0
         # Extract coordinates from table_box
         x_min_t, y_min_t, x_max_t, y_max_t = table_box
         adjusted_boxes = []
         for box in ocr_det_boxes:
             x_min_b, y_min_b, x_max_b, y_max_b = box
             # Check if the box is fully inside table_box
-            if (x_min_b+tol >= x_min_t and y_min_b+tol >= y_min_t and
-                x_max_b-tol <= x_max_t and y_max_b-tol <= y_max_t):
+            if (
+                x_min_b + tol >= x_min_t
+                and y_min_b + tol >= y_min_t
+                and x_max_b - tol <= x_max_t
+                and y_max_b - tol <= y_max_t
+            ):
                 # Adjust the coordinates to be relative to table_box
                 adjusted_box = [
                     x_min_b - x_min_t,  # Adjust x1
                     y_min_b - y_min_t,  # Adjust y1
                     x_max_b - x_min_t,  # Adjust x2
-                    y_max_b - y_min_t   # Adjust y2
+                    y_max_b - y_min_t,  # Adjust y2
                 ]
                 adjusted_boxes.append(adjusted_box)
             # Discard boxes not fully inside table_box
         return adjusted_boxes
 
-    def cells_det_results_reprocessing(self, cells_det_results, cells_det_scores, ocr_det_results, html_pred_boxes_nums):
+    def cells_det_results_reprocessing(
+        self, cells_det_results, cells_det_scores, ocr_det_results, html_pred_boxes_nums
+    ):
         """
         Process and filter cells_det_results based on ocr_det_results and html_pred_boxes_nums.
 
@@ -373,6 +385,7 @@ class TableRecognitionPipelineV2(BasePipeline):
         Returns:
             List[List[float]]: The processed list of rectangles.
         """
+
         # Function to compute IoU between two rectangles
         def compute_iou(box1, box2):
             """
@@ -419,15 +432,17 @@ class TableRecognitionPipelineV2(BasePipeline):
             if N >= num_rects:
                 return rectangles
             # Compute the center points of the rectangles
-            centers = np.array([
+            centers = np.array(
                 [
-                    (rect[0] + rect[2]) / 2,  # Center x-coordinate
-                    (rect[1] + rect[3]) / 2   # Center y-coordinate
+                    [
+                        (rect[0] + rect[2]) / 2,  # Center x-coordinate
+                        (rect[1] + rect[3]) / 2,  # Center y-coordinate
+                    ]
+                    for rect in rectangles
                 ]
-                for rect in rectangles
-            ])
+            )
             # Perform KMeans clustering on the center points to group them into N clusters
-            kmeans = KMeans(n_clusters=N, random_state=0, n_init='auto')
+            kmeans = KMeans(n_clusters=N, random_state=0, n_init="auto")
             labels = kmeans.fit_predict(centers)
             # Initialize a list to store the combined rectangles
             combined_rectangles = []
@@ -478,28 +493,36 @@ class TableRecognitionPipelineV2(BasePipeline):
                 iou = compute_iou(ocr_rect, cell_rect)
                 if iou > 0:
                     merge_ocr_box_iou.append(iou)
-                if (iou>=iou_threshold) or (sum(merge_ocr_box_iou)>=iou_threshold):
+                if (iou >= iou_threshold) or (sum(merge_ocr_box_iou) >= iou_threshold):
                     has_large_iou = True
                     break
             if not has_large_iou:
                 ocr_miss_boxes.append(ocr_rect)
         # If no ocr_miss_boxes, return cells_det_results
         if len(ocr_miss_boxes) == 0:
-            final_results = cells_det_results if more_cells_flag==True else cells_det_results.tolist()
+            final_results = (
+                cells_det_results
+                if more_cells_flag == True
+                else cells_det_results.tolist()
+            )
         else:
             if more_cells_flag == True:
-                final_results = combine_rectangles(cells_det_results+ocr_miss_boxes, html_pred_boxes_nums)
+                final_results = combine_rectangles(
+                    cells_det_results + ocr_miss_boxes, html_pred_boxes_nums
+                )
             else:
                 # Need to combine ocr_miss_boxes into N rectangles
                 N = html_pred_boxes_nums - len(cells_det_results)
                 # Combine ocr_miss_boxes into N rectangles
                 ocr_supp_boxes = combine_rectangles(ocr_miss_boxes, N)
                 # Combine cells_det_results and ocr_supp_boxes
-                final_results = np.concatenate((cells_det_results, ocr_supp_boxes), axis=0).tolist()
-        if len(final_results) <= 0.6*html_pred_boxes_nums:
+                final_results = np.concatenate(
+                    (cells_det_results, ocr_supp_boxes), axis=0
+                ).tolist()
+        if len(final_results) <= 0.6 * html_pred_boxes_nums:
             final_results = combine_rectangles(ocr_det_results, html_pred_boxes_nums)
         return final_results
-    
+
     def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes):
         """
         Splits OCR bounding boxes by table cells and retrieves text.
@@ -523,7 +546,7 @@ class TableRecognitionPipelineV2(BasePipeline):
             # Perform OCR on the defined region of the image and get the recognized text.
             rec_te = next(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))
             # Concatenate the texts and append them to the texts_list.
-            texts_list.append(''.join(rec_te["rec_texts"]))
+            texts_list.append("".join(rec_te["rec_texts"]))
         # Return the list of recognized texts from each cell.
         return texts_list
 
@@ -533,7 +556,7 @@ class TableRecognitionPipelineV2(BasePipeline):
         overall_ocr_res: OCRResult,
         table_box: list,
         use_table_cells_ocr_results: bool = False,
-        use_e2e_wired_table_rec_model: bool = False, 
+        use_e2e_wired_table_rec_model: bool = False,
         use_e2e_wireless_table_rec_model: bool = False,
         flag_find_nei_text: bool = True,
     ) -> SingleTableRecognitionResult:
@@ -564,42 +587,73 @@ class TableRecognitionPipelineV2(BasePipeline):
             else:
                 table_cells_pred = next(
                     self.wired_table_cells_detection_model(image_array, threshold=0.3)
-                ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection. 
-                  # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
+                )  # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
+                # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
         elif table_cls_result == "wireless_table":
             table_structure_pred = next(self.wireless_table_rec_model(image_array))
             if use_e2e_wireless_table_rec_model == True:
                 use_e2e_model = True
             else:
                 table_cells_pred = next(
-                    self.wireless_table_cells_detection_model(image_array, threshold=0.3)
-                ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection. 
-                  # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
+                    self.wireless_table_cells_detection_model(
+                        image_array, threshold=0.3
+                    )
+                )  # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
+                # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
 
         if use_e2e_model == False:
-            table_structure_result = self.extract_results(table_structure_pred, "table_stru")
-            table_cells_result, table_cells_score = self.extract_results(table_cells_pred, "det")
-            table_cells_result, table_cells_score = self.cells_det_results_nms(table_cells_result, table_cells_score)
-            ocr_det_boxes = self.get_region_ocr_det_boxes(overall_ocr_res["rec_boxes"].tolist(), table_box)
+            table_structure_result = self.extract_results(
+                table_structure_pred, "table_stru"
+            )
+            table_cells_result, table_cells_score = self.extract_results(
+                table_cells_pred, "det"
+            )
+            table_cells_result, table_cells_score = self.cells_det_results_nms(
+                table_cells_result, table_cells_score
+            )
+            ocr_det_boxes = self.get_region_ocr_det_boxes(
+                overall_ocr_res["rec_boxes"].tolist(), table_box
+            )
             table_cells_result = self.cells_det_results_reprocessing(
-                table_cells_result, table_cells_score, ocr_det_boxes, len(table_structure_pred['bbox'])
+                table_cells_result,
+                table_cells_score,
+                ocr_det_boxes,
+                len(table_structure_pred["bbox"]),
             )
             if use_table_cells_ocr_results == True:
-                cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result)
+                cells_texts_list = self.split_ocr_bboxes_by_table_cells(
+                    image_array, table_cells_result
+                )
             else:
                 cells_texts_list = []
             single_table_recognition_res = get_table_recognition_res(
-                table_box, table_structure_result, table_cells_result, overall_ocr_res, cells_texts_list, use_table_cells_ocr_results
+                table_box,
+                table_structure_result,
+                table_cells_result,
+                overall_ocr_res,
+                cells_texts_list,
+                use_table_cells_ocr_results,
             )
         else:
             if use_table_cells_ocr_results == True:
-                table_cells_result_e2e = list(map(lambda arr: arr.tolist(), table_structure_pred['bbox']))
-                table_cells_result_e2e = [[rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result_e2e]
-                cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result_e2e)
+                table_cells_result_e2e = list(
+                    map(lambda arr: arr.tolist(), table_structure_pred["bbox"])
+                )
+                table_cells_result_e2e = [
+                    [rect[0], rect[1], rect[4], rect[5]]
+                    for rect in table_cells_result_e2e
+                ]
+                cells_texts_list = self.split_ocr_bboxes_by_table_cells(
+                    image_array, table_cells_result_e2e
+                )
             else:
                 cells_texts_list = []
             single_table_recognition_res = get_table_recognition_res_e2e(
-                table_box, table_structure_pred, overall_ocr_res, cells_texts_list, use_table_cells_ocr_results
+                table_box,
+                table_structure_pred,
+                overall_ocr_res,
+                cells_texts_list,
+                use_table_cells_ocr_results,
             )
 
         neighbor_text = ""
@@ -628,9 +682,9 @@ class TableRecognitionPipelineV2(BasePipeline):
         text_det_box_thresh: Optional[float] = None,
         text_det_unclip_ratio: Optional[float] = None,
         text_rec_score_thresh: Optional[float] = None,
-        use_table_cells_ocr_results: Optional[bool] = False,
-        use_e2e_wired_table_rec_model: Optional[bool] = False,
-        use_e2e_wireless_table_rec_model: Optional[bool] = False,
+        use_table_cells_ocr_results: bool = False,
+        use_e2e_wired_table_rec_model: bool = False,
+        use_e2e_wireless_table_rec_model: bool = False,
         **kwargs,
     ) -> TableRecognitionResult:
         """

+ 1 - 0
paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition.py

@@ -60,6 +60,7 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
             text_det_box_thresh=request.textDetBoxThresh,
             text_det_unclip_ratio=request.textDetUnclipRatio,
             text_rec_score_thresh=request.textRecScoreThresh,
+            use_table_cells_ocr_results=request.useTableCellsOcrResults,
         )
 
         table_rec_results: List[Dict[str, Any]] = []

+ 3 - 0
paddlex/inference/serving/basic_serving/_pipeline_apps/table_recognition_v2.py

@@ -60,6 +60,9 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
             text_det_box_thresh=request.textDetBoxThresh,
             text_det_unclip_ratio=request.textDetUnclipRatio,
             text_rec_score_thresh=request.textRecScoreThresh,
+            use_table_cells_ocr_results=request.useTableCellsOcrResults,
+            use_e2e_wired_table_rec_model=request.useE2eWiredTableRecModel,
+            use_e2e_wireless_table_rec_model=request.useE2eWirelessTableRecModel,
         )
 
         table_rec_results: List[Dict[str, Any]] = []

+ 1 - 0
paddlex/inference/serving/schemas/table_recognition.py

@@ -48,6 +48,7 @@ class InferRequest(ocr.BaseInferRequest):
     textDetBoxThresh: Optional[float] = None
     textDetUnclipRatio: Optional[float] = None
     textRecScoreThresh: Optional[float] = None
+    useTableCellsOcrResults: bool = False
 
 
 class TableRecResult(BaseModel):

+ 3 - 0
paddlex/inference/serving/schemas/table_recognition_v2.py

@@ -48,6 +48,9 @@ class InferRequest(ocr.BaseInferRequest):
     textDetBoxThresh: Optional[float] = None
     textDetUnclipRatio: Optional[float] = None
     textRecScoreThresh: Optional[float] = None
+    useTableCellsOcrResults: bool = False
+    useE2eWiredTableRecModel: bool = False
+    useE2eWirelessTableRecModel: bool = False
 
 
 class TableRecResult(BaseModel):