pipeline.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. from typing import Any, Dict, List, Optional, Tuple, Union
  16. import numpy as np
  17. from ....utils import logging
  18. from ....utils.deps import pipeline_requires_extra
  19. from ...common.batch_sampler import MarkDownBatchSampler
  20. from ...utils.hpi import HPIConfig
  21. from ...utils.pp_option import PaddlePredictorOption
  22. from ..base import BasePipeline
  23. from .result import MarkdownResult
  24. @pipeline_requires_extra("trans")
  25. class PP_DocTranslation_Pipeline(BasePipeline):
  26. entities = ["PP-DocTranslation"]
  27. def __init__(
  28. self,
  29. config: Dict,
  30. device: str = None,
  31. pp_option: PaddlePredictorOption = None,
  32. use_hpip: bool = False,
  33. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  34. initial_predictor: bool = False,
  35. ) -> None:
  36. """Initializes the PP_Translation_Pipeline.
  37. Args:
  38. config (Dict): Configuration dictionary containing various settings.
  39. device (str, optional): Device to run the predictions on. Defaults to None.
  40. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  41. use_hpip (bool, optional): Whether to use the high-performance
  42. inference plugin (HPIP) by default. Defaults to False.
  43. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
  44. The default high-performance inference configuration dictionary.
  45. Defaults to None.
  46. initial_predictor (bool, optional): Whether to initialize the predictor. Defaults to True.
  47. """
  48. super().__init__(
  49. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_config=hpi_config
  50. )
  51. self.pipeline_name = config["pipeline_name"]
  52. self.config = config
  53. self.use_layout_parser = config.get("use_layout_parser", True)
  54. self.layout_parsing_pipeline = None
  55. self.chat_bot = None
  56. if initial_predictor:
  57. self.inintial_visual_predictor(config)
  58. self.inintial_chat_predictor(config)
  59. self.markdown_batch_sampler = MarkDownBatchSampler()
  60. def inintial_visual_predictor(self, config: dict) -> None:
  61. """
  62. Initializes the visual predictor with the given configuration.
  63. Args:
  64. config (dict): The configuration dictionary containing the necessary
  65. parameters for initializing the predictor.
  66. Returns:
  67. None
  68. """
  69. self.use_layout_parser = config.get("use_layout_parser", True)
  70. if self.use_layout_parser:
  71. layout_parsing_config = config.get("SubPipelines", {}).get(
  72. "LayoutParser",
  73. {"pipeline_config_error": "config error for layout_parsing_pipeline!"},
  74. )
  75. self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
  76. return
  77. def inintial_chat_predictor(self, config: dict) -> None:
  78. """
  79. Initializes the chat predictor with the given configuration.
  80. Args:
  81. config (dict): The configuration dictionary containing the necessary
  82. parameters for initializing the predictor.
  83. Returns:
  84. None
  85. """
  86. from .. import create_chat_bot
  87. chat_bot_config = config.get("SubModules", {}).get(
  88. "LLM_Chat",
  89. {"chat_bot_config_error": "config error for llm chat bot!"},
  90. )
  91. self.chat_bot = create_chat_bot(chat_bot_config)
  92. from .. import create_prompt_engineering
  93. translate_pe_config = (
  94. config.get("SubModules", {})
  95. .get("PromptEngneering", {})
  96. .get(
  97. "Translate_CommonText",
  98. {"pe_config_error": "config error for translate_pe_config!"},
  99. )
  100. )
  101. self.translate_pe = create_prompt_engineering(translate_pe_config)
  102. return
  103. def predict(self, *args, **kwargs) -> None:
  104. logging.error(
  105. "PP-Translation Pipeline do not support to call `predict()` directly! Please invoke `visual_predict`, `build_vector`, `chat` sequentially to obtain the result."
  106. )
  107. return
  108. def visual_predict(
  109. self,
  110. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  111. use_doc_orientation_classify: Optional[bool] = None,
  112. use_doc_unwarping: Optional[bool] = None,
  113. use_textline_orientation: Optional[bool] = None,
  114. use_seal_recognition: Optional[bool] = None,
  115. use_table_recognition: Optional[bool] = None,
  116. layout_threshold: Optional[Union[float, dict]] = None,
  117. layout_nms: Optional[bool] = None,
  118. layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None,
  119. layout_merge_bboxes_mode: Optional[str] = None,
  120. text_det_limit_side_len: Optional[int] = None,
  121. text_det_limit_type: Optional[str] = None,
  122. text_det_thresh: Optional[float] = None,
  123. text_det_box_thresh: Optional[float] = None,
  124. text_det_unclip_ratio: Optional[float] = None,
  125. text_rec_score_thresh: Optional[float] = None,
  126. seal_det_limit_side_len: Optional[int] = None,
  127. seal_det_limit_type: Optional[str] = None,
  128. seal_det_thresh: Optional[float] = None,
  129. seal_det_box_thresh: Optional[float] = None,
  130. seal_det_unclip_ratio: Optional[float] = None,
  131. seal_rec_score_thresh: Optional[float] = None,
  132. **kwargs,
  133. ) -> dict:
  134. """
  135. This function takes an input image or a list of images and performs various visual
  136. prediction tasks such as document orientation classification, document unwarping,
  137. general OCR, seal recognition, and table recognition based on the provided flags.
  138. Args:
  139. input (Union[str, list[str], np.ndarray, list[np.ndarray]]): Input image path, list of image paths,
  140. numpy array of an image, or list of numpy arrays.
  141. use_doc_orientation_classify (bool): Flag to use document orientation classification.
  142. use_doc_unwarping (bool): Flag to use document unwarping.
  143. use_textline_orientation (Optional[bool]): Whether to use textline orientation prediction.
  144. use_seal_recognition (bool): Flag to use seal recognition.
  145. use_table_recognition (bool): Flag to use table recognition.
  146. layout_threshold (Optional[float]): The threshold value to filter out low-confidence predictions. Default is None.
  147. layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
  148. layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
  149. Defaults to None.
  150. If it's a single number, then both width and height are used.
  151. If it's a tuple of two numbers, then they are used separately for width and height respectively.
  152. If it's None, then no unclipping will be performed.
  153. layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
  154. text_det_limit_side_len (Optional[int]): Maximum side length for text detection.
  155. text_det_limit_type (Optional[str]): Type of limit to apply for text detection.
  156. text_det_thresh (Optional[float]): Threshold for text detection.
  157. text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
  158. text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
  159. text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
  160. seal_det_limit_side_len (Optional[int]): Maximum side length for seal detection.
  161. seal_det_limit_type (Optional[str]): Type of limit to apply for seal detection.
  162. seal_det_thresh (Optional[float]): Threshold for seal detection.
  163. seal_det_box_thresh (Optional[float]): Threshold for seal detection boxes.
  164. seal_det_unclip_ratio (Optional[float]): Ratio for unclipping seal detection boxes.
  165. seal_rec_score_thresh (Optional[float]): Score threshold for seal recognition.
  166. **kwargs: Additional keyword arguments.
  167. Returns:
  168. dict: A dictionary containing the layout parsing result and visual information.
  169. """
  170. if self.use_layout_parser == False:
  171. logging.error("The models for layout parser are not initialized.")
  172. yield {"error": "The models for layout parser are not initialized."}
  173. if self.layout_parsing_pipeline is None:
  174. logging.warning(
  175. "The layout parsing pipeline is not initialized, will initialize it now."
  176. )
  177. self.inintial_visual_predictor(self.config)
  178. for layout_parsing_result in self.layout_parsing_pipeline.predict(
  179. input,
  180. use_doc_orientation_classify=use_doc_orientation_classify,
  181. use_doc_unwarping=use_doc_unwarping,
  182. use_textline_orientation=use_textline_orientation,
  183. use_seal_recognition=use_seal_recognition,
  184. use_table_recognition=use_table_recognition,
  185. layout_threshold=layout_threshold,
  186. layout_nms=layout_nms,
  187. layout_unclip_ratio=layout_unclip_ratio,
  188. layout_merge_bboxes_mode=layout_merge_bboxes_mode,
  189. text_det_limit_side_len=text_det_limit_side_len,
  190. text_det_limit_type=text_det_limit_type,
  191. text_det_thresh=text_det_thresh,
  192. text_det_box_thresh=text_det_box_thresh,
  193. text_det_unclip_ratio=text_det_unclip_ratio,
  194. text_rec_score_thresh=text_rec_score_thresh,
  195. seal_det_box_thresh=seal_det_box_thresh,
  196. seal_det_limit_side_len=seal_det_limit_side_len,
  197. seal_det_limit_type=seal_det_limit_type,
  198. seal_det_thresh=seal_det_thresh,
  199. seal_det_unclip_ratio=seal_det_unclip_ratio,
  200. seal_rec_score_thresh=seal_rec_score_thresh,
  201. ):
  202. visual_predict_res = {
  203. "layout_parsing_result": layout_parsing_result,
  204. }
  205. yield visual_predict_res
  206. def load_from_markdown(self, input):
  207. markdown_info_list = []
  208. for markdown_sample in self.markdown_batch_sampler.sample(input):
  209. markdown_content = markdown_sample.instances[0]
  210. input_path = markdown_sample.input_paths[0]
  211. markdown_info = {
  212. "input_path": input_path,
  213. "page_index": None,
  214. "markdown_texts": markdown_content,
  215. "page_continuation_flags": (True, True),
  216. }
  217. markdown_info_list.append(MarkdownResult(markdown_info))
  218. return markdown_info_list
  219. def split_markdown(self, md_text, chunk_size):
  220. from bs4 import BeautifulSoup
  221. if (
  222. not isinstance(md_text, str)
  223. or not isinstance(chunk_size, int)
  224. or chunk_size <= 0
  225. ):
  226. raise ValueError("Invalid input parameters.")
  227. chunks = []
  228. current_chunk = []
  229. # 如果整体文本小于chunk_size,直接返回
  230. if len(md_text) < chunk_size:
  231. return [md_text]
  232. # 段落分割,两个及以上换行符视为分段
  233. paragraphs = re.split(r"\n{2,}", md_text)
  234. def split_table_to_chunks(table_html):
  235. # 使用 BeautifulSoup 解析表格
  236. soup = BeautifulSoup(table_html, "html.parser")
  237. table = soup.find("table")
  238. if not table:
  239. return [table_html] # 如果没有找到表格,直接返回原始内容
  240. # 提取所有<tr>行
  241. trs = table.find_all("tr")
  242. # 按行累加,确保每个chunk长度<=chunk_size,且不破坏<tr>的完整性
  243. table_chunks = []
  244. current_rows = []
  245. current_len = len("<table></table>") # 基础长度
  246. for tr in trs:
  247. tr_str = str(tr)
  248. row_len = len(tr_str)
  249. if current_rows and current_len + row_len > chunk_size:
  250. # 打包当前chunk
  251. content = "<table>" + "".join(current_rows) + "</table>"
  252. table_chunks.append(content)
  253. current_rows = [] # 重置当前行列表
  254. current_len = len("<table></table>") + row_len
  255. current_rows.append(tr_str)
  256. current_len += row_len
  257. if current_rows:
  258. content = "<table>" + "".join(current_rows) + "</table>"
  259. table_chunks.append(content)
  260. return table_chunks
  261. # 句子分割,英文句号需区分小数点
  262. sentence_pattern = re.compile(
  263. r"(?<=[。!?!?])|(?<=\.)\s+(?=[A-Z])|(?<=\.)\s*$"
  264. )
  265. for paragraph in paragraphs:
  266. paragraph = paragraph.strip()
  267. if not paragraph:
  268. continue
  269. # 使用 BeautifulSoup 检查是否为完整表格
  270. soup = BeautifulSoup(paragraph, "html.parser")
  271. table = soup.find("table")
  272. if table:
  273. table_html = str(table)
  274. if len(table_html) <= chunk_size:
  275. if current_chunk:
  276. chunks.append("\n\n".join(current_chunk))
  277. current_chunk = []
  278. chunks.append(table_html)
  279. else:
  280. # 表格太大,行分段
  281. if current_chunk:
  282. chunks.append("\n\n".join(current_chunk))
  283. current_chunk = []
  284. table_chunks = split_table_to_chunks(table_html)
  285. chunks.extend(table_chunks)
  286. continue
  287. # 普通文本处理
  288. if sum(len(s) for s in current_chunk) + len(paragraph) <= chunk_size:
  289. current_chunk.append(paragraph)
  290. elif len(paragraph) <= chunk_size:
  291. if current_chunk:
  292. chunks.append("\n\n".join(current_chunk))
  293. current_chunk = [paragraph]
  294. else:
  295. # 段落太长,按句子切分
  296. sentences = [
  297. s for s in sentence_pattern.split(paragraph) if s and s.strip()
  298. ]
  299. for sentence in sentences:
  300. sentence = sentence.strip()
  301. if not sentence:
  302. continue
  303. if len(sentence) > chunk_size:
  304. print(sentence)
  305. raise ValueError("A sentence exceeds the chunk size limit.")
  306. if sum(len(s) for s in current_chunk) + len(sentence) > chunk_size:
  307. if current_chunk:
  308. chunks.append("\n\n".join(current_chunk))
  309. current_chunk = [sentence]
  310. else:
  311. current_chunk.append(sentence)
  312. if sum(len(s) for s in current_chunk) >= chunk_size:
  313. chunks.append("\n\n".join(current_chunk))
  314. current_chunk = []
  315. if current_chunk:
  316. chunks.append("\n\n".join(current_chunk))
  317. return [c for c in chunks if c.strip()]
  318. def translate(
  319. self,
  320. ori_md_info_list: List[Dict],
  321. target_language: str = "zh",
  322. chunk_size: int = 5000,
  323. task_description: str = None,
  324. output_format: str = None,
  325. rules_str: str = None,
  326. few_shot_demo_text_content: str = None,
  327. few_shot_demo_key_value_list: str = None,
  328. chat_bot_config=None,
  329. **kwargs,
  330. ):
  331. """
  332. Translate the given original text into the specified target language using the configured translation model.
  333. Args:
  334. original_text (str): The original text to be translated.
  335. target_language (str): The desired target language code.
  336. **kwargs: Additional keyword arguments passed to the translation model.
  337. Returns:
  338. str: The translated text in the target language.
  339. """
  340. if self.chat_bot is None:
  341. logging.warning(
  342. "The LLM chat bot is not initialized,will initialize it now."
  343. )
  344. self.inintial_chat_predictor(self.config)
  345. if chat_bot_config is not None:
  346. from .. import create_chat_bot
  347. chat_bot = create_chat_bot(chat_bot_config)
  348. else:
  349. chat_bot = self.chat_bot
  350. if (
  351. isinstance(ori_md_info_list, list)
  352. and ori_md_info_list[0].get("page_index") is not None
  353. ):
  354. # for multi page pdf
  355. ori_md_info_list = [self.concatenate_markdown_pages(ori_md_info_list)]
  356. for ori_md in ori_md_info_list:
  357. original_texts = ori_md["markdown_texts"]
  358. chunks = self.split_markdown(original_texts, chunk_size)
  359. target_language_chunks = []
  360. if len(chunks) > 1:
  361. logging.info(
  362. f"Get the markdown text, it's length is {len(original_texts)}, will split it into {len(chunks)} parts."
  363. )
  364. logging.info(
  365. "Starting to translate the markdown text, will take a while. please wait..."
  366. )
  367. for idx, chunk in enumerate(chunks):
  368. logging.info(f"Translating the {idx+1}/{len(chunks)} part.")
  369. prompt = self.translate_pe.generate_prompt(
  370. original_text=chunk,
  371. language=target_language,
  372. task_description=task_description,
  373. output_format=output_format,
  374. rules_str=rules_str,
  375. few_shot_demo_text_content=few_shot_demo_text_content,
  376. few_shot_demo_key_value_list=few_shot_demo_key_value_list,
  377. )
  378. target_language_chunk = chat_bot.generate_chat_results(
  379. prompt=prompt
  380. ).get("content", "")
  381. target_language_chunks.append(target_language_chunk)
  382. target_language_texts = "\n\n".join(target_language_chunks)
  383. yield MarkdownResult(
  384. {
  385. "language": target_language,
  386. "input_path": ori_md["input_path"],
  387. "page_index": ori_md["page_index"],
  388. "page_continuation_flags": ori_md["page_continuation_flags"],
  389. "markdown_texts": target_language_texts,
  390. }
  391. )
  392. def concatenate_markdown_pages(self, markdown_list: list) -> tuple:
  393. """
  394. Concatenate Markdown content from multiple pages into a single document.
  395. Args:
  396. markdown_list (list): A list containing Markdown data for each page.
  397. Returns:
  398. tuple: A tuple containing the processed Markdown text.
  399. """
  400. markdown_texts = ""
  401. previous_page_last_element_paragraph_end_flag = True
  402. if len(markdown_list) == 0:
  403. raise ValueError("The length of markdown_list is zero.")
  404. for res in markdown_list:
  405. # Get the paragraph flags for the current page
  406. page_first_element_paragraph_start_flag: bool = res[
  407. "page_continuation_flags"
  408. ][0]
  409. page_last_element_paragraph_end_flag: bool = res["page_continuation_flags"][
  410. 1
  411. ]
  412. # Determine whether to add a space or a newline
  413. if (
  414. not page_first_element_paragraph_start_flag
  415. and not previous_page_last_element_paragraph_end_flag
  416. ):
  417. last_char_of_markdown = markdown_texts[-1] if markdown_texts else ""
  418. first_char_of_handler = (
  419. res["markdown_texts"][0] if res["markdown_texts"] else ""
  420. )
  421. # Check if the last character and the first character are Chinese characters
  422. last_is_chinese_char = (
  423. re.match(r"[\u4e00-\u9fff]", last_char_of_markdown)
  424. if last_char_of_markdown
  425. else False
  426. )
  427. first_is_chinese_char = (
  428. re.match(r"[\u4e00-\u9fff]", first_char_of_handler)
  429. if first_char_of_handler
  430. else False
  431. )
  432. if not (last_is_chinese_char or first_is_chinese_char):
  433. markdown_texts += " " + res["markdown_texts"]
  434. else:
  435. markdown_texts += res["markdown_texts"]
  436. else:
  437. markdown_texts += "\n\n" + res["markdown_texts"]
  438. previous_page_last_element_paragraph_end_flag = (
  439. page_last_element_paragraph_end_flag
  440. )
  441. concatenate_result = {
  442. "input_path": markdown_list[0]["input_path"],
  443. "page_index": None,
  444. "page_continuation_flags": (True, True),
  445. "markdown_texts": markdown_texts,
  446. }
  447. return MarkdownResult(concatenate_result)