|
|
@@ -22,6 +22,7 @@ from ..base import BasePipeline
|
|
|
from ..components import CropByBoxes
|
|
|
from .utils import get_neighbor_boxes_idx
|
|
|
from .table_recognition_post_processing_v2 import get_table_recognition_res
|
|
|
+from .table_recognition_post_processing import get_table_recognition_res as get_table_recognition_res_e2e
|
|
|
from .result import SingleTableRecognitionResult, TableRecognitionResult
|
|
|
from ....utils import logging
|
|
|
from ...utils.pp_option import PaddlePredictorOption
|
|
|
@@ -532,6 +533,8 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
overall_ocr_res: OCRResult,
|
|
|
table_box: list,
|
|
|
use_table_cells_ocr_results: bool = False,
|
|
|
+ use_e2e_wired_table_rec_model: bool = False,
|
|
|
+ use_e2e_wireless_table_rec_model: bool = False,
|
|
|
flag_find_nei_text: bool = True,
|
|
|
) -> SingleTableRecognitionResult:
|
|
|
"""
|
|
|
@@ -542,6 +545,9 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
|
|
|
The overall OCR results containing text recognition information.
|
|
|
table_box (list): The table box coordinates.
|
|
|
+ use_table_cells_ocr_results (bool): whether to use OCR results with cells.
|
|
|
+ use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
|
|
|
+ use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model.
|
|
|
flag_find_nei_text (bool): Whether to find neighboring text.
|
|
|
Returns:
|
|
|
SingleTableRecognitionResult: single table recognition result.
|
|
|
@@ -549,32 +555,53 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
|
|
|
table_cls_pred = next(self.table_cls_model(image_array))
|
|
|
table_cls_result = self.extract_results(table_cls_pred, "cls")
|
|
|
+ use_e2e_model = False
|
|
|
+
|
|
|
if table_cls_result == "wired_table":
|
|
|
table_structure_pred = next(self.wired_table_rec_model(image_array))
|
|
|
- table_cells_pred = next(
|
|
|
- self.wired_table_cells_detection_model(image_array, threshold=0.3)
|
|
|
- ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
- # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
+ if use_e2e_wired_table_rec_model == True:
|
|
|
+ use_e2e_model = True
|
|
|
+ else:
|
|
|
+ table_cells_pred = next(
|
|
|
+ self.wired_table_cells_detection_model(image_array, threshold=0.3)
|
|
|
+ ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
+ # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
elif table_cls_result == "wireless_table":
|
|
|
table_structure_pred = next(self.wireless_table_rec_model(image_array))
|
|
|
- table_cells_pred = next(
|
|
|
- self.wireless_table_cells_detection_model(image_array, threshold=0.3)
|
|
|
- ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
- # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
- table_structure_result = self.extract_results(table_structure_pred, "table_stru")
|
|
|
- table_cells_result, table_cells_score = self.extract_results(table_cells_pred, "det")
|
|
|
- table_cells_result, table_cells_score = self.cells_det_results_nms(table_cells_result, table_cells_score)
|
|
|
- ocr_det_boxes = self.get_region_ocr_det_boxes(overall_ocr_res["rec_boxes"].tolist(), table_box)
|
|
|
- table_cells_result = self.cells_det_results_reprocessing(
|
|
|
- table_cells_result, table_cells_score, ocr_det_boxes, len(table_structure_pred['bbox'])
|
|
|
- )
|
|
|
- if use_table_cells_ocr_results == True:
|
|
|
- cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result)
|
|
|
+ if use_e2e_wireless_table_rec_model == True:
|
|
|
+ use_e2e_model = True
|
|
|
+ else:
|
|
|
+ table_cells_pred = next(
|
|
|
+ self.wireless_table_cells_detection_model(image_array, threshold=0.3)
|
|
|
+ ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
+ # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
+
|
|
|
+ if use_e2e_model == False:
|
|
|
+ table_structure_result = self.extract_results(table_structure_pred, "table_stru")
|
|
|
+ table_cells_result, table_cells_score = self.extract_results(table_cells_pred, "det")
|
|
|
+ table_cells_result, table_cells_score = self.cells_det_results_nms(table_cells_result, table_cells_score)
|
|
|
+ ocr_det_boxes = self.get_region_ocr_det_boxes(overall_ocr_res["rec_boxes"].tolist(), table_box)
|
|
|
+ table_cells_result = self.cells_det_results_reprocessing(
|
|
|
+ table_cells_result, table_cells_score, ocr_det_boxes, len(table_structure_pred['bbox'])
|
|
|
+ )
|
|
|
+ if use_table_cells_ocr_results == True:
|
|
|
+ cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result)
|
|
|
+ else:
|
|
|
+ cells_texts_list = []
|
|
|
+ single_table_recognition_res = get_table_recognition_res(
|
|
|
+ table_box, table_structure_result, table_cells_result, overall_ocr_res, cells_texts_list, use_table_cells_ocr_results
|
|
|
+ )
|
|
|
else:
|
|
|
- cells_texts_list = []
|
|
|
- single_table_recognition_res = get_table_recognition_res(
|
|
|
- table_box, table_structure_result, table_cells_result, overall_ocr_res, cells_texts_list, use_table_cells_ocr_results
|
|
|
- )
|
|
|
+ if use_table_cells_ocr_results == True:
|
|
|
+ table_cells_result_e2e = list(map(lambda arr: arr.tolist(), table_structure_pred['bbox']))
|
|
|
+ table_cells_result_e2e = [[rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result_e2e]
|
|
|
+ cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result_e2e)
|
|
|
+ else:
|
|
|
+ cells_texts_list = []
|
|
|
+ single_table_recognition_res = get_table_recognition_res_e2e(
|
|
|
+ table_box, table_structure_pred, overall_ocr_res, cells_texts_list, use_table_cells_ocr_results
|
|
|
+ )
|
|
|
+
|
|
|
neighbor_text = ""
|
|
|
if flag_find_nei_text:
|
|
|
match_idx_list = get_neighbor_boxes_idx(
|
|
|
@@ -602,6 +629,8 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
text_det_unclip_ratio: Optional[float] = None,
|
|
|
text_rec_score_thresh: Optional[float] = None,
|
|
|
use_table_cells_ocr_results: Optional[bool] = False,
|
|
|
+ use_e2e_wired_table_rec_model: Optional[bool] = False,
|
|
|
+ use_e2e_wireless_table_rec_model: Optional[bool] = False,
|
|
|
**kwargs,
|
|
|
) -> TableRecognitionResult:
|
|
|
"""
|
|
|
@@ -616,6 +645,10 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
It will be used if it is not None and use_ocr_model is False.
|
|
|
layout_det_res (DetResult): The layout detection result.
|
|
|
It will be used if it is not None and use_layout_detection is False.
|
|
|
+ use_table_cells_ocr_results (bool): whether to use OCR results with cells.
|
|
|
+ use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
|
|
|
+ use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model.
|
|
|
+ flag_find_nei_text (bool): Whether to find neighboring text.
|
|
|
**kwargs: Additional keyword arguments.
|
|
|
|
|
|
Returns:
|
|
|
@@ -674,6 +707,8 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
overall_ocr_res,
|
|
|
table_box,
|
|
|
use_table_cells_ocr_results,
|
|
|
+ use_e2e_wired_table_rec_model,
|
|
|
+ use_e2e_wireless_table_rec_model,
|
|
|
flag_find_nei_text=False,
|
|
|
)
|
|
|
single_table_rec_res["table_region_id"] = table_region_id
|
|
|
@@ -690,7 +725,12 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
table_box = crop_img_info["box"]
|
|
|
single_table_rec_res = (
|
|
|
self.predict_single_table_recognition_res(
|
|
|
- crop_img_info["img"], overall_ocr_res, table_box, use_table_cells_ocr_results
|
|
|
+ crop_img_info["img"],
|
|
|
+ overall_ocr_res,
|
|
|
+ table_box,
|
|
|
+ use_table_cells_ocr_results,
|
|
|
+ use_e2e_wired_table_rec_model,
|
|
|
+ use_e2e_wireless_table_rec_model,
|
|
|
)
|
|
|
)
|
|
|
single_table_rec_res["table_region_id"] = table_region_id
|