pipeline_v2.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. from typing import Any, Dict, Optional, Union, Tuple
  16. import numpy as np
  17. import re
  18. import copy
  19. from ....utils import logging
  20. from ...common.batch_sampler import ImageBatchSampler
  21. from ...common.reader import ReadImage
  22. from ...models.object_detection.result import DetResult
  23. from ...utils.pp_option import PaddlePredictorOption
  24. from ...utils.hpi import HPIConfig
  25. from ..base import BasePipeline
  26. from ..ocr.result import OCRResult
  27. from .result_v2 import LayoutParsingResultV2
  28. from .utils import get_single_block_parsing_res, get_sub_regions_ocr_res, gather_imgs
  29. class LayoutParsingPipelineV2(BasePipeline):
  30. """Layout Parsing Pipeline V2"""
  31. entities = ["PP-StructureV3"]
  32. def __init__(
  33. self,
  34. config: dict,
  35. device: str = None,
  36. pp_option: PaddlePredictorOption = None,
  37. use_hpip: bool = False,
  38. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  39. ) -> None:
  40. """Initializes the layout parsing pipeline.
  41. Args:
  42. config (Dict): Configuration dictionary containing various settings.
  43. device (str, optional): Device to run the predictions on. Defaults to None.
  44. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  45. use_hpip (bool, optional): Whether to use the high-performance
  46. inference plugin (HPIP) by default. Defaults to False.
  47. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
  48. The default high-performance inference configuration dictionary.
  49. Defaults to None.
  50. """
  51. super().__init__(
  52. device=device,
  53. pp_option=pp_option,
  54. use_hpip=use_hpip,
  55. hpi_config=hpi_config,
  56. )
  57. self.inintial_predictor(config)
  58. self.batch_sampler = ImageBatchSampler(batch_size=1)
  59. self.img_reader = ReadImage(format="BGR")
  60. def inintial_predictor(self, config: dict) -> None:
  61. """Initializes the predictor based on the provided configuration.
  62. Args:
  63. config (Dict): A dictionary containing the configuration for the predictor.
  64. Returns:
  65. None
  66. """
  67. self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
  68. self.use_general_ocr = config.get("use_general_ocr", True)
  69. self.use_table_recognition = config.get("use_table_recognition", True)
  70. self.use_seal_recognition = config.get("use_seal_recognition", True)
  71. self.use_formula_recognition = config.get(
  72. "use_formula_recognition",
  73. True,
  74. )
  75. if self.use_doc_preprocessor:
  76. doc_preprocessor_config = config.get("SubPipelines", {}).get(
  77. "DocPreprocessor",
  78. {
  79. "pipeline_config_error": "config error for doc_preprocessor_pipeline!",
  80. },
  81. )
  82. self.doc_preprocessor_pipeline = self.create_pipeline(
  83. doc_preprocessor_config,
  84. )
  85. layout_det_config = config.get("SubModules", {}).get(
  86. "LayoutDetection",
  87. {"model_config_error": "config error for layout_det_model!"},
  88. )
  89. layout_kwargs = {}
  90. if (threshold := layout_det_config.get("threshold", None)) is not None:
  91. layout_kwargs["threshold"] = threshold
  92. if (layout_nms := layout_det_config.get("layout_nms", None)) is not None:
  93. layout_kwargs["layout_nms"] = layout_nms
  94. if (
  95. layout_unclip_ratio := layout_det_config.get("layout_unclip_ratio", None)
  96. ) is not None:
  97. layout_kwargs["layout_unclip_ratio"] = layout_unclip_ratio
  98. if (
  99. layout_merge_bboxes_mode := layout_det_config.get(
  100. "layout_merge_bboxes_mode", None
  101. )
  102. ) is not None:
  103. layout_kwargs["layout_merge_bboxes_mode"] = layout_merge_bboxes_mode
  104. self.layout_det_model = self.create_model(layout_det_config, **layout_kwargs)
  105. if self.use_general_ocr or self.use_table_recognition:
  106. general_ocr_config = config.get("SubPipelines", {}).get(
  107. "GeneralOCR",
  108. {"pipeline_config_error": "config error for general_ocr_pipeline!"},
  109. )
  110. self.general_ocr_pipeline = self.create_pipeline(
  111. general_ocr_config,
  112. )
  113. if self.use_seal_recognition:
  114. seal_recognition_config = config.get("SubPipelines", {}).get(
  115. "SealRecognition",
  116. {
  117. "pipeline_config_error": "config error for seal_recognition_pipeline!",
  118. },
  119. )
  120. self.seal_recognition_pipeline = self.create_pipeline(
  121. seal_recognition_config,
  122. )
  123. if self.use_table_recognition:
  124. table_recognition_config = config.get("SubPipelines", {}).get(
  125. "TableRecognition",
  126. {
  127. "pipeline_config_error": "config error for table_recognition_pipeline!",
  128. },
  129. )
  130. self.table_recognition_pipeline = self.create_pipeline(
  131. table_recognition_config,
  132. )
  133. if self.use_formula_recognition:
  134. formula_recognition_config = config.get("SubPipelines", {}).get(
  135. "FormulaRecognition",
  136. {
  137. "pipeline_config_error": "config error for formula_recognition_pipeline!",
  138. },
  139. )
  140. self.formula_recognition_pipeline = self.create_pipeline(
  141. formula_recognition_config,
  142. )
  143. return
  144. def get_text_paragraphs_ocr_res(
  145. self,
  146. overall_ocr_res: OCRResult,
  147. layout_det_res: DetResult,
  148. ) -> OCRResult:
  149. """
  150. Retrieves the OCR results for text paragraphs, excluding those of formulas, tables, and seals.
  151. Args:
  152. overall_ocr_res (OCRResult): The overall OCR result containing text information.
  153. layout_det_res (DetResult): The detection result containing the layout information of the document.
  154. Returns:
  155. OCRResult: The OCR result for text paragraphs after excluding formulas, tables, and seals.
  156. """
  157. object_boxes = []
  158. for box_info in layout_det_res["boxes"]:
  159. if box_info["label"].lower() in ["formula", "table", "seal"]:
  160. object_boxes.append(box_info["coordinate"])
  161. object_boxes = np.array(object_boxes)
  162. sub_regions_ocr_res = get_sub_regions_ocr_res(
  163. overall_ocr_res, object_boxes, flag_within=False
  164. )
  165. return sub_regions_ocr_res
  166. def check_model_settings_valid(self, input_params: dict) -> bool:
  167. """
  168. Check if the input parameters are valid based on the initialized models.
  169. Args:
  170. input_params (Dict): A dictionary containing input parameters.
  171. Returns:
  172. bool: True if all required models are initialized according to input parameters, False otherwise.
  173. """
  174. if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
  175. logging.error(
  176. "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized.",
  177. )
  178. return False
  179. if input_params["use_general_ocr"] and not self.use_general_ocr:
  180. logging.error(
  181. "Set use_general_ocr, but the models for general OCR are not initialized.",
  182. )
  183. return False
  184. if input_params["use_seal_recognition"] and not self.use_seal_recognition:
  185. logging.error(
  186. "Set use_seal_recognition, but the models for seal recognition are not initialized.",
  187. )
  188. return False
  189. if input_params["use_table_recognition"] and not self.use_table_recognition:
  190. logging.error(
  191. "Set use_table_recognition, but the models for table recognition are not initialized.",
  192. )
  193. return False
  194. return True
  195. def get_layout_parsing_res(
  196. self,
  197. image: list,
  198. layout_det_res: DetResult,
  199. overall_ocr_res: OCRResult,
  200. table_res_list: list,
  201. seal_res_list: list,
  202. formula_res_list: list,
  203. imgs_in_doc: list,
  204. text_det_limit_side_len: Optional[int] = None,
  205. text_det_limit_type: Optional[str] = None,
  206. text_det_thresh: Optional[float] = None,
  207. text_det_box_thresh: Optional[float] = None,
  208. text_det_unclip_ratio: Optional[float] = None,
  209. text_rec_score_thresh: Optional[float] = None,
  210. ) -> list:
  211. """
  212. Retrieves the layout parsing result based on the layout detection result, OCR result, and other recognition results.
  213. Args:
  214. image (list): The input image.
  215. layout_det_res (DetResult): The detection result containing the layout information of the document.
  216. overall_ocr_res (OCRResult): The overall OCR result containing text information.
  217. table_res_list (list): A list of table recognition results.
  218. seal_res_list (list): A list of seal recognition results.
  219. formula_res_list (list): A list of formula recognition results.
  220. text_det_limit_side_len (Optional[int], optional): The maximum side length of the text detection region. Defaults to None.
  221. text_det_limit_type (Optional[str], optional): The type of limit for the text detection region. Defaults to None.
  222. text_det_thresh (Optional[float], optional): The confidence threshold for text detection. Defaults to None.
  223. text_det_box_thresh (Optional[float], optional): The confidence threshold for text detection bounding boxes. Defaults to None
  224. text_det_unclip_ratio (Optional[float], optional): The unclip ratio for text detection. Defaults to None.
  225. text_rec_score_thresh (Optional[float], optional): The score threshold for text recognition. Defaults to None.
  226. Returns:
  227. list: A list of dictionaries representing the layout parsing result.
  228. """
  229. matched_ocr_dict = {}
  230. image = np.array(image)
  231. object_boxes = []
  232. footnote_list = []
  233. max_bottom_text_coordinate = 0
  234. for object_box_idx, box_info in enumerate(layout_det_res["boxes"]):
  235. box = box_info["coordinate"]
  236. label = box_info["label"].lower()
  237. object_boxes.append(box)
  238. # set the label of footnote to text, when it is above the text boxes
  239. if label == "footnote":
  240. footnote_list.append(object_box_idx)
  241. if label == "text" and box[3] > max_bottom_text_coordinate:
  242. max_bottom_text_coordinate = box[3]
  243. if label not in ["formula", "table", "seal"]:
  244. _, matched_idxs = get_sub_regions_ocr_res(
  245. overall_ocr_res, [box], return_match_idx=True
  246. )
  247. for matched_idx in matched_idxs:
  248. if matched_ocr_dict.get(matched_idx, None) is None:
  249. matched_ocr_dict[matched_idx] = [object_box_idx]
  250. else:
  251. matched_ocr_dict[matched_idx].append(object_box_idx)
  252. for footnote_idx in footnote_list:
  253. if (
  254. layout_det_res["boxes"][footnote_idx]["coordinate"][3]
  255. < max_bottom_text_coordinate
  256. ):
  257. layout_det_res["boxes"][footnote_idx]["label"] = "text"
  258. already_processed = set()
  259. for matched_idx, layout_box_ids in matched_ocr_dict.items():
  260. if len(layout_box_ids) <= 1:
  261. continue
  262. # one ocr is matched to multiple layout boxes, split the text into multiple lines
  263. for idx in layout_box_ids:
  264. if idx in already_processed:
  265. continue
  266. already_processed.add(idx)
  267. wht_im = np.ones(image.shape, dtype=image.dtype) * 255
  268. box = object_boxes[idx]
  269. x1, y1, x2, y2 = [int(i) for i in box]
  270. wht_im[y1:y2, x1:x2, :] = image[y1:y2, x1:x2, :]
  271. sub_ocr_res = next(
  272. self.general_ocr_pipeline(
  273. wht_im,
  274. text_det_limit_side_len=text_det_limit_side_len,
  275. text_det_limit_type=text_det_limit_type,
  276. text_det_thresh=text_det_thresh,
  277. text_det_box_thresh=text_det_box_thresh,
  278. text_det_unclip_ratio=text_det_unclip_ratio,
  279. text_rec_score_thresh=text_rec_score_thresh,
  280. )
  281. )
  282. _, matched_idxs = get_sub_regions_ocr_res(
  283. overall_ocr_res, [box], return_match_idx=True
  284. )
  285. for matched_idx in sorted(matched_idxs, reverse=True):
  286. del overall_ocr_res["dt_polys"][matched_idx]
  287. del overall_ocr_res["rec_texts"][matched_idx]
  288. overall_ocr_res["rec_boxes"] = np.delete(
  289. overall_ocr_res["rec_boxes"], matched_idx, axis=0
  290. )
  291. del overall_ocr_res["rec_polys"][matched_idx]
  292. del overall_ocr_res["rec_scores"][matched_idx]
  293. if sub_ocr_res["rec_boxes"].size > 0:
  294. sub_ocr_res["rec_labels"] = ["text"] * len(sub_ocr_res["rec_texts"])
  295. overall_ocr_res["dt_polys"].extend(sub_ocr_res["dt_polys"])
  296. overall_ocr_res["rec_texts"].extend(sub_ocr_res["rec_texts"])
  297. overall_ocr_res["rec_boxes"] = np.concatenate(
  298. [overall_ocr_res["rec_boxes"], sub_ocr_res["rec_boxes"]], axis=0
  299. )
  300. overall_ocr_res["rec_polys"].extend(sub_ocr_res["rec_polys"])
  301. overall_ocr_res["rec_scores"].extend(sub_ocr_res["rec_scores"])
  302. overall_ocr_res["rec_labels"].extend(sub_ocr_res["rec_labels"])
  303. for formula_res in formula_res_list:
  304. x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
  305. poly_points = [
  306. (x_min, y_min),
  307. (x_max, y_min),
  308. (x_max, y_max),
  309. (x_min, y_max),
  310. ]
  311. overall_ocr_res["dt_polys"].append(poly_points)
  312. overall_ocr_res["rec_texts"].append(f"${formula_res['rec_formula']}$")
  313. overall_ocr_res["rec_boxes"] = np.vstack(
  314. (overall_ocr_res["rec_boxes"], [formula_res["dt_polys"]])
  315. )
  316. overall_ocr_res["rec_labels"].append("formula")
  317. overall_ocr_res["rec_polys"].append(poly_points)
  318. overall_ocr_res["rec_scores"].append(1)
  319. parsing_res_list = get_single_block_parsing_res(
  320. self.general_ocr_pipeline,
  321. overall_ocr_res=overall_ocr_res,
  322. layout_det_res=layout_det_res,
  323. table_res_list=table_res_list,
  324. seal_res_list=seal_res_list,
  325. )
  326. return parsing_res_list
  327. def get_model_settings(
  328. self,
  329. use_doc_orientation_classify: Union[bool, None],
  330. use_doc_unwarping: Union[bool, None],
  331. use_general_ocr: Union[bool, None],
  332. use_seal_recognition: Union[bool, None],
  333. use_table_recognition: Union[bool, None],
  334. use_formula_recognition: Union[bool, None],
  335. ) -> dict:
  336. """
  337. Get the model settings based on the provided parameters or default values.
  338. Args:
  339. use_doc_orientation_classify (Union[bool, None]): Enables document orientation classification if True. Defaults to system setting if None.
  340. use_doc_unwarping (Union[bool, None]): Enables document unwarping if True. Defaults to system setting if None.
  341. use_general_ocr (Union[bool, None]): Enables general OCR if True. Defaults to system setting if None.
  342. use_seal_recognition (Union[bool, None]): Enables seal recognition if True. Defaults to system setting if None.
  343. use_table_recognition (Union[bool, None]): Enables table recognition if True. Defaults to system setting if None.
  344. use_formula_recognition (Union[bool, None]): Enables formula recognition if True. Defaults to system setting if None.
  345. Returns:
  346. dict: A dictionary containing the model settings.
  347. """
  348. if use_doc_orientation_classify is None and use_doc_unwarping is None:
  349. use_doc_preprocessor = self.use_doc_preprocessor
  350. else:
  351. if use_doc_orientation_classify is True or use_doc_unwarping is True:
  352. use_doc_preprocessor = True
  353. else:
  354. use_doc_preprocessor = False
  355. if use_general_ocr is None:
  356. use_general_ocr = self.use_general_ocr
  357. if use_seal_recognition is None:
  358. use_seal_recognition = self.use_seal_recognition
  359. if use_table_recognition is None:
  360. use_table_recognition = self.use_table_recognition
  361. if use_formula_recognition is None:
  362. use_formula_recognition = self.use_formula_recognition
  363. return dict(
  364. use_doc_preprocessor=use_doc_preprocessor,
  365. use_general_ocr=use_general_ocr,
  366. use_seal_recognition=use_seal_recognition,
  367. use_table_recognition=use_table_recognition,
  368. use_formula_recognition=use_formula_recognition,
  369. )
  370. def predict(
  371. self,
  372. input: Union[str, list[str], np.ndarray, list[np.ndarray]],
  373. use_doc_orientation_classify: Union[bool, None] = None,
  374. use_doc_unwarping: Union[bool, None] = None,
  375. use_textline_orientation: Optional[bool] = None,
  376. use_general_ocr: Union[bool, None] = None,
  377. use_seal_recognition: Union[bool, None] = None,
  378. use_table_recognition: Union[bool, None] = None,
  379. use_formula_recognition: Union[bool, None] = None,
  380. layout_threshold: Optional[Union[float, dict]] = None,
  381. layout_nms: Optional[bool] = None,
  382. layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None,
  383. layout_merge_bboxes_mode: Optional[str] = None,
  384. text_det_limit_side_len: Union[int, None] = None,
  385. text_det_limit_type: Union[str, None] = None,
  386. text_det_thresh: Union[float, None] = None,
  387. text_det_box_thresh: Union[float, None] = None,
  388. text_det_unclip_ratio: Union[float, None] = None,
  389. text_rec_score_thresh: Union[float, None] = None,
  390. seal_det_limit_side_len: Union[int, None] = None,
  391. seal_det_limit_type: Union[str, None] = None,
  392. seal_det_thresh: Union[float, None] = None,
  393. seal_det_box_thresh: Union[float, None] = None,
  394. seal_det_unclip_ratio: Union[float, None] = None,
  395. seal_rec_score_thresh: Union[float, None] = None,
  396. use_table_cells_ocr_results: bool = False,
  397. use_e2e_wired_table_rec_model: bool = False,
  398. use_e2e_wireless_table_rec_model: bool = True,
  399. **kwargs,
  400. ) -> LayoutParsingResultV2:
  401. """
  402. Predicts the layout parsing result for the given input.
  403. Args:
  404. use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
  405. use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
  406. use_textline_orientation (Optional[bool]): Whether to use textline orientation prediction.
  407. use_general_ocr (Optional[bool]): Whether to use general OCR.
  408. use_seal_recognition (Optional[bool]): Whether to use seal recognition.
  409. use_table_recognition (Optional[bool]): Whether to use table recognition.
  410. use_formula_recognition (Optional[bool]): Whether to use formula recognition.
  411. layout_threshold (Optional[float]): The threshold value to filter out low-confidence predictions. Default is None.
  412. layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
  413. layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
  414. Defaults to None.
  415. If it's a single number, then both width and height are used.
  416. If it's a tuple of two numbers, then they are used separately for width and height respectively.
  417. If it's None, then no unclipping will be performed.
  418. layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
  419. text_det_limit_side_len (Optional[int]): Maximum side length for text detection.
  420. text_det_limit_type (Optional[str]): Type of limit to apply for text detection.
  421. text_det_thresh (Optional[float]): Threshold for text detection.
  422. text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
  423. text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
  424. text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
  425. seal_det_limit_side_len (Optional[int]): Maximum side length for seal detection.
  426. seal_det_limit_type (Optional[str]): Type of limit to apply for seal detection.
  427. seal_det_thresh (Optional[float]): Threshold for seal detection.
  428. seal_det_box_thresh (Optional[float]): Threshold for seal detection boxes.
  429. seal_det_unclip_ratio (Optional[float]): Ratio for unclipping seal detection boxes.
  430. seal_rec_score_thresh (Optional[float]): Score threshold for seal recognition.
  431. use_table_cells_ocr_results (bool): whether to use OCR results with cells.
  432. use_e2e_wired_table_rec_model (bool): Whether to use end-to-end wired table recognition model.
  433. use_e2e_wireless_table_rec_model (bool): Whether to use end-to-end wireless table recognition model.
  434. **kwargs (Any): Additional settings to extend functionality.
  435. Returns:
  436. LayoutParsingResultV2: The predicted layout parsing result.
  437. """
  438. model_settings = self.get_model_settings(
  439. use_doc_orientation_classify,
  440. use_doc_unwarping,
  441. use_general_ocr,
  442. use_seal_recognition,
  443. use_table_recognition,
  444. use_formula_recognition,
  445. )
  446. if not self.check_model_settings_valid(model_settings):
  447. yield {"error": "the input params for model settings are invalid!"}
  448. for batch_data in self.batch_sampler(input):
  449. image_array = self.img_reader(batch_data.instances)[0]
  450. if model_settings["use_doc_preprocessor"]:
  451. doc_preprocessor_res = next(
  452. self.doc_preprocessor_pipeline(
  453. image_array,
  454. use_doc_orientation_classify=use_doc_orientation_classify,
  455. use_doc_unwarping=use_doc_unwarping,
  456. ),
  457. )
  458. else:
  459. doc_preprocessor_res = {"output_img": image_array}
  460. doc_preprocessor_image = doc_preprocessor_res["output_img"]
  461. layout_det_res = next(
  462. self.layout_det_model(
  463. doc_preprocessor_image,
  464. threshold=layout_threshold,
  465. layout_nms=layout_nms,
  466. layout_unclip_ratio=layout_unclip_ratio,
  467. layout_merge_bboxes_mode=layout_merge_bboxes_mode,
  468. )
  469. )
  470. imgs_in_doc = gather_imgs(doc_preprocessor_image, layout_det_res["boxes"])
  471. if model_settings["use_formula_recognition"]:
  472. formula_res_all = next(
  473. self.formula_recognition_pipeline(
  474. doc_preprocessor_image,
  475. use_layout_detection=False,
  476. use_doc_orientation_classify=False,
  477. use_doc_unwarping=False,
  478. layout_det_res=layout_det_res,
  479. ),
  480. )
  481. formula_res_list = formula_res_all["formula_res_list"]
  482. else:
  483. formula_res_list = []
  484. for formula_res in formula_res_list:
  485. x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
  486. doc_preprocessor_image[y_min:y_max, x_min:x_max, :] = 255.0
  487. if (
  488. model_settings["use_general_ocr"]
  489. or model_settings["use_table_recognition"]
  490. ):
  491. overall_ocr_res = next(
  492. self.general_ocr_pipeline(
  493. doc_preprocessor_image,
  494. use_textline_orientation=use_textline_orientation,
  495. text_det_limit_side_len=text_det_limit_side_len,
  496. text_det_limit_type=text_det_limit_type,
  497. text_det_thresh=text_det_thresh,
  498. text_det_box_thresh=text_det_box_thresh,
  499. text_det_unclip_ratio=text_det_unclip_ratio,
  500. text_rec_score_thresh=text_rec_score_thresh,
  501. ),
  502. )
  503. else:
  504. overall_ocr_res = {}
  505. overall_ocr_res["rec_labels"] = ["text"] * len(overall_ocr_res["rec_texts"])
  506. if model_settings["use_table_recognition"]:
  507. table_contents = copy.deepcopy(overall_ocr_res)
  508. for formula_res in formula_res_list:
  509. x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
  510. poly_points = [
  511. (x_min, y_min),
  512. (x_max, y_min),
  513. (x_max, y_max),
  514. (x_min, y_max),
  515. ]
  516. table_contents["dt_polys"].append(poly_points)
  517. table_contents["rec_texts"].append(
  518. f"${formula_res['rec_formula']}$"
  519. )
  520. table_contents["rec_boxes"] = np.vstack(
  521. (table_contents["rec_boxes"], [formula_res["dt_polys"]])
  522. )
  523. table_contents["rec_polys"].append(poly_points)
  524. table_contents["rec_scores"].append(1)
  525. for img in imgs_in_doc:
  526. img_path = img["path"]
  527. x_min, y_min, x_max, y_max = img["coordinate"]
  528. poly_points = [
  529. (x_min, y_min),
  530. (x_max, y_min),
  531. (x_max, y_max),
  532. (x_min, y_max),
  533. ]
  534. table_contents["dt_polys"].append(poly_points)
  535. table_contents["rec_texts"].append(
  536. f'<div style="text-align: center;"><img src="{img_path}" alt="Image" /></div>'
  537. )
  538. if table_contents["rec_boxes"].size == 0:
  539. table_contents["rec_boxes"] = np.array([img["coordinate"]])
  540. else:
  541. table_contents["rec_boxes"] = np.vstack(
  542. (table_contents["rec_boxes"], img["coordinate"])
  543. )
  544. table_contents["rec_polys"].append(poly_points)
  545. table_contents["rec_scores"].append(img["score"])
  546. table_res_all = next(
  547. self.table_recognition_pipeline(
  548. doc_preprocessor_image,
  549. use_doc_orientation_classify=False,
  550. use_doc_unwarping=False,
  551. use_layout_detection=False,
  552. use_ocr_model=False,
  553. overall_ocr_res=table_contents,
  554. layout_det_res=layout_det_res,
  555. cell_sort_by_y_projection=True,
  556. use_table_cells_ocr_results=use_table_cells_ocr_results,
  557. use_e2e_wired_table_rec_model=use_e2e_wired_table_rec_model,
  558. use_e2e_wireless_table_rec_model=use_e2e_wireless_table_rec_model,
  559. ),
  560. )
  561. table_res_list = table_res_all["table_res_list"]
  562. else:
  563. table_res_list = []
  564. if model_settings["use_seal_recognition"]:
  565. seal_res_all = next(
  566. self.seal_recognition_pipeline(
  567. doc_preprocessor_image,
  568. use_doc_orientation_classify=False,
  569. use_doc_unwarping=False,
  570. use_layout_detection=False,
  571. layout_det_res=layout_det_res,
  572. seal_det_limit_side_len=seal_det_limit_side_len,
  573. seal_det_limit_type=seal_det_limit_type,
  574. seal_det_thresh=seal_det_thresh,
  575. seal_det_box_thresh=seal_det_box_thresh,
  576. seal_det_unclip_ratio=seal_det_unclip_ratio,
  577. seal_rec_score_thresh=seal_rec_score_thresh,
  578. ),
  579. )
  580. seal_res_list = seal_res_all["seal_res_list"]
  581. else:
  582. seal_res_list = []
  583. parsing_res_list = self.get_layout_parsing_res(
  584. doc_preprocessor_image,
  585. layout_det_res=layout_det_res,
  586. overall_ocr_res=overall_ocr_res,
  587. table_res_list=table_res_list,
  588. seal_res_list=seal_res_list,
  589. formula_res_list=formula_res_list,
  590. imgs_in_doc=imgs_in_doc,
  591. text_det_limit_side_len=text_det_limit_side_len,
  592. text_det_limit_type=text_det_limit_type,
  593. text_det_thresh=text_det_thresh,
  594. text_det_box_thresh=text_det_box_thresh,
  595. text_det_unclip_ratio=text_det_unclip_ratio,
  596. text_rec_score_thresh=text_rec_score_thresh,
  597. )
  598. for formula_res in formula_res_list:
  599. x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
  600. doc_preprocessor_image[y_min:y_max, x_min:x_max, :] = formula_res[
  601. "input_img"
  602. ]
  603. single_img_res = {
  604. "input_path": batch_data.input_paths[0],
  605. "page_index": batch_data.page_indexes[0],
  606. "doc_preprocessor_res": doc_preprocessor_res,
  607. "layout_det_res": layout_det_res,
  608. "overall_ocr_res": overall_ocr_res,
  609. "table_res_list": table_res_list,
  610. "seal_res_list": seal_res_list,
  611. "formula_res_list": formula_res_list,
  612. "parsing_res_list": parsing_res_list,
  613. "imgs_in_doc": imgs_in_doc,
  614. "model_settings": model_settings,
  615. }
  616. yield LayoutParsingResultV2(single_img_res)
  617. def concatenate_markdown_pages(self, markdown_list: list) -> tuple:
  618. """
  619. Concatenate Markdown content from multiple pages into a single document.
  620. Args:
  621. markdown_list (list): A list containing Markdown data for each page.
  622. Returns:
  623. tuple: A tuple containing the processed Markdown text.
  624. """
  625. markdown_texts = ""
  626. previous_page_last_element_paragraph_end_flag = True
  627. for res in markdown_list:
  628. # Get the paragraph flags for the current page
  629. page_first_element_paragraph_start_flag: bool = res[
  630. "page_continuation_flags"
  631. ][0]
  632. page_last_element_paragraph_end_flag: bool = res["page_continuation_flags"][
  633. 1
  634. ]
  635. # Determine whether to add a space or a newline
  636. if (
  637. not page_first_element_paragraph_start_flag
  638. and not previous_page_last_element_paragraph_end_flag
  639. ):
  640. last_char_of_markdown = markdown_texts[-1] if markdown_texts else ""
  641. first_char_of_handler = (
  642. res["markdown_texts"][0] if res["markdown_texts"] else ""
  643. )
  644. # Check if the last character and the first character are Chinese characters
  645. last_is_chinese_char = (
  646. re.match(r"[\u4e00-\u9fff]", last_char_of_markdown)
  647. if last_char_of_markdown
  648. else False
  649. )
  650. first_is_chinese_char = (
  651. re.match(r"[\u4e00-\u9fff]", first_char_of_handler)
  652. if first_char_of_handler
  653. else False
  654. )
  655. if not (last_is_chinese_char or first_is_chinese_char):
  656. markdown_texts += " " + res["markdown_texts"]
  657. else:
  658. markdown_texts += res["markdown_texts"]
  659. else:
  660. markdown_texts += "\n\n" + res["markdown_texts"]
  661. previous_page_last_element_paragraph_end_flag = (
  662. page_last_element_paragraph_end_flag
  663. )
  664. return markdown_texts