# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import re from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np from ....utils import logging from ....utils.deps import ( function_requires_deps, is_dep_available, pipeline_requires_extra, ) from ...common.batch_sampler import ImageBatchSampler from ...common.reader import ReadImage from ...models.object_detection.result import DetResult from ...utils.benchmark import benchmark from ...utils.hpi import HPIConfig from ...utils.pp_option import PaddlePredictorOption from .._parallel import AutoParallelImageSimpleInferencePipeline from ..base import BasePipeline from ..components import CropByBoxes from ..doc_preprocessor.result import DocPreprocessorResult from ..layout_parsing.utils import get_sub_regions_ocr_res from ..ocr.result import OCRResult from .result import SingleTableRecognitionResult, TableRecognitionResult from .table_recognition_post_processing import ( get_table_recognition_res as get_table_recognition_res_e2e, ) from .table_recognition_post_processing_v2 import get_table_recognition_res from .utils import get_neighbor_boxes_idx if is_dep_available("scikit-learn"): from sklearn.cluster import KMeans @benchmark.time_methods class _TableRecognitionPipelineV2(BasePipeline): """Table Recognition Pipeline""" def __init__( self, config: Dict, device: str = None, pp_option: PaddlePredictorOption = None, use_hpip: bool = False, hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None, ) -> None: """Initializes the layout parsing pipeline. Args: config (Dict): Configuration dictionary containing various settings. device (str, optional): Device to run the predictions on. Defaults to None. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None. use_hpip (bool, optional): Whether to use the high-performance inference plugin (HPIP) by default. Defaults to False. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional): The default high-performance inference configuration dictionary. Defaults to None. """ super().__init__( device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_config=hpi_config ) self.use_doc_preprocessor = config.get("use_doc_preprocessor", True) if self.use_doc_preprocessor: doc_preprocessor_config = config.get("SubPipelines", {}).get( "DocPreprocessor", { "pipeline_config_error": "config error for doc_preprocessor_pipeline!" }, ) self.doc_preprocessor_pipeline = self.create_pipeline( doc_preprocessor_config ) self.use_layout_detection = config.get("use_layout_detection", True) if self.use_layout_detection: layout_det_config = config.get("SubModules", {}).get( "LayoutDetection", {"model_config_error": "config error for layout_det_model!"}, ) self.layout_det_model = self.create_model(layout_det_config) table_cls_config = config.get("SubModules", {}).get( "TableClassification", {"model_config_error": "config error for table_classification_model!"}, ) self.table_cls_model = self.create_model(table_cls_config) wired_table_rec_config = config.get("SubModules", {}).get( "WiredTableStructureRecognition", {"model_config_error": "config error for wired_table_structure_model!"}, ) self.wired_table_rec_model = self.create_model(wired_table_rec_config) wireless_table_rec_config = config.get("SubModules", {}).get( "WirelessTableStructureRecognition", {"model_config_error": "config error for wireless_table_structure_model!"}, ) self.wireless_table_rec_model = self.create_model(wireless_table_rec_config) wired_table_cells_det_config = config.get("SubModules", {}).get( "WiredTableCellsDetection", { "model_config_error": "config error for wired_table_cells_detection_model!" }, ) self.wired_table_cells_detection_model = self.create_model( wired_table_cells_det_config ) wireless_table_cells_det_config = config.get("SubModules", {}).get( "WirelessTableCellsDetection", { "model_config_error": "config error for wireless_table_cells_detection_model!" }, ) self.wireless_table_cells_detection_model = self.create_model( wireless_table_cells_det_config ) self.use_ocr_model = config.get("use_ocr_model", True) self.general_ocr_pipeline = None if self.use_ocr_model: general_ocr_config = config.get("SubPipelines", {}).get( "GeneralOCR", {"pipeline_config_error": "config error for general_ocr_pipeline!"}, ) self.general_ocr_pipeline = self.create_pipeline(general_ocr_config) else: self.general_ocr_config_bak = config.get("SubPipelines", {}).get( "GeneralOCR", None ) self.table_orientation_classify_model = None self.table_orientation_classify_config = config.get("SubModules", {}).get( "TableOrientationClassify", None ) self._crop_by_boxes = CropByBoxes() self.batch_sampler = ImageBatchSampler(batch_size=1) self.img_reader = ReadImage(format="BGR") def get_model_settings( self, use_doc_orientation_classify: Optional[bool], use_doc_unwarping: Optional[bool], use_layout_detection: Optional[bool], use_ocr_model: Optional[bool], ) -> dict: """ Get the model settings based on the provided parameters or default values. Args: use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification. use_doc_unwarping (Optional[bool]): Whether to use document unwarping. use_layout_detection (Optional[bool]): Whether to use layout detection. use_ocr_model (Optional[bool]): Whether to use OCR model. Returns: dict: A dictionary containing the model settings. """ if use_doc_orientation_classify is None and use_doc_unwarping is None: use_doc_preprocessor = self.use_doc_preprocessor else: if use_doc_orientation_classify is True or use_doc_unwarping is True: use_doc_preprocessor = True else: use_doc_preprocessor = False if use_layout_detection is None: use_layout_detection = self.use_layout_detection if use_ocr_model is None: use_ocr_model = self.use_ocr_model return dict( use_doc_preprocessor=use_doc_preprocessor, use_layout_detection=use_layout_detection, use_ocr_model=use_ocr_model, ) def check_model_settings_valid( self, model_settings: Dict, overall_ocr_res: OCRResult, layout_det_res: DetResult, ) -> bool: """ Check if the input parameters are valid based on the initialized models. Args: model_settings (Dict): A dictionary containing input parameters. overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline. The overall OCR result with convert_points_to_boxes information. layout_det_res (DetResult): The layout detection result. Returns: bool: True if all required models are initialized according to input parameters, False otherwise. """ if model_settings["use_doc_preprocessor"] and not self.use_doc_preprocessor: logging.error( "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized." ) return False if model_settings["use_layout_detection"]: if layout_det_res is not None: logging.error( "The layout detection model has already been initialized, please set use_layout_detection=False" ) return False if not self.use_layout_detection: logging.error( "Set use_layout_detection, but the models for layout detection are not initialized." ) return False if model_settings["use_ocr_model"]: if overall_ocr_res is not None: logging.error( "The OCR models have already been initialized, please set use_ocr_model=False" ) return False if not self.use_ocr_model: logging.error( "Set use_ocr_model, but the models for OCR are not initialized." ) return False else: if overall_ocr_res is None: logging.error("Set use_ocr_model=False, but no OCR results were found.") return False return True def predict_doc_preprocessor_res( self, image_array: np.ndarray, input_params: dict ) -> Tuple[DocPreprocessorResult, np.ndarray]: """ Preprocess the document image based on input parameters. Args: image_array (np.ndarray): The input image array. input_params (dict): Dictionary containing preprocessing parameters. Returns: tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing result dictionary and the processed image array. """ if input_params["use_doc_preprocessor"]: use_doc_orientation_classify = input_params["use_doc_orientation_classify"] use_doc_unwarping = input_params["use_doc_unwarping"] doc_preprocessor_res = list( self.doc_preprocessor_pipeline( image_array, use_doc_orientation_classify=use_doc_orientation_classify, use_doc_unwarping=use_doc_unwarping, ) )[0] doc_preprocessor_image = doc_preprocessor_res["output_img"] else: doc_preprocessor_res = {} doc_preprocessor_image = image_array return doc_preprocessor_res, doc_preprocessor_image def extract_results(self, pred, task): if task == "cls": return pred["label_names"][np.argmax(pred["scores"])] elif task == "det": threshold = 0.0 result = [] cell_score = [] if "boxes" in pred and isinstance(pred["boxes"], list): for box in pred["boxes"]: if isinstance(box, dict) and "score" in box and "coordinate" in box: score = box["score"] coordinate = box["coordinate"] if isinstance(score, float) and score > threshold: result.append(coordinate) cell_score.append(score) return result, cell_score elif task == "table_stru": return pred["structure"] else: return None def cells_det_results_nms( self, cells_det_results, cells_det_scores, cells_det_threshold=0.3 ): """ Apply Non-Maximum Suppression (NMS) on detection results to remove redundant overlapping bounding boxes. Args: cells_det_results (list): List of bounding boxes, each box is in format [x1, y1, x2, y2]. cells_det_scores (list): List of confidence scores corresponding to the bounding boxes. cells_det_threshold (float): IoU threshold for suppression. Boxes with IoU greater than this threshold will be suppressed. Default is 0.5. Returns: Tuple[list, list]: A tuple containing the list of bounding boxes and confidence scores after NMS, while maintaining one-to-one correspondence. """ # Convert lists to numpy arrays for efficient computation boxes = np.array(cells_det_results) scores = np.array(cells_det_scores) # Initialize list for picked indices picked_indices = [] # Get coordinates of bounding boxes x1 = boxes[:, 0] y1 = boxes[:, 1] x2 = boxes[:, 2] y2 = boxes[:, 3] # Compute the area of the bounding boxes areas = (x2 - x1) * (y2 - y1) # Sort the bounding boxes by the confidence scores in descending order order = scores.argsort()[::-1] # Process the boxes while order.size > 0: # Index of the current highest score box i = order[0] picked_indices.append(i) # Compute IoU between the highest score box and the rest xx1 = np.maximum(x1[i], x1[order[1:]]) yy1 = np.maximum(y1[i], y1[order[1:]]) xx2 = np.minimum(x2[i], x2[order[1:]]) yy2 = np.minimum(y2[i], y2[order[1:]]) # Compute the width and height of the overlapping area w = np.maximum(0.0, xx2 - xx1) h = np.maximum(0.0, yy2 - yy1) # Compute the ratio of overlap (IoU) inter = w * h ovr = inter / (areas[i] + areas[order[1:]] - inter) # Indices of boxes with IoU less than threshold inds = np.where(ovr <= cells_det_threshold)[0] # Update order, only keep boxes with IoU less than threshold order = order[ inds + 1 ] # inds shifted by 1 because order[0] is the current box # Select the boxes and scores based on picked indices final_boxes = boxes[picked_indices].tolist() final_scores = scores[picked_indices].tolist() return final_boxes, final_scores def get_region_ocr_det_boxes(self, ocr_det_boxes, table_box): """Adjust the coordinates of ocr_det_boxes that are fully inside table_box relative to table_box. Args: ocr_det_boxes (list of list): List of bounding boxes [x1, y1, x2, y2] in the original image. table_box (list): Bounding box [x1, y1, x2, y2] of the target region in the original image. Returns: list of list: List of adjusted bounding boxes relative to table_box, for boxes fully inside table_box. """ tol = 0 # Extract coordinates from table_box x_min_t, y_min_t, x_max_t, y_max_t = table_box adjusted_boxes = [] for box in ocr_det_boxes: x_min_b, y_min_b, x_max_b, y_max_b = box # Check if the box is fully inside table_box if ( x_min_b + tol >= x_min_t and y_min_b + tol >= y_min_t and x_max_b - tol <= x_max_t and y_max_b - tol <= y_max_t ): # Adjust the coordinates to be relative to table_box adjusted_box = [ x_min_b - x_min_t, # Adjust x1 y_min_b - y_min_t, # Adjust y1 x_max_b - x_min_t, # Adjust x2 y_max_b - y_min_t, # Adjust y2 ] adjusted_boxes.append(adjusted_box) # Discard boxes not fully inside table_box return adjusted_boxes def cells_det_results_reprocessing( self, cells_det_results, cells_det_scores, ocr_det_results, html_pred_boxes_nums ): """ Process and filter cells_det_results based on ocr_det_results and html_pred_boxes_nums. Args: cells_det_results (List[List[float]]): List of detected cell rectangles [[x1, y1, x2, y2], ...]. cells_det_scores (List[float]): List of confidence scores for each rectangle in cells_det_results. ocr_det_results (List[List[float]]): List of OCR detected rectangles [[x1, y1, x2, y2], ...]. html_pred_boxes_nums (int): The desired number of rectangles in the final output. Returns: List[List[float]]: The processed list of rectangles. """ # Function to compute IoU between two rectangles def compute_iou(box1, box2): """ Compute the Intersection over Union (IoU) between two rectangles. Args: box1 (array-like): [x1, y1, x2, y2] of the first rectangle. box2 (array-like): [x1, y1, x2, y2] of the second rectangle. Returns: float: The IoU between the two rectangles. """ # Determine the coordinates of the intersection rectangle x_left = max(box1[0], box2[0]) y_top = max(box1[1], box2[1]) x_right = min(box1[2], box2[2]) y_bottom = min(box1[3], box2[3]) if x_right <= x_left or y_bottom <= y_top: return 0.0 # Calculate the area of intersection rectangle intersection_area = (x_right - x_left) * (y_bottom - y_top) # Calculate the area of both rectangles box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) (box2[2] - box2[0]) * (box2[3] - box2[1]) # Calculate the IoU iou = intersection_area / float(box1_area) return iou # Function to combine rectangles into N rectangles @function_requires_deps("scikit-learn") def combine_rectangles(rectangles, N): """ Combine rectangles into N rectangles based on geometric proximity. Args: rectangles (list of list of int): A list of rectangles, each represented by [x1, y1, x2, y2]. N (int): The desired number of combined rectangles. Returns: list of list of int: A list of N combined rectangles. """ # Number of input rectangles num_rects = len(rectangles) # If N is greater than or equal to the number of rectangles, return the original rectangles if N >= num_rects: return rectangles # Compute the center points of the rectangles centers = np.array( [ [ (rect[0] + rect[2]) / 2, # Center x-coordinate (rect[1] + rect[3]) / 2, # Center y-coordinate ] for rect in rectangles ] ) # Perform KMeans clustering on the center points to group them into N clusters kmeans = KMeans(n_clusters=N, random_state=0, n_init="auto") labels = kmeans.fit_predict(centers) # Initialize a list to store the combined rectangles combined_rectangles = [] # For each cluster, compute the minimal bounding rectangle that covers all rectangles in the cluster for i in range(N): # Get the indices of rectangles that belong to cluster i indices = np.where(labels == i)[0] if len(indices) == 0: # If no rectangles in this cluster, skip it continue # Extract the rectangles in cluster i cluster_rects = np.array([rectangles[idx] for idx in indices]) # Compute the minimal x1, y1 (top-left corner) and maximal x2, y2 (bottom-right corner) x1_min = np.min(cluster_rects[:, 0]) y1_min = np.min(cluster_rects[:, 1]) x2_max = np.max(cluster_rects[:, 2]) y2_max = np.max(cluster_rects[:, 3]) # Append the combined rectangle to the list combined_rectangles.append([x1_min, y1_min, x2_max, y2_max]) return combined_rectangles # Ensure that the inputs are numpy arrays for efficient computation cells_det_results = np.array(cells_det_results) cells_det_scores = np.array(cells_det_scores) ocr_det_results = np.array(ocr_det_results) more_cells_flag = False if len(cells_det_results) == html_pred_boxes_nums: return cells_det_results # Step 1: If cells_det_results has more rectangles than html_pred_boxes_nums elif len(cells_det_results) > html_pred_boxes_nums: more_cells_flag = True # Select the indices of the top html_pred_boxes_nums scores top_indices = np.argsort(-cells_det_scores)[:html_pred_boxes_nums] # Adjust the corresponding rectangles cells_det_results = cells_det_results[top_indices].tolist() # Threshold for IoU iou_threshold = 0.6 # List to store ocr_miss_boxes ocr_miss_boxes = [] # For each rectangle in ocr_det_results for ocr_rect in ocr_det_results: merge_ocr_box_iou = [] # Flag to indicate if ocr_rect has IoU >= threshold with any cell_rect has_large_iou = False # For each rectangle in cells_det_results for cell_rect in cells_det_results: # Compute IoU iou = compute_iou(ocr_rect, cell_rect) if iou > 0: merge_ocr_box_iou.append(iou) if (iou >= iou_threshold) or (sum(merge_ocr_box_iou) >= iou_threshold): has_large_iou = True break if not has_large_iou: ocr_miss_boxes.append(ocr_rect) # If no ocr_miss_boxes, return cells_det_results if len(ocr_miss_boxes) == 0: final_results = ( cells_det_results if more_cells_flag == True else cells_det_results.tolist() ) else: if more_cells_flag == True: final_results = combine_rectangles( cells_det_results + ocr_miss_boxes, html_pred_boxes_nums ) else: # Need to combine ocr_miss_boxes into N rectangles N = html_pred_boxes_nums - len(cells_det_results) # Combine ocr_miss_boxes into N rectangles ocr_supp_boxes = combine_rectangles(ocr_miss_boxes, N) # Combine cells_det_results and ocr_supp_boxes final_results = np.concatenate( (cells_det_results, ocr_supp_boxes), axis=0 ).tolist() if len(final_results) <= 0.6 * html_pred_boxes_nums: final_results = combine_rectangles(ocr_det_results, html_pred_boxes_nums) return final_results def split_ocr_bboxes_by_table_cells( self, cells_det_results, overall_ocr_res, ori_img, k=2 ): """ Split OCR bounding boxes based on table cell boundaries when they span multiple cells horizontally. Args: cells_det_results (list): List of cell bounding boxes in format [x1, y1, x2, y2] overall_ocr_res (dict): Dictionary containing OCR results with keys: - 'rec_boxes': OCR bounding boxes (will be converted to list) - 'rec_texts': OCR recognized texts ori_img (np.array): Original input image array k (int): Threshold for determining when to split (minimum number of cells spanned) Returns: dict: Modified overall_ocr_res with split boxes and texts """ def calculate_iou(box1, box2): """ Calculate Intersection over Union (IoU) between two bounding boxes. Args: box1 (list): [x1, y1, x2, y2] box2 (list): [x1, y1, x2, y2] Returns: float: IoU value """ # Determine intersection coordinates x_left = max(box1[0], box2[0]) y_top = max(box1[1], box2[1]) x_right = min(box1[2], box2[2]) y_bottom = min(box1[3], box2[3]) if x_right < x_left or y_bottom < y_top: return 0.0 # Calculate areas intersection_area = (x_right - x_left) * (y_bottom - y_top) box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) # return intersection_area / float(box1_area + box2_area - intersection_area) return intersection_area / box2_area def get_overlapping_cells(ocr_box, cells): """ Find cells that overlap significantly with the OCR box (IoU > 0.5). Args: ocr_box (list): OCR bounding box [x1, y1, x2, y2] cells (list): List of cell bounding boxes Returns: list: Indices of overlapping cells, sorted by x-coordinate """ overlapping = [] for idx, cell in enumerate(cells): if calculate_iou(ocr_box, cell) > 0.5: overlapping.append(idx) # Sort overlapping cells by their x-coordinate (left to right) overlapping.sort(key=lambda i: cells[i][0]) return overlapping def split_box_by_cells(ocr_box, cell_indices, cells): """ Split OCR box vertically at cell boundaries. Args: ocr_box (list): Original OCR box [x1, y1, x2, y2] cell_indices (list): Indices of cells to split by cells (list): All cell bounding boxes Returns: list: List of split boxes """ if not cell_indices: return [ocr_box] split_boxes = [] cells_to_split = [cells[i] for i in cell_indices] if ocr_box[0] < cells_to_split[0][0]: split_boxes.append( [ocr_box[0], ocr_box[1], cells_to_split[0][0], ocr_box[3]] ) for i in range(len(cells_to_split)): current_cell = cells_to_split[i] split_boxes.append( [ max(ocr_box[0], current_cell[0]), ocr_box[1], min(ocr_box[2], current_cell[2]), ocr_box[3], ] ) if i < len(cells_to_split) - 1: next_cell = cells_to_split[i + 1] if current_cell[2] < next_cell[0]: split_boxes.append( [current_cell[2], ocr_box[1], next_cell[0], ocr_box[3]] ) last_cell = cells_to_split[-1] if last_cell[2] < ocr_box[2]: split_boxes.append([last_cell[2], ocr_box[1], ocr_box[2], ocr_box[3]]) unique_boxes = [] seen = set() for box in split_boxes: box_tuple = tuple(box) if box_tuple not in seen: seen.add(box_tuple) unique_boxes.append(box) return unique_boxes # Convert OCR boxes to list if needed if hasattr(overall_ocr_res["rec_boxes"], "tolist"): ocr_det_results = overall_ocr_res["rec_boxes"].tolist() else: ocr_det_results = overall_ocr_res["rec_boxes"] ocr_texts = overall_ocr_res["rec_texts"] # Make copies to modify new_boxes = [] new_texts = [] # Process each OCR box i = 0 while i < len(ocr_det_results): ocr_box = ocr_det_results[i] text = ocr_texts[i] # Find cells that significantly overlap with this OCR box overlapping_cells = get_overlapping_cells(ocr_box, cells_det_results) # Check if we need to split (spans >= k cells) if len(overlapping_cells) >= k: # Split the box at cell boundaries split_boxes = split_box_by_cells( ocr_box, overlapping_cells, cells_det_results ) # Process each split box split_texts = [] for box in split_boxes: x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3]) if y2 - y1 > 1 and x2 - x1 > 1: ocr_result = list( self.general_ocr_pipeline.text_rec_model( ori_img[y1:y2, x1:x2, :] ) )[0] # Extract the recognized text from the OCR result if "rec_text" in ocr_result: result = ocr_result[ "rec_text" ] # Assumes "rec_texts" contains a single string else: result = "" else: result = "" split_texts.append(result) # Add split boxes and texts to results new_boxes.extend(split_boxes) new_texts.extend(split_texts) else: # Keep original box and text new_boxes.append(ocr_box) new_texts.append(text) i += 1 # Update the results dictionary overall_ocr_res["rec_boxes"] = new_boxes overall_ocr_res["rec_texts"] = new_texts return overall_ocr_res def gen_ocr_with_table_cells(self, ori_img, cells_bboxes): """ Splits OCR bounding boxes by table cells and retrieves text. Args: ori_img (ndarray): The original image from which text regions will be extracted. cells_bboxes (list or ndarray): Detected cell bounding boxes to extract text from. Returns: list: A list containing the recognized texts from each cell. """ # Check if cells_bboxes is a list and convert it if not. if not isinstance(cells_bboxes, list): cells_bboxes = cells_bboxes.tolist() texts_list = [] # Initialize a list to store the recognized texts. # Process each bounding box provided in cells_bboxes. for i in range(len(cells_bboxes)): # Extract and round up the coordinates of the bounding box. x1, y1, x2, y2 = [math.ceil(k) for k in cells_bboxes[i]] # Perform OCR on the defined region of the image and get the recognized text. if y2 - y1 > 1 and x2 - x1 > 1: rec_te = list(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))[0] # Concatenate the texts and append them to the texts_list. texts_list.append("".join(rec_te["rec_texts"])) # Return the list of recognized texts from each cell. return texts_list def map_cells_to_original_image( self, detections, table_angle, img_width, img_height ): """ Map bounding boxes from the rotated image back to the original image. Parameters: - detections: list of numpy arrays, each containing bounding box coordinates [x1, y1, x2, y2] - table_angle: rotation angle in degrees (90, 180, or 270) - width_orig: width of the original image (img1) - height_orig: height of the original image (img1) Returns: - mapped_detections: list of numpy arrays with mapped bounding box coordinates """ mapped_detections = [] for i in range(len(detections)): tbx1, tby1, tbx2, tby2 = ( detections[i][0], detections[i][1], detections[i][2], detections[i][3], ) if table_angle == "270": new_x1, new_y1 = tby1, img_width - tbx2 new_x2, new_y2 = tby2, img_width - tbx1 elif table_angle == "180": new_x1, new_y1 = img_width - tbx2, img_height - tby2 new_x2, new_y2 = img_width - tbx1, img_height - tby1 elif table_angle == "90": new_x1, new_y1 = img_height - tby2, tbx1 new_x2, new_y2 = img_height - tby1, tbx2 new_box = np.array([new_x1, new_y1, new_x2, new_y2]) mapped_detections.append(new_box) return mapped_detections def split_string_by_keywords(self, html_string): """ Split HTML string by keywords. Args: html_string (str): The HTML string. Returns: split_html (list): The list of html keywords. """ keywords = [ "", "", "", "", "", "", "", "", "", 'colspan="2"', 'colspan="3"', 'colspan="4"', 'colspan="5"', 'colspan="6"', 'colspan="7"', 'colspan="8"', 'colspan="9"', 'colspan="10"', 'colspan="11"', 'colspan="12"', 'colspan="13"', 'colspan="14"', 'colspan="15"', 'colspan="16"', 'colspan="17"', 'colspan="18"', 'colspan="19"', 'colspan="20"', 'rowspan="2"', 'rowspan="3"', 'rowspan="4"', 'rowspan="5"', 'rowspan="6"', 'rowspan="7"', 'rowspan="8"', 'rowspan="9"', 'rowspan="10"', 'rowspan="11"', 'rowspan="12"', 'rowspan="13"', 'rowspan="14"', 'rowspan="15"', 'rowspan="16"', 'rowspan="17"', 'rowspan="18"', 'rowspan="19"', 'rowspan="20"', ] regex_pattern = "|".join(re.escape(keyword) for keyword in keywords) split_result = re.split(f"({regex_pattern})", html_string) split_html = [part for part in split_result if part] return split_html def cluster_positions(self, positions, tolerance): if not positions: return [] positions = sorted(set(positions)) clustered = [] current_cluster = [positions[0]] for pos in positions[1:]: if abs(pos - current_cluster[-1]) <= tolerance: current_cluster.append(pos) else: clustered.append(sum(current_cluster) / len(current_cluster)) current_cluster = [pos] clustered.append(sum(current_cluster) / len(current_cluster)) return clustered def trans_cells_det_results_to_html(self, cells_det_results): """ Trans table cells bboxes to HTML. Args: cells_det_results (list): The table cells detection results. Returns: html (list): The list of html keywords. """ tolerance = 5 x_coords = [x for cell in cells_det_results for x in (cell[0], cell[2])] y_coords = [y for cell in cells_det_results for y in (cell[1], cell[3])] x_positions = self.cluster_positions(x_coords, tolerance) y_positions = self.cluster_positions(y_coords, tolerance) x_position_to_index = {x: i for i, x in enumerate(x_positions)} y_position_to_index = {y: i for i, y in enumerate(y_positions)} num_rows = len(y_positions) - 1 num_cols = len(x_positions) - 1 grid = [[None for _ in range(num_cols)] for _ in range(num_rows)] cells_info = [] cell_index = 0 cell_map = {} for index, cell in enumerate(cells_det_results): x1, y1, x2, y2 = cell x1_idx = min( range(len(x_positions)), key=lambda i: abs(x_positions[i] - x1) ) x2_idx = min( range(len(x_positions)), key=lambda i: abs(x_positions[i] - x2) ) y1_idx = min( range(len(y_positions)), key=lambda i: abs(y_positions[i] - y1) ) y2_idx = min( range(len(y_positions)), key=lambda i: abs(y_positions[i] - y2) ) col_start = min(x1_idx, x2_idx) col_end = max(x1_idx, x2_idx) row_start = min(y1_idx, y2_idx) row_end = max(y1_idx, y2_idx) rowspan = row_end - row_start colspan = col_end - col_start if rowspan == 0: rowspan = 1 if colspan == 0: colspan = 1 cells_info.append( { "row_start": row_start, "col_start": col_start, "rowspan": rowspan, "colspan": colspan, "content": "", } ) for r in range(row_start, row_start + rowspan): for c in range(col_start, col_start + colspan): key = (r, c) if key in cell_map: continue else: cell_map[key] = index html = "" for r in range(num_rows): html += "" c = 0 while c < num_cols: key = (r, c) if key in cell_map: cell_index = cell_map[key] cell_info = cells_info[cell_index] if cell_info["row_start"] == r and cell_info["col_start"] == c: rowspan = cell_info["rowspan"] colspan = cell_info["colspan"] rowspan_attr = f' rowspan="{rowspan}"' if rowspan > 1 else "" colspan_attr = f' colspan="{colspan}"' if colspan > 1 else "" content = cell_info["content"] html += f"{content}" c += cell_info["colspan"] else: html += "" c += 1 html += "" html += "
" html = self.split_string_by_keywords(html) return html def predict_single_table_recognition_res( self, image_array: np.ndarray, overall_ocr_res: OCRResult, table_box: list, use_e2e_wired_table_rec_model: bool = False, use_e2e_wireless_table_rec_model: bool = False, use_wired_table_cells_trans_to_html: bool = False, use_wireless_table_cells_trans_to_html: bool = False, use_ocr_results_with_table_cells: bool = True, flag_find_nei_text: bool = True, ) -> SingleTableRecognitionResult: """ Predict table recognition results from an image array, layout detection results, and OCR results. Args: image_array (np.ndarray): The input image represented as a numpy array. overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline. The overall OCR results containing text recognition information. table_box (list): The table box coordinates. use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model. use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model. use_wired_table_cells_trans_to_html (bool): Whether to use wired table cells trans to HTML. use_wireless_table_cells_trans_to_html (bool): Whether to use wireless table cells trans to HTML. use_ocr_results_with_table_cells (bool): Whether to use OCR results processed by table cells. flag_find_nei_text (bool): Whether to find neighboring text. Returns: SingleTableRecognitionResult: single table recognition result. """ table_cls_pred = list(self.table_cls_model(image_array))[0] table_cls_result = self.extract_results(table_cls_pred, "cls") use_e2e_model = False cells_trans_to_html = False if table_cls_result == "wired_table": if use_wired_table_cells_trans_to_html == True: cells_trans_to_html = True else: table_structure_pred = list(self.wired_table_rec_model(image_array))[0] if use_e2e_wired_table_rec_model == True: use_e2e_model = True if cells_trans_to_html == True: table_structure_pred = list( self.wired_table_rec_model(image_array) )[0] else: table_cells_pred = list( self.wired_table_cells_detection_model(image_array, threshold=0.3) )[ 0 ] # Setting the threshold to 0.3 can improve the accuracy of table cells detection. # If you really want more or fewer table cells detection boxes, the threshold can be adjusted. elif table_cls_result == "wireless_table": if use_wireless_table_cells_trans_to_html == True: cells_trans_to_html = True else: table_structure_pred = list(self.wireless_table_rec_model(image_array))[ 0 ] if use_e2e_wireless_table_rec_model == True: use_e2e_model = True if cells_trans_to_html == True: table_structure_pred = list( self.wireless_table_rec_model(image_array) )[0] else: table_cells_pred = list( self.wireless_table_cells_detection_model( image_array, threshold=0.3 ) )[ 0 ] # Setting the threshold to 0.3 can improve the accuracy of table cells detection. # If you really want more or fewer table cells detection boxes, the threshold can be adjusted. if use_e2e_model == False: table_cells_result, table_cells_score = self.extract_results( table_cells_pred, "det" ) table_cells_result, table_cells_score = self.cells_det_results_nms( table_cells_result, table_cells_score ) if cells_trans_to_html == True: table_structure_result = self.trans_cells_det_results_to_html( table_cells_result ) else: table_structure_result = self.extract_results( table_structure_pred, "table_stru" ) ocr_det_boxes = self.get_region_ocr_det_boxes( overall_ocr_res["rec_boxes"].tolist(), table_box ) table_cells_result = self.cells_det_results_reprocessing( table_cells_result, table_cells_score, ocr_det_boxes, len(table_structure_pred["bbox"]), ) if use_ocr_results_with_table_cells == True: if self.cells_split_ocr == True: table_box_copy = np.array([table_box]) table_ocr_pred = get_sub_regions_ocr_res( overall_ocr_res, table_box_copy ) table_ocr_pred = self.split_ocr_bboxes_by_table_cells( table_cells_result, table_ocr_pred, image_array ) cells_texts_list = [] else: cells_texts_list = self.gen_ocr_with_table_cells( image_array, table_cells_result ) table_ocr_pred = {} else: table_ocr_pred = {} cells_texts_list = [] single_table_recognition_res = get_table_recognition_res( table_box, table_structure_result, table_cells_result, overall_ocr_res, table_ocr_pred, cells_texts_list, use_ocr_results_with_table_cells, self.cells_split_ocr, ) else: cells_texts_list = [] use_ocr_results_with_table_cells = False table_cells_result_e2e = table_structure_pred["bbox"] table_cells_result_e2e = [ [rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result_e2e ] if cells_trans_to_html == True: table_structure_pred["structure"] = ( self.trans_cells_det_results_to_html(table_cells_result_e2e) ) single_table_recognition_res = get_table_recognition_res_e2e( table_box, table_structure_pred, overall_ocr_res, cells_texts_list, use_ocr_results_with_table_cells, ) neighbor_text = "" if flag_find_nei_text: match_idx_list = get_neighbor_boxes_idx( overall_ocr_res["rec_boxes"], table_box ) if len(match_idx_list) > 0: for idx in match_idx_list: neighbor_text += overall_ocr_res["rec_texts"][idx] + "; " single_table_recognition_res["neighbor_texts"] = neighbor_text return single_table_recognition_res def predict( self, input: Union[str, List[str], np.ndarray, List[np.ndarray]], use_doc_orientation_classify: Optional[bool] = None, use_doc_unwarping: Optional[bool] = None, use_layout_detection: Optional[bool] = None, use_ocr_model: Optional[bool] = None, overall_ocr_res: Optional[OCRResult] = None, layout_det_res: Optional[DetResult] = None, text_det_limit_side_len: Optional[int] = None, text_det_limit_type: Optional[str] = None, text_det_thresh: Optional[float] = None, text_det_box_thresh: Optional[float] = None, text_det_unclip_ratio: Optional[float] = None, text_rec_score_thresh: Optional[float] = None, use_e2e_wired_table_rec_model: bool = False, use_e2e_wireless_table_rec_model: bool = False, use_wired_table_cells_trans_to_html: bool = False, use_wireless_table_cells_trans_to_html: bool = False, use_table_orientation_classify: bool = True, use_ocr_results_with_table_cells: bool = True, **kwargs, ) -> TableRecognitionResult: """ This function predicts the layout parsing result for the given input. Args: input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) of pdf(s) to be processed. use_layout_detection (bool): Whether to use layout detection. use_doc_orientation_classify (bool): Whether to use document orientation classification. use_doc_unwarping (bool): Whether to use document unwarping. overall_ocr_res (OCRResult): The overall OCR result with convert_points_to_boxes information. It will be used if it is not None and use_ocr_model is False. layout_det_res (DetResult): The layout detection result. It will be used if it is not None and use_layout_detection is False. use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model. use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model. use_wired_table_cells_trans_to_html (bool): Whether to use wired table cells trans to HTML. use_wireless_table_cells_trans_to_html (bool): Whether to use wireless table cells trans to HTML. use_table_orientation_classify (bool): Whether to use table orientation classification. use_ocr_results_with_table_cells (bool): Whether to use OCR results processed by table cells. **kwargs: Additional keyword arguments. Returns: TableRecognitionResult: The predicted table recognition result. """ self.cells_split_ocr = True if use_table_orientation_classify == True and ( self.table_orientation_classify_model is None ): assert self.table_orientation_classify_config != None self.table_orientation_classify_model = self.create_model( self.table_orientation_classify_config ) model_settings = self.get_model_settings( use_doc_orientation_classify, use_doc_unwarping, use_layout_detection, use_ocr_model, ) if not self.check_model_settings_valid( model_settings, overall_ocr_res, layout_det_res ): yield {"error": "the input params for model settings are invalid!"} for img_id, batch_data in enumerate(self.batch_sampler(input)): image_array = self.img_reader(batch_data.instances)[0] if model_settings["use_doc_preprocessor"]: doc_preprocessor_res = list( self.doc_preprocessor_pipeline( image_array, use_doc_orientation_classify=use_doc_orientation_classify, use_doc_unwarping=use_doc_unwarping, ) )[0] else: doc_preprocessor_res = {"output_img": image_array} doc_preprocessor_image = doc_preprocessor_res["output_img"] if model_settings["use_ocr_model"]: overall_ocr_res = list( self.general_ocr_pipeline( doc_preprocessor_image, text_det_limit_side_len=text_det_limit_side_len, text_det_limit_type=text_det_limit_type, text_det_thresh=text_det_thresh, text_det_box_thresh=text_det_box_thresh, text_det_unclip_ratio=text_det_unclip_ratio, text_rec_score_thresh=text_rec_score_thresh, ) )[0] elif self.general_ocr_pipeline is None and ( ( use_ocr_results_with_table_cells == True and self.cells_split_ocr == False ) or use_table_orientation_classify == True ): assert self.general_ocr_config_bak != None self.general_ocr_pipeline = self.create_pipeline( self.general_ocr_config_bak ) if use_table_orientation_classify == False: table_angle = "0" table_res_list = [] table_region_id = 1 if not model_settings["use_layout_detection"] and layout_det_res is None: img_height, img_width = doc_preprocessor_image.shape[:2] table_box = [0, 0, img_width - 1, img_height - 1] if use_table_orientation_classify == True: table_angle = list( self.table_orientation_classify_model(doc_preprocessor_image) )[0]["label_names"][0] if table_angle == "90": doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=1) elif table_angle == "180": doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=2) elif table_angle == "270": doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=3) if table_angle in ["90", "180", "270"]: overall_ocr_res = list( self.general_ocr_pipeline( doc_preprocessor_image, text_det_limit_side_len=text_det_limit_side_len, text_det_limit_type=text_det_limit_type, text_det_thresh=text_det_thresh, text_det_box_thresh=text_det_box_thresh, text_det_unclip_ratio=text_det_unclip_ratio, text_rec_score_thresh=text_rec_score_thresh, ) )[0] tbx1, tby1, tbx2, tby2 = ( table_box[0], table_box[1], table_box[2], table_box[3], ) if table_angle == "90": new_x1, new_y1 = tby1, img_width - tbx2 new_x2, new_y2 = tby2, img_width - tbx1 elif table_angle == "180": new_x1, new_y1 = img_width - tbx2, img_height - tby2 new_x2, new_y2 = img_width - tbx1, img_height - tby1 elif table_angle == "270": new_x1, new_y1 = img_height - tby2, tbx1 new_x2, new_y2 = img_height - tby1, tbx2 table_box = [new_x1, new_y1, new_x2, new_y2] single_table_rec_res = self.predict_single_table_recognition_res( doc_preprocessor_image, overall_ocr_res, table_box, use_e2e_wired_table_rec_model, use_e2e_wireless_table_rec_model, use_wired_table_cells_trans_to_html, use_wireless_table_cells_trans_to_html, use_ocr_results_with_table_cells, flag_find_nei_text=False, ) single_table_rec_res["table_region_id"] = table_region_id if use_table_orientation_classify == True and table_angle != "0": img_height, img_width = doc_preprocessor_image.shape[:2] single_table_rec_res["cell_box_list"] = ( self.map_cells_to_original_image( single_table_rec_res["cell_box_list"], table_angle, img_width, img_height, ) ) table_res_list.append(single_table_rec_res) table_region_id += 1 else: if model_settings["use_layout_detection"]: layout_det_res = list( self.layout_det_model(doc_preprocessor_image) )[0] img_height, img_width = doc_preprocessor_image.shape[:2] for box_info in layout_det_res["boxes"]: if box_info["label"].lower() in ["table"]: crop_img_info = self._crop_by_boxes( doc_preprocessor_image, [box_info] ) crop_img_info = crop_img_info[0] table_box = crop_img_info["box"] if use_table_orientation_classify == True: doc_preprocessor_image_copy = doc_preprocessor_image.copy() table_angle = list( self.table_orientation_classify_model( crop_img_info["img"] ) )[0]["label_names"][0] if table_angle == "90": crop_img_info["img"] = np.rot90(crop_img_info["img"], k=1) doc_preprocessor_image_copy = np.rot90( doc_preprocessor_image_copy, k=1 ) elif table_angle == "180": crop_img_info["img"] = np.rot90(crop_img_info["img"], k=2) doc_preprocessor_image_copy = np.rot90( doc_preprocessor_image_copy, k=2 ) elif table_angle == "270": crop_img_info["img"] = np.rot90(crop_img_info["img"], k=3) doc_preprocessor_image_copy = np.rot90( doc_preprocessor_image_copy, k=3 ) if table_angle in ["90", "180", "270"]: overall_ocr_res = list( self.general_ocr_pipeline( doc_preprocessor_image_copy, text_det_limit_side_len=text_det_limit_side_len, text_det_limit_type=text_det_limit_type, text_det_thresh=text_det_thresh, text_det_box_thresh=text_det_box_thresh, text_det_unclip_ratio=text_det_unclip_ratio, text_rec_score_thresh=text_rec_score_thresh, ) )[0] tbx1, tby1, tbx2, tby2 = ( table_box[0], table_box[1], table_box[2], table_box[3], ) if table_angle == "90": new_x1, new_y1 = tby1, img_width - tbx2 new_x2, new_y2 = tby2, img_width - tbx1 elif table_angle == "180": new_x1, new_y1 = img_width - tbx2, img_height - tby2 new_x2, new_y2 = img_width - tbx1, img_height - tby1 elif table_angle == "270": new_x1, new_y1 = img_height - tby2, tbx1 new_x2, new_y2 = img_height - tby1, tbx2 table_box = [new_x1, new_y1, new_x2, new_y2] single_table_rec_res = ( self.predict_single_table_recognition_res( crop_img_info["img"], overall_ocr_res, table_box, use_e2e_wired_table_rec_model, use_e2e_wireless_table_rec_model, use_wired_table_cells_trans_to_html, use_wireless_table_cells_trans_to_html, use_ocr_results_with_table_cells, ) ) single_table_rec_res["table_region_id"] = table_region_id if ( use_table_orientation_classify == True and table_angle != "0" ): img_height_copy, img_width_copy = ( doc_preprocessor_image_copy.shape[:2] ) single_table_rec_res["cell_box_list"] = ( self.map_cells_to_original_image( single_table_rec_res["cell_box_list"], table_angle, img_width_copy, img_height_copy, ) ) table_res_list.append(single_table_rec_res) table_region_id += 1 single_img_res = { "input_path": batch_data.input_paths[0], "page_index": batch_data.page_indexes[0], "doc_preprocessor_res": doc_preprocessor_res, "layout_det_res": layout_det_res, "overall_ocr_res": overall_ocr_res, "table_res_list": table_res_list, "model_settings": model_settings, } yield TableRecognitionResult(single_img_res) @pipeline_requires_extra("ocr") class TableRecognitionPipelineV2(AutoParallelImageSimpleInferencePipeline): entities = ["table_recognition_v2"] @property def _pipeline_cls(self): return _TableRecognitionPipelineV2 def _get_batch_size(self, config): return 1