pipeline_v2.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from typing import Any, Dict, List, Optional, Tuple, Union
  16. import numpy as np
  17. from ....utils import logging
  18. from ....utils.deps import (
  19. function_requires_deps,
  20. is_dep_available,
  21. pipeline_requires_extra,
  22. )
  23. from ...common.batch_sampler import ImageBatchSampler
  24. from ...common.reader import ReadImage
  25. from ...models.object_detection.result import DetResult
  26. from ...utils.hpi import HPIConfig
  27. from ...utils.pp_option import PaddlePredictorOption
  28. from .._parallel import AutoParallelImageSimpleInferencePipeline
  29. from ..base import BasePipeline
  30. from ..components import CropByBoxes
  31. from ..doc_preprocessor.result import DocPreprocessorResult
  32. from ..ocr.result import OCRResult
  33. from .result import SingleTableRecognitionResult, TableRecognitionResult
  34. from .table_recognition_post_processing import (
  35. get_table_recognition_res as get_table_recognition_res_e2e,
  36. )
  37. from .table_recognition_post_processing_v2 import get_table_recognition_res
  38. from .utils import get_neighbor_boxes_idx
  39. if is_dep_available("scikit-learn"):
  40. from sklearn.cluster import KMeans
  41. class _TableRecognitionPipelineV2(BasePipeline):
  42. """Table Recognition Pipeline"""
  43. def __init__(
  44. self,
  45. config: Dict,
  46. device: str = None,
  47. pp_option: PaddlePredictorOption = None,
  48. use_hpip: bool = False,
  49. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  50. ) -> None:
  51. """Initializes the layout parsing pipeline.
  52. Args:
  53. config (Dict): Configuration dictionary containing various settings.
  54. device (str, optional): Device to run the predictions on. Defaults to None.
  55. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  56. use_hpip (bool, optional): Whether to use the high-performance
  57. inference plugin (HPIP) by default. Defaults to False.
  58. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
  59. The default high-performance inference configuration dictionary.
  60. Defaults to None.
  61. """
  62. super().__init__(
  63. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_config=hpi_config
  64. )
  65. self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
  66. if self.use_doc_preprocessor:
  67. doc_preprocessor_config = config.get("SubPipelines", {}).get(
  68. "DocPreprocessor",
  69. {
  70. "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
  71. },
  72. )
  73. self.doc_preprocessor_pipeline = self.create_pipeline(
  74. doc_preprocessor_config
  75. )
  76. self.use_layout_detection = config.get("use_layout_detection", True)
  77. if self.use_layout_detection:
  78. layout_det_config = config.get("SubModules", {}).get(
  79. "LayoutDetection",
  80. {"model_config_error": "config error for layout_det_model!"},
  81. )
  82. self.layout_det_model = self.create_model(layout_det_config)
  83. table_cls_config = config.get("SubModules", {}).get(
  84. "TableClassification",
  85. {"model_config_error": "config error for table_classification_model!"},
  86. )
  87. self.table_cls_model = self.create_model(table_cls_config)
  88. wired_table_rec_config = config.get("SubModules", {}).get(
  89. "WiredTableStructureRecognition",
  90. {"model_config_error": "config error for wired_table_structure_model!"},
  91. )
  92. self.wired_table_rec_model = self.create_model(wired_table_rec_config)
  93. wireless_table_rec_config = config.get("SubModules", {}).get(
  94. "WirelessTableStructureRecognition",
  95. {"model_config_error": "config error for wireless_table_structure_model!"},
  96. )
  97. self.wireless_table_rec_model = self.create_model(wireless_table_rec_config)
  98. wired_table_cells_det_config = config.get("SubModules", {}).get(
  99. "WiredTableCellsDetection",
  100. {
  101. "model_config_error": "config error for wired_table_cells_detection_model!"
  102. },
  103. )
  104. self.wired_table_cells_detection_model = self.create_model(
  105. wired_table_cells_det_config
  106. )
  107. wireless_table_cells_det_config = config.get("SubModules", {}).get(
  108. "WirelessTableCellsDetection",
  109. {
  110. "model_config_error": "config error for wireless_table_cells_detection_model!"
  111. },
  112. )
  113. self.wireless_table_cells_detection_model = self.create_model(
  114. wireless_table_cells_det_config
  115. )
  116. self.use_ocr_model = config.get("use_ocr_model", True)
  117. if self.use_ocr_model:
  118. general_ocr_config = config.get("SubPipelines", {}).get(
  119. "GeneralOCR",
  120. {"pipeline_config_error": "config error for general_ocr_pipeline!"},
  121. )
  122. self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
  123. else:
  124. self.general_ocr_config_bak = config.get("SubPipelines", {}).get(
  125. "GeneralOCR", None
  126. )
  127. self._crop_by_boxes = CropByBoxes()
  128. self.batch_sampler = ImageBatchSampler(batch_size=config.get("batch_size", 1))
  129. self.img_reader = ReadImage(format="BGR")
  130. def get_model_settings(
  131. self,
  132. use_doc_orientation_classify: Optional[bool],
  133. use_doc_unwarping: Optional[bool],
  134. use_layout_detection: Optional[bool],
  135. use_ocr_model: Optional[bool],
  136. ) -> dict:
  137. """
  138. Get the model settings based on the provided parameters or default values.
  139. Args:
  140. use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
  141. use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
  142. use_layout_detection (Optional[bool]): Whether to use layout detection.
  143. use_ocr_model (Optional[bool]): Whether to use OCR model.
  144. Returns:
  145. dict: A dictionary containing the model settings.
  146. """
  147. if use_doc_orientation_classify is None and use_doc_unwarping is None:
  148. use_doc_preprocessor = self.use_doc_preprocessor
  149. else:
  150. if use_doc_orientation_classify is True or use_doc_unwarping is True:
  151. use_doc_preprocessor = True
  152. else:
  153. use_doc_preprocessor = False
  154. if use_layout_detection is None:
  155. use_layout_detection = self.use_layout_detection
  156. if use_ocr_model is None:
  157. use_ocr_model = self.use_ocr_model
  158. return dict(
  159. use_doc_preprocessor=use_doc_preprocessor,
  160. use_layout_detection=use_layout_detection,
  161. use_ocr_model=use_ocr_model,
  162. )
  163. def check_model_settings_valid(
  164. self,
  165. model_settings: Dict,
  166. overall_ocr_res: OCRResult,
  167. layout_det_res: Union[DetResult, List[DetResult]],
  168. ) -> bool:
  169. """
  170. Check if the input parameters are valid based on the initialized models.
  171. Args:
  172. model_settings (Dict): A dictionary containing input parameters.
  173. overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
  174. The overall OCR result with convert_points_to_boxes information.
  175. layout_det_res (Union[DetResult, List[DetResult]]): The layout detection result(s).
  176. Returns:
  177. bool: True if all required models are initialized according to input parameters, False otherwise.
  178. """
  179. if model_settings["use_doc_preprocessor"] and not self.use_doc_preprocessor:
  180. logging.error(
  181. "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
  182. )
  183. return False
  184. if model_settings["use_layout_detection"]:
  185. if layout_det_res is not None:
  186. logging.error(
  187. "The layout detection model has already been initialized, please set use_layout_detection=False"
  188. )
  189. return False
  190. if not self.use_layout_detection:
  191. logging.error(
  192. "Set use_layout_detection, but the models for layout detection are not initialized."
  193. )
  194. return False
  195. if model_settings["use_ocr_model"]:
  196. if overall_ocr_res is not None:
  197. logging.error(
  198. "The OCR models have already been initialized, please set use_ocr_model=False"
  199. )
  200. return False
  201. if not self.use_ocr_model:
  202. logging.error(
  203. "Set use_ocr_model, but the models for OCR are not initialized."
  204. )
  205. return False
  206. else:
  207. if overall_ocr_res is None:
  208. logging.error("Set use_ocr_model=False, but no OCR results were found.")
  209. return False
  210. return True
  211. def predict_doc_preprocessor_res(
  212. self, image_array: np.ndarray, input_params: dict
  213. ) -> Tuple[DocPreprocessorResult, np.ndarray]:
  214. """
  215. Preprocess the document image based on input parameters.
  216. Args:
  217. image_array (np.ndarray): The input image array.
  218. input_params (dict): Dictionary containing preprocessing parameters.
  219. Returns:
  220. tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
  221. result dictionary and the processed image array.
  222. """
  223. if input_params["use_doc_preprocessor"]:
  224. use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
  225. use_doc_unwarping = input_params["use_doc_unwarping"]
  226. doc_preprocessor_res = next(
  227. self.doc_preprocessor_pipeline(
  228. image_array,
  229. use_doc_orientation_classify=use_doc_orientation_classify,
  230. use_doc_unwarping=use_doc_unwarping,
  231. )
  232. )
  233. doc_preprocessor_image = doc_preprocessor_res["output_img"]
  234. else:
  235. doc_preprocessor_res = {}
  236. doc_preprocessor_image = image_array
  237. return doc_preprocessor_res, doc_preprocessor_image
  238. def extract_results(self, pred, task):
  239. if task == "cls":
  240. return pred["label_names"][np.argmax(pred["scores"])]
  241. elif task == "det":
  242. threshold = 0.0
  243. result = []
  244. cell_score = []
  245. if "boxes" in pred and isinstance(pred["boxes"], list):
  246. for box in pred["boxes"]:
  247. if isinstance(box, dict) and "score" in box and "coordinate" in box:
  248. score = box["score"]
  249. coordinate = box["coordinate"]
  250. if isinstance(score, float) and score > threshold:
  251. result.append(coordinate)
  252. cell_score.append(score)
  253. return result, cell_score
  254. elif task == "table_stru":
  255. return pred["structure"]
  256. else:
  257. return None
  258. def cells_det_results_nms(
  259. self, cells_det_results, cells_det_scores, cells_det_threshold=0.3
  260. ):
  261. """
  262. Apply Non-Maximum Suppression (NMS) on detection results to remove redundant overlapping bounding boxes.
  263. Args:
  264. cells_det_results (list): List of bounding boxes, each box is in format [x1, y1, x2, y2].
  265. cells_det_scores (list): List of confidence scores corresponding to the bounding boxes.
  266. cells_det_threshold (float): IoU threshold for suppression. Boxes with IoU greater than this threshold
  267. will be suppressed. Default is 0.5.
  268. Returns:
  269. Tuple[list, list]: A tuple containing the list of bounding boxes and confidence scores after NMS,
  270. while maintaining one-to-one correspondence.
  271. """
  272. # Convert lists to numpy arrays for efficient computation
  273. boxes = np.array(cells_det_results)
  274. scores = np.array(cells_det_scores)
  275. # Initialize list for picked indices
  276. picked_indices = []
  277. # Get coordinates of bounding boxes
  278. x1 = boxes[:, 0]
  279. y1 = boxes[:, 1]
  280. x2 = boxes[:, 2]
  281. y2 = boxes[:, 3]
  282. # Compute the area of the bounding boxes
  283. areas = (x2 - x1) * (y2 - y1)
  284. # Sort the bounding boxes by the confidence scores in descending order
  285. order = scores.argsort()[::-1]
  286. # Process the boxes
  287. while order.size > 0:
  288. # Index of the current highest score box
  289. i = order[0]
  290. picked_indices.append(i)
  291. # Compute IoU between the highest score box and the rest
  292. xx1 = np.maximum(x1[i], x1[order[1:]])
  293. yy1 = np.maximum(y1[i], y1[order[1:]])
  294. xx2 = np.minimum(x2[i], x2[order[1:]])
  295. yy2 = np.minimum(y2[i], y2[order[1:]])
  296. # Compute the width and height of the overlapping area
  297. w = np.maximum(0.0, xx2 - xx1)
  298. h = np.maximum(0.0, yy2 - yy1)
  299. # Compute the ratio of overlap (IoU)
  300. inter = w * h
  301. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  302. # Indices of boxes with IoU less than threshold
  303. inds = np.where(ovr <= cells_det_threshold)[0]
  304. # Update order, only keep boxes with IoU less than threshold
  305. order = order[
  306. inds + 1
  307. ] # inds shifted by 1 because order[0] is the current box
  308. # Select the boxes and scores based on picked indices
  309. final_boxes = boxes[picked_indices].tolist()
  310. final_scores = scores[picked_indices].tolist()
  311. return final_boxes, final_scores
  312. def get_region_ocr_det_boxes(self, ocr_det_boxes, table_box):
  313. """Adjust the coordinates of ocr_det_boxes that are fully inside table_box relative to table_box.
  314. Args:
  315. ocr_det_boxes (list of list): List of bounding boxes [x1, y1, x2, y2] in the original image.
  316. table_box (list): Bounding box [x1, y1, x2, y2] of the target region in the original image.
  317. Returns:
  318. list of list: List of adjusted bounding boxes relative to table_box, for boxes fully inside table_box.
  319. """
  320. tol = 0
  321. # Extract coordinates from table_box
  322. x_min_t, y_min_t, x_max_t, y_max_t = table_box
  323. adjusted_boxes = []
  324. for box in ocr_det_boxes:
  325. x_min_b, y_min_b, x_max_b, y_max_b = box
  326. # Check if the box is fully inside table_box
  327. if (
  328. x_min_b + tol >= x_min_t
  329. and y_min_b + tol >= y_min_t
  330. and x_max_b - tol <= x_max_t
  331. and y_max_b - tol <= y_max_t
  332. ):
  333. # Adjust the coordinates to be relative to table_box
  334. adjusted_box = [
  335. x_min_b - x_min_t, # Adjust x1
  336. y_min_b - y_min_t, # Adjust y1
  337. x_max_b - x_min_t, # Adjust x2
  338. y_max_b - y_min_t, # Adjust y2
  339. ]
  340. adjusted_boxes.append(adjusted_box)
  341. # Discard boxes not fully inside table_box
  342. return adjusted_boxes
  343. def cells_det_results_reprocessing(
  344. self, cells_det_results, cells_det_scores, ocr_det_results, html_pred_boxes_nums
  345. ):
  346. """
  347. Process and filter cells_det_results based on ocr_det_results and html_pred_boxes_nums.
  348. Args:
  349. cells_det_results (List[List[float]]): List of detected cell rectangles [[x1, y1, x2, y2], ...].
  350. cells_det_scores (List[float]): List of confidence scores for each rectangle in cells_det_results.
  351. ocr_det_results (List[List[float]]): List of OCR detected rectangles [[x1, y1, x2, y2], ...].
  352. html_pred_boxes_nums (int): The desired number of rectangles in the final output.
  353. Returns:
  354. List[List[float]]: The processed list of rectangles.
  355. """
  356. # Function to compute IoU between two rectangles
  357. def compute_iou(box1, box2):
  358. """
  359. Compute the Intersection over Union (IoU) between two rectangles.
  360. Args:
  361. box1 (array-like): [x1, y1, x2, y2] of the first rectangle.
  362. box2 (array-like): [x1, y1, x2, y2] of the second rectangle.
  363. Returns:
  364. float: The IoU between the two rectangles.
  365. """
  366. # Determine the coordinates of the intersection rectangle
  367. x_left = max(box1[0], box2[0])
  368. y_top = max(box1[1], box2[1])
  369. x_right = min(box1[2], box2[2])
  370. y_bottom = min(box1[3], box2[3])
  371. if x_right <= x_left or y_bottom <= y_top:
  372. return 0.0
  373. # Calculate the area of intersection rectangle
  374. intersection_area = (x_right - x_left) * (y_bottom - y_top)
  375. # Calculate the area of both rectangles
  376. box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
  377. (box2[2] - box2[0]) * (box2[3] - box2[1])
  378. # Calculate the IoU
  379. iou = intersection_area / float(box1_area)
  380. return iou
  381. # Function to combine rectangles into N rectangles
  382. @function_requires_deps("scikit-learn")
  383. def combine_rectangles(rectangles, N):
  384. """
  385. Combine rectangles into N rectangles based on geometric proximity.
  386. Args:
  387. rectangles (list of list of int): A list of rectangles, each represented by [x1, y1, x2, y2].
  388. N (int): The desired number of combined rectangles.
  389. Returns:
  390. list of list of int: A list of N combined rectangles.
  391. """
  392. # Number of input rectangles
  393. num_rects = len(rectangles)
  394. # If N is greater than or equal to the number of rectangles, return the original rectangles
  395. if N >= num_rects:
  396. return rectangles
  397. # Compute the center points of the rectangles
  398. centers = np.array(
  399. [
  400. [
  401. (rect[0] + rect[2]) / 2, # Center x-coordinate
  402. (rect[1] + rect[3]) / 2, # Center y-coordinate
  403. ]
  404. for rect in rectangles
  405. ]
  406. )
  407. # Perform KMeans clustering on the center points to group them into N clusters
  408. kmeans = KMeans(n_clusters=N, random_state=0, n_init="auto")
  409. labels = kmeans.fit_predict(centers)
  410. # Initialize a list to store the combined rectangles
  411. combined_rectangles = []
  412. # For each cluster, compute the minimal bounding rectangle that covers all rectangles in the cluster
  413. for i in range(N):
  414. # Get the indices of rectangles that belong to cluster i
  415. indices = np.where(labels == i)[0]
  416. if len(indices) == 0:
  417. # If no rectangles in this cluster, skip it
  418. continue
  419. # Extract the rectangles in cluster i
  420. cluster_rects = np.array([rectangles[idx] for idx in indices])
  421. # Compute the minimal x1, y1 (top-left corner) and maximal x2, y2 (bottom-right corner)
  422. x1_min = np.min(cluster_rects[:, 0])
  423. y1_min = np.min(cluster_rects[:, 1])
  424. x2_max = np.max(cluster_rects[:, 2])
  425. y2_max = np.max(cluster_rects[:, 3])
  426. # Append the combined rectangle to the list
  427. combined_rectangles.append([x1_min, y1_min, x2_max, y2_max])
  428. return combined_rectangles
  429. # Ensure that the inputs are numpy arrays for efficient computation
  430. cells_det_results = np.array(cells_det_results)
  431. cells_det_scores = np.array(cells_det_scores)
  432. ocr_det_results = np.array(ocr_det_results)
  433. more_cells_flag = False
  434. if len(cells_det_results) == html_pred_boxes_nums:
  435. return cells_det_results
  436. # Step 1: If cells_det_results has more rectangles than html_pred_boxes_nums
  437. elif len(cells_det_results) > html_pred_boxes_nums:
  438. more_cells_flag = True
  439. # Select the indices of the top html_pred_boxes_nums scores
  440. top_indices = np.argsort(-cells_det_scores)[:html_pred_boxes_nums]
  441. # Adjust the corresponding rectangles
  442. cells_det_results = cells_det_results[top_indices].tolist()
  443. # Threshold for IoU
  444. iou_threshold = 0.6
  445. # List to store ocr_miss_boxes
  446. ocr_miss_boxes = []
  447. # For each rectangle in ocr_det_results
  448. for ocr_rect in ocr_det_results:
  449. merge_ocr_box_iou = []
  450. # Flag to indicate if ocr_rect has IoU >= threshold with any cell_rect
  451. has_large_iou = False
  452. # For each rectangle in cells_det_results
  453. for cell_rect in cells_det_results:
  454. # Compute IoU
  455. iou = compute_iou(ocr_rect, cell_rect)
  456. if iou > 0:
  457. merge_ocr_box_iou.append(iou)
  458. if (iou >= iou_threshold) or (sum(merge_ocr_box_iou) >= iou_threshold):
  459. has_large_iou = True
  460. break
  461. if not has_large_iou:
  462. ocr_miss_boxes.append(ocr_rect)
  463. # If no ocr_miss_boxes, return cells_det_results
  464. if len(ocr_miss_boxes) == 0:
  465. final_results = (
  466. cells_det_results
  467. if more_cells_flag == True
  468. else cells_det_results.tolist()
  469. )
  470. else:
  471. if more_cells_flag == True:
  472. final_results = combine_rectangles(
  473. cells_det_results + ocr_miss_boxes, html_pred_boxes_nums
  474. )
  475. else:
  476. # Need to combine ocr_miss_boxes into N rectangles
  477. N = html_pred_boxes_nums - len(cells_det_results)
  478. # Combine ocr_miss_boxes into N rectangles
  479. ocr_supp_boxes = combine_rectangles(ocr_miss_boxes, N)
  480. # Combine cells_det_results and ocr_supp_boxes
  481. final_results = np.concatenate(
  482. (cells_det_results, ocr_supp_boxes), axis=0
  483. ).tolist()
  484. if len(final_results) <= 0.6 * html_pred_boxes_nums:
  485. final_results = combine_rectangles(ocr_det_results, html_pred_boxes_nums)
  486. return final_results
  487. def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes):
  488. """
  489. Splits OCR bounding boxes by table cells and retrieves text.
  490. Args:
  491. ori_img (ndarray): The original image from which text regions will be extracted.
  492. cells_bboxes (list or ndarray): Detected cell bounding boxes to extract text from.
  493. Returns:
  494. list: A list containing the recognized texts from each cell.
  495. """
  496. # Check if cells_bboxes is a list and convert it if not.
  497. if not isinstance(cells_bboxes, list):
  498. cells_bboxes = cells_bboxes.tolist()
  499. texts_list = [] # Initialize a list to store the recognized texts.
  500. # Process each bounding box provided in cells_bboxes.
  501. for i in range(len(cells_bboxes)):
  502. # Extract and round up the coordinates of the bounding box.
  503. x1, y1, x2, y2 = [math.ceil(k) for k in cells_bboxes[i]]
  504. # Perform OCR on the defined region of the image and get the recognized text.
  505. rec_te = next(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))
  506. # Concatenate the texts and append them to the texts_list.
  507. texts_list.append("".join(rec_te["rec_texts"]))
  508. # Return the list of recognized texts from each cell.
  509. return texts_list
  510. def _predict(
  511. self,
  512. image_arrays: List[np.ndarray],
  513. overall_ocr_results: List[OCRResult],
  514. table_boxes: List[list],
  515. use_table_cells_ocr_results: bool = False,
  516. use_e2e_wired_table_rec_model: bool = False,
  517. use_e2e_wireless_table_rec_model: bool = False,
  518. flag_find_nei_text: bool = True,
  519. ) -> List[SingleTableRecognitionResult]:
  520. """
  521. Predict table recognition results from image arrays, layout detection results, and OCR results.
  522. Args:
  523. image_arrays (List[np.ndarray]): The input image arrays.
  524. overall_ocr_results (List[OCRResult]): Overall OCR results obtained after running the OCR pipeline.
  525. The overall OCR results contain text recognition information.
  526. table_boxes (List[list]): The table box coordinates.
  527. use_table_cells_ocr_results (bool): whether to use OCR results with cells.
  528. use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
  529. use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model.
  530. flag_find_nei_text (bool): Whether to find neighboring text.
  531. Returns:
  532. List[SingleTableRecognitionResult]: Single table recognition results.
  533. """
  534. # TODO: Batch inference
  535. results = []
  536. for image_array, overall_ocr_res, table_box in zip(
  537. image_arrays, overall_ocr_results, table_boxes
  538. ):
  539. table_cls_pred = next(self.table_cls_model(image_array))
  540. table_cls_result = self.extract_results(table_cls_pred, "cls")
  541. use_e2e_model = False
  542. if table_cls_result == "wired_table":
  543. table_structure_pred = next(self.wired_table_rec_model(image_array))
  544. if use_e2e_wired_table_rec_model == True:
  545. use_e2e_model = True
  546. else:
  547. table_cells_pred = next(
  548. self.wired_table_cells_detection_model(
  549. image_array, threshold=0.3
  550. )
  551. ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
  552. # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
  553. elif table_cls_result == "wireless_table":
  554. table_structure_pred = next(self.wireless_table_rec_model(image_array))
  555. if use_e2e_wireless_table_rec_model == True:
  556. use_e2e_model = True
  557. else:
  558. table_cells_pred = next(
  559. self.wireless_table_cells_detection_model(
  560. image_array, threshold=0.3
  561. )
  562. ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
  563. # If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
  564. if use_e2e_model == False:
  565. table_structure_result = self.extract_results(
  566. table_structure_pred, "table_stru"
  567. )
  568. table_cells_result, table_cells_score = self.extract_results(
  569. table_cells_pred, "det"
  570. )
  571. table_cells_result, table_cells_score = self.cells_det_results_nms(
  572. table_cells_result, table_cells_score
  573. )
  574. ocr_det_boxes = self.get_region_ocr_det_boxes(
  575. overall_ocr_res["rec_boxes"].tolist(), table_box
  576. )
  577. table_cells_result = self.cells_det_results_reprocessing(
  578. table_cells_result,
  579. table_cells_score,
  580. ocr_det_boxes,
  581. len(table_structure_pred["bbox"]),
  582. )
  583. if use_table_cells_ocr_results == True:
  584. cells_texts_list = self.split_ocr_bboxes_by_table_cells(
  585. image_array, table_cells_result
  586. )
  587. else:
  588. cells_texts_list = []
  589. single_table_recognition_res = get_table_recognition_res(
  590. table_box,
  591. table_structure_result,
  592. table_cells_result,
  593. overall_ocr_res,
  594. cells_texts_list,
  595. use_table_cells_ocr_results,
  596. )
  597. else:
  598. if use_table_cells_ocr_results == True:
  599. table_cells_result_e2e = list(
  600. map(lambda arr: arr.tolist(), table_structure_pred["bbox"])
  601. )
  602. table_cells_result_e2e = [
  603. [rect[0], rect[1], rect[4], rect[5]]
  604. for rect in table_cells_result_e2e
  605. ]
  606. cells_texts_list = self.split_ocr_bboxes_by_table_cells(
  607. image_array, table_cells_result_e2e
  608. )
  609. else:
  610. cells_texts_list = []
  611. single_table_recognition_res = get_table_recognition_res_e2e(
  612. table_box,
  613. table_structure_pred,
  614. overall_ocr_res,
  615. cells_texts_list,
  616. use_table_cells_ocr_results,
  617. )
  618. neighbor_text = ""
  619. if flag_find_nei_text:
  620. match_idx_list = get_neighbor_boxes_idx(
  621. overall_ocr_res["rec_boxes"], table_box
  622. )
  623. if len(match_idx_list) > 0:
  624. for idx in match_idx_list:
  625. neighbor_text += overall_ocr_res["rec_texts"][idx] + "; "
  626. single_table_recognition_res["neighbor_texts"] = neighbor_text
  627. results.append(single_table_recognition_res)
  628. return results
  629. def predict(
  630. self,
  631. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  632. use_doc_orientation_classify: Optional[bool] = None,
  633. use_doc_unwarping: Optional[bool] = None,
  634. use_layout_detection: Optional[bool] = None,
  635. use_ocr_model: Optional[bool] = None,
  636. overall_ocr_res: Optional[Union[OCRResult, List[OCRResult]]] = None,
  637. layout_det_res: Optional[Union[DetResult, List[DetResult]]] = None,
  638. text_det_limit_side_len: Optional[int] = None,
  639. text_det_limit_type: Optional[str] = None,
  640. text_det_thresh: Optional[float] = None,
  641. text_det_box_thresh: Optional[float] = None,
  642. text_det_unclip_ratio: Optional[float] = None,
  643. text_rec_score_thresh: Optional[float] = None,
  644. use_table_cells_ocr_results: bool = False,
  645. use_e2e_wired_table_rec_model: bool = False,
  646. use_e2e_wireless_table_rec_model: bool = False,
  647. **kwargs,
  648. ) -> TableRecognitionResult:
  649. """
  650. This function predicts the layout parsing result for the given input.
  651. Args:
  652. input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) of pdf(s) to be processed.
  653. use_layout_detection (bool): Whether to use layout detection.
  654. use_doc_orientation_classify (bool): Whether to use document orientation classification.
  655. use_doc_unwarping (bool): Whether to use document unwarping.
  656. overall_ocr_res (Union[OCRResult, List[OCRResult]]): The overall OCR results with convert_points_to_boxes information.
  657. It will be used if it is not None and use_ocr_model is False.
  658. layout_det_res (Union[DetResult, List[DetResult]]): The layout detection result(s).
  659. It will be used if it is not None and use_layout_detection is False.
  660. use_table_cells_ocr_results (bool): whether to use OCR results with cells.
  661. use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
  662. use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model.
  663. flag_find_nei_text (bool): Whether to find neighboring text.
  664. **kwargs: Additional keyword arguments.
  665. Returns:
  666. TableRecognitionResult: The predicted table recognition result.
  667. """
  668. model_settings = self.get_model_settings(
  669. use_doc_orientation_classify,
  670. use_doc_unwarping,
  671. use_layout_detection,
  672. use_ocr_model,
  673. )
  674. if not self.check_model_settings_valid(
  675. model_settings, overall_ocr_res, layout_det_res
  676. ):
  677. yield {"error": "the input params for model settings are invalid!"}
  678. external_overall_ocr_results = overall_ocr_res
  679. if external_overall_ocr_results is not None:
  680. if not isinstance(external_overall_ocr_results, list):
  681. external_overall_ocr_results = [external_overall_ocr_results]
  682. external_overall_ocr_results = iter(external_overall_ocr_results)
  683. external_layout_det_results = layout_det_res
  684. if external_layout_det_results is not None:
  685. if not isinstance(external_layout_det_results, list):
  686. external_layout_det_results = [external_layout_det_results]
  687. external_layout_det_results = iter(external_layout_det_results)
  688. for _, batch_data in enumerate(self.batch_sampler(input)):
  689. image_arrays = self.img_reader(batch_data.instances)
  690. if model_settings["use_doc_preprocessor"]:
  691. doc_preprocessor_results = list(
  692. self.doc_preprocessor_pipeline(
  693. image_arrays,
  694. use_doc_orientation_classify=use_doc_orientation_classify,
  695. use_doc_unwarping=use_doc_unwarping,
  696. )
  697. )
  698. else:
  699. doc_preprocessor_results = [{"output_img": arr} for arr in image_arrays]
  700. doc_preprocessor_images = [
  701. item["output_img"] for item in doc_preprocessor_results
  702. ]
  703. if model_settings["use_ocr_model"]:
  704. overall_ocr_results = list(
  705. self.general_ocr_pipeline(
  706. doc_preprocessor_images,
  707. text_det_limit_side_len=text_det_limit_side_len,
  708. text_det_limit_type=text_det_limit_type,
  709. text_det_thresh=text_det_thresh,
  710. text_det_box_thresh=text_det_box_thresh,
  711. text_det_unclip_ratio=text_det_unclip_ratio,
  712. text_rec_score_thresh=text_rec_score_thresh,
  713. )
  714. )
  715. else:
  716. overall_ocr_results = []
  717. for _ in doc_preprocessor_images:
  718. try:
  719. overall_ocr_res = next(external_overall_ocr_results)
  720. except StopIteration:
  721. raise ValueError("No more overall OCR results")
  722. overall_ocr_results.append(overall_ocr_res)
  723. if use_table_cells_ocr_results:
  724. # FIXME: This creates a new pipeline on each call.
  725. assert self.general_ocr_config_bak is not None
  726. self.general_ocr_pipeline = self.create_pipeline(
  727. self.general_ocr_config_bak
  728. )
  729. if (
  730. not model_settings["use_layout_detection"]
  731. and external_layout_det_results is None
  732. ):
  733. layout_det_results = [{} for _ in doc_preprocessor_images]
  734. table_boxes = []
  735. for img in doc_preprocessor_images:
  736. img_height, img_width = img.shape[:2]
  737. table_box = [0, 0, img_width - 1, img_height - 1]
  738. table_boxes.append(table_box)
  739. flat_table_results = self._predict(
  740. doc_preprocessor_images,
  741. overall_ocr_results,
  742. table_boxes,
  743. use_table_cells_ocr_results,
  744. use_e2e_wired_table_rec_model,
  745. use_e2e_wireless_table_rec_model,
  746. flag_find_nei_text=False,
  747. )
  748. for table_res in flat_table_results:
  749. table_res["table_region_id"] = 1
  750. table_results = [[item] for item in flat_table_results]
  751. else:
  752. if model_settings["use_layout_detection"]:
  753. layout_det_results = list(
  754. self.layout_det_model(doc_preprocessor_images)
  755. )
  756. else:
  757. layout_det_results = []
  758. for _ in doc_preprocessor_images:
  759. try:
  760. layout_det_res = next(external_layout_det_results)
  761. except StopIteration:
  762. raise ValueError("No more layout det results")
  763. layout_det_results.append(layout_det_res)
  764. cropped_imgs = []
  765. table_boxes = []
  766. repeated_overall_ocr_results = []
  767. chunk_indices = [0]
  768. for image_array, layout_det_res, overall_ocr_res in zip(
  769. image_arrays, layout_det_results, overall_ocr_results
  770. ):
  771. for box_info in layout_det_res["boxes"]:
  772. if box_info["label"].lower() in ["table"]:
  773. crop_img_info = self._crop_by_boxes(image_array, [box_info])
  774. crop_img_info = crop_img_info[0]
  775. cropped_imgs.append(crop_img_info["img"])
  776. table_boxes.append(crop_img_info["box"])
  777. repeated_overall_ocr_results.append(overall_ocr_res)
  778. chunk_indices.append(len(cropped_imgs))
  779. flat_table_results = self._predict(
  780. cropped_imgs,
  781. repeated_overall_ocr_results,
  782. table_boxes,
  783. use_table_cells_ocr_results,
  784. use_e2e_wired_table_rec_model,
  785. use_e2e_wireless_table_rec_model,
  786. )
  787. table_results = [
  788. flat_table_results[i:j]
  789. for i, j in zip(chunk_indices[:-1], chunk_indices[1:])
  790. ]
  791. for table_results_for_img in table_results:
  792. table_region_id = 1
  793. for table_res in table_results_for_img:
  794. table_res["table_region_id"] = table_region_id
  795. table_region_id += 1
  796. for (
  797. input_path,
  798. page_index,
  799. doc_preprocessor_res,
  800. layout_det_res,
  801. overall_ocr_res,
  802. table_results_for_img,
  803. ) in zip(
  804. batch_data.input_paths,
  805. batch_data.page_indexes,
  806. doc_preprocessor_results,
  807. layout_det_results,
  808. overall_ocr_results,
  809. table_results,
  810. ):
  811. single_img_res = {
  812. "input_path": input_path,
  813. "page_index": page_index,
  814. "doc_preprocessor_res": doc_preprocessor_res,
  815. "layout_det_res": layout_det_res,
  816. "overall_ocr_res": overall_ocr_res,
  817. "table_res_list": table_results_for_img,
  818. "model_settings": model_settings,
  819. }
  820. yield TableRecognitionResult(single_img_res)
  821. @pipeline_requires_extra("ocr")
  822. class TableRecognitionPipelineV2(AutoParallelImageSimpleInferencePipeline):
  823. entities = ["table_recognition_v2"]
  824. @property
  825. def _pipeline_cls(self):
  826. return _TableRecognitionPipelineV2
  827. def _get_batch_size(self, config):
  828. return config.get("batch_size", 1)