| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- import re
- from typing import Any, Dict, List, Optional, Tuple, Union
- import numpy as np
- from ....utils import logging
- from ....utils.deps import (
- function_requires_deps,
- is_dep_available,
- pipeline_requires_extra,
- )
- from ...common.batch_sampler import ImageBatchSampler
- from ...common.reader import ReadImage
- from ...models.object_detection.result import DetResult
- from ...utils.benchmark import benchmark
- from ...utils.hpi import HPIConfig
- from ...utils.pp_option import PaddlePredictorOption
- from .._parallel import AutoParallelImageSimpleInferencePipeline
- from ..base import BasePipeline
- from ..components import CropByBoxes
- from ..doc_preprocessor.result import DocPreprocessorResult
- from ..layout_parsing.utils import get_sub_regions_ocr_res
- from ..ocr.result import OCRResult
- from .result import SingleTableRecognitionResult, TableRecognitionResult
- from .table_recognition_post_processing import (
- get_table_recognition_res as get_table_recognition_res_e2e,
- )
- from .table_recognition_post_processing_v2 import get_table_recognition_res
- from .utils import get_neighbor_boxes_idx
- if is_dep_available("scikit-learn"):
- from sklearn.cluster import KMeans
- @benchmark.time_methods
- class _TableRecognitionPipelineV2(BasePipeline):
- """Table Recognition Pipeline"""
- def __init__(
- self,
- config: Dict,
- device: str = None,
- pp_option: PaddlePredictorOption = None,
- use_hpip: bool = False,
- hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
- ) -> None:
- """Initializes the layout parsing pipeline.
- Args:
- config (Dict): Configuration dictionary containing various settings.
- device (str, optional): Device to run the predictions on. Defaults to None.
- pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
- use_hpip (bool, optional): Whether to use the high-performance
- inference plugin (HPIP) by default. Defaults to False.
- hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
- The default high-performance inference configuration dictionary.
- Defaults to None.
- """
- super().__init__(
- device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_config=hpi_config
- )
- self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
- if self.use_doc_preprocessor:
- doc_preprocessor_config = config.get("SubPipelines", {}).get(
- "DocPreprocessor",
- {
- "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
- },
- )
- self.doc_preprocessor_pipeline = self.create_pipeline(
- doc_preprocessor_config
- )
- self.use_layout_detection = config.get("use_layout_detection", True)
- if self.use_layout_detection:
- layout_det_config = config.get("SubModules", {}).get(
- "LayoutDetection",
- {"model_config_error": "config error for layout_det_model!"},
- )
- self.layout_det_model = self.create_model(layout_det_config)
- table_cls_config = config.get("SubModules", {}).get(
- "TableClassification",
- {"model_config_error": "config error for table_classification_model!"},
- )
- self.table_cls_model = self.create_model(table_cls_config)
- wired_table_rec_config = config.get("SubModules", {}).get(
- "WiredTableStructureRecognition",
- {"model_config_error": "config error for wired_table_structure_model!"},
- )
- self.wired_table_rec_model = self.create_model(wired_table_rec_config)
- wireless_table_rec_config = config.get("SubModules", {}).get(
- "WirelessTableStructureRecognition",
- {"model_config_error": "config error for wireless_table_structure_model!"},
- )
- self.wireless_table_rec_model = self.create_model(wireless_table_rec_config)
- wired_table_cells_det_config = config.get("SubModules", {}).get(
- "WiredTableCellsDetection",
- {
- "model_config_error": "config error for wired_table_cells_detection_model!"
- },
- )
- self.wired_table_cells_detection_model = self.create_model(
- wired_table_cells_det_config
- )
- wireless_table_cells_det_config = config.get("SubModules", {}).get(
- "WirelessTableCellsDetection",
- {
- "model_config_error": "config error for wireless_table_cells_detection_model!"
- },
- )
- self.wireless_table_cells_detection_model = self.create_model(
- wireless_table_cells_det_config
- )
- self.use_ocr_model = config.get("use_ocr_model", True)
- self.general_ocr_pipeline = None
- if self.use_ocr_model:
- general_ocr_config = config.get("SubPipelines", {}).get(
- "GeneralOCR",
- {"pipeline_config_error": "config error for general_ocr_pipeline!"},
- )
- self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
- else:
- self.general_ocr_config_bak = config.get("SubPipelines", {}).get(
- "GeneralOCR", None
- )
- self.table_orientation_classify_model = None
- self.table_orientation_classify_config = config.get("SubModules", {}).get(
- "TableOrientationClassify", None
- )
- self._crop_by_boxes = CropByBoxes()
- self.batch_sampler = ImageBatchSampler(batch_size=1)
- self.img_reader = ReadImage(format="BGR")
- def get_model_settings(
- self,
- use_doc_orientation_classify: Optional[bool],
- use_doc_unwarping: Optional[bool],
- use_layout_detection: Optional[bool],
- use_ocr_model: Optional[bool],
- ) -> dict:
- """
- Get the model settings based on the provided parameters or default values.
- Args:
- use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
- use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
- use_layout_detection (Optional[bool]): Whether to use layout detection.
- use_ocr_model (Optional[bool]): Whether to use OCR model.
- Returns:
- dict: A dictionary containing the model settings.
- """
- if use_doc_orientation_classify is None and use_doc_unwarping is None:
- use_doc_preprocessor = self.use_doc_preprocessor
- else:
- if use_doc_orientation_classify is True or use_doc_unwarping is True:
- use_doc_preprocessor = True
- else:
- use_doc_preprocessor = False
- if use_layout_detection is None:
- use_layout_detection = self.use_layout_detection
- if use_ocr_model is None:
- use_ocr_model = self.use_ocr_model
- return dict(
- use_doc_preprocessor=use_doc_preprocessor,
- use_layout_detection=use_layout_detection,
- use_ocr_model=use_ocr_model,
- )
- def check_model_settings_valid(
- self,
- model_settings: Dict,
- overall_ocr_res: OCRResult,
- layout_det_res: DetResult,
- ) -> bool:
- """
- Check if the input parameters are valid based on the initialized models.
- Args:
- model_settings (Dict): A dictionary containing input parameters.
- overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
- The overall OCR result with convert_points_to_boxes information.
- layout_det_res (DetResult): The layout detection result.
- Returns:
- bool: True if all required models are initialized according to input parameters, False otherwise.
- """
- if model_settings["use_doc_preprocessor"] and not self.use_doc_preprocessor:
- logging.error(
- "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
- )
- return False
- if model_settings["use_layout_detection"]:
- if layout_det_res is not None:
- logging.error(
- "The layout detection model has already been initialized, please set use_layout_detection=False"
- )
- return False
- if not self.use_layout_detection:
- logging.error(
- "Set use_layout_detection, but the models for layout detection are not initialized."
- )
- return False
- if model_settings["use_ocr_model"]:
- if overall_ocr_res is not None:
- logging.error(
- "The OCR models have already been initialized, please set use_ocr_model=False"
- )
- return False
- if not self.use_ocr_model:
- logging.error(
- "Set use_ocr_model, but the models for OCR are not initialized."
- )
- return False
- else:
- if overall_ocr_res is None:
- logging.error("Set use_ocr_model=False, but no OCR results were found.")
- return False
- return True
- def predict_doc_preprocessor_res(
- self, image_array: np.ndarray, input_params: dict
- ) -> Tuple[DocPreprocessorResult, np.ndarray]:
- """
- Preprocess the document image based on input parameters.
- Args:
- image_array (np.ndarray): The input image array.
- input_params (dict): Dictionary containing preprocessing parameters.
- Returns:
- tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
- result dictionary and the processed image array.
- """
- if input_params["use_doc_preprocessor"]:
- use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
- use_doc_unwarping = input_params["use_doc_unwarping"]
- doc_preprocessor_res = list(
- self.doc_preprocessor_pipeline(
- image_array,
- use_doc_orientation_classify=use_doc_orientation_classify,
- use_doc_unwarping=use_doc_unwarping,
- )
- )[0]
- doc_preprocessor_image = doc_preprocessor_res["output_img"]
- else:
- doc_preprocessor_res = {}
- doc_preprocessor_image = image_array
- return doc_preprocessor_res, doc_preprocessor_image
- def extract_results(self, pred, task):
- if task == "cls":
- return pred["label_names"][np.argmax(pred["scores"])]
- elif task == "det":
- threshold = 0.0
- result = []
- cell_score = []
- if "boxes" in pred and isinstance(pred["boxes"], list):
- for box in pred["boxes"]:
- if isinstance(box, dict) and "score" in box and "coordinate" in box:
- score = box["score"]
- coordinate = box["coordinate"]
- if isinstance(score, float) and score > threshold:
- result.append(coordinate)
- cell_score.append(score)
- return result, cell_score
- elif task == "table_stru":
- return pred["structure"]
- else:
- return None
- def cells_det_results_nms(
- self, cells_det_results, cells_det_scores, cells_det_threshold=0.3
- ):
- """
- Apply Non-Maximum Suppression (NMS) on detection results to remove redundant overlapping bounding boxes.
- Args:
- cells_det_results (list): List of bounding boxes, each box is in format [x1, y1, x2, y2].
- cells_det_scores (list): List of confidence scores corresponding to the bounding boxes.
- cells_det_threshold (float): IoU threshold for suppression. Boxes with IoU greater than this threshold
- will be suppressed. Default is 0.5.
- Returns:
- Tuple[list, list]: A tuple containing the list of bounding boxes and confidence scores after NMS,
- while maintaining one-to-one correspondence.
- """
- # Convert lists to numpy arrays for efficient computation
- boxes = np.array(cells_det_results)
- scores = np.array(cells_det_scores)
- # Initialize list for picked indices
- picked_indices = []
- # Get coordinates of bounding boxes
- x1 = boxes[:, 0]
- y1 = boxes[:, 1]
- x2 = boxes[:, 2]
- y2 = boxes[:, 3]
- # Compute the area of the bounding boxes
- areas = (x2 - x1) * (y2 - y1)
- # Sort the bounding boxes by the confidence scores in descending order
- order = scores.argsort()[::-1]
- # Process the boxes
- while order.size > 0:
- # Index of the current highest score box
- i = order[0]
- picked_indices.append(i)
- # Compute IoU between the highest score box and the rest
- xx1 = np.maximum(x1[i], x1[order[1:]])
- yy1 = np.maximum(y1[i], y1[order[1:]])
- xx2 = np.minimum(x2[i], x2[order[1:]])
- yy2 = np.minimum(y2[i], y2[order[1:]])
- # Compute the width and height of the overlapping area
- w = np.maximum(0.0, xx2 - xx1)
- h = np.maximum(0.0, yy2 - yy1)
- # Compute the ratio of overlap (IoU)
- inter = w * h
- ovr = inter / (areas[i] + areas[order[1:]] - inter)
- # Indices of boxes with IoU less than threshold
- inds = np.where(ovr <= cells_det_threshold)[0]
- # Update order, only keep boxes with IoU less than threshold
- order = order[
- inds + 1
- ] # inds shifted by 1 because order[0] is the current box
- # Select the boxes and scores based on picked indices
- final_boxes = boxes[picked_indices].tolist()
- final_scores = scores[picked_indices].tolist()
- return final_boxes, final_scores
- def get_region_ocr_det_boxes(self, ocr_det_boxes, table_box):
- """Adjust the coordinates of ocr_det_boxes that are fully inside table_box relative to table_box.
- Args:
- ocr_det_boxes (list of list): List of bounding boxes [x1, y1, x2, y2] in the original image.
- table_box (list): Bounding box [x1, y1, x2, y2] of the target region in the original image.
- Returns:
- list of list: List of adjusted bounding boxes relative to table_box, for boxes fully inside table_box.
- """
- tol = 0
- # Extract coordinates from table_box
- x_min_t, y_min_t, x_max_t, y_max_t = table_box
- adjusted_boxes = []
- for box in ocr_det_boxes:
- x_min_b, y_min_b, x_max_b, y_max_b = box
- # Check if the box is fully inside table_box
- if (
- x_min_b + tol >= x_min_t
- and y_min_b + tol >= y_min_t
- and x_max_b - tol <= x_max_t
- and y_max_b - tol <= y_max_t
- ):
- # Adjust the coordinates to be relative to table_box
- adjusted_box = [
- x_min_b - x_min_t, # Adjust x1
- y_min_b - y_min_t, # Adjust y1
- x_max_b - x_min_t, # Adjust x2
- y_max_b - y_min_t, # Adjust y2
- ]
- adjusted_boxes.append(adjusted_box)
- # Discard boxes not fully inside table_box
- return adjusted_boxes
- def cells_det_results_reprocessing(
- self, cells_det_results, cells_det_scores, ocr_det_results, html_pred_boxes_nums
- ):
- """
- Process and filter cells_det_results based on ocr_det_results and html_pred_boxes_nums.
- Args:
- cells_det_results (List[List[float]]): List of detected cell rectangles [[x1, y1, x2, y2], ...].
- cells_det_scores (List[float]): List of confidence scores for each rectangle in cells_det_results.
- ocr_det_results (List[List[float]]): List of OCR detected rectangles [[x1, y1, x2, y2], ...].
- html_pred_boxes_nums (int): The desired number of rectangles in the final output.
- Returns:
- List[List[float]]: The processed list of rectangles.
- """
- # Function to compute IoU between two rectangles
- def compute_iou(box1, box2):
- """
- Compute the Intersection over Union (IoU) between two rectangles.
- Args:
- box1 (array-like): [x1, y1, x2, y2] of the first rectangle.
- box2 (array-like): [x1, y1, x2, y2] of the second rectangle.
- Returns:
- float: The IoU between the two rectangles.
- """
- # Determine the coordinates of the intersection rectangle
- x_left = max(box1[0], box2[0])
- y_top = max(box1[1], box2[1])
- x_right = min(box1[2], box2[2])
- y_bottom = min(box1[3], box2[3])
- if x_right <= x_left or y_bottom <= y_top:
- return 0.0
- # Calculate the area of intersection rectangle
- intersection_area = (x_right - x_left) * (y_bottom - y_top)
- # Calculate the area of both rectangles
- box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
- (box2[2] - box2[0]) * (box2[3] - box2[1])
- # Calculate the IoU
- iou = intersection_area / float(box1_area)
- return iou
- # Function to combine rectangles into N rectangles
- @function_requires_deps("scikit-learn")
- def combine_rectangles(rectangles, N):
- """
- Combine rectangles into N rectangles based on geometric proximity.
- Args:
- rectangles (list of list of int): A list of rectangles, each represented by [x1, y1, x2, y2].
- N (int): The desired number of combined rectangles.
- Returns:
- list of list of int: A list of N combined rectangles.
- """
- # Number of input rectangles
- num_rects = len(rectangles)
- # If N is greater than or equal to the number of rectangles, return the original rectangles
- if N >= num_rects:
- return rectangles
- # Compute the center points of the rectangles
- centers = np.array(
- [
- [
- (rect[0] + rect[2]) / 2, # Center x-coordinate
- (rect[1] + rect[3]) / 2, # Center y-coordinate
- ]
- for rect in rectangles
- ]
- )
- # Perform KMeans clustering on the center points to group them into N clusters
- kmeans = KMeans(n_clusters=N, random_state=0, n_init="auto")
- labels = kmeans.fit_predict(centers)
- # Initialize a list to store the combined rectangles
- combined_rectangles = []
- # For each cluster, compute the minimal bounding rectangle that covers all rectangles in the cluster
- for i in range(N):
- # Get the indices of rectangles that belong to cluster i
- indices = np.where(labels == i)[0]
- if len(indices) == 0:
- # If no rectangles in this cluster, skip it
- continue
- # Extract the rectangles in cluster i
- cluster_rects = np.array([rectangles[idx] for idx in indices])
- # Compute the minimal x1, y1 (top-left corner) and maximal x2, y2 (bottom-right corner)
- x1_min = np.min(cluster_rects[:, 0])
- y1_min = np.min(cluster_rects[:, 1])
- x2_max = np.max(cluster_rects[:, 2])
- y2_max = np.max(cluster_rects[:, 3])
- # Append the combined rectangle to the list
- combined_rectangles.append([x1_min, y1_min, x2_max, y2_max])
- return combined_rectangles
- # Ensure that the inputs are numpy arrays for efficient computation
- cells_det_results = np.array(cells_det_results)
- cells_det_scores = np.array(cells_det_scores)
- ocr_det_results = np.array(ocr_det_results)
- more_cells_flag = False
- if len(cells_det_results) == html_pred_boxes_nums:
- return cells_det_results
- # Step 1: If cells_det_results has more rectangles than html_pred_boxes_nums
- elif len(cells_det_results) > html_pred_boxes_nums:
- more_cells_flag = True
- # Select the indices of the top html_pred_boxes_nums scores
- top_indices = np.argsort(-cells_det_scores)[:html_pred_boxes_nums]
- # Adjust the corresponding rectangles
- cells_det_results = cells_det_results[top_indices].tolist()
- # Threshold for IoU
- iou_threshold = 0.6
- # List to store ocr_miss_boxes
- ocr_miss_boxes = []
- # For each rectangle in ocr_det_results
- for ocr_rect in ocr_det_results:
- merge_ocr_box_iou = []
- # Flag to indicate if ocr_rect has IoU >= threshold with any cell_rect
- has_large_iou = False
- # For each rectangle in cells_det_results
- for cell_rect in cells_det_results:
- # Compute IoU
- iou = compute_iou(ocr_rect, cell_rect)
- if iou > 0:
- merge_ocr_box_iou.append(iou)
- if (iou >= iou_threshold) or (sum(merge_ocr_box_iou) >= iou_threshold):
- has_large_iou = True
- break
- if not has_large_iou:
- ocr_miss_boxes.append(ocr_rect)
- # If no ocr_miss_boxes, return cells_det_results
- if len(ocr_miss_boxes) == 0:
- final_results = (
- cells_det_results
- if more_cells_flag == True
- else cells_det_results.tolist()
- )
- else:
- if more_cells_flag == True:
- final_results = combine_rectangles(
- cells_det_results + ocr_miss_boxes, html_pred_boxes_nums
- )
- else:
- # Need to combine ocr_miss_boxes into N rectangles
- N = html_pred_boxes_nums - len(cells_det_results)
- # Combine ocr_miss_boxes into N rectangles
- ocr_supp_boxes = combine_rectangles(ocr_miss_boxes, N)
- # Combine cells_det_results and ocr_supp_boxes
- final_results = np.concatenate(
- (cells_det_results, ocr_supp_boxes), axis=0
- ).tolist()
- if len(final_results) <= 0.6 * html_pred_boxes_nums:
- final_results = combine_rectangles(ocr_det_results, html_pred_boxes_nums)
- return final_results
- def split_ocr_bboxes_by_table_cells(
- self, cells_det_results, overall_ocr_res, ori_img, k=2
- ):
- """
- Split OCR bounding boxes based on table cell boundaries when they span multiple cells horizontally.
- Args:
- cells_det_results (list): List of cell bounding boxes in format [x1, y1, x2, y2]
- overall_ocr_res (dict): Dictionary containing OCR results with keys:
- - 'rec_boxes': OCR bounding boxes (will be converted to list)
- - 'rec_texts': OCR recognized texts
- ori_img (np.array): Original input image array
- k (int): Threshold for determining when to split (minimum number of cells spanned)
- Returns:
- dict: Modified overall_ocr_res with split boxes and texts
- """
- def calculate_iou(box1, box2):
- """
- Calculate Intersection over Union (IoU) between two bounding boxes.
- Args:
- box1 (list): [x1, y1, x2, y2]
- box2 (list): [x1, y1, x2, y2]
- Returns:
- float: IoU value
- """
- # Determine intersection coordinates
- x_left = max(box1[0], box2[0])
- y_top = max(box1[1], box2[1])
- x_right = min(box1[2], box2[2])
- y_bottom = min(box1[3], box2[3])
- if x_right < x_left or y_bottom < y_top:
- return 0.0
- # Calculate areas
- intersection_area = (x_right - x_left) * (y_bottom - y_top)
- box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
- box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
- # return intersection_area / float(box1_area + box2_area - intersection_area)
- return intersection_area / box2_area
- def get_overlapping_cells(ocr_box, cells):
- """
- Find cells that overlap significantly with the OCR box (IoU > 0.5).
- Args:
- ocr_box (list): OCR bounding box [x1, y1, x2, y2]
- cells (list): List of cell bounding boxes
- Returns:
- list: Indices of overlapping cells, sorted by x-coordinate
- """
- overlapping = []
- for idx, cell in enumerate(cells):
- if calculate_iou(ocr_box, cell) > 0.5:
- overlapping.append(idx)
- # Sort overlapping cells by their x-coordinate (left to right)
- overlapping.sort(key=lambda i: cells[i][0])
- return overlapping
- def split_box_by_cells(ocr_box, cell_indices, cells):
- """
- Split OCR box vertically at cell boundaries.
- Args:
- ocr_box (list): Original OCR box [x1, y1, x2, y2]
- cell_indices (list): Indices of cells to split by
- cells (list): All cell bounding boxes
- Returns:
- list: List of split boxes
- """
- if not cell_indices:
- return [ocr_box]
- split_boxes = []
- cells_to_split = [cells[i] for i in cell_indices]
- if ocr_box[0] < cells_to_split[0][0]:
- split_boxes.append(
- [ocr_box[0], ocr_box[1], cells_to_split[0][0], ocr_box[3]]
- )
- for i in range(len(cells_to_split)):
- current_cell = cells_to_split[i]
- split_boxes.append(
- [
- max(ocr_box[0], current_cell[0]),
- ocr_box[1],
- min(ocr_box[2], current_cell[2]),
- ocr_box[3],
- ]
- )
- if i < len(cells_to_split) - 1:
- next_cell = cells_to_split[i + 1]
- if current_cell[2] < next_cell[0]:
- split_boxes.append(
- [current_cell[2], ocr_box[1], next_cell[0], ocr_box[3]]
- )
- last_cell = cells_to_split[-1]
- if last_cell[2] < ocr_box[2]:
- split_boxes.append([last_cell[2], ocr_box[1], ocr_box[2], ocr_box[3]])
- unique_boxes = []
- seen = set()
- for box in split_boxes:
- box_tuple = tuple(box)
- if box_tuple not in seen:
- seen.add(box_tuple)
- unique_boxes.append(box)
- return unique_boxes
- # Convert OCR boxes to list if needed
- if hasattr(overall_ocr_res["rec_boxes"], "tolist"):
- ocr_det_results = overall_ocr_res["rec_boxes"].tolist()
- else:
- ocr_det_results = overall_ocr_res["rec_boxes"]
- ocr_texts = overall_ocr_res["rec_texts"]
- # Make copies to modify
- new_boxes = []
- new_texts = []
- # Process each OCR box
- i = 0
- while i < len(ocr_det_results):
- ocr_box = ocr_det_results[i]
- text = ocr_texts[i]
- # Find cells that significantly overlap with this OCR box
- overlapping_cells = get_overlapping_cells(ocr_box, cells_det_results)
- # Check if we need to split (spans >= k cells)
- if len(overlapping_cells) >= k:
- # Split the box at cell boundaries
- split_boxes = split_box_by_cells(
- ocr_box, overlapping_cells, cells_det_results
- )
- # Process each split box
- split_texts = []
- for box in split_boxes:
- x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
- if y2 - y1 > 1 and x2 - x1 > 1:
- ocr_result = list(
- self.general_ocr_pipeline.text_rec_model(
- ori_img[y1:y2, x1:x2, :]
- )
- )[0]
- # Extract the recognized text from the OCR result
- if "rec_text" in ocr_result:
- result = ocr_result[
- "rec_text"
- ] # Assumes "rec_texts" contains a single string
- else:
- result = ""
- else:
- result = ""
- split_texts.append(result)
- # Add split boxes and texts to results
- new_boxes.extend(split_boxes)
- new_texts.extend(split_texts)
- else:
- # Keep original box and text
- new_boxes.append(ocr_box)
- new_texts.append(text)
- i += 1
- # Update the results dictionary
- overall_ocr_res["rec_boxes"] = new_boxes
- overall_ocr_res["rec_texts"] = new_texts
- return overall_ocr_res
- def gen_ocr_with_table_cells(self, ori_img, cells_bboxes):
- """
- Splits OCR bounding boxes by table cells and retrieves text.
- Args:
- ori_img (ndarray): The original image from which text regions will be extracted.
- cells_bboxes (list or ndarray): Detected cell bounding boxes to extract text from.
- Returns:
- list: A list containing the recognized texts from each cell.
- """
- # Check if cells_bboxes is a list and convert it if not.
- if not isinstance(cells_bboxes, list):
- cells_bboxes = cells_bboxes.tolist()
- texts_list = [] # Initialize a list to store the recognized texts.
- # Process each bounding box provided in cells_bboxes.
- for i in range(len(cells_bboxes)):
- # Extract and round up the coordinates of the bounding box.
- x1, y1, x2, y2 = [math.ceil(k) for k in cells_bboxes[i]]
- # Perform OCR on the defined region of the image and get the recognized text.
- if y2 - y1 > 1 and x2 - x1 > 1:
- rec_te = list(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))[0]
- # Concatenate the texts and append them to the texts_list.
- texts_list.append("".join(rec_te["rec_texts"]))
- # Return the list of recognized texts from each cell.
- return texts_list
- def map_cells_to_original_image(
- self, detections, table_angle, img_width, img_height
- ):
- """
- Map bounding boxes from the rotated image back to the original image.
- Parameters:
- - detections: list of numpy arrays, each containing bounding box coordinates [x1, y1, x2, y2]
- - table_angle: rotation angle in degrees (90, 180, or 270)
- - width_orig: width of the original image (img1)
- - height_orig: height of the original image (img1)
- Returns:
- - mapped_detections: list of numpy arrays with mapped bounding box coordinates
- """
- mapped_detections = []
- for i in range(len(detections)):
- tbx1, tby1, tbx2, tby2 = (
- detections[i][0],
- detections[i][1],
- detections[i][2],
- detections[i][3],
- )
- if table_angle == "270":
- new_x1, new_y1 = tby1, img_width - tbx2
- new_x2, new_y2 = tby2, img_width - tbx1
- elif table_angle == "180":
- new_x1, new_y1 = img_width - tbx2, img_height - tby2
- new_x2, new_y2 = img_width - tbx1, img_height - tby1
- elif table_angle == "90":
- new_x1, new_y1 = img_height - tby2, tbx1
- new_x2, new_y2 = img_height - tby1, tbx2
- new_box = np.array([new_x1, new_y1, new_x2, new_y2])
- mapped_detections.append(new_box)
- return mapped_detections
- def split_string_by_keywords(self, html_string):
- """
- Split HTML string by keywords.
- Args:
- html_string (str): The HTML string.
- Returns:
- split_html (list): The list of html keywords.
- """
- keywords = [
- "<thead>",
- "</thead>",
- "<tbody>",
- "</tbody>",
- "<tr>",
- "</tr>",
- "<td>",
- "<td",
- ">",
- "</td>",
- 'colspan="2"',
- 'colspan="3"',
- 'colspan="4"',
- 'colspan="5"',
- 'colspan="6"',
- 'colspan="7"',
- 'colspan="8"',
- 'colspan="9"',
- 'colspan="10"',
- 'colspan="11"',
- 'colspan="12"',
- 'colspan="13"',
- 'colspan="14"',
- 'colspan="15"',
- 'colspan="16"',
- 'colspan="17"',
- 'colspan="18"',
- 'colspan="19"',
- 'colspan="20"',
- 'rowspan="2"',
- 'rowspan="3"',
- 'rowspan="4"',
- 'rowspan="5"',
- 'rowspan="6"',
- 'rowspan="7"',
- 'rowspan="8"',
- 'rowspan="9"',
- 'rowspan="10"',
- 'rowspan="11"',
- 'rowspan="12"',
- 'rowspan="13"',
- 'rowspan="14"',
- 'rowspan="15"',
- 'rowspan="16"',
- 'rowspan="17"',
- 'rowspan="18"',
- 'rowspan="19"',
- 'rowspan="20"',
- ]
- regex_pattern = "|".join(re.escape(keyword) for keyword in keywords)
- split_result = re.split(f"({regex_pattern})", html_string)
- split_html = [part for part in split_result if part]
- return split_html
- def cluster_positions(self, positions, tolerance):
- if not positions:
- return []
- positions = sorted(set(positions))
- clustered = []
- current_cluster = [positions[0]]
- for pos in positions[1:]:
- if abs(pos - current_cluster[-1]) <= tolerance:
- current_cluster.append(pos)
- else:
- clustered.append(sum(current_cluster) / len(current_cluster))
- current_cluster = [pos]
- clustered.append(sum(current_cluster) / len(current_cluster))
- return clustered
- def trans_cells_det_results_to_html(self, cells_det_results):
- """
- Trans table cells bboxes to HTML.
- Args:
- cells_det_results (list): The table cells detection results.
- Returns:
- html (list): The list of html keywords.
- """
- tolerance = 5
- x_coords = [x for cell in cells_det_results for x in (cell[0], cell[2])]
- y_coords = [y for cell in cells_det_results for y in (cell[1], cell[3])]
- x_positions = self.cluster_positions(x_coords, tolerance)
- y_positions = self.cluster_positions(y_coords, tolerance)
- x_position_to_index = {x: i for i, x in enumerate(x_positions)}
- y_position_to_index = {y: i for i, y in enumerate(y_positions)}
- num_rows = len(y_positions) - 1
- num_cols = len(x_positions) - 1
- grid = [[None for _ in range(num_cols)] for _ in range(num_rows)]
- cells_info = []
- cell_index = 0
- cell_map = {}
- for index, cell in enumerate(cells_det_results):
- x1, y1, x2, y2 = cell
- x1_idx = min(
- range(len(x_positions)), key=lambda i: abs(x_positions[i] - x1)
- )
- x2_idx = min(
- range(len(x_positions)), key=lambda i: abs(x_positions[i] - x2)
- )
- y1_idx = min(
- range(len(y_positions)), key=lambda i: abs(y_positions[i] - y1)
- )
- y2_idx = min(
- range(len(y_positions)), key=lambda i: abs(y_positions[i] - y2)
- )
- col_start = min(x1_idx, x2_idx)
- col_end = max(x1_idx, x2_idx)
- row_start = min(y1_idx, y2_idx)
- row_end = max(y1_idx, y2_idx)
- rowspan = row_end - row_start
- colspan = col_end - col_start
- if rowspan == 0:
- rowspan = 1
- if colspan == 0:
- colspan = 1
- cells_info.append(
- {
- "row_start": row_start,
- "col_start": col_start,
- "rowspan": rowspan,
- "colspan": colspan,
- "content": "",
- }
- )
- for r in range(row_start, row_start + rowspan):
- for c in range(col_start, col_start + colspan):
- key = (r, c)
- if key in cell_map:
- continue
- else:
- cell_map[key] = index
- html = "<table><tbody>"
- for r in range(num_rows):
- html += "<tr>"
- c = 0
- while c < num_cols:
- key = (r, c)
- if key in cell_map:
- cell_index = cell_map[key]
- cell_info = cells_info[cell_index]
- if cell_info["row_start"] == r and cell_info["col_start"] == c:
- rowspan = cell_info["rowspan"]
- colspan = cell_info["colspan"]
- rowspan_attr = f' rowspan="{rowspan}"' if rowspan > 1 else ""
- colspan_attr = f' colspan="{colspan}"' if colspan > 1 else ""
- content = cell_info["content"]
- html += f"<td{rowspan_attr}{colspan_attr}>{content}</td>"
- c += cell_info["colspan"]
- else:
- html += "<td></td>"
- c += 1
- html += "</tr>"
- html += "</tbody></table>"
- html = self.split_string_by_keywords(html)
- return html
- def predict_single_table_recognition_res(
- self,
- image_array: np.ndarray,
- overall_ocr_res: OCRResult,
- table_box: list,
- use_e2e_wired_table_rec_model: bool = False,
- use_e2e_wireless_table_rec_model: bool = False,
- use_wired_table_cells_trans_to_html: bool = False,
- use_wireless_table_cells_trans_to_html: bool = False,
- use_ocr_results_with_table_cells: bool = True,
- flag_find_nei_text: bool = True,
- ) -> SingleTableRecognitionResult:
- """
- Predict table recognition results from an image array, layout detection results, and OCR results.
- Args:
- image_array (np.ndarray): The input image represented as a numpy array.
- overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
- The overall OCR results containing text recognition information.
- table_box (list): The table box coordinates.
- use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
- use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model.
- use_wired_table_cells_trans_to_html (bool): Whether to use wired table cells trans to HTML.
- use_wireless_table_cells_trans_to_html (bool): Whether to use wireless table cells trans to HTML.
- use_ocr_results_with_table_cells (bool): Whether to use OCR results processed by table cells.
- flag_find_nei_text (bool): Whether to find neighboring text.
- Returns:
- SingleTableRecognitionResult: single table recognition result.
- """
- table_cls_pred = list(self.table_cls_model(image_array))[0]
- table_cls_result = self.extract_results(table_cls_pred, "cls")
- use_e2e_model = False
- cells_trans_to_html = False
- if table_cls_result == "wired_table":
- if use_wired_table_cells_trans_to_html == True:
- cells_trans_to_html = True
- else:
- table_structure_pred = list(self.wired_table_rec_model(image_array))[0]
- if use_e2e_wired_table_rec_model == True:
- use_e2e_model = True
- if cells_trans_to_html == True:
- table_structure_pred = list(
- self.wired_table_rec_model(image_array)
- )[0]
- else:
- table_cells_pred = list(
- self.wired_table_cells_detection_model(image_array, threshold=0.3)
- )[
- 0
- ] # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
- # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
- elif table_cls_result == "wireless_table":
- if use_wireless_table_cells_trans_to_html == True:
- cells_trans_to_html = True
- else:
- table_structure_pred = list(self.wireless_table_rec_model(image_array))[
- 0
- ]
- if use_e2e_wireless_table_rec_model == True:
- use_e2e_model = True
- if cells_trans_to_html == True:
- table_structure_pred = list(
- self.wireless_table_rec_model(image_array)
- )[0]
- else:
- table_cells_pred = list(
- self.wireless_table_cells_detection_model(
- image_array, threshold=0.3
- )
- )[
- 0
- ] # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
- # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
- if use_e2e_model == False:
- table_cells_result, table_cells_score = self.extract_results(
- table_cells_pred, "det"
- )
- table_cells_result, table_cells_score = self.cells_det_results_nms(
- table_cells_result, table_cells_score
- )
- if cells_trans_to_html == True:
- table_structure_result = self.trans_cells_det_results_to_html(
- table_cells_result
- )
- else:
- table_structure_result = self.extract_results(
- table_structure_pred, "table_stru"
- )
- ocr_det_boxes = self.get_region_ocr_det_boxes(
- overall_ocr_res["rec_boxes"].tolist(), table_box
- )
- table_cells_result = self.cells_det_results_reprocessing(
- table_cells_result,
- table_cells_score,
- ocr_det_boxes,
- len(table_structure_pred["bbox"]),
- )
- if use_ocr_results_with_table_cells == True:
- if self.cells_split_ocr == True:
- table_box_copy = np.array([table_box])
- table_ocr_pred = get_sub_regions_ocr_res(
- overall_ocr_res, table_box_copy
- )
- table_ocr_pred = self.split_ocr_bboxes_by_table_cells(
- table_cells_result, table_ocr_pred, image_array
- )
- cells_texts_list = []
- else:
- cells_texts_list = self.gen_ocr_with_table_cells(
- image_array, table_cells_result
- )
- table_ocr_pred = {}
- else:
- table_ocr_pred = {}
- cells_texts_list = []
- single_table_recognition_res = get_table_recognition_res(
- table_box,
- table_structure_result,
- table_cells_result,
- overall_ocr_res,
- table_ocr_pred,
- cells_texts_list,
- use_ocr_results_with_table_cells,
- self.cells_split_ocr,
- )
- else:
- cells_texts_list = []
- use_ocr_results_with_table_cells = False
- table_cells_result_e2e = table_structure_pred["bbox"]
- table_cells_result_e2e = [
- [rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result_e2e
- ]
- if cells_trans_to_html == True:
- table_structure_pred["structure"] = (
- self.trans_cells_det_results_to_html(table_cells_result_e2e)
- )
- single_table_recognition_res = get_table_recognition_res_e2e(
- table_box,
- table_structure_pred,
- overall_ocr_res,
- cells_texts_list,
- use_ocr_results_with_table_cells,
- )
- neighbor_text = ""
- if flag_find_nei_text:
- match_idx_list = get_neighbor_boxes_idx(
- overall_ocr_res["rec_boxes"], table_box
- )
- if len(match_idx_list) > 0:
- for idx in match_idx_list:
- neighbor_text += overall_ocr_res["rec_texts"][idx] + "; "
- single_table_recognition_res["neighbor_texts"] = neighbor_text
- return single_table_recognition_res
- def predict(
- self,
- input: Union[str, List[str], np.ndarray, List[np.ndarray]],
- use_doc_orientation_classify: Optional[bool] = None,
- use_doc_unwarping: Optional[bool] = None,
- use_layout_detection: Optional[bool] = None,
- use_ocr_model: Optional[bool] = None,
- overall_ocr_res: Optional[OCRResult] = None,
- layout_det_res: Optional[DetResult] = None,
- text_det_limit_side_len: Optional[int] = None,
- text_det_limit_type: Optional[str] = None,
- text_det_thresh: Optional[float] = None,
- text_det_box_thresh: Optional[float] = None,
- text_det_unclip_ratio: Optional[float] = None,
- text_rec_score_thresh: Optional[float] = None,
- use_e2e_wired_table_rec_model: bool = False,
- use_e2e_wireless_table_rec_model: bool = False,
- use_wired_table_cells_trans_to_html: bool = False,
- use_wireless_table_cells_trans_to_html: bool = False,
- use_table_orientation_classify: bool = True,
- use_ocr_results_with_table_cells: bool = True,
- **kwargs,
- ) -> TableRecognitionResult:
- """
- This function predicts the layout parsing result for the given input.
- Args:
- input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) of pdf(s) to be processed.
- use_layout_detection (bool): Whether to use layout detection.
- use_doc_orientation_classify (bool): Whether to use document orientation classification.
- use_doc_unwarping (bool): Whether to use document unwarping.
- overall_ocr_res (OCRResult): The overall OCR result with convert_points_to_boxes information.
- It will be used if it is not None and use_ocr_model is False.
- layout_det_res (DetResult): The layout detection result.
- It will be used if it is not None and use_layout_detection is False.
- use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
- use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model.
- use_wired_table_cells_trans_to_html (bool): Whether to use wired table cells trans to HTML.
- use_wireless_table_cells_trans_to_html (bool): Whether to use wireless table cells trans to HTML.
- use_table_orientation_classify (bool): Whether to use table orientation classification.
- use_ocr_results_with_table_cells (bool): Whether to use OCR results processed by table cells.
- **kwargs: Additional keyword arguments.
- Returns:
- TableRecognitionResult: The predicted table recognition result.
- """
- # self.cells_split_ocr = True
- # add by zhch158, 新增:允许外部控制单元格内是否拆分重识别(默认保持兼容 True)
- use_table_cells_split_ocr = kwargs.get("use_table_cells_split_ocr", None)
- self.cells_split_ocr = (
- bool(use_table_cells_split_ocr) if use_table_cells_split_ocr is not None else True
- )
- if use_table_orientation_classify == True and (
- self.table_orientation_classify_model is None
- ):
- assert self.table_orientation_classify_config != None
- self.table_orientation_classify_model = self.create_model(
- self.table_orientation_classify_config
- )
- model_settings = self.get_model_settings(
- use_doc_orientation_classify,
- use_doc_unwarping,
- use_layout_detection,
- use_ocr_model,
- )
- if not self.check_model_settings_valid(
- model_settings, overall_ocr_res, layout_det_res
- ):
- yield {"error": "the input params for model settings are invalid!"}
- for img_id, batch_data in enumerate(self.batch_sampler(input)):
- image_array = self.img_reader(batch_data.instances)[0]
- if model_settings["use_doc_preprocessor"]:
- doc_preprocessor_res = list(
- self.doc_preprocessor_pipeline(
- image_array,
- use_doc_orientation_classify=use_doc_orientation_classify,
- use_doc_unwarping=use_doc_unwarping,
- )
- )[0]
- else:
- doc_preprocessor_res = {"output_img": image_array}
- doc_preprocessor_image = doc_preprocessor_res["output_img"]
- if model_settings["use_ocr_model"]:
- overall_ocr_res = list(
- self.general_ocr_pipeline(
- doc_preprocessor_image,
- text_det_limit_side_len=text_det_limit_side_len,
- text_det_limit_type=text_det_limit_type,
- text_det_thresh=text_det_thresh,
- text_det_box_thresh=text_det_box_thresh,
- text_det_unclip_ratio=text_det_unclip_ratio,
- text_rec_score_thresh=text_rec_score_thresh,
- )
- )[0]
- elif self.general_ocr_pipeline is None and (
- (
- use_ocr_results_with_table_cells == True
- and self.cells_split_ocr == False
- )
- or use_table_orientation_classify == True
- ):
- assert self.general_ocr_config_bak != None
- self.general_ocr_pipeline = self.create_pipeline(
- self.general_ocr_config_bak
- )
- if use_table_orientation_classify == False:
- table_angle = "0"
- table_res_list = []
- table_region_id = 1
- if not model_settings["use_layout_detection"] and layout_det_res is None:
- img_height, img_width = doc_preprocessor_image.shape[:2]
- table_box = [0, 0, img_width - 1, img_height - 1]
- if use_table_orientation_classify == True:
- table_angle = list(
- self.table_orientation_classify_model(doc_preprocessor_image)
- )[0]["label_names"][0]
- if table_angle == "90":
- doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=1)
- elif table_angle == "180":
- doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=2)
- elif table_angle == "270":
- doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=3)
- if table_angle in ["90", "180", "270"]:
- overall_ocr_res = list(
- self.general_ocr_pipeline(
- doc_preprocessor_image,
- text_det_limit_side_len=text_det_limit_side_len,
- text_det_limit_type=text_det_limit_type,
- text_det_thresh=text_det_thresh,
- text_det_box_thresh=text_det_box_thresh,
- text_det_unclip_ratio=text_det_unclip_ratio,
- text_rec_score_thresh=text_rec_score_thresh,
- )
- )[0]
- tbx1, tby1, tbx2, tby2 = (
- table_box[0],
- table_box[1],
- table_box[2],
- table_box[3],
- )
- if table_angle == "90":
- new_x1, new_y1 = tby1, img_width - tbx2
- new_x2, new_y2 = tby2, img_width - tbx1
- elif table_angle == "180":
- new_x1, new_y1 = img_width - tbx2, img_height - tby2
- new_x2, new_y2 = img_width - tbx1, img_height - tby1
- elif table_angle == "270":
- new_x1, new_y1 = img_height - tby2, tbx1
- new_x2, new_y2 = img_height - tby1, tbx2
- table_box = [new_x1, new_y1, new_x2, new_y2]
- single_table_rec_res = self.predict_single_table_recognition_res(
- doc_preprocessor_image,
- overall_ocr_res,
- table_box,
- use_e2e_wired_table_rec_model,
- use_e2e_wireless_table_rec_model,
- use_wired_table_cells_trans_to_html,
- use_wireless_table_cells_trans_to_html,
- use_ocr_results_with_table_cells,
- flag_find_nei_text=False,
- )
- single_table_rec_res["table_region_id"] = table_region_id
- if use_table_orientation_classify == True and table_angle != "0":
- img_height, img_width = doc_preprocessor_image.shape[:2]
- single_table_rec_res["cell_box_list"] = (
- self.map_cells_to_original_image(
- single_table_rec_res["cell_box_list"],
- table_angle,
- img_width,
- img_height,
- )
- )
- table_res_list.append(single_table_rec_res)
- table_region_id += 1
- else:
- if model_settings["use_layout_detection"]:
- layout_det_res = list(
- self.layout_det_model(doc_preprocessor_image)
- )[0]
- img_height, img_width = doc_preprocessor_image.shape[:2]
- for box_info in layout_det_res["boxes"]:
- if box_info["label"].lower() in ["table"]:
- crop_img_info = self._crop_by_boxes(
- doc_preprocessor_image, [box_info]
- )
- crop_img_info = crop_img_info[0]
- table_box = crop_img_info["box"]
- if use_table_orientation_classify == True:
- doc_preprocessor_image_copy = doc_preprocessor_image.copy()
- table_angle = list(
- self.table_orientation_classify_model(
- crop_img_info["img"]
- )
- )[0]["label_names"][0]
- if table_angle == "90":
- crop_img_info["img"] = np.rot90(crop_img_info["img"], k=1)
- doc_preprocessor_image_copy = np.rot90(
- doc_preprocessor_image_copy, k=1
- )
- elif table_angle == "180":
- crop_img_info["img"] = np.rot90(crop_img_info["img"], k=2)
- doc_preprocessor_image_copy = np.rot90(
- doc_preprocessor_image_copy, k=2
- )
- elif table_angle == "270":
- crop_img_info["img"] = np.rot90(crop_img_info["img"], k=3)
- doc_preprocessor_image_copy = np.rot90(
- doc_preprocessor_image_copy, k=3
- )
- if table_angle in ["90", "180", "270"]:
- overall_ocr_res = list(
- self.general_ocr_pipeline(
- doc_preprocessor_image_copy,
- text_det_limit_side_len=text_det_limit_side_len,
- text_det_limit_type=text_det_limit_type,
- text_det_thresh=text_det_thresh,
- text_det_box_thresh=text_det_box_thresh,
- text_det_unclip_ratio=text_det_unclip_ratio,
- text_rec_score_thresh=text_rec_score_thresh,
- )
- )[0]
- tbx1, tby1, tbx2, tby2 = (
- table_box[0],
- table_box[1],
- table_box[2],
- table_box[3],
- )
- if table_angle == "90":
- new_x1, new_y1 = tby1, img_width - tbx2
- new_x2, new_y2 = tby2, img_width - tbx1
- elif table_angle == "180":
- new_x1, new_y1 = img_width - tbx2, img_height - tby2
- new_x2, new_y2 = img_width - tbx1, img_height - tby1
- elif table_angle == "270":
- new_x1, new_y1 = img_height - tby2, tbx1
- new_x2, new_y2 = img_height - tby1, tbx2
- table_box = [new_x1, new_y1, new_x2, new_y2]
- single_table_rec_res = (
- self.predict_single_table_recognition_res(
- crop_img_info["img"],
- overall_ocr_res,
- table_box,
- use_e2e_wired_table_rec_model,
- use_e2e_wireless_table_rec_model,
- use_wired_table_cells_trans_to_html,
- use_wireless_table_cells_trans_to_html,
- use_ocr_results_with_table_cells,
- )
- )
- single_table_rec_res["table_region_id"] = table_region_id
- if (
- use_table_orientation_classify == True
- and table_angle != "0"
- ):
- img_height_copy, img_width_copy = (
- doc_preprocessor_image_copy.shape[:2]
- )
- single_table_rec_res["cell_box_list"] = (
- self.map_cells_to_original_image(
- single_table_rec_res["cell_box_list"],
- table_angle,
- img_width_copy,
- img_height_copy,
- )
- )
- table_res_list.append(single_table_rec_res)
- table_region_id += 1
- single_img_res = {
- "input_path": batch_data.input_paths[0],
- "page_index": batch_data.page_indexes[0],
- "doc_preprocessor_res": doc_preprocessor_res,
- "layout_det_res": layout_det_res,
- "overall_ocr_res": overall_ocr_res,
- "table_res_list": table_res_list,
- "model_settings": model_settings,
- }
- yield TableRecognitionResult(single_img_res)
- @pipeline_requires_extra("ocr")
- class TableRecognitionPipelineV2(AutoParallelImageSimpleInferencePipeline):
- entities = ["table_recognition_v2"]
- @property
- def _pipeline_cls(self):
- return _TableRecognitionPipelineV2
- def _get_batch_size(self, config):
- return 1
|