pipeline.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from typing import Any, Dict, List, Optional, Tuple, Union
  16. import numpy as np
  17. from ....utils import logging
  18. from ....utils.deps import pipeline_requires_extra
  19. from ...common.batch_sampler import ImageBatchSampler
  20. from ...common.reader import ReadImage
  21. from ...models.object_detection.result import DetResult
  22. from ...utils.benchmark import benchmark
  23. from ...utils.hpi import HPIConfig
  24. from ...utils.pp_option import PaddlePredictorOption
  25. from .._parallel import AutoParallelImageSimpleInferencePipeline
  26. from ..base import BasePipeline
  27. from ..components import CropByBoxes
  28. from ..doc_preprocessor.result import DocPreprocessorResult
  29. from ..ocr.result import OCRResult
  30. from .result import SingleTableRecognitionResult, TableRecognitionResult
  31. from .table_recognition_post_processing import get_table_recognition_res
  32. from .utils import get_neighbor_boxes_idx
  33. @benchmark.time_methods
  34. class _TableRecognitionPipeline(BasePipeline):
  35. """Table Recognition Pipeline"""
  36. def __init__(
  37. self,
  38. config: Dict,
  39. device: str = None,
  40. pp_option: PaddlePredictorOption = None,
  41. use_hpip: bool = False,
  42. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  43. ) -> None:
  44. """Initializes the layout parsing pipeline.
  45. Args:
  46. config (Dict): Configuration dictionary containing various settings.
  47. device (str, optional): Device to run the predictions on. Defaults to None.
  48. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  49. use_hpip (bool, optional): Whether to use the high-performance
  50. inference plugin (HPIP) by default. Defaults to False.
  51. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
  52. The default high-performance inference configuration dictionary.
  53. Defaults to None.
  54. """
  55. super().__init__(
  56. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_config=hpi_config
  57. )
  58. self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
  59. if self.use_doc_preprocessor:
  60. doc_preprocessor_config = config.get("SubPipelines", {}).get(
  61. "DocPreprocessor",
  62. {
  63. "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
  64. },
  65. )
  66. self.doc_preprocessor_pipeline = self.create_pipeline(
  67. doc_preprocessor_config
  68. )
  69. self.use_layout_detection = config.get("use_layout_detection", True)
  70. if self.use_layout_detection:
  71. layout_det_config = config.get("SubModules", {}).get(
  72. "LayoutDetection",
  73. {"model_config_error": "config error for layout_det_model!"},
  74. )
  75. self.layout_det_model = self.create_model(layout_det_config)
  76. table_structure_config = config.get("SubModules", {}).get(
  77. "TableStructureRecognition",
  78. {"model_config_error": "config error for table_structure_model!"},
  79. )
  80. self.table_structure_model = self.create_model(table_structure_config)
  81. self.use_ocr_model = config.get("use_ocr_model", True)
  82. if self.use_ocr_model:
  83. general_ocr_config = config.get("SubPipelines", {}).get(
  84. "GeneralOCR",
  85. {"pipeline_config_error": "config error for general_ocr_pipeline!"},
  86. )
  87. self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
  88. else:
  89. self.general_ocr_config_bak = config.get("SubPipelines", {}).get(
  90. "GeneralOCR", None
  91. )
  92. self._crop_by_boxes = CropByBoxes()
  93. self.batch_sampler = ImageBatchSampler(batch_size=1)
  94. self.img_reader = ReadImage(format="BGR")
  95. def get_model_settings(
  96. self,
  97. use_doc_orientation_classify: Optional[bool],
  98. use_doc_unwarping: Optional[bool],
  99. use_layout_detection: Optional[bool],
  100. use_ocr_model: Optional[bool],
  101. ) -> dict:
  102. """
  103. Get the model settings based on the provided parameters or default values.
  104. Args:
  105. use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
  106. use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
  107. use_layout_detection (Optional[bool]): Whether to use layout detection.
  108. use_ocr_model (Optional[bool]): Whether to use OCR model.
  109. Returns:
  110. dict: A dictionary containing the model settings.
  111. """
  112. if use_doc_orientation_classify is None and use_doc_unwarping is None:
  113. use_doc_preprocessor = self.use_doc_preprocessor
  114. else:
  115. if use_doc_orientation_classify is True or use_doc_unwarping is True:
  116. use_doc_preprocessor = True
  117. else:
  118. use_doc_preprocessor = False
  119. if use_layout_detection is None:
  120. use_layout_detection = self.use_layout_detection
  121. if use_ocr_model is None:
  122. use_ocr_model = self.use_ocr_model
  123. return dict(
  124. use_doc_preprocessor=use_doc_preprocessor,
  125. use_layout_detection=use_layout_detection,
  126. use_ocr_model=use_ocr_model,
  127. )
  128. def check_model_settings_valid(
  129. self,
  130. model_settings: Dict,
  131. overall_ocr_res: OCRResult,
  132. layout_det_res: DetResult,
  133. ) -> bool:
  134. """
  135. Check if the input parameters are valid based on the initialized models.
  136. Args:
  137. model_settings (Dict): A dictionary containing input parameters.
  138. overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
  139. The overall OCR result with convert_points_to_boxes information.
  140. layout_det_res (DetResult): The layout detection result.
  141. Returns:
  142. bool: True if all required models are initialized according to input parameters, False otherwise.
  143. """
  144. if model_settings["use_doc_preprocessor"] and not self.use_doc_preprocessor:
  145. logging.error(
  146. "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
  147. )
  148. return False
  149. if model_settings["use_layout_detection"]:
  150. if layout_det_res is not None:
  151. logging.error(
  152. "The layout detection model has already been initialized, please set use_layout_detection=False"
  153. )
  154. return False
  155. if not self.use_layout_detection:
  156. logging.error(
  157. "Set use_layout_detection, but the models for layout detection are not initialized."
  158. )
  159. return False
  160. if model_settings["use_ocr_model"]:
  161. if overall_ocr_res is not None:
  162. logging.error(
  163. "The OCR models have already been initialized, please set use_ocr_model=False"
  164. )
  165. return False
  166. if not self.use_ocr_model:
  167. logging.error(
  168. "Set use_ocr_model, but the models for OCR are not initialized."
  169. )
  170. return False
  171. else:
  172. if overall_ocr_res is None:
  173. logging.error("Set use_ocr_model=False, but no OCR results were found.")
  174. return False
  175. return True
  176. def predict_doc_preprocessor_res(
  177. self, image_array: np.ndarray, input_params: dict
  178. ) -> Tuple[DocPreprocessorResult, np.ndarray]:
  179. """
  180. Preprocess the document image based on input parameters.
  181. Args:
  182. image_array (np.ndarray): The input image array.
  183. input_params (dict): Dictionary containing preprocessing parameters.
  184. Returns:
  185. tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
  186. result dictionary and the processed image array.
  187. """
  188. if input_params["use_doc_preprocessor"]:
  189. use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
  190. use_doc_unwarping = input_params["use_doc_unwarping"]
  191. doc_preprocessor_res = list(
  192. self.doc_preprocessor_pipeline(
  193. image_array,
  194. use_doc_orientation_classify=use_doc_orientation_classify,
  195. use_doc_unwarping=use_doc_unwarping,
  196. )
  197. )[0]
  198. doc_preprocessor_image = doc_preprocessor_res["output_img"]
  199. else:
  200. doc_preprocessor_res = {}
  201. doc_preprocessor_image = image_array
  202. return doc_preprocessor_res, doc_preprocessor_image
  203. def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes):
  204. """
  205. Splits OCR bounding boxes by table cells and retrieves text.
  206. Args:
  207. ori_img (ndarray): The original image from which text regions will be extracted.
  208. cells_bboxes (list or ndarray): Detected cell bounding boxes to extract text from.
  209. Returns:
  210. list: A list containing the recognized texts from each cell.
  211. """
  212. # Check if cells_bboxes is a list and convert it if not.
  213. if not isinstance(cells_bboxes, list):
  214. cells_bboxes = cells_bboxes.tolist()
  215. texts_list = [] # Initialize a list to store the recognized texts.
  216. # Process each bounding box provided in cells_bboxes.
  217. for i in range(len(cells_bboxes)):
  218. # Extract and round up the coordinates of the bounding box.
  219. x1, y1, x2, y2 = [math.ceil(k) for k in cells_bboxes[i]]
  220. # Perform OCR on the defined region of the image and get the recognized text.
  221. rec_te = list(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))[0]
  222. # Concatenate the texts and append them to the texts_list.
  223. texts_list.append("".join(rec_te["rec_texts"]))
  224. # Return the list of recognized texts from each cell.
  225. return texts_list
  226. def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes):
  227. """
  228. Splits OCR bounding boxes by table cells and retrieves text.
  229. Args:
  230. ori_img (ndarray): The original image from which text regions will be extracted.
  231. cells_bboxes (list or ndarray): Detected cell bounding boxes to extract text from.
  232. Returns:
  233. list: A list containing the recognized texts from each cell.
  234. """
  235. # Check if cells_bboxes is a list and convert it if not.
  236. if not isinstance(cells_bboxes, list):
  237. cells_bboxes = cells_bboxes.tolist()
  238. texts_list = [] # Initialize a list to store the recognized texts.
  239. # Process each bounding box provided in cells_bboxes.
  240. for i in range(len(cells_bboxes)):
  241. # Extract and round up the coordinates of the bounding box.
  242. x1, y1, x2, y2 = [math.ceil(k) for k in cells_bboxes[i]]
  243. # Perform OCR on the defined region of the image and get the recognized text.
  244. rec_te = list(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))[0]
  245. # Concatenate the texts and append them to the texts_list.
  246. texts_list.append("".join(rec_te["rec_texts"]))
  247. # Return the list of recognized texts from each cell.
  248. return texts_list
  249. def predict_single_table_recognition_res(
  250. self,
  251. image_array: np.ndarray,
  252. overall_ocr_res: OCRResult,
  253. table_box: list,
  254. use_ocr_results_with_table_cells: bool = False,
  255. flag_find_nei_text: bool = True,
  256. cell_sort_by_y_projection: bool = False,
  257. ) -> SingleTableRecognitionResult:
  258. """
  259. Predict table recognition results from an image array, layout detection results, and OCR results.
  260. Args:
  261. image_array (np.ndarray): The input image represented as a numpy array.
  262. overall_ocr_res (OCRResult): Overall OCR result obtained after running the OCR pipeline.
  263. The overall OCR results containing text recognition information.
  264. table_box (list): The table box coordinates.
  265. use_ocr_results_with_table_cells (bool): whether to use OCR results with cells.
  266. flag_find_nei_text (bool): Whether to find neighboring text.
  267. cell_sort_by_y_projection (bool): Whether to sort the matched OCR boxes by y-projection.
  268. Returns:
  269. SingleTableRecognitionResult: single table recognition result.
  270. """
  271. table_structure_pred = list(self.table_structure_model(image_array))[0]
  272. if use_ocr_results_with_table_cells == True:
  273. table_cells_result = table_structure_pred["bbox"]
  274. table_cells_result = [
  275. [rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result
  276. ]
  277. cells_texts_list = self.split_ocr_bboxes_by_table_cells(
  278. image_array, table_cells_result
  279. )
  280. else:
  281. cells_texts_list = []
  282. single_table_recognition_res = get_table_recognition_res(
  283. table_box,
  284. table_structure_pred,
  285. overall_ocr_res,
  286. cells_texts_list,
  287. use_ocr_results_with_table_cells,
  288. cell_sort_by_y_projection=cell_sort_by_y_projection,
  289. )
  290. neighbor_text = ""
  291. if flag_find_nei_text:
  292. match_idx_list = get_neighbor_boxes_idx(
  293. overall_ocr_res["rec_boxes"], table_box
  294. )
  295. if len(match_idx_list) > 0:
  296. for idx in match_idx_list:
  297. neighbor_text += overall_ocr_res["rec_texts"][idx] + "; "
  298. single_table_recognition_res["neighbor_texts"] = neighbor_text
  299. return single_table_recognition_res
  300. def predict(
  301. self,
  302. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  303. use_doc_orientation_classify: Optional[bool] = None,
  304. use_doc_unwarping: Optional[bool] = None,
  305. use_layout_detection: Optional[bool] = None,
  306. use_ocr_model: Optional[bool] = None,
  307. overall_ocr_res: Optional[OCRResult] = None,
  308. layout_det_res: Optional[DetResult] = None,
  309. text_det_limit_side_len: Optional[int] = None,
  310. text_det_limit_type: Optional[str] = None,
  311. text_det_thresh: Optional[float] = None,
  312. text_det_box_thresh: Optional[float] = None,
  313. text_det_unclip_ratio: Optional[float] = None,
  314. text_rec_score_thresh: Optional[float] = None,
  315. use_ocr_results_with_table_cells: bool = False,
  316. cell_sort_by_y_projection: Optional[bool] = None,
  317. **kwargs,
  318. ) -> TableRecognitionResult:
  319. """
  320. This function predicts the layout parsing result for the given input.
  321. Args:
  322. input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) of pdf(s) to be processed.
  323. use_layout_detection (bool): Whether to use layout detection.
  324. use_doc_orientation_classify (bool): Whether to use document orientation classification.
  325. use_doc_unwarping (bool): Whether to use document unwarping.
  326. overall_ocr_res (OCRResult): The overall OCR result with convert_points_to_boxes information.
  327. It will be used if it is not None and use_ocr_model is False.
  328. layout_det_res (DetResult): The layout detection result.
  329. It will be used if it is not None and use_layout_detection is False.
  330. use_ocr_results_with_table_cells (bool): whether to use OCR results with cells.
  331. cell_sort_by_y_projection (bool): Whether to sort the matched OCR boxes by y-projection.
  332. **kwargs: Additional keyword arguments.
  333. Returns:
  334. TableRecognitionResult: The predicted table recognition result.
  335. """
  336. model_settings = self.get_model_settings(
  337. use_doc_orientation_classify,
  338. use_doc_unwarping,
  339. use_layout_detection,
  340. use_ocr_model,
  341. )
  342. if cell_sort_by_y_projection is None:
  343. cell_sort_by_y_projection = False
  344. if not self.check_model_settings_valid(
  345. model_settings, overall_ocr_res, layout_det_res
  346. ):
  347. yield {"error": "the input params for model settings are invalid!"}
  348. for img_id, batch_data in enumerate(self.batch_sampler(input)):
  349. image_array = self.img_reader(batch_data.instances)[0]
  350. if model_settings["use_doc_preprocessor"]:
  351. doc_preprocessor_res = list(
  352. self.doc_preprocessor_pipeline(
  353. image_array,
  354. use_doc_orientation_classify=use_doc_orientation_classify,
  355. use_doc_unwarping=use_doc_unwarping,
  356. )
  357. )[0]
  358. else:
  359. doc_preprocessor_res = {"output_img": image_array}
  360. doc_preprocessor_image = doc_preprocessor_res["output_img"]
  361. if model_settings["use_ocr_model"]:
  362. overall_ocr_res = list(
  363. self.general_ocr_pipeline(
  364. doc_preprocessor_image,
  365. text_det_limit_side_len=text_det_limit_side_len,
  366. text_det_limit_type=text_det_limit_type,
  367. text_det_thresh=text_det_thresh,
  368. text_det_box_thresh=text_det_box_thresh,
  369. text_det_unclip_ratio=text_det_unclip_ratio,
  370. text_rec_score_thresh=text_rec_score_thresh,
  371. )
  372. )[0]
  373. elif use_ocr_results_with_table_cells == True:
  374. assert self.general_ocr_config_bak != None
  375. self.general_ocr_pipeline = self.create_pipeline(
  376. self.general_ocr_config_bak
  377. )
  378. table_res_list = []
  379. table_region_id = 1
  380. if not model_settings["use_layout_detection"] and layout_det_res is None:
  381. layout_det_res = {}
  382. img_height, img_width = doc_preprocessor_image.shape[:2]
  383. table_box = [0, 0, img_width - 1, img_height - 1]
  384. single_table_rec_res = self.predict_single_table_recognition_res(
  385. doc_preprocessor_image,
  386. overall_ocr_res,
  387. table_box,
  388. use_ocr_results_with_table_cells,
  389. flag_find_nei_text=False,
  390. cell_sort_by_y_projection=cell_sort_by_y_projection,
  391. )
  392. single_table_rec_res["table_region_id"] = table_region_id
  393. table_res_list.append(single_table_rec_res)
  394. table_region_id += 1
  395. else:
  396. if model_settings["use_layout_detection"]:
  397. layout_det_res = list(
  398. self.layout_det_model(doc_preprocessor_image)
  399. )[0]
  400. for box_info in layout_det_res["boxes"]:
  401. if box_info["label"].lower() in ["table"]:
  402. crop_img_info = self._crop_by_boxes(image_array, [box_info])
  403. crop_img_info = crop_img_info[0]
  404. table_box = crop_img_info["box"]
  405. single_table_rec_res = (
  406. self.predict_single_table_recognition_res(
  407. crop_img_info["img"],
  408. overall_ocr_res,
  409. table_box,
  410. use_ocr_results_with_table_cells,
  411. cell_sort_by_y_projection=cell_sort_by_y_projection,
  412. )
  413. )
  414. single_table_rec_res["table_region_id"] = table_region_id
  415. table_res_list.append(single_table_rec_res)
  416. table_region_id += 1
  417. single_img_res = {
  418. "input_path": batch_data.input_paths[0],
  419. "page_index": batch_data.page_indexes[0],
  420. "doc_preprocessor_res": doc_preprocessor_res,
  421. "layout_det_res": layout_det_res,
  422. "overall_ocr_res": overall_ocr_res,
  423. "table_res_list": table_res_list,
  424. "model_settings": model_settings,
  425. }
  426. yield TableRecognitionResult(single_img_res)
  427. @pipeline_requires_extra("ocr")
  428. class TableRecognitionPipeline(AutoParallelImageSimpleInferencePipeline):
  429. entities = ["table_recognition"]
  430. @property
  431. def _pipeline_cls(self):
  432. return _TableRecognitionPipeline
  433. def _get_batch_size(self, config):
  434. return 1