pipeline_v2.py 61 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import re
  16. from typing import Any, Dict, List, Optional, Tuple, Union
  17. import numpy as np
  18. from ....utils import logging
  19. from ....utils.deps import (
  20. function_requires_deps,
  21. is_dep_available,
  22. pipeline_requires_extra,
  23. )
  24. from ...common.batch_sampler import ImageBatchSampler
  25. from ...common.reader import ReadImage
  26. from ...models.object_detection.result import DetResult
  27. from ...utils.benchmark import benchmark
  28. from ...utils.hpi import HPIConfig
  29. from ...utils.pp_option import PaddlePredictorOption
  30. from .._parallel import AutoParallelImageSimpleInferencePipeline
  31. from ..base import BasePipeline
  32. from ..components import CropByBoxes
  33. from ..doc_preprocessor.result import DocPreprocessorResult
  34. from ..layout_parsing.utils import get_sub_regions_ocr_res
  35. from ..ocr.result import OCRResult
  36. from .result import SingleTableRecognitionResult, TableRecognitionResult
  37. from .table_recognition_post_processing import (
  38. get_table_recognition_res as get_table_recognition_res_e2e,
  39. )
  40. from .table_recognition_post_processing_v2 import get_table_recognition_res
  41. from .utils import get_neighbor_boxes_idx
  42. if is_dep_available("scikit-learn"):
  43. from sklearn.cluster import KMeans
  44. @benchmark.time_methods
  45. class _TableRecognitionPipelineV2(BasePipeline):
  46. """Table Recognition Pipeline"""
  47. def __init__(
  48. self,
  49. config: Dict,
  50. device: str = None,
  51. pp_option: PaddlePredictorOption = None,
  52. use_hpip: bool = False,
  53. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  54. ) -> None:
  55. """Initializes the layout parsing pipeline.
  56. Args:
  57. config (Dict): Configuration dictionary containing various settings.
  58. device (str, optional): Device to run the predictions on. Defaults to None.
  59. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  60. use_hpip (bool, optional): Whether to use the high-performance
  61. inference plugin (HPIP) by default. Defaults to False.
  62. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
  63. The default high-performance inference configuration dictionary.
  64. Defaults to None.
  65. """
  66. super().__init__(
  67. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_config=hpi_config
  68. )
  69. self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
  70. if self.use_doc_preprocessor:
  71. doc_preprocessor_config = config.get("SubPipelines", {}).get(
  72. "DocPreprocessor",
  73. {
  74. "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
  75. },
  76. )
  77. self.doc_preprocessor_pipeline = self.create_pipeline(
  78. doc_preprocessor_config
  79. )
  80. self.use_layout_detection = config.get("use_layout_detection", True)
  81. if self.use_layout_detection:
  82. layout_det_config = config.get("SubModules", {}).get(
  83. "LayoutDetection",
  84. {"model_config_error": "config error for layout_det_model!"},
  85. )
  86. self.layout_det_model = self.create_model(layout_det_config)
  87. table_cls_config = config.get("SubModules", {}).get(
  88. "TableClassification",
  89. {"model_config_error": "config error for table_classification_model!"},
  90. )
  91. self.table_cls_model = self.create_model(table_cls_config)
  92. wired_table_rec_config = config.get("SubModules", {}).get(
  93. "WiredTableStructureRecognition",
  94. {"model_config_error": "config error for wired_table_structure_model!"},
  95. )
  96. self.wired_table_rec_model = self.create_model(wired_table_rec_config)
  97. wireless_table_rec_config = config.get("SubModules", {}).get(
  98. "WirelessTableStructureRecognition",
  99. {"model_config_error": "config error for wireless_table_structure_model!"},
  100. )
  101. self.wireless_table_rec_model = self.create_model(wireless_table_rec_config)
  102. wired_table_cells_det_config = config.get("SubModules", {}).get(
  103. "WiredTableCellsDetection",
  104. {
  105. "model_config_error": "config error for wired_table_cells_detection_model!"
  106. },
  107. )
  108. self.wired_table_cells_detection_model = self.create_model(
  109. wired_table_cells_det_config
  110. )
  111. wireless_table_cells_det_config = config.get("SubModules", {}).get(
  112. "WirelessTableCellsDetection",
  113. {
  114. "model_config_error": "config error for wireless_table_cells_detection_model!"
  115. },
  116. )
  117. self.wireless_table_cells_detection_model = self.create_model(
  118. wireless_table_cells_det_config
  119. )
  120. self.use_ocr_model = config.get("use_ocr_model", True)
  121. self.general_ocr_pipeline = None
  122. if self.use_ocr_model:
  123. general_ocr_config = config.get("SubPipelines", {}).get(
  124. "GeneralOCR",
  125. {"pipeline_config_error": "config error for general_ocr_pipeline!"},
  126. )
  127. self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
  128. else:
  129. self.general_ocr_config_bak = config.get("SubPipelines", {}).get(
  130. "GeneralOCR", None
  131. )
  132. self.table_orientation_classify_model = None
  133. self.table_orientation_classify_config = config.get("SubModules", {}).get(
  134. "TableOrientationClassify", None
  135. )
  136. self._crop_by_boxes = CropByBoxes()
  137. self.batch_sampler = ImageBatchSampler(batch_size=1)
  138. self.img_reader = ReadImage(format="BGR")
  139. def get_model_settings(
  140. self,
  141. use_doc_orientation_classify: Optional[bool],
  142. use_doc_unwarping: Optional[bool],
  143. use_layout_detection: Optional[bool],
  144. use_ocr_model: Optional[bool],
  145. ) -> dict:
  146. """
  147. Get the model settings based on the provided parameters or default values.
  148. Args:
  149. use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
  150. use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
  151. use_layout_detection (Optional[bool]): Whether to use layout detection.
  152. use_ocr_model (Optional[bool]): Whether to use OCR model.
  153. Returns:
  154. dict: A dictionary containing the model settings.
  155. """
  156. if use_doc_orientation_classify is None and use_doc_unwarping is None:
  157. use_doc_preprocessor = self.use_doc_preprocessor
  158. else:
  159. if use_doc_orientation_classify is True or use_doc_unwarping is True:
  160. use_doc_preprocessor = True
  161. else:
  162. use_doc_preprocessor = False
  163. if use_layout_detection is None:
  164. use_layout_detection = self.use_layout_detection
  165. if use_ocr_model is None:
  166. use_ocr_model = self.use_ocr_model
  167. return dict(
  168. use_doc_preprocessor=use_doc_preprocessor,
  169. use_layout_detection=use_layout_detection,
  170. use_ocr_model=use_ocr_model,
  171. )
  172. def check_model_settings_valid(
  173. self,
  174. model_settings: Dict,
  175. overall_ocr_res: OCRResult,
  176. layout_det_res: DetResult,
  177. ) -> bool:
  178. """
  179. Check if the input parameters are valid based on the initialized models.
  180. Args:
  181. model_settings (Dict): A dictionary containing input parameters.
  182. overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
  183. The overall OCR result with convert_points_to_boxes information.
  184. layout_det_res (DetResult): The layout detection result.
  185. Returns:
  186. bool: True if all required models are initialized according to input parameters, False otherwise.
  187. """
  188. if model_settings["use_doc_preprocessor"] and not self.use_doc_preprocessor:
  189. logging.error(
  190. "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
  191. )
  192. return False
  193. if model_settings["use_layout_detection"]:
  194. if layout_det_res is not None:
  195. logging.error(
  196. "The layout detection model has already been initialized, please set use_layout_detection=False"
  197. )
  198. return False
  199. if not self.use_layout_detection:
  200. logging.error(
  201. "Set use_layout_detection, but the models for layout detection are not initialized."
  202. )
  203. return False
  204. if model_settings["use_ocr_model"]:
  205. if overall_ocr_res is not None:
  206. logging.error(
  207. "The OCR models have already been initialized, please set use_ocr_model=False"
  208. )
  209. return False
  210. if not self.use_ocr_model:
  211. logging.error(
  212. "Set use_ocr_model, but the models for OCR are not initialized."
  213. )
  214. return False
  215. else:
  216. if overall_ocr_res is None:
  217. logging.error("Set use_ocr_model=False, but no OCR results were found.")
  218. return False
  219. return True
  220. def predict_doc_preprocessor_res(
  221. self, image_array: np.ndarray, input_params: dict
  222. ) -> Tuple[DocPreprocessorResult, np.ndarray]:
  223. """
  224. Preprocess the document image based on input parameters.
  225. Args:
  226. image_array (np.ndarray): The input image array.
  227. input_params (dict): Dictionary containing preprocessing parameters.
  228. Returns:
  229. tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
  230. result dictionary and the processed image array.
  231. """
  232. if input_params["use_doc_preprocessor"]:
  233. use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
  234. use_doc_unwarping = input_params["use_doc_unwarping"]
  235. doc_preprocessor_res = list(
  236. self.doc_preprocessor_pipeline(
  237. image_array,
  238. use_doc_orientation_classify=use_doc_orientation_classify,
  239. use_doc_unwarping=use_doc_unwarping,
  240. )
  241. )[0]
  242. doc_preprocessor_image = doc_preprocessor_res["output_img"]
  243. else:
  244. doc_preprocessor_res = {}
  245. doc_preprocessor_image = image_array
  246. return doc_preprocessor_res, doc_preprocessor_image
  247. def extract_results(self, pred, task):
  248. if task == "cls":
  249. return pred["label_names"][np.argmax(pred["scores"])]
  250. elif task == "det":
  251. threshold = 0.0
  252. result = []
  253. cell_score = []
  254. if "boxes" in pred and isinstance(pred["boxes"], list):
  255. for box in pred["boxes"]:
  256. if isinstance(box, dict) and "score" in box and "coordinate" in box:
  257. score = box["score"]
  258. coordinate = box["coordinate"]
  259. if isinstance(score, float) and score > threshold:
  260. result.append(coordinate)
  261. cell_score.append(score)
  262. return result, cell_score
  263. elif task == "table_stru":
  264. return pred["structure"]
  265. else:
  266. return None
  267. def cells_det_results_nms(
  268. self, cells_det_results, cells_det_scores, cells_det_threshold=0.3
  269. ):
  270. """
  271. Apply Non-Maximum Suppression (NMS) on detection results to remove redundant overlapping bounding boxes.
  272. Args:
  273. cells_det_results (list): List of bounding boxes, each box is in format [x1, y1, x2, y2].
  274. cells_det_scores (list): List of confidence scores corresponding to the bounding boxes.
  275. cells_det_threshold (float): IoU threshold for suppression. Boxes with IoU greater than this threshold
  276. will be suppressed. Default is 0.5.
  277. Returns:
  278. Tuple[list, list]: A tuple containing the list of bounding boxes and confidence scores after NMS,
  279. while maintaining one-to-one correspondence.
  280. """
  281. # Convert lists to numpy arrays for efficient computation
  282. boxes = np.array(cells_det_results)
  283. scores = np.array(cells_det_scores)
  284. # Initialize list for picked indices
  285. picked_indices = []
  286. # Get coordinates of bounding boxes
  287. x1 = boxes[:, 0]
  288. y1 = boxes[:, 1]
  289. x2 = boxes[:, 2]
  290. y2 = boxes[:, 3]
  291. # Compute the area of the bounding boxes
  292. areas = (x2 - x1) * (y2 - y1)
  293. # Sort the bounding boxes by the confidence scores in descending order
  294. order = scores.argsort()[::-1]
  295. # Process the boxes
  296. while order.size > 0:
  297. # Index of the current highest score box
  298. i = order[0]
  299. picked_indices.append(i)
  300. # Compute IoU between the highest score box and the rest
  301. xx1 = np.maximum(x1[i], x1[order[1:]])
  302. yy1 = np.maximum(y1[i], y1[order[1:]])
  303. xx2 = np.minimum(x2[i], x2[order[1:]])
  304. yy2 = np.minimum(y2[i], y2[order[1:]])
  305. # Compute the width and height of the overlapping area
  306. w = np.maximum(0.0, xx2 - xx1)
  307. h = np.maximum(0.0, yy2 - yy1)
  308. # Compute the ratio of overlap (IoU)
  309. inter = w * h
  310. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  311. # Indices of boxes with IoU less than threshold
  312. inds = np.where(ovr <= cells_det_threshold)[0]
  313. # Update order, only keep boxes with IoU less than threshold
  314. order = order[
  315. inds + 1
  316. ] # inds shifted by 1 because order[0] is the current box
  317. # Select the boxes and scores based on picked indices
  318. final_boxes = boxes[picked_indices].tolist()
  319. final_scores = scores[picked_indices].tolist()
  320. return final_boxes, final_scores
  321. def get_region_ocr_det_boxes(self, ocr_det_boxes, table_box):
  322. """Adjust the coordinates of ocr_det_boxes that are fully inside table_box relative to table_box.
  323. Args:
  324. ocr_det_boxes (list of list): List of bounding boxes [x1, y1, x2, y2] in the original image.
  325. table_box (list): Bounding box [x1, y1, x2, y2] of the target region in the original image.
  326. Returns:
  327. list of list: List of adjusted bounding boxes relative to table_box, for boxes fully inside table_box.
  328. """
  329. tol = 0
  330. # Extract coordinates from table_box
  331. x_min_t, y_min_t, x_max_t, y_max_t = table_box
  332. adjusted_boxes = []
  333. for box in ocr_det_boxes:
  334. x_min_b, y_min_b, x_max_b, y_max_b = box
  335. # Check if the box is fully inside table_box
  336. if (
  337. x_min_b + tol >= x_min_t
  338. and y_min_b + tol >= y_min_t
  339. and x_max_b - tol <= x_max_t
  340. and y_max_b - tol <= y_max_t
  341. ):
  342. # Adjust the coordinates to be relative to table_box
  343. adjusted_box = [
  344. x_min_b - x_min_t, # Adjust x1
  345. y_min_b - y_min_t, # Adjust y1
  346. x_max_b - x_min_t, # Adjust x2
  347. y_max_b - y_min_t, # Adjust y2
  348. ]
  349. adjusted_boxes.append(adjusted_box)
  350. # Discard boxes not fully inside table_box
  351. return adjusted_boxes
  352. def cells_det_results_reprocessing(
  353. self, cells_det_results, cells_det_scores, ocr_det_results, html_pred_boxes_nums
  354. ):
  355. """
  356. Process and filter cells_det_results based on ocr_det_results and html_pred_boxes_nums.
  357. Args:
  358. cells_det_results (List[List[float]]): List of detected cell rectangles [[x1, y1, x2, y2], ...].
  359. cells_det_scores (List[float]): List of confidence scores for each rectangle in cells_det_results.
  360. ocr_det_results (List[List[float]]): List of OCR detected rectangles [[x1, y1, x2, y2], ...].
  361. html_pred_boxes_nums (int): The desired number of rectangles in the final output.
  362. Returns:
  363. List[List[float]]: The processed list of rectangles.
  364. """
  365. # Function to compute IoU between two rectangles
  366. def compute_iou(box1, box2):
  367. """
  368. Compute the Intersection over Union (IoU) between two rectangles.
  369. Args:
  370. box1 (array-like): [x1, y1, x2, y2] of the first rectangle.
  371. box2 (array-like): [x1, y1, x2, y2] of the second rectangle.
  372. Returns:
  373. float: The IoU between the two rectangles.
  374. """
  375. # Determine the coordinates of the intersection rectangle
  376. x_left = max(box1[0], box2[0])
  377. y_top = max(box1[1], box2[1])
  378. x_right = min(box1[2], box2[2])
  379. y_bottom = min(box1[3], box2[3])
  380. if x_right <= x_left or y_bottom <= y_top:
  381. return 0.0
  382. # Calculate the area of intersection rectangle
  383. intersection_area = (x_right - x_left) * (y_bottom - y_top)
  384. # Calculate the area of both rectangles
  385. box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
  386. (box2[2] - box2[0]) * (box2[3] - box2[1])
  387. # Calculate the IoU
  388. iou = intersection_area / float(box1_area)
  389. return iou
  390. # Function to combine rectangles into N rectangles
  391. @function_requires_deps("scikit-learn")
  392. def combine_rectangles(rectangles, N):
  393. """
  394. Combine rectangles into N rectangles based on geometric proximity.
  395. Args:
  396. rectangles (list of list of int): A list of rectangles, each represented by [x1, y1, x2, y2].
  397. N (int): The desired number of combined rectangles.
  398. Returns:
  399. list of list of int: A list of N combined rectangles.
  400. """
  401. # Number of input rectangles
  402. num_rects = len(rectangles)
  403. # If N is greater than or equal to the number of rectangles, return the original rectangles
  404. if N >= num_rects:
  405. return rectangles
  406. # Compute the center points of the rectangles
  407. centers = np.array(
  408. [
  409. [
  410. (rect[0] + rect[2]) / 2, # Center x-coordinate
  411. (rect[1] + rect[3]) / 2, # Center y-coordinate
  412. ]
  413. for rect in rectangles
  414. ]
  415. )
  416. # Perform KMeans clustering on the center points to group them into N clusters
  417. kmeans = KMeans(n_clusters=N, random_state=0, n_init="auto")
  418. labels = kmeans.fit_predict(centers)
  419. # Initialize a list to store the combined rectangles
  420. combined_rectangles = []
  421. # For each cluster, compute the minimal bounding rectangle that covers all rectangles in the cluster
  422. for i in range(N):
  423. # Get the indices of rectangles that belong to cluster i
  424. indices = np.where(labels == i)[0]
  425. if len(indices) == 0:
  426. # If no rectangles in this cluster, skip it
  427. continue
  428. # Extract the rectangles in cluster i
  429. cluster_rects = np.array([rectangles[idx] for idx in indices])
  430. # Compute the minimal x1, y1 (top-left corner) and maximal x2, y2 (bottom-right corner)
  431. x1_min = np.min(cluster_rects[:, 0])
  432. y1_min = np.min(cluster_rects[:, 1])
  433. x2_max = np.max(cluster_rects[:, 2])
  434. y2_max = np.max(cluster_rects[:, 3])
  435. # Append the combined rectangle to the list
  436. combined_rectangles.append([x1_min, y1_min, x2_max, y2_max])
  437. return combined_rectangles
  438. # Ensure that the inputs are numpy arrays for efficient computation
  439. cells_det_results = np.array(cells_det_results)
  440. cells_det_scores = np.array(cells_det_scores)
  441. ocr_det_results = np.array(ocr_det_results)
  442. more_cells_flag = False
  443. if len(cells_det_results) == html_pred_boxes_nums:
  444. return cells_det_results
  445. # Step 1: If cells_det_results has more rectangles than html_pred_boxes_nums
  446. elif len(cells_det_results) > html_pred_boxes_nums:
  447. more_cells_flag = True
  448. # Select the indices of the top html_pred_boxes_nums scores
  449. top_indices = np.argsort(-cells_det_scores)[:html_pred_boxes_nums]
  450. # Adjust the corresponding rectangles
  451. cells_det_results = cells_det_results[top_indices].tolist()
  452. # Threshold for IoU
  453. iou_threshold = 0.6
  454. # List to store ocr_miss_boxes
  455. ocr_miss_boxes = []
  456. # For each rectangle in ocr_det_results
  457. for ocr_rect in ocr_det_results:
  458. merge_ocr_box_iou = []
  459. # Flag to indicate if ocr_rect has IoU >= threshold with any cell_rect
  460. has_large_iou = False
  461. # For each rectangle in cells_det_results
  462. for cell_rect in cells_det_results:
  463. # Compute IoU
  464. iou = compute_iou(ocr_rect, cell_rect)
  465. if iou > 0:
  466. merge_ocr_box_iou.append(iou)
  467. if (iou >= iou_threshold) or (sum(merge_ocr_box_iou) >= iou_threshold):
  468. has_large_iou = True
  469. break
  470. if not has_large_iou:
  471. ocr_miss_boxes.append(ocr_rect)
  472. # If no ocr_miss_boxes, return cells_det_results
  473. if len(ocr_miss_boxes) == 0:
  474. final_results = (
  475. cells_det_results
  476. if more_cells_flag == True
  477. else cells_det_results.tolist()
  478. )
  479. else:
  480. if more_cells_flag == True:
  481. final_results = combine_rectangles(
  482. cells_det_results + ocr_miss_boxes, html_pred_boxes_nums
  483. )
  484. else:
  485. # Need to combine ocr_miss_boxes into N rectangles
  486. N = html_pred_boxes_nums - len(cells_det_results)
  487. # Combine ocr_miss_boxes into N rectangles
  488. ocr_supp_boxes = combine_rectangles(ocr_miss_boxes, N)
  489. # Combine cells_det_results and ocr_supp_boxes
  490. final_results = np.concatenate(
  491. (cells_det_results, ocr_supp_boxes), axis=0
  492. ).tolist()
  493. if len(final_results) <= 0.6 * html_pred_boxes_nums:
  494. final_results = combine_rectangles(ocr_det_results, html_pred_boxes_nums)
  495. return final_results
  496. def split_ocr_bboxes_by_table_cells(
  497. self, cells_det_results, overall_ocr_res, ori_img, k=2
  498. ):
  499. """
  500. Split OCR bounding boxes based on table cell boundaries when they span multiple cells horizontally.
  501. Args:
  502. cells_det_results (list): List of cell bounding boxes in format [x1, y1, x2, y2]
  503. overall_ocr_res (dict): Dictionary containing OCR results with keys:
  504. - 'rec_boxes': OCR bounding boxes (will be converted to list)
  505. - 'rec_texts': OCR recognized texts
  506. ori_img (np.array): Original input image array
  507. k (int): Threshold for determining when to split (minimum number of cells spanned)
  508. Returns:
  509. dict: Modified overall_ocr_res with split boxes and texts
  510. """
  511. def calculate_iou(box1, box2):
  512. """
  513. Calculate Intersection over Union (IoU) between two bounding boxes.
  514. Args:
  515. box1 (list): [x1, y1, x2, y2]
  516. box2 (list): [x1, y1, x2, y2]
  517. Returns:
  518. float: IoU value
  519. """
  520. # Determine intersection coordinates
  521. x_left = max(box1[0], box2[0])
  522. y_top = max(box1[1], box2[1])
  523. x_right = min(box1[2], box2[2])
  524. y_bottom = min(box1[3], box2[3])
  525. if x_right < x_left or y_bottom < y_top:
  526. return 0.0
  527. # Calculate areas
  528. intersection_area = (x_right - x_left) * (y_bottom - y_top)
  529. box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
  530. box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
  531. # return intersection_area / float(box1_area + box2_area - intersection_area)
  532. return intersection_area / box2_area
  533. def get_overlapping_cells(ocr_box, cells):
  534. """
  535. Find cells that overlap significantly with the OCR box (IoU > 0.5).
  536. Args:
  537. ocr_box (list): OCR bounding box [x1, y1, x2, y2]
  538. cells (list): List of cell bounding boxes
  539. Returns:
  540. list: Indices of overlapping cells, sorted by x-coordinate
  541. """
  542. overlapping = []
  543. for idx, cell in enumerate(cells):
  544. if calculate_iou(ocr_box, cell) > 0.5:
  545. overlapping.append(idx)
  546. # Sort overlapping cells by their x-coordinate (left to right)
  547. overlapping.sort(key=lambda i: cells[i][0])
  548. return overlapping
  549. def split_box_by_cells(ocr_box, cell_indices, cells):
  550. """
  551. Split OCR box vertically at cell boundaries.
  552. Args:
  553. ocr_box (list): Original OCR box [x1, y1, x2, y2]
  554. cell_indices (list): Indices of cells to split by
  555. cells (list): All cell bounding boxes
  556. Returns:
  557. list: List of split boxes
  558. """
  559. if not cell_indices:
  560. return [ocr_box]
  561. split_boxes = []
  562. cells_to_split = [cells[i] for i in cell_indices]
  563. if ocr_box[0] < cells_to_split[0][0]:
  564. split_boxes.append(
  565. [ocr_box[0], ocr_box[1], cells_to_split[0][0], ocr_box[3]]
  566. )
  567. for i in range(len(cells_to_split)):
  568. current_cell = cells_to_split[i]
  569. split_boxes.append(
  570. [
  571. max(ocr_box[0], current_cell[0]),
  572. ocr_box[1],
  573. min(ocr_box[2], current_cell[2]),
  574. ocr_box[3],
  575. ]
  576. )
  577. if i < len(cells_to_split) - 1:
  578. next_cell = cells_to_split[i + 1]
  579. if current_cell[2] < next_cell[0]:
  580. split_boxes.append(
  581. [current_cell[2], ocr_box[1], next_cell[0], ocr_box[3]]
  582. )
  583. last_cell = cells_to_split[-1]
  584. if last_cell[2] < ocr_box[2]:
  585. split_boxes.append([last_cell[2], ocr_box[1], ocr_box[2], ocr_box[3]])
  586. unique_boxes = []
  587. seen = set()
  588. for box in split_boxes:
  589. box_tuple = tuple(box)
  590. if box_tuple not in seen:
  591. seen.add(box_tuple)
  592. unique_boxes.append(box)
  593. return unique_boxes
  594. # Convert OCR boxes to list if needed
  595. if hasattr(overall_ocr_res["rec_boxes"], "tolist"):
  596. ocr_det_results = overall_ocr_res["rec_boxes"].tolist()
  597. else:
  598. ocr_det_results = overall_ocr_res["rec_boxes"]
  599. ocr_texts = overall_ocr_res["rec_texts"]
  600. # Make copies to modify
  601. new_boxes = []
  602. new_texts = []
  603. # Process each OCR box
  604. i = 0
  605. while i < len(ocr_det_results):
  606. ocr_box = ocr_det_results[i]
  607. text = ocr_texts[i]
  608. # Find cells that significantly overlap with this OCR box
  609. overlapping_cells = get_overlapping_cells(ocr_box, cells_det_results)
  610. # Check if we need to split (spans >= k cells)
  611. if len(overlapping_cells) >= k:
  612. # Split the box at cell boundaries
  613. split_boxes = split_box_by_cells(
  614. ocr_box, overlapping_cells, cells_det_results
  615. )
  616. # Process each split box
  617. split_texts = []
  618. for box in split_boxes:
  619. x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
  620. if y2 - y1 > 1 and x2 - x1 > 1:
  621. ocr_result = list(
  622. self.general_ocr_pipeline.text_rec_model(
  623. ori_img[y1:y2, x1:x2, :]
  624. )
  625. )[0]
  626. # Extract the recognized text from the OCR result
  627. if "rec_text" in ocr_result:
  628. result = ocr_result[
  629. "rec_text"
  630. ] # Assumes "rec_texts" contains a single string
  631. else:
  632. result = ""
  633. else:
  634. result = ""
  635. split_texts.append(result)
  636. # Add split boxes and texts to results
  637. new_boxes.extend(split_boxes)
  638. new_texts.extend(split_texts)
  639. else:
  640. # Keep original box and text
  641. new_boxes.append(ocr_box)
  642. new_texts.append(text)
  643. i += 1
  644. # Update the results dictionary
  645. overall_ocr_res["rec_boxes"] = new_boxes
  646. overall_ocr_res["rec_texts"] = new_texts
  647. return overall_ocr_res
  648. def gen_ocr_with_table_cells(self, ori_img, cells_bboxes):
  649. """
  650. Splits OCR bounding boxes by table cells and retrieves text.
  651. Args:
  652. ori_img (ndarray): The original image from which text regions will be extracted.
  653. cells_bboxes (list or ndarray): Detected cell bounding boxes to extract text from.
  654. Returns:
  655. list: A list containing the recognized texts from each cell.
  656. """
  657. # Check if cells_bboxes is a list and convert it if not.
  658. if not isinstance(cells_bboxes, list):
  659. cells_bboxes = cells_bboxes.tolist()
  660. texts_list = [] # Initialize a list to store the recognized texts.
  661. # Process each bounding box provided in cells_bboxes.
  662. for i in range(len(cells_bboxes)):
  663. # Extract and round up the coordinates of the bounding box.
  664. x1, y1, x2, y2 = [math.ceil(k) for k in cells_bboxes[i]]
  665. # Perform OCR on the defined region of the image and get the recognized text.
  666. if y2 - y1 > 1 and x2 - x1 > 1:
  667. rec_te = list(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))[0]
  668. # Concatenate the texts and append them to the texts_list.
  669. texts_list.append("".join(rec_te["rec_texts"]))
  670. # Return the list of recognized texts from each cell.
  671. return texts_list
  672. def map_cells_to_original_image(
  673. self, detections, table_angle, img_width, img_height
  674. ):
  675. """
  676. Map bounding boxes from the rotated image back to the original image.
  677. Parameters:
  678. - detections: list of numpy arrays, each containing bounding box coordinates [x1, y1, x2, y2]
  679. - table_angle: rotation angle in degrees (90, 180, or 270)
  680. - width_orig: width of the original image (img1)
  681. - height_orig: height of the original image (img1)
  682. Returns:
  683. - mapped_detections: list of numpy arrays with mapped bounding box coordinates
  684. """
  685. mapped_detections = []
  686. for i in range(len(detections)):
  687. tbx1, tby1, tbx2, tby2 = (
  688. detections[i][0],
  689. detections[i][1],
  690. detections[i][2],
  691. detections[i][3],
  692. )
  693. if table_angle == "270":
  694. new_x1, new_y1 = tby1, img_width - tbx2
  695. new_x2, new_y2 = tby2, img_width - tbx1
  696. elif table_angle == "180":
  697. new_x1, new_y1 = img_width - tbx2, img_height - tby2
  698. new_x2, new_y2 = img_width - tbx1, img_height - tby1
  699. elif table_angle == "90":
  700. new_x1, new_y1 = img_height - tby2, tbx1
  701. new_x2, new_y2 = img_height - tby1, tbx2
  702. new_box = np.array([new_x1, new_y1, new_x2, new_y2])
  703. mapped_detections.append(new_box)
  704. return mapped_detections
  705. def split_string_by_keywords(self, html_string):
  706. """
  707. Split HTML string by keywords.
  708. Args:
  709. html_string (str): The HTML string.
  710. Returns:
  711. split_html (list): The list of html keywords.
  712. """
  713. keywords = [
  714. "<thead>",
  715. "</thead>",
  716. "<tbody>",
  717. "</tbody>",
  718. "<tr>",
  719. "</tr>",
  720. "<td>",
  721. "<td",
  722. ">",
  723. "</td>",
  724. 'colspan="2"',
  725. 'colspan="3"',
  726. 'colspan="4"',
  727. 'colspan="5"',
  728. 'colspan="6"',
  729. 'colspan="7"',
  730. 'colspan="8"',
  731. 'colspan="9"',
  732. 'colspan="10"',
  733. 'colspan="11"',
  734. 'colspan="12"',
  735. 'colspan="13"',
  736. 'colspan="14"',
  737. 'colspan="15"',
  738. 'colspan="16"',
  739. 'colspan="17"',
  740. 'colspan="18"',
  741. 'colspan="19"',
  742. 'colspan="20"',
  743. 'rowspan="2"',
  744. 'rowspan="3"',
  745. 'rowspan="4"',
  746. 'rowspan="5"',
  747. 'rowspan="6"',
  748. 'rowspan="7"',
  749. 'rowspan="8"',
  750. 'rowspan="9"',
  751. 'rowspan="10"',
  752. 'rowspan="11"',
  753. 'rowspan="12"',
  754. 'rowspan="13"',
  755. 'rowspan="14"',
  756. 'rowspan="15"',
  757. 'rowspan="16"',
  758. 'rowspan="17"',
  759. 'rowspan="18"',
  760. 'rowspan="19"',
  761. 'rowspan="20"',
  762. ]
  763. regex_pattern = "|".join(re.escape(keyword) for keyword in keywords)
  764. split_result = re.split(f"({regex_pattern})", html_string)
  765. split_html = [part for part in split_result if part]
  766. return split_html
  767. def cluster_positions(self, positions, tolerance):
  768. if not positions:
  769. return []
  770. positions = sorted(set(positions))
  771. clustered = []
  772. current_cluster = [positions[0]]
  773. for pos in positions[1:]:
  774. if abs(pos - current_cluster[-1]) <= tolerance:
  775. current_cluster.append(pos)
  776. else:
  777. clustered.append(sum(current_cluster) / len(current_cluster))
  778. current_cluster = [pos]
  779. clustered.append(sum(current_cluster) / len(current_cluster))
  780. return clustered
  781. def trans_cells_det_results_to_html(self, cells_det_results):
  782. """
  783. Trans table cells bboxes to HTML.
  784. Args:
  785. cells_det_results (list): The table cells detection results.
  786. Returns:
  787. html (list): The list of html keywords.
  788. """
  789. tolerance = 5
  790. x_coords = [x for cell in cells_det_results for x in (cell[0], cell[2])]
  791. y_coords = [y for cell in cells_det_results for y in (cell[1], cell[3])]
  792. x_positions = self.cluster_positions(x_coords, tolerance)
  793. y_positions = self.cluster_positions(y_coords, tolerance)
  794. x_position_to_index = {x: i for i, x in enumerate(x_positions)}
  795. y_position_to_index = {y: i for i, y in enumerate(y_positions)}
  796. num_rows = len(y_positions) - 1
  797. num_cols = len(x_positions) - 1
  798. grid = [[None for _ in range(num_cols)] for _ in range(num_rows)]
  799. cells_info = []
  800. cell_index = 0
  801. cell_map = {}
  802. for index, cell in enumerate(cells_det_results):
  803. x1, y1, x2, y2 = cell
  804. x1_idx = min(
  805. range(len(x_positions)), key=lambda i: abs(x_positions[i] - x1)
  806. )
  807. x2_idx = min(
  808. range(len(x_positions)), key=lambda i: abs(x_positions[i] - x2)
  809. )
  810. y1_idx = min(
  811. range(len(y_positions)), key=lambda i: abs(y_positions[i] - y1)
  812. )
  813. y2_idx = min(
  814. range(len(y_positions)), key=lambda i: abs(y_positions[i] - y2)
  815. )
  816. col_start = min(x1_idx, x2_idx)
  817. col_end = max(x1_idx, x2_idx)
  818. row_start = min(y1_idx, y2_idx)
  819. row_end = max(y1_idx, y2_idx)
  820. rowspan = row_end - row_start
  821. colspan = col_end - col_start
  822. if rowspan == 0:
  823. rowspan = 1
  824. if colspan == 0:
  825. colspan = 1
  826. cells_info.append(
  827. {
  828. "row_start": row_start,
  829. "col_start": col_start,
  830. "rowspan": rowspan,
  831. "colspan": colspan,
  832. "content": "",
  833. }
  834. )
  835. for r in range(row_start, row_start + rowspan):
  836. for c in range(col_start, col_start + colspan):
  837. key = (r, c)
  838. if key in cell_map:
  839. continue
  840. else:
  841. cell_map[key] = index
  842. html = "<table><tbody>"
  843. for r in range(num_rows):
  844. html += "<tr>"
  845. c = 0
  846. while c < num_cols:
  847. key = (r, c)
  848. if key in cell_map:
  849. cell_index = cell_map[key]
  850. cell_info = cells_info[cell_index]
  851. if cell_info["row_start"] == r and cell_info["col_start"] == c:
  852. rowspan = cell_info["rowspan"]
  853. colspan = cell_info["colspan"]
  854. rowspan_attr = f' rowspan="{rowspan}"' if rowspan > 1 else ""
  855. colspan_attr = f' colspan="{colspan}"' if colspan > 1 else ""
  856. content = cell_info["content"]
  857. html += f"<td{rowspan_attr}{colspan_attr}>{content}</td>"
  858. c += cell_info["colspan"]
  859. else:
  860. html += "<td></td>"
  861. c += 1
  862. html += "</tr>"
  863. html += "</tbody></table>"
  864. html = self.split_string_by_keywords(html)
  865. return html
  866. def predict_single_table_recognition_res(
  867. self,
  868. image_array: np.ndarray,
  869. overall_ocr_res: OCRResult,
  870. table_box: list,
  871. use_e2e_wired_table_rec_model: bool = False,
  872. use_e2e_wireless_table_rec_model: bool = False,
  873. use_wired_table_cells_trans_to_html: bool = False,
  874. use_wireless_table_cells_trans_to_html: bool = False,
  875. use_ocr_results_with_table_cells: bool = True,
  876. flag_find_nei_text: bool = True,
  877. ) -> SingleTableRecognitionResult:
  878. """
  879. Predict table recognition results from an image array, layout detection results, and OCR results.
  880. Args:
  881. image_array (np.ndarray): The input image represented as a numpy array.
  882. overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
  883. The overall OCR results containing text recognition information.
  884. table_box (list): The table box coordinates.
  885. use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
  886. use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model.
  887. use_wired_table_cells_trans_to_html (bool): Whether to use wired table cells trans to HTML.
  888. use_wireless_table_cells_trans_to_html (bool): Whether to use wireless table cells trans to HTML.
  889. use_ocr_results_with_table_cells (bool): Whether to use OCR results processed by table cells.
  890. flag_find_nei_text (bool): Whether to find neighboring text.
  891. Returns:
  892. SingleTableRecognitionResult: single table recognition result.
  893. """
  894. table_cls_pred = list(self.table_cls_model(image_array))[0]
  895. table_cls_result = self.extract_results(table_cls_pred, "cls")
  896. use_e2e_model = False
  897. cells_trans_to_html = False
  898. if table_cls_result == "wired_table":
  899. if use_wired_table_cells_trans_to_html == True:
  900. cells_trans_to_html = True
  901. else:
  902. table_structure_pred = list(self.wired_table_rec_model(image_array))[0]
  903. if use_e2e_wired_table_rec_model == True:
  904. use_e2e_model = True
  905. if cells_trans_to_html == True:
  906. table_structure_pred = list(
  907. self.wired_table_rec_model(image_array)
  908. )[0]
  909. else:
  910. table_cells_pred = list(
  911. self.wired_table_cells_detection_model(image_array, threshold=0.3)
  912. )[
  913. 0
  914. ] # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
  915. # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
  916. elif table_cls_result == "wireless_table":
  917. if use_wireless_table_cells_trans_to_html == True:
  918. cells_trans_to_html = True
  919. else:
  920. table_structure_pred = list(self.wireless_table_rec_model(image_array))[
  921. 0
  922. ]
  923. if use_e2e_wireless_table_rec_model == True:
  924. use_e2e_model = True
  925. if cells_trans_to_html == True:
  926. table_structure_pred = list(
  927. self.wireless_table_rec_model(image_array)
  928. )[0]
  929. else:
  930. table_cells_pred = list(
  931. self.wireless_table_cells_detection_model(
  932. image_array, threshold=0.3
  933. )
  934. )[
  935. 0
  936. ] # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
  937. # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
  938. if use_e2e_model == False:
  939. table_cells_result, table_cells_score = self.extract_results(
  940. table_cells_pred, "det"
  941. )
  942. table_cells_result, table_cells_score = self.cells_det_results_nms(
  943. table_cells_result, table_cells_score
  944. )
  945. if cells_trans_to_html == True:
  946. table_structure_result = self.trans_cells_det_results_to_html(
  947. table_cells_result
  948. )
  949. else:
  950. table_structure_result = self.extract_results(
  951. table_structure_pred, "table_stru"
  952. )
  953. ocr_det_boxes = self.get_region_ocr_det_boxes(
  954. overall_ocr_res["rec_boxes"].tolist(), table_box
  955. )
  956. table_cells_result = self.cells_det_results_reprocessing(
  957. table_cells_result,
  958. table_cells_score,
  959. ocr_det_boxes,
  960. len(table_structure_pred["bbox"]),
  961. )
  962. if use_ocr_results_with_table_cells == True:
  963. if self.cells_split_ocr == True:
  964. table_box_copy = np.array([table_box])
  965. table_ocr_pred = get_sub_regions_ocr_res(
  966. overall_ocr_res, table_box_copy
  967. )
  968. table_ocr_pred = self.split_ocr_bboxes_by_table_cells(
  969. table_cells_result, table_ocr_pred, image_array
  970. )
  971. cells_texts_list = []
  972. else:
  973. cells_texts_list = self.gen_ocr_with_table_cells(
  974. image_array, table_cells_result
  975. )
  976. table_ocr_pred = {}
  977. else:
  978. table_ocr_pred = {}
  979. cells_texts_list = []
  980. single_table_recognition_res = get_table_recognition_res(
  981. table_box,
  982. table_structure_result,
  983. table_cells_result,
  984. overall_ocr_res,
  985. table_ocr_pred,
  986. cells_texts_list,
  987. use_ocr_results_with_table_cells,
  988. self.cells_split_ocr,
  989. )
  990. else:
  991. cells_texts_list = []
  992. use_ocr_results_with_table_cells = False
  993. table_cells_result_e2e = table_structure_pred["bbox"]
  994. table_cells_result_e2e = [
  995. [rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result_e2e
  996. ]
  997. if cells_trans_to_html == True:
  998. table_structure_pred["structure"] = (
  999. self.trans_cells_det_results_to_html(table_cells_result_e2e)
  1000. )
  1001. single_table_recognition_res = get_table_recognition_res_e2e(
  1002. table_box,
  1003. table_structure_pred,
  1004. overall_ocr_res,
  1005. cells_texts_list,
  1006. use_ocr_results_with_table_cells,
  1007. )
  1008. neighbor_text = ""
  1009. if flag_find_nei_text:
  1010. match_idx_list = get_neighbor_boxes_idx(
  1011. overall_ocr_res["rec_boxes"], table_box
  1012. )
  1013. if len(match_idx_list) > 0:
  1014. for idx in match_idx_list:
  1015. neighbor_text += overall_ocr_res["rec_texts"][idx] + "; "
  1016. single_table_recognition_res["neighbor_texts"] = neighbor_text
  1017. return single_table_recognition_res
  1018. def predict(
  1019. self,
  1020. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  1021. use_doc_orientation_classify: Optional[bool] = None,
  1022. use_doc_unwarping: Optional[bool] = None,
  1023. use_layout_detection: Optional[bool] = None,
  1024. use_ocr_model: Optional[bool] = None,
  1025. overall_ocr_res: Optional[OCRResult] = None,
  1026. layout_det_res: Optional[DetResult] = None,
  1027. text_det_limit_side_len: Optional[int] = None,
  1028. text_det_limit_type: Optional[str] = None,
  1029. text_det_thresh: Optional[float] = None,
  1030. text_det_box_thresh: Optional[float] = None,
  1031. text_det_unclip_ratio: Optional[float] = None,
  1032. text_rec_score_thresh: Optional[float] = None,
  1033. use_e2e_wired_table_rec_model: bool = False,
  1034. use_e2e_wireless_table_rec_model: bool = False,
  1035. use_wired_table_cells_trans_to_html: bool = False,
  1036. use_wireless_table_cells_trans_to_html: bool = False,
  1037. use_table_orientation_classify: bool = True,
  1038. use_ocr_results_with_table_cells: bool = True,
  1039. **kwargs,
  1040. ) -> TableRecognitionResult:
  1041. """
  1042. This function predicts the layout parsing result for the given input.
  1043. Args:
  1044. input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) of pdf(s) to be processed.
  1045. use_layout_detection (bool): Whether to use layout detection.
  1046. use_doc_orientation_classify (bool): Whether to use document orientation classification.
  1047. use_doc_unwarping (bool): Whether to use document unwarping.
  1048. overall_ocr_res (OCRResult): The overall OCR result with convert_points_to_boxes information.
  1049. It will be used if it is not None and use_ocr_model is False.
  1050. layout_det_res (DetResult): The layout detection result.
  1051. It will be used if it is not None and use_layout_detection is False.
  1052. use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
  1053. use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model.
  1054. use_wired_table_cells_trans_to_html (bool): Whether to use wired table cells trans to HTML.
  1055. use_wireless_table_cells_trans_to_html (bool): Whether to use wireless table cells trans to HTML.
  1056. use_table_orientation_classify (bool): Whether to use table orientation classification.
  1057. use_ocr_results_with_table_cells (bool): Whether to use OCR results processed by table cells.
  1058. **kwargs: Additional keyword arguments.
  1059. Returns:
  1060. TableRecognitionResult: The predicted table recognition result.
  1061. """
  1062. # self.cells_split_ocr = True
  1063. # add by zhch158, 新增:允许外部控制单元格内是否拆分重识别(默认保持兼容 True)
  1064. use_table_cells_split_ocr = kwargs.get("use_table_cells_split_ocr", None)
  1065. self.cells_split_ocr = (
  1066. bool(use_table_cells_split_ocr) if use_table_cells_split_ocr is not None else True
  1067. )
  1068. if use_table_orientation_classify == True and (
  1069. self.table_orientation_classify_model is None
  1070. ):
  1071. assert self.table_orientation_classify_config != None
  1072. self.table_orientation_classify_model = self.create_model(
  1073. self.table_orientation_classify_config
  1074. )
  1075. model_settings = self.get_model_settings(
  1076. use_doc_orientation_classify,
  1077. use_doc_unwarping,
  1078. use_layout_detection,
  1079. use_ocr_model,
  1080. )
  1081. if not self.check_model_settings_valid(
  1082. model_settings, overall_ocr_res, layout_det_res
  1083. ):
  1084. yield {"error": "the input params for model settings are invalid!"}
  1085. for img_id, batch_data in enumerate(self.batch_sampler(input)):
  1086. image_array = self.img_reader(batch_data.instances)[0]
  1087. if model_settings["use_doc_preprocessor"]:
  1088. doc_preprocessor_res = list(
  1089. self.doc_preprocessor_pipeline(
  1090. image_array,
  1091. use_doc_orientation_classify=use_doc_orientation_classify,
  1092. use_doc_unwarping=use_doc_unwarping,
  1093. )
  1094. )[0]
  1095. else:
  1096. doc_preprocessor_res = {"output_img": image_array}
  1097. doc_preprocessor_image = doc_preprocessor_res["output_img"]
  1098. if model_settings["use_ocr_model"]:
  1099. overall_ocr_res = list(
  1100. self.general_ocr_pipeline(
  1101. doc_preprocessor_image,
  1102. text_det_limit_side_len=text_det_limit_side_len,
  1103. text_det_limit_type=text_det_limit_type,
  1104. text_det_thresh=text_det_thresh,
  1105. text_det_box_thresh=text_det_box_thresh,
  1106. text_det_unclip_ratio=text_det_unclip_ratio,
  1107. text_rec_score_thresh=text_rec_score_thresh,
  1108. )
  1109. )[0]
  1110. elif self.general_ocr_pipeline is None and (
  1111. (
  1112. use_ocr_results_with_table_cells == True
  1113. and self.cells_split_ocr == False
  1114. )
  1115. or use_table_orientation_classify == True
  1116. ):
  1117. assert self.general_ocr_config_bak != None
  1118. self.general_ocr_pipeline = self.create_pipeline(
  1119. self.general_ocr_config_bak
  1120. )
  1121. if use_table_orientation_classify == False:
  1122. table_angle = "0"
  1123. table_res_list = []
  1124. table_region_id = 1
  1125. if not model_settings["use_layout_detection"] and layout_det_res is None:
  1126. img_height, img_width = doc_preprocessor_image.shape[:2]
  1127. table_box = [0, 0, img_width - 1, img_height - 1]
  1128. if use_table_orientation_classify == True:
  1129. table_angle = list(
  1130. self.table_orientation_classify_model(doc_preprocessor_image)
  1131. )[0]["label_names"][0]
  1132. if table_angle == "90":
  1133. doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=1)
  1134. elif table_angle == "180":
  1135. doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=2)
  1136. elif table_angle == "270":
  1137. doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=3)
  1138. if table_angle in ["90", "180", "270"]:
  1139. overall_ocr_res = list(
  1140. self.general_ocr_pipeline(
  1141. doc_preprocessor_image,
  1142. text_det_limit_side_len=text_det_limit_side_len,
  1143. text_det_limit_type=text_det_limit_type,
  1144. text_det_thresh=text_det_thresh,
  1145. text_det_box_thresh=text_det_box_thresh,
  1146. text_det_unclip_ratio=text_det_unclip_ratio,
  1147. text_rec_score_thresh=text_rec_score_thresh,
  1148. )
  1149. )[0]
  1150. tbx1, tby1, tbx2, tby2 = (
  1151. table_box[0],
  1152. table_box[1],
  1153. table_box[2],
  1154. table_box[3],
  1155. )
  1156. if table_angle == "90":
  1157. new_x1, new_y1 = tby1, img_width - tbx2
  1158. new_x2, new_y2 = tby2, img_width - tbx1
  1159. elif table_angle == "180":
  1160. new_x1, new_y1 = img_width - tbx2, img_height - tby2
  1161. new_x2, new_y2 = img_width - tbx1, img_height - tby1
  1162. elif table_angle == "270":
  1163. new_x1, new_y1 = img_height - tby2, tbx1
  1164. new_x2, new_y2 = img_height - tby1, tbx2
  1165. table_box = [new_x1, new_y1, new_x2, new_y2]
  1166. single_table_rec_res = self.predict_single_table_recognition_res(
  1167. doc_preprocessor_image,
  1168. overall_ocr_res,
  1169. table_box,
  1170. use_e2e_wired_table_rec_model,
  1171. use_e2e_wireless_table_rec_model,
  1172. use_wired_table_cells_trans_to_html,
  1173. use_wireless_table_cells_trans_to_html,
  1174. use_ocr_results_with_table_cells,
  1175. flag_find_nei_text=False,
  1176. )
  1177. single_table_rec_res["table_region_id"] = table_region_id
  1178. if use_table_orientation_classify == True and table_angle != "0":
  1179. img_height, img_width = doc_preprocessor_image.shape[:2]
  1180. single_table_rec_res["cell_box_list"] = (
  1181. self.map_cells_to_original_image(
  1182. single_table_rec_res["cell_box_list"],
  1183. table_angle,
  1184. img_width,
  1185. img_height,
  1186. )
  1187. )
  1188. table_res_list.append(single_table_rec_res)
  1189. table_region_id += 1
  1190. else:
  1191. if model_settings["use_layout_detection"]:
  1192. layout_det_res = list(
  1193. self.layout_det_model(doc_preprocessor_image)
  1194. )[0]
  1195. img_height, img_width = doc_preprocessor_image.shape[:2]
  1196. for box_info in layout_det_res["boxes"]:
  1197. if box_info["label"].lower() in ["table"]:
  1198. crop_img_info = self._crop_by_boxes(
  1199. doc_preprocessor_image, [box_info]
  1200. )
  1201. crop_img_info = crop_img_info[0]
  1202. table_box = crop_img_info["box"]
  1203. if use_table_orientation_classify == True:
  1204. doc_preprocessor_image_copy = doc_preprocessor_image.copy()
  1205. table_angle = list(
  1206. self.table_orientation_classify_model(
  1207. crop_img_info["img"]
  1208. )
  1209. )[0]["label_names"][0]
  1210. if table_angle == "90":
  1211. crop_img_info["img"] = np.rot90(crop_img_info["img"], k=1)
  1212. doc_preprocessor_image_copy = np.rot90(
  1213. doc_preprocessor_image_copy, k=1
  1214. )
  1215. elif table_angle == "180":
  1216. crop_img_info["img"] = np.rot90(crop_img_info["img"], k=2)
  1217. doc_preprocessor_image_copy = np.rot90(
  1218. doc_preprocessor_image_copy, k=2
  1219. )
  1220. elif table_angle == "270":
  1221. crop_img_info["img"] = np.rot90(crop_img_info["img"], k=3)
  1222. doc_preprocessor_image_copy = np.rot90(
  1223. doc_preprocessor_image_copy, k=3
  1224. )
  1225. if table_angle in ["90", "180", "270"]:
  1226. overall_ocr_res = list(
  1227. self.general_ocr_pipeline(
  1228. doc_preprocessor_image_copy,
  1229. text_det_limit_side_len=text_det_limit_side_len,
  1230. text_det_limit_type=text_det_limit_type,
  1231. text_det_thresh=text_det_thresh,
  1232. text_det_box_thresh=text_det_box_thresh,
  1233. text_det_unclip_ratio=text_det_unclip_ratio,
  1234. text_rec_score_thresh=text_rec_score_thresh,
  1235. )
  1236. )[0]
  1237. tbx1, tby1, tbx2, tby2 = (
  1238. table_box[0],
  1239. table_box[1],
  1240. table_box[2],
  1241. table_box[3],
  1242. )
  1243. if table_angle == "90":
  1244. new_x1, new_y1 = tby1, img_width - tbx2
  1245. new_x2, new_y2 = tby2, img_width - tbx1
  1246. elif table_angle == "180":
  1247. new_x1, new_y1 = img_width - tbx2, img_height - tby2
  1248. new_x2, new_y2 = img_width - tbx1, img_height - tby1
  1249. elif table_angle == "270":
  1250. new_x1, new_y1 = img_height - tby2, tbx1
  1251. new_x2, new_y2 = img_height - tby1, tbx2
  1252. table_box = [new_x1, new_y1, new_x2, new_y2]
  1253. single_table_rec_res = (
  1254. self.predict_single_table_recognition_res(
  1255. crop_img_info["img"],
  1256. overall_ocr_res,
  1257. table_box,
  1258. use_e2e_wired_table_rec_model,
  1259. use_e2e_wireless_table_rec_model,
  1260. use_wired_table_cells_trans_to_html,
  1261. use_wireless_table_cells_trans_to_html,
  1262. use_ocr_results_with_table_cells,
  1263. )
  1264. )
  1265. single_table_rec_res["table_region_id"] = table_region_id
  1266. if (
  1267. use_table_orientation_classify == True
  1268. and table_angle != "0"
  1269. ):
  1270. img_height_copy, img_width_copy = (
  1271. doc_preprocessor_image_copy.shape[:2]
  1272. )
  1273. single_table_rec_res["cell_box_list"] = (
  1274. self.map_cells_to_original_image(
  1275. single_table_rec_res["cell_box_list"],
  1276. table_angle,
  1277. img_width_copy,
  1278. img_height_copy,
  1279. )
  1280. )
  1281. table_res_list.append(single_table_rec_res)
  1282. table_region_id += 1
  1283. single_img_res = {
  1284. "input_path": batch_data.input_paths[0],
  1285. "page_index": batch_data.page_indexes[0],
  1286. "doc_preprocessor_res": doc_preprocessor_res,
  1287. "layout_det_res": layout_det_res,
  1288. "overall_ocr_res": overall_ocr_res,
  1289. "table_res_list": table_res_list,
  1290. "model_settings": model_settings,
  1291. }
  1292. yield TableRecognitionResult(single_img_res)
  1293. @pipeline_requires_extra("ocr")
  1294. class TableRecognitionPipelineV2(AutoParallelImageSimpleInferencePipeline):
  1295. entities = ["table_recognition_v2"]
  1296. @property
  1297. def _pipeline_cls(self):
  1298. return _TableRecognitionPipelineV2
  1299. def _get_batch_size(self, config):
  1300. return 1