|
|
@@ -22,7 +22,9 @@ from ..base import BasePipeline
|
|
|
from ..components import CropByBoxes
|
|
|
from .utils import get_neighbor_boxes_idx
|
|
|
from .table_recognition_post_processing_v2 import get_table_recognition_res
|
|
|
-from .table_recognition_post_processing import get_table_recognition_res as get_table_recognition_res_e2e
|
|
|
+from .table_recognition_post_processing import (
|
|
|
+ get_table_recognition_res as get_table_recognition_res_e2e,
|
|
|
+)
|
|
|
from .result import SingleTableRecognitionResult, TableRecognitionResult
|
|
|
from ....utils import logging
|
|
|
from ...utils.pp_option import PaddlePredictorOption
|
|
|
@@ -276,8 +278,10 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
return pred["structure"]
|
|
|
else:
|
|
|
return None
|
|
|
-
|
|
|
- def cells_det_results_nms(self, cells_det_results, cells_det_scores, cells_det_threshold=0.3):
|
|
|
+
|
|
|
+ def cells_det_results_nms(
|
|
|
+ self, cells_det_results, cells_det_scores, cells_det_threshold=0.3
|
|
|
+ ):
|
|
|
"""
|
|
|
Apply Non-Maximum Suppression (NMS) on detection results to remove redundant overlapping bounding boxes.
|
|
|
|
|
|
@@ -324,12 +328,14 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
# Indices of boxes with IoU less than threshold
|
|
|
inds = np.where(ovr <= cells_det_threshold)[0]
|
|
|
# Update order, only keep boxes with IoU less than threshold
|
|
|
- order = order[inds + 1] # inds shifted by 1 because order[0] is the current box
|
|
|
+ order = order[
|
|
|
+ inds + 1
|
|
|
+ ] # inds shifted by 1 because order[0] is the current box
|
|
|
# Select the boxes and scores based on picked indices
|
|
|
final_boxes = boxes[picked_indices].tolist()
|
|
|
final_scores = scores[picked_indices].tolist()
|
|
|
return final_boxes, final_scores
|
|
|
-
|
|
|
+
|
|
|
def get_region_ocr_det_boxes(self, ocr_det_boxes, table_box):
|
|
|
"""Adjust the coordinates of ocr_det_boxes that are fully inside table_box relative to table_box.
|
|
|
|
|
|
@@ -340,27 +346,33 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
Returns:
|
|
|
list of list: List of adjusted bounding boxes relative to table_box, for boxes fully inside table_box.
|
|
|
"""
|
|
|
- tol=0
|
|
|
+ tol = 0
|
|
|
# Extract coordinates from table_box
|
|
|
x_min_t, y_min_t, x_max_t, y_max_t = table_box
|
|
|
adjusted_boxes = []
|
|
|
for box in ocr_det_boxes:
|
|
|
x_min_b, y_min_b, x_max_b, y_max_b = box
|
|
|
# Check if the box is fully inside table_box
|
|
|
- if (x_min_b+tol >= x_min_t and y_min_b+tol >= y_min_t and
|
|
|
- x_max_b-tol <= x_max_t and y_max_b-tol <= y_max_t):
|
|
|
+ if (
|
|
|
+ x_min_b + tol >= x_min_t
|
|
|
+ and y_min_b + tol >= y_min_t
|
|
|
+ and x_max_b - tol <= x_max_t
|
|
|
+ and y_max_b - tol <= y_max_t
|
|
|
+ ):
|
|
|
# Adjust the coordinates to be relative to table_box
|
|
|
adjusted_box = [
|
|
|
x_min_b - x_min_t, # Adjust x1
|
|
|
y_min_b - y_min_t, # Adjust y1
|
|
|
x_max_b - x_min_t, # Adjust x2
|
|
|
- y_max_b - y_min_t # Adjust y2
|
|
|
+ y_max_b - y_min_t, # Adjust y2
|
|
|
]
|
|
|
adjusted_boxes.append(adjusted_box)
|
|
|
# Discard boxes not fully inside table_box
|
|
|
return adjusted_boxes
|
|
|
|
|
|
- def cells_det_results_reprocessing(self, cells_det_results, cells_det_scores, ocr_det_results, html_pred_boxes_nums):
|
|
|
+ def cells_det_results_reprocessing(
|
|
|
+ self, cells_det_results, cells_det_scores, ocr_det_results, html_pred_boxes_nums
|
|
|
+ ):
|
|
|
"""
|
|
|
Process and filter cells_det_results based on ocr_det_results and html_pred_boxes_nums.
|
|
|
|
|
|
@@ -373,6 +385,7 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
Returns:
|
|
|
List[List[float]]: The processed list of rectangles.
|
|
|
"""
|
|
|
+
|
|
|
# Function to compute IoU between two rectangles
|
|
|
def compute_iou(box1, box2):
|
|
|
"""
|
|
|
@@ -419,15 +432,17 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
if N >= num_rects:
|
|
|
return rectangles
|
|
|
# Compute the center points of the rectangles
|
|
|
- centers = np.array([
|
|
|
+ centers = np.array(
|
|
|
[
|
|
|
- (rect[0] + rect[2]) / 2, # Center x-coordinate
|
|
|
- (rect[1] + rect[3]) / 2 # Center y-coordinate
|
|
|
+ [
|
|
|
+ (rect[0] + rect[2]) / 2, # Center x-coordinate
|
|
|
+ (rect[1] + rect[3]) / 2, # Center y-coordinate
|
|
|
+ ]
|
|
|
+ for rect in rectangles
|
|
|
]
|
|
|
- for rect in rectangles
|
|
|
- ])
|
|
|
+ )
|
|
|
# Perform KMeans clustering on the center points to group them into N clusters
|
|
|
- kmeans = KMeans(n_clusters=N, random_state=0, n_init='auto')
|
|
|
+ kmeans = KMeans(n_clusters=N, random_state=0, n_init="auto")
|
|
|
labels = kmeans.fit_predict(centers)
|
|
|
# Initialize a list to store the combined rectangles
|
|
|
combined_rectangles = []
|
|
|
@@ -478,28 +493,36 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
iou = compute_iou(ocr_rect, cell_rect)
|
|
|
if iou > 0:
|
|
|
merge_ocr_box_iou.append(iou)
|
|
|
- if (iou>=iou_threshold) or (sum(merge_ocr_box_iou)>=iou_threshold):
|
|
|
+ if (iou >= iou_threshold) or (sum(merge_ocr_box_iou) >= iou_threshold):
|
|
|
has_large_iou = True
|
|
|
break
|
|
|
if not has_large_iou:
|
|
|
ocr_miss_boxes.append(ocr_rect)
|
|
|
# If no ocr_miss_boxes, return cells_det_results
|
|
|
if len(ocr_miss_boxes) == 0:
|
|
|
- final_results = cells_det_results if more_cells_flag==True else cells_det_results.tolist()
|
|
|
+ final_results = (
|
|
|
+ cells_det_results
|
|
|
+ if more_cells_flag == True
|
|
|
+ else cells_det_results.tolist()
|
|
|
+ )
|
|
|
else:
|
|
|
if more_cells_flag == True:
|
|
|
- final_results = combine_rectangles(cells_det_results+ocr_miss_boxes, html_pred_boxes_nums)
|
|
|
+ final_results = combine_rectangles(
|
|
|
+ cells_det_results + ocr_miss_boxes, html_pred_boxes_nums
|
|
|
+ )
|
|
|
else:
|
|
|
# Need to combine ocr_miss_boxes into N rectangles
|
|
|
N = html_pred_boxes_nums - len(cells_det_results)
|
|
|
# Combine ocr_miss_boxes into N rectangles
|
|
|
ocr_supp_boxes = combine_rectangles(ocr_miss_boxes, N)
|
|
|
# Combine cells_det_results and ocr_supp_boxes
|
|
|
- final_results = np.concatenate((cells_det_results, ocr_supp_boxes), axis=0).tolist()
|
|
|
- if len(final_results) <= 0.6*html_pred_boxes_nums:
|
|
|
+ final_results = np.concatenate(
|
|
|
+ (cells_det_results, ocr_supp_boxes), axis=0
|
|
|
+ ).tolist()
|
|
|
+ if len(final_results) <= 0.6 * html_pred_boxes_nums:
|
|
|
final_results = combine_rectangles(ocr_det_results, html_pred_boxes_nums)
|
|
|
return final_results
|
|
|
-
|
|
|
+
|
|
|
def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes):
|
|
|
"""
|
|
|
Splits OCR bounding boxes by table cells and retrieves text.
|
|
|
@@ -523,7 +546,7 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
# Perform OCR on the defined region of the image and get the recognized text.
|
|
|
rec_te = next(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))
|
|
|
# Concatenate the texts and append them to the texts_list.
|
|
|
- texts_list.append(''.join(rec_te["rec_texts"]))
|
|
|
+ texts_list.append("".join(rec_te["rec_texts"]))
|
|
|
# Return the list of recognized texts from each cell.
|
|
|
return texts_list
|
|
|
|
|
|
@@ -533,7 +556,7 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
overall_ocr_res: OCRResult,
|
|
|
table_box: list,
|
|
|
use_table_cells_ocr_results: bool = False,
|
|
|
- use_e2e_wired_table_rec_model: bool = False,
|
|
|
+ use_e2e_wired_table_rec_model: bool = False,
|
|
|
use_e2e_wireless_table_rec_model: bool = False,
|
|
|
flag_find_nei_text: bool = True,
|
|
|
) -> SingleTableRecognitionResult:
|
|
|
@@ -564,42 +587,73 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
else:
|
|
|
table_cells_pred = next(
|
|
|
self.wired_table_cells_detection_model(image_array, threshold=0.3)
|
|
|
- ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
- # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
+ ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
+ # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
elif table_cls_result == "wireless_table":
|
|
|
table_structure_pred = next(self.wireless_table_rec_model(image_array))
|
|
|
if use_e2e_wireless_table_rec_model == True:
|
|
|
use_e2e_model = True
|
|
|
else:
|
|
|
table_cells_pred = next(
|
|
|
- self.wireless_table_cells_detection_model(image_array, threshold=0.3)
|
|
|
- ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
- # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
+ self.wireless_table_cells_detection_model(
|
|
|
+ image_array, threshold=0.3
|
|
|
+ )
|
|
|
+ ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
+ # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
|
|
|
if use_e2e_model == False:
|
|
|
- table_structure_result = self.extract_results(table_structure_pred, "table_stru")
|
|
|
- table_cells_result, table_cells_score = self.extract_results(table_cells_pred, "det")
|
|
|
- table_cells_result, table_cells_score = self.cells_det_results_nms(table_cells_result, table_cells_score)
|
|
|
- ocr_det_boxes = self.get_region_ocr_det_boxes(overall_ocr_res["rec_boxes"].tolist(), table_box)
|
|
|
+ table_structure_result = self.extract_results(
|
|
|
+ table_structure_pred, "table_stru"
|
|
|
+ )
|
|
|
+ table_cells_result, table_cells_score = self.extract_results(
|
|
|
+ table_cells_pred, "det"
|
|
|
+ )
|
|
|
+ table_cells_result, table_cells_score = self.cells_det_results_nms(
|
|
|
+ table_cells_result, table_cells_score
|
|
|
+ )
|
|
|
+ ocr_det_boxes = self.get_region_ocr_det_boxes(
|
|
|
+ overall_ocr_res["rec_boxes"].tolist(), table_box
|
|
|
+ )
|
|
|
table_cells_result = self.cells_det_results_reprocessing(
|
|
|
- table_cells_result, table_cells_score, ocr_det_boxes, len(table_structure_pred['bbox'])
|
|
|
+ table_cells_result,
|
|
|
+ table_cells_score,
|
|
|
+ ocr_det_boxes,
|
|
|
+ len(table_structure_pred["bbox"]),
|
|
|
)
|
|
|
if use_table_cells_ocr_results == True:
|
|
|
- cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result)
|
|
|
+ cells_texts_list = self.split_ocr_bboxes_by_table_cells(
|
|
|
+ image_array, table_cells_result
|
|
|
+ )
|
|
|
else:
|
|
|
cells_texts_list = []
|
|
|
single_table_recognition_res = get_table_recognition_res(
|
|
|
- table_box, table_structure_result, table_cells_result, overall_ocr_res, cells_texts_list, use_table_cells_ocr_results
|
|
|
+ table_box,
|
|
|
+ table_structure_result,
|
|
|
+ table_cells_result,
|
|
|
+ overall_ocr_res,
|
|
|
+ cells_texts_list,
|
|
|
+ use_table_cells_ocr_results,
|
|
|
)
|
|
|
else:
|
|
|
if use_table_cells_ocr_results == True:
|
|
|
- table_cells_result_e2e = list(map(lambda arr: arr.tolist(), table_structure_pred['bbox']))
|
|
|
- table_cells_result_e2e = [[rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result_e2e]
|
|
|
- cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result_e2e)
|
|
|
+ table_cells_result_e2e = list(
|
|
|
+ map(lambda arr: arr.tolist(), table_structure_pred["bbox"])
|
|
|
+ )
|
|
|
+ table_cells_result_e2e = [
|
|
|
+ [rect[0], rect[1], rect[4], rect[5]]
|
|
|
+ for rect in table_cells_result_e2e
|
|
|
+ ]
|
|
|
+ cells_texts_list = self.split_ocr_bboxes_by_table_cells(
|
|
|
+ image_array, table_cells_result_e2e
|
|
|
+ )
|
|
|
else:
|
|
|
cells_texts_list = []
|
|
|
single_table_recognition_res = get_table_recognition_res_e2e(
|
|
|
- table_box, table_structure_pred, overall_ocr_res, cells_texts_list, use_table_cells_ocr_results
|
|
|
+ table_box,
|
|
|
+ table_structure_pred,
|
|
|
+ overall_ocr_res,
|
|
|
+ cells_texts_list,
|
|
|
+ use_table_cells_ocr_results,
|
|
|
)
|
|
|
|
|
|
neighbor_text = ""
|
|
|
@@ -628,9 +682,9 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
text_det_box_thresh: Optional[float] = None,
|
|
|
text_det_unclip_ratio: Optional[float] = None,
|
|
|
text_rec_score_thresh: Optional[float] = None,
|
|
|
- use_table_cells_ocr_results: Optional[bool] = False,
|
|
|
- use_e2e_wired_table_rec_model: Optional[bool] = False,
|
|
|
- use_e2e_wireless_table_rec_model: Optional[bool] = False,
|
|
|
+ use_table_cells_ocr_results: bool = False,
|
|
|
+ use_e2e_wired_table_rec_model: bool = False,
|
|
|
+ use_e2e_wireless_table_rec_model: bool = False,
|
|
|
**kwargs,
|
|
|
) -> TableRecognitionResult:
|
|
|
"""
|