|
|
@@ -128,6 +128,11 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
{"pipeline_config_error": "config error for general_ocr_pipeline!"},
|
|
|
)
|
|
|
self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
|
|
|
+ else:
|
|
|
+ self.general_ocr_config_bak = config.get("SubPipelines", {}).get(
|
|
|
+ "GeneralOCR",
|
|
|
+ None
|
|
|
+ )
|
|
|
|
|
|
self._crop_by_boxes = CropByBoxes()
|
|
|
|
|
|
@@ -595,25 +600,15 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
use_e2e_model = True
|
|
|
else:
|
|
|
table_cells_pred = next(
|
|
|
- self.wireless_table_cells_detection_model(
|
|
|
- image_array, threshold=0.3
|
|
|
- )
|
|
|
+ self.wireless_table_cells_detection_model(image_array, threshold=0.3)
|
|
|
) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
# If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
|
|
|
if use_e2e_model == False:
|
|
|
- table_structure_result = self.extract_results(
|
|
|
- table_structure_pred, "table_stru"
|
|
|
- )
|
|
|
- table_cells_result, table_cells_score = self.extract_results(
|
|
|
- table_cells_pred, "det"
|
|
|
- )
|
|
|
- table_cells_result, table_cells_score = self.cells_det_results_nms(
|
|
|
- table_cells_result, table_cells_score
|
|
|
- )
|
|
|
- ocr_det_boxes = self.get_region_ocr_det_boxes(
|
|
|
- overall_ocr_res["rec_boxes"].tolist(), table_box
|
|
|
- )
|
|
|
+ table_structure_result = self.extract_results(table_structure_pred, "table_stru")
|
|
|
+ table_cells_result, table_cells_score = self.extract_results(table_cells_pred, "det")
|
|
|
+ table_cells_result, table_cells_score = self.cells_det_results_nms(table_cells_result, table_cells_score)
|
|
|
+ ocr_det_boxes = self.get_region_ocr_det_boxes(overall_ocr_res["rec_boxes"].tolist(), table_box)
|
|
|
table_cells_result = self.cells_det_results_reprocessing(
|
|
|
table_cells_result,
|
|
|
table_cells_score,
|
|
|
@@ -621,9 +616,7 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
len(table_structure_pred["bbox"]),
|
|
|
)
|
|
|
if use_table_cells_ocr_results == True:
|
|
|
- cells_texts_list = self.split_ocr_bboxes_by_table_cells(
|
|
|
- image_array, table_cells_result
|
|
|
- )
|
|
|
+ cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result)
|
|
|
else:
|
|
|
cells_texts_list = []
|
|
|
single_table_recognition_res = get_table_recognition_res(
|
|
|
@@ -636,16 +629,9 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
)
|
|
|
else:
|
|
|
if use_table_cells_ocr_results == True:
|
|
|
- table_cells_result_e2e = list(
|
|
|
- map(lambda arr: arr.tolist(), table_structure_pred["bbox"])
|
|
|
- )
|
|
|
- table_cells_result_e2e = [
|
|
|
- [rect[0], rect[1], rect[4], rect[5]]
|
|
|
- for rect in table_cells_result_e2e
|
|
|
- ]
|
|
|
- cells_texts_list = self.split_ocr_bboxes_by_table_cells(
|
|
|
- image_array, table_cells_result_e2e
|
|
|
- )
|
|
|
+ table_cells_result_e2e = list(map(lambda arr: arr.tolist(), table_structure_pred["bbox"]))
|
|
|
+ table_cells_result_e2e = [[rect[0], rect[1], rect[4], rect[5]]for rect in table_cells_result_e2e]
|
|
|
+ cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result_e2e)
|
|
|
else:
|
|
|
cells_texts_list = []
|
|
|
single_table_recognition_res = get_table_recognition_res_e2e(
|
|
|
@@ -749,6 +735,9 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
text_rec_score_thresh=text_rec_score_thresh,
|
|
|
)
|
|
|
)
|
|
|
+ elif use_table_cells_ocr_results == True:
|
|
|
+ assert self.general_ocr_config_bak != None
|
|
|
+ self.general_ocr_pipeline = self.create_pipeline(self.general_ocr_config_bak)
|
|
|
|
|
|
table_res_list = []
|
|
|
table_region_id = 1
|