pipeline.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..base import BasePipeline
  15. from typing import Any, Dict, Optional
  16. # import numpy as np
  17. # import cv2
  18. from .result import VisualInfoResult
  19. import re
  20. ########## [TODO]后续需要更新路径
  21. from ...components.transforms import ReadImage
  22. import json
  23. from ....utils import logging
  24. from ...utils.pp_option import PaddlePredictorOption
  25. from ..layout_parsing.result import LayoutParsingResult
  26. import numpy as np
  27. class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
  28. """PP-ChatOCRv3-doc Pipeline"""
  29. entities = "PP-ChatOCRv3-doc"
  30. def __init__(
  31. self,
  32. config: Dict,
  33. device: str = None,
  34. pp_option: PaddlePredictorOption = None,
  35. use_hpip: bool = False,
  36. hpi_params: Optional[Dict[str, Any]] = None,
  37. ) -> None:
  38. """Initializes the pp-chatocrv3-doc pipeline.
  39. Args:
  40. config (Dict): Configuration dictionary containing various settings.
  41. device (str, optional): Device to run the predictions on. Defaults to None.
  42. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  43. use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
  44. hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
  45. """
  46. super().__init__(
  47. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
  48. )
  49. self.inintial_predictor(config)
  50. self.img_reader = ReadImage(format="BGR")
  51. self.table_structure_len_max = 500
  52. def inintial_predictor(self, config: dict) -> None:
  53. """
  54. Initializes the predictor with the given configuration.
  55. Args:
  56. config (dict): The configuration dictionary containing the necessary
  57. parameters for initializing the predictor.
  58. Returns:
  59. None
  60. """
  61. layout_parsing_config = config["SubPipelines"]["LayoutParser"]
  62. self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
  63. from .. import create_chat_bot
  64. chat_bot_config = config["SubModules"]["LLM_Chat"]
  65. self.chat_bot = create_chat_bot(chat_bot_config)
  66. from .. import create_retriever
  67. retriever_config = config["SubModules"]["LLM_Retriever"]
  68. self.retriever = create_retriever(retriever_config)
  69. from .. import create_prompt_engeering
  70. text_pe_config = config["SubModules"]["PromptEngneering"]["KIE_CommonText"]
  71. self.text_pe = create_prompt_engeering(text_pe_config)
  72. table_pe_config = config["SubModules"]["PromptEngneering"]["KIE_Table"]
  73. self.table_pe = create_prompt_engeering(table_pe_config)
  74. return
  75. def decode_visual_result(
  76. self, layout_parsing_result: LayoutParsingResult
  77. ) -> VisualInfoResult:
  78. """
  79. Decodes the visual result from the layout parsing result.
  80. Args:
  81. layout_parsing_result (LayoutParsingResult): The result of layout parsing.
  82. Returns:
  83. VisualInfoResult: The decoded visual information.
  84. """
  85. text_paragraphs_ocr_res = layout_parsing_result["text_paragraphs_ocr_res"]
  86. seal_res_list = layout_parsing_result["seal_res_list"]
  87. normal_text_dict = {}
  88. for seal_res in seal_res_list:
  89. for text in seal_res["rec_text"]:
  90. layout_type = "印章"
  91. if layout_type not in normal_text_dict:
  92. normal_text_dict[layout_type] = f"{text}"
  93. else:
  94. normal_text_dict[layout_type] += f"\n {text}"
  95. for text in text_paragraphs_ocr_res["rec_text"]:
  96. layout_type = "words in text block"
  97. if layout_type not in normal_text_dict:
  98. normal_text_dict[layout_type] = text
  99. else:
  100. normal_text_dict[layout_type] += f"\n {text}"
  101. table_res_list = layout_parsing_result["table_res_list"]
  102. table_text_list = []
  103. table_html_list = []
  104. for table_res in table_res_list:
  105. table_html_list.append(table_res["pred_html"])
  106. single_table_text = " ".join(table_res["table_ocr_pred"]["rec_text"])
  107. table_text_list.append(single_table_text)
  108. visual_info = {}
  109. visual_info["normal_text_dict"] = normal_text_dict
  110. visual_info["table_text_list"] = table_text_list
  111. visual_info["table_html_list"] = table_html_list
  112. return VisualInfoResult(visual_info)
  113. # Function to perform visual prediction on input images
  114. def visual_predict(
  115. self,
  116. input: str | list[str] | np.ndarray | list[np.ndarray],
  117. use_doc_orientation_classify: bool = False, # Whether to use document orientation classification
  118. use_doc_unwarping: bool = False, # Whether to use document unwarping
  119. use_common_ocr: bool = True, # Whether to use common OCR
  120. use_seal_recognition: bool = True, # Whether to use seal recognition
  121. use_table_recognition: bool = True, # Whether to use table recognition
  122. **kwargs,
  123. ) -> dict:
  124. """
  125. This function takes an input image or a list of images and performs various visual
  126. prediction tasks such as document orientation classification, document unwarping,
  127. common OCR, seal recognition, and table recognition based on the provided flags.
  128. Args:
  129. input (str | list[str] | np.ndarray | list[np.ndarray]): Input image path, list of image paths,
  130. numpy array of an image, or list of numpy arrays.
  131. use_doc_orientation_classify (bool): Flag to use document orientation classification.
  132. use_doc_unwarping (bool): Flag to use document unwarping.
  133. use_common_ocr (bool): Flag to use common OCR.
  134. use_seal_recognition (bool): Flag to use seal recognition.
  135. use_table_recognition (bool): Flag to use table recognition.
  136. **kwargs: Additional keyword arguments.
  137. Returns:
  138. dict: A dictionary containing the layout parsing result and visual information.
  139. """
  140. if not isinstance(input, list):
  141. input_list = [input]
  142. else:
  143. input_list = input
  144. img_id = 1
  145. for input in input_list:
  146. if isinstance(input, str):
  147. image_array = next(self.img_reader(input))[0]["img"]
  148. else:
  149. image_array = input
  150. assert len(image_array.shape) == 3
  151. layout_parsing_result = next(
  152. self.layout_parsing_pipeline.predict(
  153. image_array,
  154. use_doc_orientation_classify=use_doc_orientation_classify,
  155. use_doc_unwarping=use_doc_unwarping,
  156. use_common_ocr=use_common_ocr,
  157. use_seal_recognition=use_seal_recognition,
  158. use_table_recognition=use_table_recognition,
  159. )
  160. )
  161. visual_info = self.decode_visual_result(layout_parsing_result)
  162. visual_predict_res = {
  163. "layout_parsing_result": layout_parsing_result,
  164. "visual_info": visual_info,
  165. }
  166. yield visual_predict_res
  167. def save_visual_info_list(
  168. self, visual_info: VisualInfoResult, save_path: str
  169. ) -> None:
  170. """
  171. Save the visual info list to the specified file path.
  172. Args:
  173. visual_info (VisualInfoResult): The visual info result, which can be a single object or a list of objects.
  174. save_path (str): The file path to save the visual info list.
  175. Returns:
  176. None
  177. """
  178. if not isinstance(visual_info, list):
  179. visual_info_list = [visual_info]
  180. else:
  181. visual_info_list = visual_info
  182. with open(save_path, "w") as fout:
  183. fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
  184. return
  185. def load_visual_info_list(self, data_path: str) -> list[VisualInfoResult]:
  186. """
  187. Loads visual info list from a JSON file.
  188. Args:
  189. data_path (str): The path to the JSON file containing visual info.
  190. Returns:
  191. list[VisualInfoResult]: A list of VisualInfoResult objects parsed from the JSON file.
  192. """
  193. with open(data_path, "r") as fin:
  194. data = fin.readline()
  195. visual_info_list = json.loads(data)
  196. return visual_info_list
  197. def merge_visual_info_list(
  198. self, visual_info_list: list[VisualInfoResult]
  199. ) -> tuple[list, list, list]:
  200. """
  201. Merge visual info lists.
  202. Args:
  203. visual_info_list (list[VisualInfoResult]): A list of visual info results.
  204. Returns:
  205. tuple[list, list, list]: A tuple containing three lists, one for normal text dicts,
  206. one for table text lists, and one for table HTML lists.
  207. """
  208. all_normal_text_list = []
  209. all_table_text_list = []
  210. all_table_html_list = []
  211. for single_visual_info in visual_info_list:
  212. normal_text_dict = single_visual_info["normal_text_dict"]
  213. table_text_list = single_visual_info["table_text_list"]
  214. table_html_list = single_visual_info["table_html_list"]
  215. all_normal_text_list.append(normal_text_dict)
  216. all_table_text_list.extend(table_text_list)
  217. all_table_html_list.extend(table_html_list)
  218. return all_normal_text_list, all_table_text_list, all_table_html_list
  219. def build_vector(
  220. self,
  221. visual_info: VisualInfoResult,
  222. min_characters: int = 3500,
  223. llm_request_interval: float = 1.0,
  224. ) -> dict:
  225. """
  226. Build a vector representation from visual information.
  227. Args:
  228. visual_info (VisualInfoResult): The visual information input, can be a single instance or a list of instances.
  229. min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
  230. llm_request_interval (float): The interval between LLM requests, defaults to 1.0.
  231. Returns:
  232. dict: A dictionary containing the vector info and a flag indicating if the text is too short.
  233. """
  234. if not isinstance(visual_info, list):
  235. visual_info_list = [visual_info]
  236. else:
  237. visual_info_list = visual_info
  238. all_visual_info = self.merge_visual_info_list(visual_info_list)
  239. all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
  240. vector_info = {}
  241. all_items = []
  242. for i, normal_text_dict in enumerate(all_normal_text_list):
  243. for type, text in normal_text_dict.items():
  244. all_items += [f"{type}:{text}\n"]
  245. for table_html, table_text in zip(all_table_html_list, all_table_text_list):
  246. if len(table_html) > min_characters - self.table_structure_len_max:
  247. all_items += [f"table:{table_text}\n"]
  248. all_text_str = "".join(all_items)
  249. if len(all_text_str) > min_characters:
  250. vector_info["flag_too_short_text"] = False
  251. vector_info["vector"] = self.retriever.generate_vector_database(all_items)
  252. else:
  253. vector_info["flag_too_short_text"] = True
  254. vector_info["vector"] = all_items
  255. return vector_info
  256. def format_key(self, key_list: str | list[str]) -> list[str]:
  257. """
  258. Formats the key list.
  259. Args:
  260. key_list (str|list[str]): A string or a list of strings representing the keys.
  261. Returns:
  262. list[str]: A list of formatted keys.
  263. """
  264. if key_list == "":
  265. return []
  266. if isinstance(key_list, list):
  267. return key_list
  268. if isinstance(key_list, str):
  269. key_list = re.sub(r"[\t\n\r\f\v]", "", key_list)
  270. key_list = key_list.replace(",", ",").split(",")
  271. return key_list
  272. return []
  273. def fix_llm_result_format(self, llm_result: str) -> dict:
  274. """
  275. Fix the format of the LLM result.
  276. Args:
  277. llm_result (str): The result from the LLM (Large Language Model).
  278. Returns:
  279. dict: A fixed format dictionary from the LLM result.
  280. """
  281. if not llm_result:
  282. return {}
  283. if "json" in llm_result or "```" in llm_result:
  284. llm_result = (
  285. llm_result.replace("```", "").replace("json", "").replace("/n", "")
  286. )
  287. llm_result = llm_result.replace("[", "").replace("]", "")
  288. try:
  289. llm_result = json.loads(llm_result)
  290. llm_result_final = {}
  291. for key in llm_result:
  292. value = llm_result[key]
  293. if isinstance(value, list):
  294. if len(value) > 0:
  295. llm_result_final[key] = value[0]
  296. else:
  297. llm_result_final[key] = value
  298. return llm_result_final
  299. except:
  300. results = (
  301. llm_result.replace("\n", "")
  302. .replace(" ", "")
  303. .replace("{", "")
  304. .replace("}", "")
  305. )
  306. if not results.endswith('"'):
  307. results = results + '"'
  308. pattern = r'"(.*?)": "([^"]*)"'
  309. matches = re.findall(pattern, str(results))
  310. if len(matches) > 0:
  311. llm_result = {k: v for k, v in matches}
  312. return llm_result
  313. else:
  314. return {}
  315. def generate_and_merge_chat_results(
  316. self, prompt: str, key_list: list, final_results: dict, failed_results: dict
  317. ) -> None:
  318. """
  319. Generate and merge chat results into the final results dictionary.
  320. Args:
  321. prompt (str): The input prompt for the chat bot.
  322. key_list (list): A list of keys to track which results to merge.
  323. final_results (dict): The dictionary to store the final merged results.
  324. failed_results (dict): A dictionary of failed results to avoid merging.
  325. Returns:
  326. None
  327. """
  328. llm_result = self.chat_bot.generate_chat_results(prompt)
  329. if llm_result is None:
  330. logging.warning(
  331. "chat bot error: \n [prompt:]\n %s\n [result:] %s\n"
  332. % (prompt, self.chat_bot.ERROR_MASSAGE)
  333. )
  334. return
  335. llm_result = self.fix_llm_result_format(llm_result)
  336. for key, value in llm_result.items():
  337. if value not in failed_results and key in key_list:
  338. key_list.remove(key)
  339. final_results[key] = value
  340. return
  341. def chat(
  342. self,
  343. key_list: str | list[str],
  344. visual_info: VisualInfoResult,
  345. use_vector_retrieval: bool = True,
  346. vector_info: dict = None,
  347. min_characters: int = 3500,
  348. text_task_description: str = None,
  349. text_output_format: str = None,
  350. text_rules_str: str = None,
  351. text_few_shot_demo_text_content: str = None,
  352. text_few_shot_demo_key_value_list: str = None,
  353. table_task_description: str = None,
  354. table_output_format: str = None,
  355. table_rules_str: str = None,
  356. table_few_shot_demo_text_content: str = None,
  357. table_few_shot_demo_key_value_list: str = None,
  358. ) -> dict:
  359. """
  360. Generates chat results based on the provided key list and visual information.
  361. Args:
  362. key_list (str | list[str]): A single key or a list of keys to extract information.
  363. visual_info (VisualInfoResult): The visual information result.
  364. use_vector_retrieval (bool): Whether to use vector retrieval.
  365. vector_info (dict): The vector information for retrieval.
  366. min_characters (int): The minimum number of characters required.
  367. text_task_description (str): The description of the text task.
  368. text_output_format (str): The output format for text results.
  369. text_rules_str (str): The rules for generating text results.
  370. text_few_shot_demo_text_content (str): The text content for few-shot demos.
  371. text_few_shot_demo_key_value_list (str): The key-value list for few-shot demos.
  372. table_task_description (str): The description of the table task.
  373. table_output_format (str): The output format for table results.
  374. table_rules_str (str): The rules for generating table results.
  375. table_few_shot_demo_text_content (str): The text content for table few-shot demos.
  376. table_few_shot_demo_key_value_list (str): The key-value list for table few-shot demos.
  377. Returns:
  378. dict: A dictionary containing the chat results.
  379. """
  380. key_list = self.format_key(key_list)
  381. if len(key_list) == 0:
  382. return {"error": "输入的key_list无效!"}
  383. if not isinstance(visual_info, list):
  384. visual_info_list = [visual_info]
  385. else:
  386. visual_info_list = visual_info
  387. all_visual_info = self.merge_visual_info_list(visual_info_list)
  388. all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
  389. final_results = {}
  390. failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
  391. for table_html, table_text in zip(all_table_html_list, all_table_text_list):
  392. if len(table_html) <= min_characters - self.table_structure_len_max:
  393. for table_info in [table_html, table_text]:
  394. if len(key_list) > 0:
  395. prompt = self.table_pe.generate_prompt(
  396. table_info,
  397. key_list,
  398. task_description=table_task_description,
  399. output_format=table_output_format,
  400. rules_str=table_rules_str,
  401. few_shot_demo_text_content=table_few_shot_demo_text_content,
  402. few_shot_demo_key_value_list=table_few_shot_demo_key_value_list,
  403. )
  404. self.generate_and_merge_chat_results(
  405. prompt, key_list, final_results, failed_results
  406. )
  407. if len(key_list) > 0:
  408. if use_vector_retrieval and vector_info is not None:
  409. question_key_list = [f"抽取关键信息:{key}" for key in key_list]
  410. vector = vector_info["vector"]
  411. if not vector_info["flag_too_short_text"]:
  412. related_text = self.retriever.similarity_retrieval(
  413. question_key_list, vector
  414. )
  415. # print(question_key_list, related_text)
  416. else:
  417. if len(vector) > 0:
  418. related_text = "".join(vector)
  419. else:
  420. related_text = ""
  421. else:
  422. all_items = []
  423. for i, normal_text_dict in enumerate(all_normal_text_list):
  424. for type, text in normal_text_dict.items():
  425. all_items += [f"{type}:{text}\n"]
  426. related_text = "".join(all_items)
  427. if len(related_text) > min_characters:
  428. logging.warning(
  429. "The input text content is too long, the large language model may truncate it."
  430. )
  431. if len(related_text) > 0:
  432. prompt = self.text_pe.generate_prompt(
  433. related_text,
  434. key_list,
  435. task_description=text_task_description,
  436. output_format=text_output_format,
  437. rules_str=text_rules_str,
  438. few_shot_demo_text_content=text_few_shot_demo_text_content,
  439. few_shot_demo_key_value_list=text_few_shot_demo_key_value_list,
  440. )
  441. # print(prompt)
  442. self.generate_and_merge_chat_results(
  443. prompt, key_list, final_results, failed_results
  444. )
  445. return {"chat_res": final_results}
  446. def predict(self, *args, **kwargs) -> None:
  447. logging.error(
  448. "PP-ChatOCRv3-doc Pipeline do not support to call `predict()` directly! Please invoke `visual_predict`, `build_vector`, `chat` sequentially to obtain the result."
  449. )
  450. return