|
|
@@ -15,6 +15,7 @@
|
|
|
import os, sys
|
|
|
from typing import Any, Dict, Optional, Union, List, Tuple
|
|
|
import numpy as np
|
|
|
+import math
|
|
|
import cv2
|
|
|
from sklearn.cluster import KMeans
|
|
|
from ..base import BasePipeline
|
|
|
@@ -497,12 +498,40 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
if len(final_results) <= 0.6*html_pred_boxes_nums:
|
|
|
final_results = combine_rectangles(ocr_det_results, html_pred_boxes_nums)
|
|
|
return final_results
|
|
|
+
|
|
|
+ def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes):
|
|
|
+ """
|
|
|
+ Splits OCR bounding boxes by table cells and retrieves text.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ ori_img (ndarray): The original image from which text regions will be extracted.
|
|
|
+ cells_bboxes (list or ndarray): Detected cell bounding boxes to extract text from.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ list: A list containing the recognized texts from each cell.
|
|
|
+ """
|
|
|
+
|
|
|
+ # Check if cells_bboxes is a list and convert it if not.
|
|
|
+ if not isinstance(cells_bboxes, list):
|
|
|
+ cells_bboxes = cells_bboxes.tolist()
|
|
|
+ texts_list = [] # Initialize a list to store the recognized texts.
|
|
|
+ # Process each bounding box provided in cells_bboxes.
|
|
|
+ for i in range(len(cells_bboxes)):
|
|
|
+ # Extract and round up the coordinates of the bounding box.
|
|
|
+ x1, y1, x2, y2 = [math.ceil(k) for k in cells_bboxes[i]]
|
|
|
+ # Perform OCR on the defined region of the image and get the recognized text.
|
|
|
+ rec_te = next(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))
|
|
|
+ # Concatenate the texts and append them to the texts_list.
|
|
|
+ texts_list.append(''.join(rec_te["rec_texts"]))
|
|
|
+ # Return the list of recognized texts from each cell.
|
|
|
+ return texts_list
|
|
|
|
|
|
def predict_single_table_recognition_res(
|
|
|
self,
|
|
|
image_array: np.ndarray,
|
|
|
overall_ocr_res: OCRResult,
|
|
|
table_box: list,
|
|
|
+ use_table_cells_ocr_results: bool = False,
|
|
|
flag_find_nei_text: bool = True,
|
|
|
) -> SingleTableRecognitionResult:
|
|
|
"""
|
|
|
@@ -517,6 +546,7 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
Returns:
|
|
|
SingleTableRecognitionResult: single table recognition result.
|
|
|
"""
|
|
|
+
|
|
|
table_cls_pred = next(self.table_cls_model(image_array))
|
|
|
table_cls_result = self.extract_results(table_cls_pred, "cls")
|
|
|
if table_cls_result == "wired_table":
|
|
|
@@ -538,8 +568,12 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
table_cells_result = self.cells_det_results_reprocessing(
|
|
|
table_cells_result, table_cells_score, ocr_det_boxes, len(table_structure_pred['bbox'])
|
|
|
)
|
|
|
+ if use_table_cells_ocr_results == True:
|
|
|
+ cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result)
|
|
|
+ else:
|
|
|
+ cells_texts_list = []
|
|
|
single_table_recognition_res = get_table_recognition_res(
|
|
|
- table_box, table_structure_result, table_cells_result, overall_ocr_res
|
|
|
+ table_box, table_structure_result, table_cells_result, overall_ocr_res, cells_texts_list, use_table_cells_ocr_results
|
|
|
)
|
|
|
neighbor_text = ""
|
|
|
if flag_find_nei_text:
|
|
|
@@ -567,6 +601,7 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
text_det_box_thresh: Optional[float] = None,
|
|
|
text_det_unclip_ratio: Optional[float] = None,
|
|
|
text_rec_score_thresh: Optional[float] = None,
|
|
|
+ use_table_cells_ocr_results: Optional[bool] = False,
|
|
|
**kwargs,
|
|
|
) -> TableRecognitionResult:
|
|
|
"""
|
|
|
@@ -638,6 +673,7 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
doc_preprocessor_image,
|
|
|
overall_ocr_res,
|
|
|
table_box,
|
|
|
+ use_table_cells_ocr_results,
|
|
|
flag_find_nei_text=False,
|
|
|
)
|
|
|
single_table_rec_res["table_region_id"] = table_region_id
|
|
|
@@ -654,7 +690,7 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
table_box = crop_img_info["box"]
|
|
|
single_table_rec_res = (
|
|
|
self.predict_single_table_recognition_res(
|
|
|
- crop_img_info["img"], overall_ocr_res, table_box
|
|
|
+ crop_img_info["img"], overall_ocr_res, table_box, use_table_cells_ocr_results
|
|
|
)
|
|
|
)
|
|
|
single_table_rec_res["table_region_id"] = table_region_id
|