pipeline.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. from typing import Any, Dict, List, Optional, Tuple, Union
  16. import numpy as np
  17. from ....utils import logging
  18. from ....utils.deps import pipeline_requires_extra
  19. from ...common.batch_sampler import MarkDownBatchSampler
  20. from ...utils.hpi import HPIConfig
  21. from ...utils.pp_option import PaddlePredictorOption
  22. from ..base import BasePipeline
  23. from .result import MarkdownResult
  24. @pipeline_requires_extra("ie")
  25. class PP_DocTranslation_Pipeline(BasePipeline):
  26. entities = ["PP-DocTranslation"]
  27. def __init__(
  28. self,
  29. config: Dict,
  30. device: str = None,
  31. pp_option: PaddlePredictorOption = None,
  32. use_hpip: bool = False,
  33. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  34. initial_predictor: bool = False,
  35. ) -> None:
  36. """Initializes the PP_Translation_Pipeline.
  37. Args:
  38. config (Dict): Configuration dictionary containing various settings.
  39. device (str, optional): Device to run the predictions on. Defaults to None.
  40. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  41. use_hpip (bool, optional): Whether to use the high-performance
  42. inference plugin (HPIP) by default. Defaults to False.
  43. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
  44. The default high-performance inference configuration dictionary.
  45. Defaults to None.
  46. initial_predictor (bool, optional): Whether to initialize the predictor. Defaults to True.
  47. """
  48. super().__init__(
  49. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_config=hpi_config
  50. )
  51. self.pipeline_name = config["pipeline_name"]
  52. self.config = config
  53. self.use_layout_parser = config.get("use_layout_parser", True)
  54. self.layout_parsing_pipeline = None
  55. self.chat_bot = None
  56. if initial_predictor:
  57. self.inintial_visual_predictor(config)
  58. self.inintial_chat_predictor(config)
  59. self.markdown_batch_sampler = MarkDownBatchSampler()
  60. def inintial_visual_predictor(self, config: dict) -> None:
  61. """
  62. Initializes the visual predictor with the given configuration.
  63. Args:
  64. config (dict): The configuration dictionary containing the necessary
  65. parameters for initializing the predictor.
  66. Returns:
  67. None
  68. """
  69. self.use_layout_parser = config.get("use_layout_parser", True)
  70. if self.use_layout_parser:
  71. layout_parsing_config = config.get("SubPipelines", {}).get(
  72. "LayoutParser",
  73. {"pipeline_config_error": "config error for layout_parsing_pipeline!"},
  74. )
  75. self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
  76. return
  77. def inintial_chat_predictor(self, config: dict) -> None:
  78. """
  79. Initializes the chat predictor with the given configuration.
  80. Args:
  81. config (dict): The configuration dictionary containing the necessary
  82. parameters for initializing the predictor.
  83. Returns:
  84. None
  85. """
  86. from .. import create_chat_bot
  87. chat_bot_config = config.get("SubModules", {}).get(
  88. "LLM_Chat",
  89. {"chat_bot_config_error": "config error for llm chat bot!"},
  90. )
  91. self.chat_bot = create_chat_bot(chat_bot_config)
  92. from .. import create_prompt_engineering
  93. translate_pe_config = (
  94. config.get("SubModules", {})
  95. .get("PromptEngneering", {})
  96. .get(
  97. "Translate_CommonText",
  98. {"pe_config_error": "config error for translate_pe_config!"},
  99. )
  100. )
  101. self.translate_pe = create_prompt_engineering(translate_pe_config)
  102. return
  103. def predict(self, *args, **kwargs) -> None:
  104. logging.error(
  105. "PP-Translation Pipeline do not support to call `predict()` directly! Please invoke `visual_predict`, `build_vector`, `chat` sequentially to obtain the result."
  106. )
  107. return
  108. def visual_predict(
  109. self,
  110. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  111. use_doc_orientation_classify: Optional[bool] = None,
  112. use_doc_unwarping: Optional[bool] = None,
  113. use_textline_orientation: Optional[bool] = None,
  114. use_seal_recognition: Optional[bool] = None,
  115. use_table_recognition: Optional[bool] = None,
  116. layout_threshold: Optional[Union[float, dict]] = None,
  117. layout_nms: Optional[bool] = None,
  118. layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None,
  119. layout_merge_bboxes_mode: Optional[str] = None,
  120. text_det_limit_side_len: Optional[int] = None,
  121. text_det_limit_type: Optional[str] = None,
  122. text_det_thresh: Optional[float] = None,
  123. text_det_box_thresh: Optional[float] = None,
  124. text_det_unclip_ratio: Optional[float] = None,
  125. text_rec_score_thresh: Optional[float] = None,
  126. seal_det_limit_side_len: Optional[int] = None,
  127. seal_det_limit_type: Optional[str] = None,
  128. seal_det_thresh: Optional[float] = None,
  129. seal_det_box_thresh: Optional[float] = None,
  130. seal_det_unclip_ratio: Optional[float] = None,
  131. seal_rec_score_thresh: Optional[float] = None,
  132. **kwargs,
  133. ) -> dict:
  134. """
  135. This function takes an input image or a list of images and performs various visual
  136. prediction tasks such as document orientation classification, document unwarping,
  137. general OCR, seal recognition, and table recognition based on the provided flags.
  138. Args:
  139. input (Union[str, list[str], np.ndarray, list[np.ndarray]]): Input image path, list of image paths,
  140. numpy array of an image, or list of numpy arrays.
  141. use_doc_orientation_classify (bool): Flag to use document orientation classification.
  142. use_doc_unwarping (bool): Flag to use document unwarping.
  143. use_textline_orientation (Optional[bool]): Whether to use textline orientation prediction.
  144. use_seal_recognition (bool): Flag to use seal recognition.
  145. use_table_recognition (bool): Flag to use table recognition.
  146. layout_threshold (Optional[float]): The threshold value to filter out low-confidence predictions. Default is None.
  147. layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
  148. layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
  149. Defaults to None.
  150. If it's a single number, then both width and height are used.
  151. If it's a tuple of two numbers, then they are used separately for width and height respectively.
  152. If it's None, then no unclipping will be performed.
  153. layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
  154. text_det_limit_side_len (Optional[int]): Maximum side length for text detection.
  155. text_det_limit_type (Optional[str]): Type of limit to apply for text detection.
  156. text_det_thresh (Optional[float]): Threshold for text detection.
  157. text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
  158. text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
  159. text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
  160. seal_det_limit_side_len (Optional[int]): Maximum side length for seal detection.
  161. seal_det_limit_type (Optional[str]): Type of limit to apply for seal detection.
  162. seal_det_thresh (Optional[float]): Threshold for seal detection.
  163. seal_det_box_thresh (Optional[float]): Threshold for seal detection boxes.
  164. seal_det_unclip_ratio (Optional[float]): Ratio for unclipping seal detection boxes.
  165. seal_rec_score_thresh (Optional[float]): Score threshold for seal recognition.
  166. **kwargs: Additional keyword arguments.
  167. Returns:
  168. dict: A dictionary containing the layout parsing result and visual information.
  169. """
  170. if self.use_layout_parser == False:
  171. logging.error("The models for layout parser are not initialized.")
  172. yield {"error": "The models for layout parser are not initialized."}
  173. if self.layout_parsing_pipeline is None:
  174. logging.warning(
  175. "The layout parsing pipeline is not initialized, will initialize it now."
  176. )
  177. self.inintial_visual_predictor(self.config)
  178. for layout_parsing_result in self.layout_parsing_pipeline.predict(
  179. input,
  180. use_doc_orientation_classify=use_doc_orientation_classify,
  181. use_doc_unwarping=use_doc_unwarping,
  182. use_textline_orientation=use_textline_orientation,
  183. use_seal_recognition=use_seal_recognition,
  184. use_table_recognition=use_table_recognition,
  185. layout_threshold=layout_threshold,
  186. layout_nms=layout_nms,
  187. layout_unclip_ratio=layout_unclip_ratio,
  188. layout_merge_bboxes_mode=layout_merge_bboxes_mode,
  189. text_det_limit_side_len=text_det_limit_side_len,
  190. text_det_limit_type=text_det_limit_type,
  191. text_det_thresh=text_det_thresh,
  192. text_det_box_thresh=text_det_box_thresh,
  193. text_det_unclip_ratio=text_det_unclip_ratio,
  194. text_rec_score_thresh=text_rec_score_thresh,
  195. seal_det_box_thresh=seal_det_box_thresh,
  196. seal_det_limit_side_len=seal_det_limit_side_len,
  197. seal_det_limit_type=seal_det_limit_type,
  198. seal_det_thresh=seal_det_thresh,
  199. seal_det_unclip_ratio=seal_det_unclip_ratio,
  200. seal_rec_score_thresh=seal_rec_score_thresh,
  201. ):
  202. visual_predict_res = {
  203. "layout_parsing_result": layout_parsing_result,
  204. }
  205. yield visual_predict_res
  206. def load_from_markdown(self, input):
  207. markdown_info_list = []
  208. for markdown_sample in self.markdown_batch_sampler.sample(input):
  209. markdown_content = markdown_sample.instances[0]
  210. input_path = markdown_sample.input_paths[0]
  211. markdown_info = {
  212. "input_path": input_path,
  213. "page_index": None,
  214. "markdown_texts": markdown_content,
  215. "page_continuation_flags": (True, True),
  216. }
  217. markdown_info_list.append(MarkdownResult(markdown_info))
  218. return markdown_info_list
  219. def split_markdown(self, md_text, chunk_size):
  220. if (
  221. not isinstance(md_text, str)
  222. or not isinstance(chunk_size, int)
  223. or chunk_size <= 0
  224. ):
  225. raise ValueError("Invalid input parameters.")
  226. chunks = []
  227. current_chunk = []
  228. # if md_text less than chunk_size, return the md_text
  229. if len(md_text) < chunk_size:
  230. chunks.append(md_text)
  231. return chunks
  232. # split the md_text into paragraphs
  233. paragraphs = md_text.split("\n\n")
  234. for paragraph in paragraphs:
  235. if len(paragraph) == 0:
  236. # 空行直接跳过
  237. continue
  238. if len(paragraph) <= chunk_size:
  239. current_chunk.append(paragraph)
  240. else:
  241. # if the paragraph is too long, split it into sentences
  242. sentences = re.split(r"(?<=[。.!?])", paragraph)
  243. for sentence in sentences:
  244. if len(sentence) == 0:
  245. continue
  246. if len(sentence) > chunk_size:
  247. raise ValueError("A sentence exceeds the chunk size limit.")
  248. # if the current chunk is too long, store it and start a new one
  249. if sum(len(s) for s in current_chunk) + len(sentence) > chunk_size:
  250. chunks.append("\n\n".join(current_chunk))
  251. current_chunk = [sentence]
  252. else:
  253. current_chunk.append(sentence)
  254. if sum(len(s) for s in current_chunk) >= chunk_size:
  255. chunks.append("\n\n".join(current_chunk))
  256. current_chunk = []
  257. if current_chunk:
  258. chunks.append("\n\n".join(current_chunk))
  259. return chunks
  260. def translate(
  261. self,
  262. ori_md_info_list: List[Dict],
  263. target_language: str = "zh",
  264. chunk_size: int = 5000,
  265. task_description: str = None,
  266. output_format: str = None,
  267. rules_str: str = None,
  268. few_shot_demo_text_content: str = None,
  269. few_shot_demo_key_value_list: str = None,
  270. chat_bot_config=None,
  271. **kwargs,
  272. ):
  273. """
  274. Translate the given original text into the specified target language using the configured translation model.
  275. Args:
  276. original_text (str): The original text to be translated.
  277. target_language (str): The desired target language code.
  278. **kwargs: Additional keyword arguments passed to the translation model.
  279. Returns:
  280. str: The translated text in the target language.
  281. """
  282. if self.chat_bot is None:
  283. logging.warning(
  284. "The LLM chat bot is not initialized,will initialize it now."
  285. )
  286. self.inintial_chat_predictor(self.config)
  287. if chat_bot_config is not None:
  288. from .. import create_chat_bot
  289. chat_bot = create_chat_bot(chat_bot_config)
  290. else:
  291. chat_bot = self.chat_bot
  292. if (
  293. isinstance(ori_md_info_list, list)
  294. and ori_md_info_list[0].get("page_index") is not None
  295. ):
  296. # for multi page pdf
  297. ori_md_info_list = [self.concatenate_markdown_pages(ori_md_info_list)]
  298. for ori_md in ori_md_info_list:
  299. original_texts = ori_md["markdown_texts"]
  300. chunks = self.split_markdown(original_texts, chunk_size)
  301. target_language_chunks = []
  302. if len(chunks) > 1:
  303. logging.info(
  304. f"Get the markdown text, it's length is {len(original_texts)}, will split it into {len(chunks)} parts."
  305. )
  306. logging.info(
  307. "Starting to translate the markdown text, will take a while. please wait..."
  308. )
  309. for idx, chunk in enumerate(chunks):
  310. logging.info(f"Translating the {idx+1}/{len(chunks)} part.")
  311. prompt = self.translate_pe.generate_prompt(
  312. original_text=chunk,
  313. language=target_language,
  314. task_description=task_description,
  315. output_format=output_format,
  316. rules_str=rules_str,
  317. few_shot_demo_text_content=few_shot_demo_text_content,
  318. few_shot_demo_key_value_list=few_shot_demo_key_value_list,
  319. )
  320. target_language_chunk = chat_bot.generate_chat_results(
  321. prompt=prompt
  322. ).get("content", "")
  323. target_language_chunks.append(target_language_chunk)
  324. target_language_texts = "\n\n".join(target_language_chunks)
  325. yield MarkdownResult(
  326. {
  327. "language": target_language,
  328. "input_path": ori_md["input_path"],
  329. "page_index": ori_md["page_index"],
  330. "page_continuation_flags": ori_md["page_continuation_flags"],
  331. "markdown_texts": target_language_texts,
  332. }
  333. )
  334. def concatenate_markdown_pages(self, markdown_list: list) -> tuple:
  335. """
  336. Concatenate Markdown content from multiple pages into a single document.
  337. Args:
  338. markdown_list (list): A list containing Markdown data for each page.
  339. Returns:
  340. tuple: A tuple containing the processed Markdown text.
  341. """
  342. markdown_texts = ""
  343. previous_page_last_element_paragraph_end_flag = True
  344. if len(markdown_list) == 0:
  345. raise ValueError("The length of markdown_list is zero.")
  346. for res in markdown_list:
  347. # Get the paragraph flags for the current page
  348. page_first_element_paragraph_start_flag: bool = res[
  349. "page_continuation_flags"
  350. ][0]
  351. page_last_element_paragraph_end_flag: bool = res["page_continuation_flags"][
  352. 1
  353. ]
  354. # Determine whether to add a space or a newline
  355. if (
  356. not page_first_element_paragraph_start_flag
  357. and not previous_page_last_element_paragraph_end_flag
  358. ):
  359. last_char_of_markdown = markdown_texts[-1] if markdown_texts else ""
  360. first_char_of_handler = (
  361. res["markdown_texts"][0] if res["markdown_texts"] else ""
  362. )
  363. # Check if the last character and the first character are Chinese characters
  364. last_is_chinese_char = (
  365. re.match(r"[\u4e00-\u9fff]", last_char_of_markdown)
  366. if last_char_of_markdown
  367. else False
  368. )
  369. first_is_chinese_char = (
  370. re.match(r"[\u4e00-\u9fff]", first_char_of_handler)
  371. if first_char_of_handler
  372. else False
  373. )
  374. if not (last_is_chinese_char or first_is_chinese_char):
  375. markdown_texts += " " + res["markdown_texts"]
  376. else:
  377. markdown_texts += res["markdown_texts"]
  378. else:
  379. markdown_texts += "\n\n" + res["markdown_texts"]
  380. previous_page_last_element_paragraph_end_flag = (
  381. page_last_element_paragraph_end_flag
  382. )
  383. concatenate_result = {
  384. "input_path": markdown_list[0]["input_path"],
  385. "page_index": None,
  386. "page_continuation_flags": (True, True),
  387. "markdown_texts": markdown_texts,
  388. }
  389. return MarkdownResult(concatenate_result)