pipeline_v2.py 60 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import re
  16. from typing import Any, Dict, List, Optional, Tuple, Union
  17. import numpy as np
  18. from ....utils import logging
  19. from ....utils.deps import (
  20. function_requires_deps,
  21. is_dep_available,
  22. pipeline_requires_extra,
  23. )
  24. from ...common.batch_sampler import ImageBatchSampler
  25. from ...common.reader import ReadImage
  26. from ...models.object_detection.result import DetResult
  27. from ...utils.hpi import HPIConfig
  28. from ...utils.pp_option import PaddlePredictorOption
  29. from .._parallel import AutoParallelImageSimpleInferencePipeline
  30. from ..base import BasePipeline
  31. from ..components import CropByBoxes
  32. from ..doc_preprocessor.result import DocPreprocessorResult
  33. from ..layout_parsing.utils import get_sub_regions_ocr_res
  34. from ..ocr.result import OCRResult
  35. from .result import SingleTableRecognitionResult, TableRecognitionResult
  36. from .table_recognition_post_processing import (
  37. get_table_recognition_res as get_table_recognition_res_e2e,
  38. )
  39. from .table_recognition_post_processing_v2 import get_table_recognition_res
  40. from .utils import get_neighbor_boxes_idx
  41. if is_dep_available("scikit-learn"):
  42. from sklearn.cluster import KMeans
  43. class _TableRecognitionPipelineV2(BasePipeline):
  44. """Table Recognition Pipeline"""
  45. def __init__(
  46. self,
  47. config: Dict,
  48. device: str = None,
  49. pp_option: PaddlePredictorOption = None,
  50. use_hpip: bool = False,
  51. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  52. ) -> None:
  53. """Initializes the layout parsing pipeline.
  54. Args:
  55. config (Dict): Configuration dictionary containing various settings.
  56. device (str, optional): Device to run the predictions on. Defaults to None.
  57. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  58. use_hpip (bool, optional): Whether to use the high-performance
  59. inference plugin (HPIP) by default. Defaults to False.
  60. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
  61. The default high-performance inference configuration dictionary.
  62. Defaults to None.
  63. """
  64. super().__init__(
  65. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_config=hpi_config
  66. )
  67. self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
  68. if self.use_doc_preprocessor:
  69. doc_preprocessor_config = config.get("SubPipelines", {}).get(
  70. "DocPreprocessor",
  71. {
  72. "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
  73. },
  74. )
  75. self.doc_preprocessor_pipeline = self.create_pipeline(
  76. doc_preprocessor_config
  77. )
  78. self.use_layout_detection = config.get("use_layout_detection", True)
  79. if self.use_layout_detection:
  80. layout_det_config = config.get("SubModules", {}).get(
  81. "LayoutDetection",
  82. {"model_config_error": "config error for layout_det_model!"},
  83. )
  84. self.layout_det_model = self.create_model(layout_det_config)
  85. table_cls_config = config.get("SubModules", {}).get(
  86. "TableClassification",
  87. {"model_config_error": "config error for table_classification_model!"},
  88. )
  89. self.table_cls_model = self.create_model(table_cls_config)
  90. wired_table_rec_config = config.get("SubModules", {}).get(
  91. "WiredTableStructureRecognition",
  92. {"model_config_error": "config error for wired_table_structure_model!"},
  93. )
  94. self.wired_table_rec_model = self.create_model(wired_table_rec_config)
  95. wireless_table_rec_config = config.get("SubModules", {}).get(
  96. "WirelessTableStructureRecognition",
  97. {"model_config_error": "config error for wireless_table_structure_model!"},
  98. )
  99. self.wireless_table_rec_model = self.create_model(wireless_table_rec_config)
  100. wired_table_cells_det_config = config.get("SubModules", {}).get(
  101. "WiredTableCellsDetection",
  102. {
  103. "model_config_error": "config error for wired_table_cells_detection_model!"
  104. },
  105. )
  106. self.wired_table_cells_detection_model = self.create_model(
  107. wired_table_cells_det_config
  108. )
  109. wireless_table_cells_det_config = config.get("SubModules", {}).get(
  110. "WirelessTableCellsDetection",
  111. {
  112. "model_config_error": "config error for wireless_table_cells_detection_model!"
  113. },
  114. )
  115. self.wireless_table_cells_detection_model = self.create_model(
  116. wireless_table_cells_det_config
  117. )
  118. self.use_ocr_model = config.get("use_ocr_model", True)
  119. self.general_ocr_pipeline = None
  120. if self.use_ocr_model:
  121. general_ocr_config = config.get("SubPipelines", {}).get(
  122. "GeneralOCR",
  123. {"pipeline_config_error": "config error for general_ocr_pipeline!"},
  124. )
  125. self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
  126. else:
  127. self.general_ocr_config_bak = config.get("SubPipelines", {}).get(
  128. "GeneralOCR", None
  129. )
  130. self.table_orientation_classify_model = None
  131. self.table_orientation_classify_config = config.get("SubModules", {}).get(
  132. "TableOrientationClassify", None
  133. )
  134. self._crop_by_boxes = CropByBoxes()
  135. self.batch_sampler = ImageBatchSampler(batch_size=1)
  136. self.img_reader = ReadImage(format="BGR")
  137. def get_model_settings(
  138. self,
  139. use_doc_orientation_classify: Optional[bool],
  140. use_doc_unwarping: Optional[bool],
  141. use_layout_detection: Optional[bool],
  142. use_ocr_model: Optional[bool],
  143. ) -> dict:
  144. """
  145. Get the model settings based on the provided parameters or default values.
  146. Args:
  147. use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
  148. use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
  149. use_layout_detection (Optional[bool]): Whether to use layout detection.
  150. use_ocr_model (Optional[bool]): Whether to use OCR model.
  151. Returns:
  152. dict: A dictionary containing the model settings.
  153. """
  154. if use_doc_orientation_classify is None and use_doc_unwarping is None:
  155. use_doc_preprocessor = self.use_doc_preprocessor
  156. else:
  157. if use_doc_orientation_classify is True or use_doc_unwarping is True:
  158. use_doc_preprocessor = True
  159. else:
  160. use_doc_preprocessor = False
  161. if use_layout_detection is None:
  162. use_layout_detection = self.use_layout_detection
  163. if use_ocr_model is None:
  164. use_ocr_model = self.use_ocr_model
  165. return dict(
  166. use_doc_preprocessor=use_doc_preprocessor,
  167. use_layout_detection=use_layout_detection,
  168. use_ocr_model=use_ocr_model,
  169. )
  170. def check_model_settings_valid(
  171. self,
  172. model_settings: Dict,
  173. overall_ocr_res: OCRResult,
  174. layout_det_res: DetResult,
  175. ) -> bool:
  176. """
  177. Check if the input parameters are valid based on the initialized models.
  178. Args:
  179. model_settings (Dict): A dictionary containing input parameters.
  180. overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
  181. The overall OCR result with convert_points_to_boxes information.
  182. layout_det_res (DetResult): The layout detection result.
  183. Returns:
  184. bool: True if all required models are initialized according to input parameters, False otherwise.
  185. """
  186. if model_settings["use_doc_preprocessor"] and not self.use_doc_preprocessor:
  187. logging.error(
  188. "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
  189. )
  190. return False
  191. if model_settings["use_layout_detection"]:
  192. if layout_det_res is not None:
  193. logging.error(
  194. "The layout detection model has already been initialized, please set use_layout_detection=False"
  195. )
  196. return False
  197. if not self.use_layout_detection:
  198. logging.error(
  199. "Set use_layout_detection, but the models for layout detection are not initialized."
  200. )
  201. return False
  202. if model_settings["use_ocr_model"]:
  203. if overall_ocr_res is not None:
  204. logging.error(
  205. "The OCR models have already been initialized, please set use_ocr_model=False"
  206. )
  207. return False
  208. if not self.use_ocr_model:
  209. logging.error(
  210. "Set use_ocr_model, but the models for OCR are not initialized."
  211. )
  212. return False
  213. else:
  214. if overall_ocr_res is None:
  215. logging.error("Set use_ocr_model=False, but no OCR results were found.")
  216. return False
  217. return True
  218. def predict_doc_preprocessor_res(
  219. self, image_array: np.ndarray, input_params: dict
  220. ) -> Tuple[DocPreprocessorResult, np.ndarray]:
  221. """
  222. Preprocess the document image based on input parameters.
  223. Args:
  224. image_array (np.ndarray): The input image array.
  225. input_params (dict): Dictionary containing preprocessing parameters.
  226. Returns:
  227. tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
  228. result dictionary and the processed image array.
  229. """
  230. if input_params["use_doc_preprocessor"]:
  231. use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
  232. use_doc_unwarping = input_params["use_doc_unwarping"]
  233. doc_preprocessor_res = next(
  234. self.doc_preprocessor_pipeline(
  235. image_array,
  236. use_doc_orientation_classify=use_doc_orientation_classify,
  237. use_doc_unwarping=use_doc_unwarping,
  238. )
  239. )
  240. doc_preprocessor_image = doc_preprocessor_res["output_img"]
  241. else:
  242. doc_preprocessor_res = {}
  243. doc_preprocessor_image = image_array
  244. return doc_preprocessor_res, doc_preprocessor_image
  245. def extract_results(self, pred, task):
  246. if task == "cls":
  247. return pred["label_names"][np.argmax(pred["scores"])]
  248. elif task == "det":
  249. threshold = 0.0
  250. result = []
  251. cell_score = []
  252. if "boxes" in pred and isinstance(pred["boxes"], list):
  253. for box in pred["boxes"]:
  254. if isinstance(box, dict) and "score" in box and "coordinate" in box:
  255. score = box["score"]
  256. coordinate = box["coordinate"]
  257. if isinstance(score, float) and score > threshold:
  258. result.append(coordinate)
  259. cell_score.append(score)
  260. return result, cell_score
  261. elif task == "table_stru":
  262. return pred["structure"]
  263. else:
  264. return None
  265. def cells_det_results_nms(
  266. self, cells_det_results, cells_det_scores, cells_det_threshold=0.3
  267. ):
  268. """
  269. Apply Non-Maximum Suppression (NMS) on detection results to remove redundant overlapping bounding boxes.
  270. Args:
  271. cells_det_results (list): List of bounding boxes, each box is in format [x1, y1, x2, y2].
  272. cells_det_scores (list): List of confidence scores corresponding to the bounding boxes.
  273. cells_det_threshold (float): IoU threshold for suppression. Boxes with IoU greater than this threshold
  274. will be suppressed. Default is 0.5.
  275. Returns:
  276. Tuple[list, list]: A tuple containing the list of bounding boxes and confidence scores after NMS,
  277. while maintaining one-to-one correspondence.
  278. """
  279. # Convert lists to numpy arrays for efficient computation
  280. boxes = np.array(cells_det_results)
  281. scores = np.array(cells_det_scores)
  282. # Initialize list for picked indices
  283. picked_indices = []
  284. # Get coordinates of bounding boxes
  285. x1 = boxes[:, 0]
  286. y1 = boxes[:, 1]
  287. x2 = boxes[:, 2]
  288. y2 = boxes[:, 3]
  289. # Compute the area of the bounding boxes
  290. areas = (x2 - x1) * (y2 - y1)
  291. # Sort the bounding boxes by the confidence scores in descending order
  292. order = scores.argsort()[::-1]
  293. # Process the boxes
  294. while order.size > 0:
  295. # Index of the current highest score box
  296. i = order[0]
  297. picked_indices.append(i)
  298. # Compute IoU between the highest score box and the rest
  299. xx1 = np.maximum(x1[i], x1[order[1:]])
  300. yy1 = np.maximum(y1[i], y1[order[1:]])
  301. xx2 = np.minimum(x2[i], x2[order[1:]])
  302. yy2 = np.minimum(y2[i], y2[order[1:]])
  303. # Compute the width and height of the overlapping area
  304. w = np.maximum(0.0, xx2 - xx1)
  305. h = np.maximum(0.0, yy2 - yy1)
  306. # Compute the ratio of overlap (IoU)
  307. inter = w * h
  308. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  309. # Indices of boxes with IoU less than threshold
  310. inds = np.where(ovr <= cells_det_threshold)[0]
  311. # Update order, only keep boxes with IoU less than threshold
  312. order = order[
  313. inds + 1
  314. ] # inds shifted by 1 because order[0] is the current box
  315. # Select the boxes and scores based on picked indices
  316. final_boxes = boxes[picked_indices].tolist()
  317. final_scores = scores[picked_indices].tolist()
  318. return final_boxes, final_scores
  319. def get_region_ocr_det_boxes(self, ocr_det_boxes, table_box):
  320. """Adjust the coordinates of ocr_det_boxes that are fully inside table_box relative to table_box.
  321. Args:
  322. ocr_det_boxes (list of list): List of bounding boxes [x1, y1, x2, y2] in the original image.
  323. table_box (list): Bounding box [x1, y1, x2, y2] of the target region in the original image.
  324. Returns:
  325. list of list: List of adjusted bounding boxes relative to table_box, for boxes fully inside table_box.
  326. """
  327. tol = 0
  328. # Extract coordinates from table_box
  329. x_min_t, y_min_t, x_max_t, y_max_t = table_box
  330. adjusted_boxes = []
  331. for box in ocr_det_boxes:
  332. x_min_b, y_min_b, x_max_b, y_max_b = box
  333. # Check if the box is fully inside table_box
  334. if (
  335. x_min_b + tol >= x_min_t
  336. and y_min_b + tol >= y_min_t
  337. and x_max_b - tol <= x_max_t
  338. and y_max_b - tol <= y_max_t
  339. ):
  340. # Adjust the coordinates to be relative to table_box
  341. adjusted_box = [
  342. x_min_b - x_min_t, # Adjust x1
  343. y_min_b - y_min_t, # Adjust y1
  344. x_max_b - x_min_t, # Adjust x2
  345. y_max_b - y_min_t, # Adjust y2
  346. ]
  347. adjusted_boxes.append(adjusted_box)
  348. # Discard boxes not fully inside table_box
  349. return adjusted_boxes
  350. def cells_det_results_reprocessing(
  351. self, cells_det_results, cells_det_scores, ocr_det_results, html_pred_boxes_nums
  352. ):
  353. """
  354. Process and filter cells_det_results based on ocr_det_results and html_pred_boxes_nums.
  355. Args:
  356. cells_det_results (List[List[float]]): List of detected cell rectangles [[x1, y1, x2, y2], ...].
  357. cells_det_scores (List[float]): List of confidence scores for each rectangle in cells_det_results.
  358. ocr_det_results (List[List[float]]): List of OCR detected rectangles [[x1, y1, x2, y2], ...].
  359. html_pred_boxes_nums (int): The desired number of rectangles in the final output.
  360. Returns:
  361. List[List[float]]: The processed list of rectangles.
  362. """
  363. # Function to compute IoU between two rectangles
  364. def compute_iou(box1, box2):
  365. """
  366. Compute the Intersection over Union (IoU) between two rectangles.
  367. Args:
  368. box1 (array-like): [x1, y1, x2, y2] of the first rectangle.
  369. box2 (array-like): [x1, y1, x2, y2] of the second rectangle.
  370. Returns:
  371. float: The IoU between the two rectangles.
  372. """
  373. # Determine the coordinates of the intersection rectangle
  374. x_left = max(box1[0], box2[0])
  375. y_top = max(box1[1], box2[1])
  376. x_right = min(box1[2], box2[2])
  377. y_bottom = min(box1[3], box2[3])
  378. if x_right <= x_left or y_bottom <= y_top:
  379. return 0.0
  380. # Calculate the area of intersection rectangle
  381. intersection_area = (x_right - x_left) * (y_bottom - y_top)
  382. # Calculate the area of both rectangles
  383. box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
  384. (box2[2] - box2[0]) * (box2[3] - box2[1])
  385. # Calculate the IoU
  386. iou = intersection_area / float(box1_area)
  387. return iou
  388. # Function to combine rectangles into N rectangles
  389. @function_requires_deps("scikit-learn")
  390. def combine_rectangles(rectangles, N):
  391. """
  392. Combine rectangles into N rectangles based on geometric proximity.
  393. Args:
  394. rectangles (list of list of int): A list of rectangles, each represented by [x1, y1, x2, y2].
  395. N (int): The desired number of combined rectangles.
  396. Returns:
  397. list of list of int: A list of N combined rectangles.
  398. """
  399. # Number of input rectangles
  400. num_rects = len(rectangles)
  401. # If N is greater than or equal to the number of rectangles, return the original rectangles
  402. if N >= num_rects:
  403. return rectangles
  404. # Compute the center points of the rectangles
  405. centers = np.array(
  406. [
  407. [
  408. (rect[0] + rect[2]) / 2, # Center x-coordinate
  409. (rect[1] + rect[3]) / 2, # Center y-coordinate
  410. ]
  411. for rect in rectangles
  412. ]
  413. )
  414. # Perform KMeans clustering on the center points to group them into N clusters
  415. kmeans = KMeans(n_clusters=N, random_state=0, n_init="auto")
  416. labels = kmeans.fit_predict(centers)
  417. # Initialize a list to store the combined rectangles
  418. combined_rectangles = []
  419. # For each cluster, compute the minimal bounding rectangle that covers all rectangles in the cluster
  420. for i in range(N):
  421. # Get the indices of rectangles that belong to cluster i
  422. indices = np.where(labels == i)[0]
  423. if len(indices) == 0:
  424. # If no rectangles in this cluster, skip it
  425. continue
  426. # Extract the rectangles in cluster i
  427. cluster_rects = np.array([rectangles[idx] for idx in indices])
  428. # Compute the minimal x1, y1 (top-left corner) and maximal x2, y2 (bottom-right corner)
  429. x1_min = np.min(cluster_rects[:, 0])
  430. y1_min = np.min(cluster_rects[:, 1])
  431. x2_max = np.max(cluster_rects[:, 2])
  432. y2_max = np.max(cluster_rects[:, 3])
  433. # Append the combined rectangle to the list
  434. combined_rectangles.append([x1_min, y1_min, x2_max, y2_max])
  435. return combined_rectangles
  436. # Ensure that the inputs are numpy arrays for efficient computation
  437. cells_det_results = np.array(cells_det_results)
  438. cells_det_scores = np.array(cells_det_scores)
  439. ocr_det_results = np.array(ocr_det_results)
  440. more_cells_flag = False
  441. if len(cells_det_results) == html_pred_boxes_nums:
  442. return cells_det_results
  443. # Step 1: If cells_det_results has more rectangles than html_pred_boxes_nums
  444. elif len(cells_det_results) > html_pred_boxes_nums:
  445. more_cells_flag = True
  446. # Select the indices of the top html_pred_boxes_nums scores
  447. top_indices = np.argsort(-cells_det_scores)[:html_pred_boxes_nums]
  448. # Adjust the corresponding rectangles
  449. cells_det_results = cells_det_results[top_indices].tolist()
  450. # Threshold for IoU
  451. iou_threshold = 0.6
  452. # List to store ocr_miss_boxes
  453. ocr_miss_boxes = []
  454. # For each rectangle in ocr_det_results
  455. for ocr_rect in ocr_det_results:
  456. merge_ocr_box_iou = []
  457. # Flag to indicate if ocr_rect has IoU >= threshold with any cell_rect
  458. has_large_iou = False
  459. # For each rectangle in cells_det_results
  460. for cell_rect in cells_det_results:
  461. # Compute IoU
  462. iou = compute_iou(ocr_rect, cell_rect)
  463. if iou > 0:
  464. merge_ocr_box_iou.append(iou)
  465. if (iou >= iou_threshold) or (sum(merge_ocr_box_iou) >= iou_threshold):
  466. has_large_iou = True
  467. break
  468. if not has_large_iou:
  469. ocr_miss_boxes.append(ocr_rect)
  470. # If no ocr_miss_boxes, return cells_det_results
  471. if len(ocr_miss_boxes) == 0:
  472. final_results = (
  473. cells_det_results
  474. if more_cells_flag == True
  475. else cells_det_results.tolist()
  476. )
  477. else:
  478. if more_cells_flag == True:
  479. final_results = combine_rectangles(
  480. cells_det_results + ocr_miss_boxes, html_pred_boxes_nums
  481. )
  482. else:
  483. # Need to combine ocr_miss_boxes into N rectangles
  484. N = html_pred_boxes_nums - len(cells_det_results)
  485. # Combine ocr_miss_boxes into N rectangles
  486. ocr_supp_boxes = combine_rectangles(ocr_miss_boxes, N)
  487. # Combine cells_det_results and ocr_supp_boxes
  488. final_results = np.concatenate(
  489. (cells_det_results, ocr_supp_boxes), axis=0
  490. ).tolist()
  491. if len(final_results) <= 0.6 * html_pred_boxes_nums:
  492. final_results = combine_rectangles(ocr_det_results, html_pred_boxes_nums)
  493. return final_results
  494. def split_ocr_bboxes_by_table_cells(
  495. self, cells_det_results, overall_ocr_res, ori_img, k=2
  496. ):
  497. """
  498. Split OCR bounding boxes based on table cell boundaries when they span multiple cells horizontally.
  499. Args:
  500. cells_det_results (list): List of cell bounding boxes in format [x1, y1, x2, y2]
  501. overall_ocr_res (dict): Dictionary containing OCR results with keys:
  502. - 'rec_boxes': OCR bounding boxes (will be converted to list)
  503. - 'rec_texts': OCR recognized texts
  504. ori_img (np.array): Original input image array
  505. k (int): Threshold for determining when to split (minimum number of cells spanned)
  506. Returns:
  507. dict: Modified overall_ocr_res with split boxes and texts
  508. """
  509. def calculate_iou(box1, box2):
  510. """
  511. Calculate Intersection over Union (IoU) between two bounding boxes.
  512. Args:
  513. box1 (list): [x1, y1, x2, y2]
  514. box2 (list): [x1, y1, x2, y2]
  515. Returns:
  516. float: IoU value
  517. """
  518. # Determine intersection coordinates
  519. x_left = max(box1[0], box2[0])
  520. y_top = max(box1[1], box2[1])
  521. x_right = min(box1[2], box2[2])
  522. y_bottom = min(box1[3], box2[3])
  523. if x_right < x_left or y_bottom < y_top:
  524. return 0.0
  525. # Calculate areas
  526. intersection_area = (x_right - x_left) * (y_bottom - y_top)
  527. box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
  528. box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
  529. # return intersection_area / float(box1_area + box2_area - intersection_area)
  530. return intersection_area / box2_area
  531. def get_overlapping_cells(ocr_box, cells):
  532. """
  533. Find cells that overlap significantly with the OCR box (IoU > 0.5).
  534. Args:
  535. ocr_box (list): OCR bounding box [x1, y1, x2, y2]
  536. cells (list): List of cell bounding boxes
  537. Returns:
  538. list: Indices of overlapping cells, sorted by x-coordinate
  539. """
  540. overlapping = []
  541. for idx, cell in enumerate(cells):
  542. if calculate_iou(ocr_box, cell) > 0.5:
  543. overlapping.append(idx)
  544. # Sort overlapping cells by their x-coordinate (left to right)
  545. overlapping.sort(key=lambda i: cells[i][0])
  546. return overlapping
  547. def split_box_by_cells(ocr_box, cell_indices, cells):
  548. """
  549. Split OCR box vertically at cell boundaries.
  550. Args:
  551. ocr_box (list): Original OCR box [x1, y1, x2, y2]
  552. cell_indices (list): Indices of cells to split by
  553. cells (list): All cell bounding boxes
  554. Returns:
  555. list: List of split boxes
  556. """
  557. if not cell_indices:
  558. return [ocr_box]
  559. split_boxes = []
  560. cells_to_split = [cells[i] for i in cell_indices]
  561. if ocr_box[0] < cells_to_split[0][0]:
  562. split_boxes.append(
  563. [ocr_box[0], ocr_box[1], cells_to_split[0][0], ocr_box[3]]
  564. )
  565. for i in range(len(cells_to_split)):
  566. current_cell = cells_to_split[i]
  567. split_boxes.append(
  568. [
  569. max(ocr_box[0], current_cell[0]),
  570. ocr_box[1],
  571. min(ocr_box[2], current_cell[2]),
  572. ocr_box[3],
  573. ]
  574. )
  575. if i < len(cells_to_split) - 1:
  576. next_cell = cells_to_split[i + 1]
  577. if current_cell[2] < next_cell[0]:
  578. split_boxes.append(
  579. [current_cell[2], ocr_box[1], next_cell[0], ocr_box[3]]
  580. )
  581. last_cell = cells_to_split[-1]
  582. if last_cell[2] < ocr_box[2]:
  583. split_boxes.append([last_cell[2], ocr_box[1], ocr_box[2], ocr_box[3]])
  584. unique_boxes = []
  585. seen = set()
  586. for box in split_boxes:
  587. box_tuple = tuple(box)
  588. if box_tuple not in seen:
  589. seen.add(box_tuple)
  590. unique_boxes.append(box)
  591. return unique_boxes
  592. # Convert OCR boxes to list if needed
  593. if hasattr(overall_ocr_res["rec_boxes"], "tolist"):
  594. ocr_det_results = overall_ocr_res["rec_boxes"].tolist()
  595. else:
  596. ocr_det_results = overall_ocr_res["rec_boxes"]
  597. ocr_texts = overall_ocr_res["rec_texts"]
  598. # Make copies to modify
  599. new_boxes = []
  600. new_texts = []
  601. # Process each OCR box
  602. i = 0
  603. while i < len(ocr_det_results):
  604. ocr_box = ocr_det_results[i]
  605. text = ocr_texts[i]
  606. # Find cells that significantly overlap with this OCR box
  607. overlapping_cells = get_overlapping_cells(ocr_box, cells_det_results)
  608. # Check if we need to split (spans >= k cells)
  609. if len(overlapping_cells) >= k:
  610. # Split the box at cell boundaries
  611. split_boxes = split_box_by_cells(
  612. ocr_box, overlapping_cells, cells_det_results
  613. )
  614. # Process each split box
  615. split_texts = []
  616. for box in split_boxes:
  617. x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
  618. if y2 - y1 > 1 and x2 - x1 > 1:
  619. ocr_result = next(
  620. self.general_ocr_pipeline.text_rec_model(
  621. ori_img[y1:y2, x1:x2, :]
  622. )
  623. )
  624. # Extract the recognized text from the OCR result
  625. if "rec_text" in ocr_result:
  626. result = ocr_result[
  627. "rec_text"
  628. ] # Assumes "rec_texts" contains a single string
  629. else:
  630. result = ""
  631. else:
  632. result = ""
  633. split_texts.append(result)
  634. # Add split boxes and texts to results
  635. new_boxes.extend(split_boxes)
  636. new_texts.extend(split_texts)
  637. else:
  638. # Keep original box and text
  639. new_boxes.append(ocr_box)
  640. new_texts.append(text)
  641. i += 1
  642. # Update the results dictionary
  643. overall_ocr_res["rec_boxes"] = new_boxes
  644. overall_ocr_res["rec_texts"] = new_texts
  645. return overall_ocr_res
  646. def gen_ocr_with_table_cells(self, ori_img, cells_bboxes):
  647. """
  648. Splits OCR bounding boxes by table cells and retrieves text.
  649. Args:
  650. ori_img (ndarray): The original image from which text regions will be extracted.
  651. cells_bboxes (list or ndarray): Detected cell bounding boxes to extract text from.
  652. Returns:
  653. list: A list containing the recognized texts from each cell.
  654. """
  655. # Check if cells_bboxes is a list and convert it if not.
  656. if not isinstance(cells_bboxes, list):
  657. cells_bboxes = cells_bboxes.tolist()
  658. texts_list = [] # Initialize a list to store the recognized texts.
  659. # Process each bounding box provided in cells_bboxes.
  660. for i in range(len(cells_bboxes)):
  661. # Extract and round up the coordinates of the bounding box.
  662. x1, y1, x2, y2 = [math.ceil(k) for k in cells_bboxes[i]]
  663. # Perform OCR on the defined region of the image and get the recognized text.
  664. if y2 - y1 > 1 and x2 - x1 > 1:
  665. rec_te = next(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))
  666. # Concatenate the texts and append them to the texts_list.
  667. texts_list.append("".join(rec_te["rec_texts"]))
  668. # Return the list of recognized texts from each cell.
  669. return texts_list
  670. def map_cells_to_original_image(
  671. self, detections, table_angle, img_width, img_height
  672. ):
  673. """
  674. Map bounding boxes from the rotated image back to the original image.
  675. Parameters:
  676. - detections: list of numpy arrays, each containing bounding box coordinates [x1, y1, x2, y2]
  677. - table_angle: rotation angle in degrees (90, 180, or 270)
  678. - width_orig: width of the original image (img1)
  679. - height_orig: height of the original image (img1)
  680. Returns:
  681. - mapped_detections: list of numpy arrays with mapped bounding box coordinates
  682. """
  683. mapped_detections = []
  684. for i in range(len(detections)):
  685. tbx1, tby1, tbx2, tby2 = (
  686. detections[i][0],
  687. detections[i][1],
  688. detections[i][2],
  689. detections[i][3],
  690. )
  691. if table_angle == "270":
  692. new_x1, new_y1 = tby1, img_width - tbx2
  693. new_x2, new_y2 = tby2, img_width - tbx1
  694. elif table_angle == "180":
  695. new_x1, new_y1 = img_width - tbx2, img_height - tby2
  696. new_x2, new_y2 = img_width - tbx1, img_height - tby1
  697. elif table_angle == "90":
  698. new_x1, new_y1 = img_height - tby2, tbx1
  699. new_x2, new_y2 = img_height - tby1, tbx2
  700. new_box = np.array([new_x1, new_y1, new_x2, new_y2])
  701. mapped_detections.append(new_box)
  702. return mapped_detections
  703. def split_string_by_keywords(self, html_string):
  704. """
  705. Split HTML string by keywords.
  706. Args:
  707. html_string (str): The HTML string.
  708. Returns:
  709. split_html (list): The list of html keywords.
  710. """
  711. keywords = [
  712. "<thead>",
  713. "</thead>",
  714. "<tbody>",
  715. "</tbody>",
  716. "<tr>",
  717. "</tr>",
  718. "<td>",
  719. "<td",
  720. ">",
  721. "</td>",
  722. 'colspan="2"',
  723. 'colspan="3"',
  724. 'colspan="4"',
  725. 'colspan="5"',
  726. 'colspan="6"',
  727. 'colspan="7"',
  728. 'colspan="8"',
  729. 'colspan="9"',
  730. 'colspan="10"',
  731. 'colspan="11"',
  732. 'colspan="12"',
  733. 'colspan="13"',
  734. 'colspan="14"',
  735. 'colspan="15"',
  736. 'colspan="16"',
  737. 'colspan="17"',
  738. 'colspan="18"',
  739. 'colspan="19"',
  740. 'colspan="20"',
  741. 'rowspan="2"',
  742. 'rowspan="3"',
  743. 'rowspan="4"',
  744. 'rowspan="5"',
  745. 'rowspan="6"',
  746. 'rowspan="7"',
  747. 'rowspan="8"',
  748. 'rowspan="9"',
  749. 'rowspan="10"',
  750. 'rowspan="11"',
  751. 'rowspan="12"',
  752. 'rowspan="13"',
  753. 'rowspan="14"',
  754. 'rowspan="15"',
  755. 'rowspan="16"',
  756. 'rowspan="17"',
  757. 'rowspan="18"',
  758. 'rowspan="19"',
  759. 'rowspan="20"',
  760. ]
  761. regex_pattern = "|".join(re.escape(keyword) for keyword in keywords)
  762. split_result = re.split(f"({regex_pattern})", html_string)
  763. split_html = [part for part in split_result if part]
  764. return split_html
  765. def cluster_positions(self, positions, tolerance):
  766. if not positions:
  767. return []
  768. positions = sorted(set(positions))
  769. clustered = []
  770. current_cluster = [positions[0]]
  771. for pos in positions[1:]:
  772. if abs(pos - current_cluster[-1]) <= tolerance:
  773. current_cluster.append(pos)
  774. else:
  775. clustered.append(sum(current_cluster) / len(current_cluster))
  776. current_cluster = [pos]
  777. clustered.append(sum(current_cluster) / len(current_cluster))
  778. return clustered
  779. def trans_cells_det_results_to_html(self, cells_det_results):
  780. """
  781. Trans table cells bboxes to HTML.
  782. Args:
  783. cells_det_results (list): The table cells detection results.
  784. Returns:
  785. html (list): The list of html keywords.
  786. """
  787. tolerance = 5
  788. x_coords = [x for cell in cells_det_results for x in (cell[0], cell[2])]
  789. y_coords = [y for cell in cells_det_results for y in (cell[1], cell[3])]
  790. x_positions = self.cluster_positions(x_coords, tolerance)
  791. y_positions = self.cluster_positions(y_coords, tolerance)
  792. x_position_to_index = {x: i for i, x in enumerate(x_positions)}
  793. y_position_to_index = {y: i for i, y in enumerate(y_positions)}
  794. num_rows = len(y_positions) - 1
  795. num_cols = len(x_positions) - 1
  796. grid = [[None for _ in range(num_cols)] for _ in range(num_rows)]
  797. cells_info = []
  798. cell_index = 0
  799. cell_map = {}
  800. for index, cell in enumerate(cells_det_results):
  801. x1, y1, x2, y2 = cell
  802. x1_idx = min(
  803. range(len(x_positions)), key=lambda i: abs(x_positions[i] - x1)
  804. )
  805. x2_idx = min(
  806. range(len(x_positions)), key=lambda i: abs(x_positions[i] - x2)
  807. )
  808. y1_idx = min(
  809. range(len(y_positions)), key=lambda i: abs(y_positions[i] - y1)
  810. )
  811. y2_idx = min(
  812. range(len(y_positions)), key=lambda i: abs(y_positions[i] - y2)
  813. )
  814. col_start = min(x1_idx, x2_idx)
  815. col_end = max(x1_idx, x2_idx)
  816. row_start = min(y1_idx, y2_idx)
  817. row_end = max(y1_idx, y2_idx)
  818. rowspan = row_end - row_start
  819. colspan = col_end - col_start
  820. if rowspan == 0:
  821. rowspan = 1
  822. if colspan == 0:
  823. colspan = 1
  824. cells_info.append(
  825. {
  826. "row_start": row_start,
  827. "col_start": col_start,
  828. "rowspan": rowspan,
  829. "colspan": colspan,
  830. "content": "",
  831. }
  832. )
  833. for r in range(row_start, row_start + rowspan):
  834. for c in range(col_start, col_start + colspan):
  835. key = (r, c)
  836. if key in cell_map:
  837. continue
  838. else:
  839. cell_map[key] = index
  840. html = "<table><tbody>"
  841. for r in range(num_rows):
  842. html += "<tr>"
  843. c = 0
  844. while c < num_cols:
  845. key = (r, c)
  846. if key in cell_map:
  847. cell_index = cell_map[key]
  848. cell_info = cells_info[cell_index]
  849. if cell_info["row_start"] == r and cell_info["col_start"] == c:
  850. rowspan = cell_info["rowspan"]
  851. colspan = cell_info["colspan"]
  852. rowspan_attr = f' rowspan="{rowspan}"' if rowspan > 1 else ""
  853. colspan_attr = f' colspan="{colspan}"' if colspan > 1 else ""
  854. content = cell_info["content"]
  855. html += f"<td{rowspan_attr}{colspan_attr}>{content}</td>"
  856. c += cell_info["colspan"]
  857. else:
  858. html += "<td></td>"
  859. c += 1
  860. html += "</tr>"
  861. html += "</tbody></table>"
  862. html = self.split_string_by_keywords(html)
  863. return html
  864. def predict_single_table_recognition_res(
  865. self,
  866. image_array: np.ndarray,
  867. overall_ocr_res: OCRResult,
  868. table_box: list,
  869. use_e2e_wired_table_rec_model: bool = False,
  870. use_e2e_wireless_table_rec_model: bool = False,
  871. use_wired_table_cells_trans_to_html: bool = False,
  872. use_wireless_table_cells_trans_to_html: bool = False,
  873. use_ocr_results_with_table_cells: bool = True,
  874. flag_find_nei_text: bool = True,
  875. ) -> SingleTableRecognitionResult:
  876. """
  877. Predict table recognition results from an image array, layout detection results, and OCR results.
  878. Args:
  879. image_array (np.ndarray): The input image represented as a numpy array.
  880. overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
  881. The overall OCR results containing text recognition information.
  882. table_box (list): The table box coordinates.
  883. use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
  884. use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model.
  885. use_wired_table_cells_trans_to_html (bool): Whether to use wired tabel cells trans to HTML.
  886. use_wireless_table_cells_trans_to_html (bool): Whether to use wireless tabel cells trans to HTML.
  887. use_ocr_results_with_table_cells (bool): Whether to use OCR results processed by table cells.
  888. flag_find_nei_text (bool): Whether to find neighboring text.
  889. Returns:
  890. SingleTableRecognitionResult: single table recognition result.
  891. """
  892. table_cls_pred = next(self.table_cls_model(image_array))
  893. table_cls_result = self.extract_results(table_cls_pred, "cls")
  894. use_e2e_model = False
  895. cells_trans_to_html = False
  896. if table_cls_result == "wired_table":
  897. if use_wired_table_cells_trans_to_html == True:
  898. cells_trans_to_html = True
  899. else:
  900. table_structure_pred = next(self.wired_table_rec_model(image_array))
  901. if use_e2e_wired_table_rec_model == True:
  902. use_e2e_model = True
  903. if cells_trans_to_html == True:
  904. table_structure_pred = next(self.wired_table_rec_model(image_array))
  905. else:
  906. table_cells_pred = next(
  907. self.wired_table_cells_detection_model(image_array, threshold=0.3)
  908. ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
  909. # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
  910. elif table_cls_result == "wireless_table":
  911. if use_wireless_table_cells_trans_to_html == True:
  912. cells_trans_to_html = True
  913. else:
  914. table_structure_pred = next(self.wireless_table_rec_model(image_array))
  915. if use_e2e_wireless_table_rec_model == True:
  916. use_e2e_model = True
  917. if cells_trans_to_html == True:
  918. table_structure_pred = next(
  919. self.wireless_table_rec_model(image_array)
  920. )
  921. else:
  922. table_cells_pred = next(
  923. self.wireless_table_cells_detection_model(
  924. image_array, threshold=0.3
  925. )
  926. ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
  927. # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
  928. if use_e2e_model == False:
  929. table_cells_result, table_cells_score = self.extract_results(
  930. table_cells_pred, "det"
  931. )
  932. table_cells_result, table_cells_score = self.cells_det_results_nms(
  933. table_cells_result, table_cells_score
  934. )
  935. if cells_trans_to_html == True:
  936. table_structure_result = self.trans_cells_det_results_to_html(
  937. table_cells_result
  938. )
  939. else:
  940. table_structure_result = self.extract_results(
  941. table_structure_pred, "table_stru"
  942. )
  943. ocr_det_boxes = self.get_region_ocr_det_boxes(
  944. overall_ocr_res["rec_boxes"].tolist(), table_box
  945. )
  946. table_cells_result = self.cells_det_results_reprocessing(
  947. table_cells_result,
  948. table_cells_score,
  949. ocr_det_boxes,
  950. len(table_structure_pred["bbox"]),
  951. )
  952. if use_ocr_results_with_table_cells == True:
  953. if self.cells_split_ocr == True:
  954. table_box_copy = np.array([table_box])
  955. table_ocr_pred = get_sub_regions_ocr_res(
  956. overall_ocr_res, table_box_copy
  957. )
  958. table_ocr_pred = self.split_ocr_bboxes_by_table_cells(
  959. table_cells_result, table_ocr_pred, image_array
  960. )
  961. cells_texts_list = []
  962. else:
  963. cells_texts_list = self.gen_ocr_with_table_cells(
  964. image_array, table_cells_result
  965. )
  966. table_ocr_pred = {}
  967. else:
  968. table_ocr_pred = {}
  969. cells_texts_list = []
  970. single_table_recognition_res = get_table_recognition_res(
  971. table_box,
  972. table_structure_result,
  973. table_cells_result,
  974. overall_ocr_res,
  975. table_ocr_pred,
  976. cells_texts_list,
  977. use_ocr_results_with_table_cells,
  978. self.cells_split_ocr,
  979. )
  980. else:
  981. cells_texts_list = []
  982. use_ocr_results_with_table_cells = False
  983. table_cells_result_e2e = table_structure_pred["bbox"]
  984. table_cells_result_e2e = [
  985. [rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result_e2e
  986. ]
  987. if cells_trans_to_html == True:
  988. table_structure_pred["structure"] = (
  989. self.trans_cells_det_results_to_html(table_cells_result_e2e)
  990. )
  991. single_table_recognition_res = get_table_recognition_res_e2e(
  992. table_box,
  993. table_structure_pred,
  994. overall_ocr_res,
  995. cells_texts_list,
  996. use_ocr_results_with_table_cells,
  997. )
  998. neighbor_text = ""
  999. if flag_find_nei_text:
  1000. match_idx_list = get_neighbor_boxes_idx(
  1001. overall_ocr_res["rec_boxes"], table_box
  1002. )
  1003. if len(match_idx_list) > 0:
  1004. for idx in match_idx_list:
  1005. neighbor_text += overall_ocr_res["rec_texts"][idx] + "; "
  1006. single_table_recognition_res["neighbor_texts"] = neighbor_text
  1007. return single_table_recognition_res
  1008. def predict(
  1009. self,
  1010. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  1011. use_doc_orientation_classify: Optional[bool] = None,
  1012. use_doc_unwarping: Optional[bool] = None,
  1013. use_layout_detection: Optional[bool] = None,
  1014. use_ocr_model: Optional[bool] = None,
  1015. overall_ocr_res: Optional[OCRResult] = None,
  1016. layout_det_res: Optional[DetResult] = None,
  1017. text_det_limit_side_len: Optional[int] = None,
  1018. text_det_limit_type: Optional[str] = None,
  1019. text_det_thresh: Optional[float] = None,
  1020. text_det_box_thresh: Optional[float] = None,
  1021. text_det_unclip_ratio: Optional[float] = None,
  1022. text_rec_score_thresh: Optional[float] = None,
  1023. use_e2e_wired_table_rec_model: bool = False,
  1024. use_e2e_wireless_table_rec_model: bool = False,
  1025. use_wired_table_cells_trans_to_html: bool = False,
  1026. use_wireless_table_cells_trans_to_html: bool = False,
  1027. use_table_orientation_classify: bool = True,
  1028. use_ocr_results_with_table_cells: bool = True,
  1029. **kwargs,
  1030. ) -> TableRecognitionResult:
  1031. """
  1032. This function predicts the layout parsing result for the given input.
  1033. Args:
  1034. input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) of pdf(s) to be processed.
  1035. use_layout_detection (bool): Whether to use layout detection.
  1036. use_doc_orientation_classify (bool): Whether to use document orientation classification.
  1037. use_doc_unwarping (bool): Whether to use document unwarping.
  1038. overall_ocr_res (OCRResult): The overall OCR result with convert_points_to_boxes information.
  1039. It will be used if it is not None and use_ocr_model is False.
  1040. layout_det_res (DetResult): The layout detection result.
  1041. It will be used if it is not None and use_layout_detection is False.
  1042. use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
  1043. use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model.
  1044. use_wired_table_cells_trans_to_html (bool): Whether to use wired tabel cells trans to HTML.
  1045. use_wireless_table_cells_trans_to_html (bool): Whether to use wireless tabel cells trans to HTML.
  1046. use_table_orientation_classify (bool): Whether to use table orientation classification.
  1047. use_ocr_results_with_table_cells (bool): Whether to use OCR results processed by table cells.
  1048. **kwargs: Additional keyword arguments.
  1049. Returns:
  1050. TableRecognitionResult: The predicted table recognition result.
  1051. """
  1052. self.cells_split_ocr = True
  1053. if use_table_orientation_classify == True and (
  1054. self.table_orientation_classify_model is None
  1055. ):
  1056. assert self.table_orientation_classify_config != None
  1057. self.table_orientation_classify_model = self.create_model(
  1058. self.table_orientation_classify_config
  1059. )
  1060. model_settings = self.get_model_settings(
  1061. use_doc_orientation_classify,
  1062. use_doc_unwarping,
  1063. use_layout_detection,
  1064. use_ocr_model,
  1065. )
  1066. if not self.check_model_settings_valid(
  1067. model_settings, overall_ocr_res, layout_det_res
  1068. ):
  1069. yield {"error": "the input params for model settings are invalid!"}
  1070. for img_id, batch_data in enumerate(self.batch_sampler(input)):
  1071. image_array = self.img_reader(batch_data.instances)[0]
  1072. if model_settings["use_doc_preprocessor"]:
  1073. doc_preprocessor_res = next(
  1074. self.doc_preprocessor_pipeline(
  1075. image_array,
  1076. use_doc_orientation_classify=use_doc_orientation_classify,
  1077. use_doc_unwarping=use_doc_unwarping,
  1078. )
  1079. )
  1080. else:
  1081. doc_preprocessor_res = {"output_img": image_array}
  1082. doc_preprocessor_image = doc_preprocessor_res["output_img"]
  1083. if model_settings["use_ocr_model"]:
  1084. overall_ocr_res = next(
  1085. self.general_ocr_pipeline(
  1086. doc_preprocessor_image,
  1087. text_det_limit_side_len=text_det_limit_side_len,
  1088. text_det_limit_type=text_det_limit_type,
  1089. text_det_thresh=text_det_thresh,
  1090. text_det_box_thresh=text_det_box_thresh,
  1091. text_det_unclip_ratio=text_det_unclip_ratio,
  1092. text_rec_score_thresh=text_rec_score_thresh,
  1093. )
  1094. )
  1095. elif self.general_ocr_pipeline is None and (
  1096. (
  1097. use_ocr_results_with_table_cells == True
  1098. and self.cells_split_ocr == False
  1099. )
  1100. or use_table_orientation_classify == True
  1101. ):
  1102. assert self.general_ocr_config_bak != None
  1103. self.general_ocr_pipeline = self.create_pipeline(
  1104. self.general_ocr_config_bak
  1105. )
  1106. if use_table_orientation_classify == False:
  1107. table_angle = "0"
  1108. table_res_list = []
  1109. table_region_id = 1
  1110. if not model_settings["use_layout_detection"] and layout_det_res is None:
  1111. img_height, img_width = doc_preprocessor_image.shape[:2]
  1112. table_box = [0, 0, img_width - 1, img_height - 1]
  1113. if use_table_orientation_classify == True:
  1114. table_angle = next(
  1115. self.table_orientation_classify_model(doc_preprocessor_image)
  1116. )["label_names"][0]
  1117. if table_angle == "90":
  1118. doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=1)
  1119. elif table_angle == "180":
  1120. doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=2)
  1121. elif table_angle == "270":
  1122. doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=3)
  1123. if table_angle in ["90", "180", "270"]:
  1124. overall_ocr_res = next(
  1125. self.general_ocr_pipeline(
  1126. doc_preprocessor_image,
  1127. text_det_limit_side_len=text_det_limit_side_len,
  1128. text_det_limit_type=text_det_limit_type,
  1129. text_det_thresh=text_det_thresh,
  1130. text_det_box_thresh=text_det_box_thresh,
  1131. text_det_unclip_ratio=text_det_unclip_ratio,
  1132. text_rec_score_thresh=text_rec_score_thresh,
  1133. )
  1134. )
  1135. tbx1, tby1, tbx2, tby2 = (
  1136. table_box[0],
  1137. table_box[1],
  1138. table_box[2],
  1139. table_box[3],
  1140. )
  1141. if table_angle == "90":
  1142. new_x1, new_y1 = tby1, img_width - tbx2
  1143. new_x2, new_y2 = tby2, img_width - tbx1
  1144. elif table_angle == "180":
  1145. new_x1, new_y1 = img_width - tbx2, img_height - tby2
  1146. new_x2, new_y2 = img_width - tbx1, img_height - tby1
  1147. elif table_angle == "270":
  1148. new_x1, new_y1 = img_height - tby2, tbx1
  1149. new_x2, new_y2 = img_height - tby1, tbx2
  1150. table_box = [new_x1, new_y1, new_x2, new_y2]
  1151. layout_det_res = {}
  1152. single_table_rec_res = self.predict_single_table_recognition_res(
  1153. doc_preprocessor_image,
  1154. overall_ocr_res,
  1155. table_box,
  1156. use_e2e_wired_table_rec_model,
  1157. use_e2e_wireless_table_rec_model,
  1158. use_wired_table_cells_trans_to_html,
  1159. use_wireless_table_cells_trans_to_html,
  1160. use_ocr_results_with_table_cells,
  1161. flag_find_nei_text=False,
  1162. )
  1163. single_table_rec_res["table_region_id"] = table_region_id
  1164. if use_table_orientation_classify == True and table_angle != "0":
  1165. img_height, img_width = doc_preprocessor_image.shape[:2]
  1166. single_table_rec_res["cell_box_list"] = (
  1167. self.map_cells_to_original_image(
  1168. single_table_rec_res["cell_box_list"],
  1169. table_angle,
  1170. img_width,
  1171. img_height,
  1172. )
  1173. )
  1174. table_res_list.append(single_table_rec_res)
  1175. table_region_id += 1
  1176. else:
  1177. if model_settings["use_layout_detection"]:
  1178. layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
  1179. img_height, img_width = doc_preprocessor_image.shape[:2]
  1180. for box_info in layout_det_res["boxes"]:
  1181. if box_info["label"].lower() in ["table"]:
  1182. crop_img_info = self._crop_by_boxes(
  1183. doc_preprocessor_image, [box_info]
  1184. )
  1185. crop_img_info = crop_img_info[0]
  1186. table_box = crop_img_info["box"]
  1187. if use_table_orientation_classify == True:
  1188. doc_preprocessor_image_copy = doc_preprocessor_image.copy()
  1189. table_angle = next(
  1190. self.table_orientation_classify_model(
  1191. crop_img_info["img"]
  1192. )
  1193. )["label_names"][0]
  1194. if table_angle == "90":
  1195. crop_img_info["img"] = np.rot90(crop_img_info["img"], k=1)
  1196. doc_preprocessor_image_copy = np.rot90(
  1197. doc_preprocessor_image_copy, k=1
  1198. )
  1199. elif table_angle == "180":
  1200. crop_img_info["img"] = np.rot90(crop_img_info["img"], k=2)
  1201. doc_preprocessor_image_copy = np.rot90(
  1202. doc_preprocessor_image_copy, k=2
  1203. )
  1204. elif table_angle == "270":
  1205. crop_img_info["img"] = np.rot90(crop_img_info["img"], k=3)
  1206. doc_preprocessor_image_copy = np.rot90(
  1207. doc_preprocessor_image_copy, k=3
  1208. )
  1209. if table_angle in ["90", "180", "270"]:
  1210. overall_ocr_res = next(
  1211. self.general_ocr_pipeline(
  1212. doc_preprocessor_image_copy,
  1213. text_det_limit_side_len=text_det_limit_side_len,
  1214. text_det_limit_type=text_det_limit_type,
  1215. text_det_thresh=text_det_thresh,
  1216. text_det_box_thresh=text_det_box_thresh,
  1217. text_det_unclip_ratio=text_det_unclip_ratio,
  1218. text_rec_score_thresh=text_rec_score_thresh,
  1219. )
  1220. )
  1221. tbx1, tby1, tbx2, tby2 = (
  1222. table_box[0],
  1223. table_box[1],
  1224. table_box[2],
  1225. table_box[3],
  1226. )
  1227. if table_angle == "90":
  1228. new_x1, new_y1 = tby1, img_width - tbx2
  1229. new_x2, new_y2 = tby2, img_width - tbx1
  1230. elif table_angle == "180":
  1231. new_x1, new_y1 = img_width - tbx2, img_height - tby2
  1232. new_x2, new_y2 = img_width - tbx1, img_height - tby1
  1233. elif table_angle == "270":
  1234. new_x1, new_y1 = img_height - tby2, tbx1
  1235. new_x2, new_y2 = img_height - tby1, tbx2
  1236. table_box = [new_x1, new_y1, new_x2, new_y2]
  1237. single_table_rec_res = (
  1238. self.predict_single_table_recognition_res(
  1239. crop_img_info["img"],
  1240. overall_ocr_res,
  1241. table_box,
  1242. use_e2e_wired_table_rec_model,
  1243. use_e2e_wireless_table_rec_model,
  1244. use_wired_table_cells_trans_to_html,
  1245. use_wireless_table_cells_trans_to_html,
  1246. use_ocr_results_with_table_cells,
  1247. )
  1248. )
  1249. single_table_rec_res["table_region_id"] = table_region_id
  1250. if (
  1251. use_table_orientation_classify == True
  1252. and table_angle != "0"
  1253. ):
  1254. img_height_copy, img_width_copy = (
  1255. doc_preprocessor_image_copy.shape[:2]
  1256. )
  1257. single_table_rec_res["cell_box_list"] = (
  1258. self.map_cells_to_original_image(
  1259. single_table_rec_res["cell_box_list"],
  1260. table_angle,
  1261. img_width_copy,
  1262. img_height_copy,
  1263. )
  1264. )
  1265. table_res_list.append(single_table_rec_res)
  1266. table_region_id += 1
  1267. single_img_res = {
  1268. "input_path": batch_data.input_paths[0],
  1269. "page_index": batch_data.page_indexes[0],
  1270. "doc_preprocessor_res": doc_preprocessor_res,
  1271. "layout_det_res": layout_det_res,
  1272. "overall_ocr_res": overall_ocr_res,
  1273. "table_res_list": table_res_list,
  1274. "model_settings": model_settings,
  1275. }
  1276. yield TableRecognitionResult(single_img_res)
  1277. @pipeline_requires_extra("ocr")
  1278. class TableRecognitionPipelineV2(AutoParallelImageSimpleInferencePipeline):
  1279. entities = ["table_recognition_v2"]
  1280. @property
  1281. def _pipeline_cls(self):
  1282. return _TableRecognitionPipelineV2
  1283. def _get_batch_size(self, config):
  1284. return 1