pipeline.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from email.mime import image
  15. from typing import Any, Dict, Optional, Union, List, Tuple
  16. import numpy as np
  17. from ..base import BasePipeline
  18. from .utils import get_sub_regions_ocr_res, sorted_layout_boxes
  19. from ..components import convert_points_to_boxes
  20. from .result import LayoutParsingResult
  21. from ....utils import logging
  22. from ...utils.pp_option import PaddlePredictorOption
  23. from ...common.reader import ReadImage
  24. from ...common.batch_sampler import ImageBatchSampler
  25. from ..ocr.result import OCRResult
  26. # [TODO] 待更新models_new到models
  27. from ...models_new.object_detection.result import DetResult
  28. class LayoutParsingPipeline(BasePipeline):
  29. """Layout Parsing Pipeline"""
  30. entities = ["layout_parsing"]
  31. def __init__(
  32. self,
  33. config: Dict,
  34. device: str = None,
  35. pp_option: PaddlePredictorOption = None,
  36. use_hpip: bool = False,
  37. ) -> None:
  38. """Initializes the layout parsing pipeline.
  39. Args:
  40. config (Dict): Configuration dictionary containing various settings.
  41. device (str, optional): Device to run the predictions on. Defaults to None.
  42. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  43. use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
  44. """
  45. super().__init__(device=device, pp_option=pp_option, use_hpip=use_hpip)
  46. self.inintial_predictor(config)
  47. self.batch_sampler = ImageBatchSampler(batch_size=1)
  48. self.img_reader = ReadImage(format="BGR")
  49. def inintial_predictor(self, config: Dict) -> None:
  50. """Initializes the predictor based on the provided configuration.
  51. Args:
  52. config (Dict): A dictionary containing the configuration for the predictor.
  53. Returns:
  54. None
  55. """
  56. self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
  57. self.use_general_ocr = config.get("use_general_ocr", True)
  58. self.use_table_recognition = config.get("use_table_recognition", True)
  59. self.use_seal_recognition = config.get("use_seal_recognition", True)
  60. self.use_formula_recognition = config.get("use_formula_recognition", True)
  61. if self.use_doc_preprocessor:
  62. doc_preprocessor_config = config.get("SubPipelines", {}).get(
  63. "DocPreprocessor",
  64. {
  65. "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
  66. },
  67. )
  68. self.doc_preprocessor_pipeline = self.create_pipeline(
  69. doc_preprocessor_config
  70. )
  71. layout_det_config = config.get("SubModules", {}).get(
  72. "LayoutDetection",
  73. {"model_config_error": "config error for layout_det_model!"},
  74. )
  75. self.layout_det_model = self.create_model(layout_det_config)
  76. layout_kwargs = {}
  77. if (threshold := layout_det_config.get("threshold", None)) is not None:
  78. layout_kwargs["threshold"] = threshold
  79. if (layout_nms := layout_det_config.get("layout_nms", None)) is not None:
  80. layout_kwargs["layout_nms"] = layout_nms
  81. if (
  82. layout_unclip_ratio := layout_det_config.get("layout_unclip_ratio", None)
  83. ) is not None:
  84. layout_kwargs["layout_unclip_ratio"] = layout_unclip_ratio
  85. if (
  86. layout_merge_bboxes_mode := layout_det_config.get(
  87. "layout_merge_bboxes_mode", None
  88. )
  89. ) is not None:
  90. layout_kwargs["layout_merge_bboxes_mode"] = layout_merge_bboxes_mode
  91. self.layout_det_model = self.create_model(layout_det_config, **layout_kwargs)
  92. if self.use_general_ocr or self.use_table_recognition:
  93. general_ocr_config = config.get("SubPipelines", {}).get(
  94. "GeneralOCR",
  95. {"pipeline_config_error": "config error for general_ocr_pipeline!"},
  96. )
  97. self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
  98. if self.use_seal_recognition:
  99. seal_recognition_config = config.get("SubPipelines", {}).get(
  100. "SealRecognition",
  101. {
  102. "pipeline_config_error": "config error for seal_recognition_pipeline!"
  103. },
  104. )
  105. self.seal_recognition_pipeline = self.create_pipeline(
  106. seal_recognition_config
  107. )
  108. if self.use_table_recognition:
  109. table_recognition_config = config.get("SubPipelines", {}).get(
  110. "TableRecognition",
  111. {
  112. "pipeline_config_error": "config error for table_recognition_pipeline!"
  113. },
  114. )
  115. self.table_recognition_pipeline = self.create_pipeline(
  116. table_recognition_config
  117. )
  118. if self.use_formula_recognition:
  119. formula_recognition_config = config.get("SubPipelines", {}).get(
  120. "FormulaRecognition",
  121. {
  122. "pipeline_config_error": "config error for formula_recognition_pipeline!"
  123. },
  124. )
  125. self.formula_recognition_pipeline = self.create_pipeline(
  126. formula_recognition_config
  127. )
  128. return
  129. def get_text_paragraphs_ocr_res(
  130. self, overall_ocr_res: OCRResult, layout_det_res: DetResult
  131. ) -> OCRResult:
  132. """
  133. Retrieves the OCR results for text paragraphs, excluding those of formulas, tables, and seals.
  134. Args:
  135. overall_ocr_res (OCRResult): The overall OCR result containing text information.
  136. layout_det_res (DetResult): The detection result containing the layout information of the document.
  137. Returns:
  138. OCRResult: The OCR result for text paragraphs after excluding formulas, tables, and seals.
  139. """
  140. object_boxes = []
  141. for box_info in layout_det_res["boxes"]:
  142. if box_info["label"].lower() in ["formula", "table", "seal"]:
  143. object_boxes.append(box_info["coordinate"])
  144. object_boxes = np.array(object_boxes)
  145. sub_regions_ocr_res = get_sub_regions_ocr_res(
  146. overall_ocr_res, object_boxes, flag_within=False
  147. )
  148. return sub_regions_ocr_res
  149. def get_layout_parsing_res(
  150. self,
  151. image: list,
  152. layout_det_res: DetResult,
  153. overall_ocr_res: OCRResult,
  154. table_res_list: list,
  155. seal_res_list: list,
  156. formula_res_list: list,
  157. text_det_limit_side_len: Optional[int] = None,
  158. text_det_limit_type: Optional[str] = None,
  159. text_det_thresh: Optional[float] = None,
  160. text_det_box_thresh: Optional[float] = None,
  161. text_det_unclip_ratio: Optional[float] = None,
  162. text_rec_score_thresh: Optional[float] = None,
  163. ) -> list:
  164. """
  165. Retrieves the layout parsing result based on the layout detection result, OCR result, and other recognition results.
  166. Args:
  167. image (list): The input image.
  168. layout_det_res (DetResult): The detection result containing the layout information of the document.
  169. overall_ocr_res (OCRResult): The overall OCR result containing text information.
  170. table_res_list (list): A list of table recognition results.
  171. seal_res_list (list): A list of seal recognition results.
  172. formula_res_list (list): A list of formula recognition results.
  173. text_det_limit_side_len (Optional[int], optional): The maximum side length of the text detection region. Defaults to None.
  174. text_det_limit_type (Optional[str], optional): The type of limit for the text detection region. Defaults to None.
  175. text_det_thresh (Optional[float], optional): The confidence threshold for text detection. Defaults to None.
  176. text_det_box_thresh (Optional[float], optional): The confidence threshold for text detection bounding boxes. Defaults to None
  177. text_det_unclip_ratio (Optional[float], optional): The unclip ratio for text detection. Defaults to None.
  178. text_rec_score_thresh (Optional[float], optional): The score threshold for text recognition. Defaults to None.
  179. Returns:
  180. list: A list of dictionaries representing the layout parsing result.
  181. """
  182. layout_parsing_res = []
  183. matched_ocr_dict = {}
  184. formula_index = 0
  185. table_index = 0
  186. seal_index = 0
  187. image = np.array(image)
  188. image_labels = ["image", "figure", "img", "fig"]
  189. object_boxes = []
  190. for object_box_idx, box_info in enumerate(layout_det_res["boxes"]):
  191. single_box_res = {}
  192. box = box_info["coordinate"]
  193. label = box_info["label"].lower()
  194. single_box_res["layout_bbox"] = box
  195. object_boxes.append(box)
  196. if label == "formula":
  197. single_box_res["formula"] = formula_res_list[formula_index][
  198. "rec_formula"
  199. ]
  200. formula_index += 1
  201. elif label == "table":
  202. single_box_res["table"] = table_res_list[table_index]["pred_html"]
  203. table_index += 1
  204. elif label == "seal":
  205. single_box_res["seal"] = "".join(seal_res_list[seal_index]["rec_texts"])
  206. seal_index += 1
  207. else:
  208. ocr_res_in_box, matched_idxs = get_sub_regions_ocr_res(
  209. overall_ocr_res, [box], return_match_idx=True
  210. )
  211. for matched_idx in matched_idxs:
  212. if matched_ocr_dict.get(matched_idx, None) is None:
  213. matched_ocr_dict[matched_idx] = [object_box_idx]
  214. else:
  215. matched_ocr_dict[matched_idx].append(object_box_idx)
  216. if label in image_labels:
  217. x1, y1, x2, y2 = [int(i) for i in box]
  218. sub_image = image[y1:y2, x1:x2, :]
  219. single_box_res["image"] = sub_image
  220. single_box_res[f"{label}_text"] = "\n".join(
  221. ocr_res_in_box["rec_texts"]
  222. )
  223. else:
  224. single_box_res["text"] = "\n".join(ocr_res_in_box["rec_texts"])
  225. if single_box_res:
  226. layout_parsing_res.append(single_box_res)
  227. for layout_box_ids in matched_ocr_dict.values():
  228. # one ocr is matched to multiple layout boxes, split the text into multiple lines
  229. if len(layout_box_ids) > 1:
  230. for idx in layout_box_ids:
  231. wht_im = np.ones(image.shape, dtype=image.dtype) * 255
  232. box = layout_parsing_res[idx]["layout_bbox"]
  233. x1, y1, x2, y2 = [int(i) for i in box]
  234. wht_im[y1:y2, x1:x2, :] = image[y1:y2, x1:x2, :]
  235. sub_ocr_res = next(
  236. self.general_ocr_pipeline(
  237. wht_im,
  238. text_det_limit_side_len=text_det_limit_side_len,
  239. text_det_limit_type=text_det_limit_type,
  240. text_det_thresh=text_det_thresh,
  241. text_det_box_thresh=text_det_box_thresh,
  242. text_det_unclip_ratio=text_det_unclip_ratio,
  243. text_rec_score_thresh=text_rec_score_thresh,
  244. )
  245. )
  246. layout_parsing_res[idx]["text"] = "\n".join(
  247. sub_ocr_res["rec_texts"]
  248. )
  249. ocr_without_layout_boxes = get_sub_regions_ocr_res(
  250. overall_ocr_res, object_boxes, flag_within=False
  251. )
  252. for ocr_rec_box, ocr_rec_text in zip(
  253. ocr_without_layout_boxes["rec_boxes"], ocr_without_layout_boxes["rec_texts"]
  254. ):
  255. single_box_res = {}
  256. single_box_res["layout_bbox"] = ocr_rec_box
  257. single_box_res["text_without_layout"] = ocr_rec_text
  258. layout_parsing_res.append(single_box_res)
  259. layout_parsing_res = sorted_layout_boxes(layout_parsing_res, w=image.shape[1])
  260. return layout_parsing_res
  261. def check_model_settings_valid(self, input_params: Dict) -> bool:
  262. """
  263. Check if the input parameters are valid based on the initialized models.
  264. Args:
  265. input_params (Dict): A dictionary containing input parameters.
  266. Returns:
  267. bool: True if all required models are initialized according to input parameters, False otherwise.
  268. """
  269. if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
  270. logging.error(
  271. "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
  272. )
  273. return False
  274. if input_params["use_general_ocr"] and not self.use_general_ocr:
  275. logging.error(
  276. "Set use_general_ocr, but the models for general OCR are not initialized."
  277. )
  278. return False
  279. if input_params["use_seal_recognition"] and not self.use_seal_recognition:
  280. logging.error(
  281. "Set use_seal_recognition, but the models for seal recognition are not initialized."
  282. )
  283. return False
  284. if input_params["use_table_recognition"] and not self.use_table_recognition:
  285. logging.error(
  286. "Set use_table_recognition, but the models for table recognition are not initialized."
  287. )
  288. return False
  289. return True
  290. def get_model_settings(
  291. self,
  292. use_doc_orientation_classify: Optional[bool],
  293. use_doc_unwarping: Optional[bool],
  294. use_general_ocr: Optional[bool],
  295. use_seal_recognition: Optional[bool],
  296. use_table_recognition: Optional[bool],
  297. use_formula_recognition: Optional[bool],
  298. ) -> dict:
  299. """
  300. Get the model settings based on the provided parameters or default values.
  301. Args:
  302. use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
  303. use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
  304. use_general_ocr (Optional[bool]): Whether to use general OCR.
  305. use_seal_recognition (Optional[bool]): Whether to use seal recognition.
  306. use_table_recognition (Optional[bool]): Whether to use table recognition.
  307. Returns:
  308. dict: A dictionary containing the model settings.
  309. """
  310. if use_doc_orientation_classify is None and use_doc_unwarping is None:
  311. use_doc_preprocessor = self.use_doc_preprocessor
  312. else:
  313. if use_doc_orientation_classify is True or use_doc_unwarping is True:
  314. use_doc_preprocessor = True
  315. else:
  316. use_doc_preprocessor = False
  317. if use_general_ocr is None:
  318. use_general_ocr = self.use_general_ocr
  319. if use_seal_recognition is None:
  320. use_seal_recognition = self.use_seal_recognition
  321. if use_table_recognition is None:
  322. use_table_recognition = self.use_table_recognition
  323. if use_formula_recognition is None:
  324. use_formula_recognition = self.use_formula_recognition
  325. return dict(
  326. use_doc_preprocessor=use_doc_preprocessor,
  327. use_general_ocr=use_general_ocr,
  328. use_seal_recognition=use_seal_recognition,
  329. use_table_recognition=use_table_recognition,
  330. use_formula_recognition=use_formula_recognition,
  331. )
  332. def predict(
  333. self,
  334. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  335. use_doc_orientation_classify: Optional[bool] = None,
  336. use_doc_unwarping: Optional[bool] = None,
  337. use_general_ocr: Optional[bool] = None,
  338. use_seal_recognition: Optional[bool] = None,
  339. use_table_recognition: Optional[bool] = None,
  340. use_formula_recognition: Optional[bool] = None,
  341. text_det_limit_side_len: Optional[int] = None,
  342. text_det_limit_type: Optional[str] = None,
  343. text_det_thresh: Optional[float] = None,
  344. text_det_box_thresh: Optional[float] = None,
  345. text_det_unclip_ratio: Optional[float] = None,
  346. text_rec_score_thresh: Optional[float] = None,
  347. seal_det_limit_side_len: Optional[int] = None,
  348. seal_det_limit_type: Optional[str] = None,
  349. seal_det_thresh: Optional[float] = None,
  350. seal_det_box_thresh: Optional[float] = None,
  351. seal_det_unclip_ratio: Optional[float] = None,
  352. seal_rec_score_thresh: Optional[float] = None,
  353. layout_threshold: Optional[Union[float, dict]] = None,
  354. layout_nms: Optional[bool] = None,
  355. layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
  356. layout_merge_bboxes_mode: Optional[str] = None,
  357. **kwargs,
  358. ) -> LayoutParsingResult:
  359. """
  360. This function predicts the layout parsing result for the given input.
  361. Args:
  362. input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) or pdf(s) to be processed.
  363. use_doc_orientation_classify (bool): Whether to use document orientation classification.
  364. use_doc_unwarping (bool): Whether to use document unwarping.
  365. use_general_ocr (bool): Whether to use general OCR.
  366. use_seal_recognition (bool): Whether to use seal recognition.
  367. use_table_recognition (bool): Whether to use table recognition.
  368. **kwargs: Additional keyword arguments.
  369. Returns:
  370. LayoutParsingResult: The predicted layout parsing result.
  371. """
  372. model_settings = self.get_model_settings(
  373. use_doc_orientation_classify,
  374. use_doc_unwarping,
  375. use_general_ocr,
  376. use_seal_recognition,
  377. use_table_recognition,
  378. use_formula_recognition,
  379. )
  380. if not self.check_model_settings_valid(model_settings):
  381. yield {"error": "the input params for model settings are invalid!"}
  382. for img_id, batch_data in enumerate(self.batch_sampler(input)):
  383. image_array = self.img_reader(batch_data.instances)[0]
  384. if model_settings["use_doc_preprocessor"]:
  385. doc_preprocessor_res = next(
  386. self.doc_preprocessor_pipeline(
  387. image_array,
  388. use_doc_orientation_classify=use_doc_orientation_classify,
  389. use_doc_unwarping=use_doc_unwarping,
  390. )
  391. )
  392. else:
  393. doc_preprocessor_res = {"output_img": image_array}
  394. doc_preprocessor_image = doc_preprocessor_res["output_img"]
  395. layout_det_res = next(
  396. self.layout_det_model(
  397. doc_preprocessor_image,
  398. threshold=layout_threshold,
  399. layout_nms=layout_nms,
  400. layout_unclip_ratio=layout_unclip_ratio,
  401. layout_merge_bboxes_mode=layout_merge_bboxes_mode,
  402. )
  403. )
  404. if (
  405. model_settings["use_general_ocr"]
  406. or model_settings["use_table_recognition"]
  407. ):
  408. overall_ocr_res = next(
  409. self.general_ocr_pipeline(
  410. doc_preprocessor_image,
  411. text_det_limit_side_len=text_det_limit_side_len,
  412. text_det_limit_type=text_det_limit_type,
  413. text_det_thresh=text_det_thresh,
  414. text_det_box_thresh=text_det_box_thresh,
  415. text_det_unclip_ratio=text_det_unclip_ratio,
  416. text_rec_score_thresh=text_rec_score_thresh,
  417. )
  418. )
  419. else:
  420. overall_ocr_res = {}
  421. if model_settings["use_general_ocr"]:
  422. text_paragraphs_ocr_res = self.get_text_paragraphs_ocr_res(
  423. overall_ocr_res, layout_det_res
  424. )
  425. else:
  426. text_paragraphs_ocr_res = {}
  427. if model_settings["use_table_recognition"]:
  428. table_res_all = next(
  429. self.table_recognition_pipeline(
  430. doc_preprocessor_image,
  431. use_doc_orientation_classify=False,
  432. use_doc_unwarping=False,
  433. use_layout_detection=False,
  434. use_ocr_model=False,
  435. overall_ocr_res=overall_ocr_res,
  436. layout_det_res=layout_det_res,
  437. )
  438. )
  439. table_res_list = table_res_all["table_res_list"]
  440. else:
  441. table_res_list = []
  442. if model_settings["use_seal_recognition"]:
  443. seal_res_all = next(
  444. self.seal_recognition_pipeline(
  445. doc_preprocessor_image,
  446. use_doc_orientation_classify=False,
  447. use_doc_unwarping=False,
  448. use_layout_detection=False,
  449. layout_det_res=layout_det_res,
  450. seal_det_limit_side_len=seal_det_limit_side_len,
  451. seal_det_limit_type=seal_det_limit_type,
  452. seal_det_thresh=seal_det_thresh,
  453. seal_det_box_thresh=seal_det_box_thresh,
  454. seal_det_unclip_ratio=seal_det_unclip_ratio,
  455. seal_rec_score_thresh=seal_rec_score_thresh,
  456. )
  457. )
  458. seal_res_list = seal_res_all["seal_res_list"]
  459. else:
  460. seal_res_list = []
  461. if model_settings["use_formula_recognition"]:
  462. formula_res_all = next(
  463. self.formula_recognition_pipeline(
  464. doc_preprocessor_image,
  465. use_layout_detection=False,
  466. use_doc_orientation_classify=False,
  467. use_doc_unwarping=False,
  468. layout_det_res=layout_det_res,
  469. )
  470. )
  471. formula_res_list = formula_res_all["formula_res_list"]
  472. else:
  473. formula_res_list = []
  474. parsing_res_list = self.get_layout_parsing_res(
  475. doc_preprocessor_image,
  476. layout_det_res=layout_det_res,
  477. overall_ocr_res=overall_ocr_res,
  478. table_res_list=table_res_list,
  479. seal_res_list=seal_res_list,
  480. formula_res_list=formula_res_list,
  481. text_det_limit_side_len=text_det_limit_side_len,
  482. text_det_limit_type=text_det_limit_type,
  483. text_det_thresh=text_det_thresh,
  484. text_det_box_thresh=text_det_box_thresh,
  485. text_det_unclip_ratio=text_det_unclip_ratio,
  486. text_rec_score_thresh=text_rec_score_thresh,
  487. )
  488. single_img_res = {
  489. "input_path": batch_data.input_paths[0],
  490. "page_index": batch_data.page_indexes[0],
  491. "doc_preprocessor_res": doc_preprocessor_res,
  492. "layout_det_res": layout_det_res,
  493. "overall_ocr_res": overall_ocr_res,
  494. "text_paragraphs_ocr_res": text_paragraphs_ocr_res,
  495. "table_res_list": table_res_list,
  496. "seal_res_list": seal_res_list,
  497. "formula_res_list": formula_res_list,
  498. "parsing_res_list": parsing_res_list,
  499. "model_settings": model_settings,
  500. }
  501. yield LayoutParsingResult(single_img_res)