pipeline.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..base import BasePipeline
  15. from typing import Any, Dict, Optional
  16. # import numpy as np
  17. # import cv2
  18. from .result import VisualInfoResult
  19. import re
  20. ########## [TODO]后续需要更新路径
  21. from ...components.transforms import ReadImage
  22. import json
  23. from ....utils import logging
  24. class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
  25. """PP-ChatOCRv3-doc Pipeline"""
  26. entities = "PP-ChatOCRv3-doc"
  27. def __init__(
  28. self,
  29. config,
  30. device=None,
  31. pp_option=None,
  32. use_hpip: bool = False,
  33. hpi_params: Optional[Dict[str, Any]] = None,
  34. ):
  35. super().__init__(
  36. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
  37. )
  38. self.inintial_predictor(config)
  39. self.img_reader = ReadImage(format="BGR")
  40. def inintial_predictor(self, config):
  41. # layout_parsing_config = config['SubPipelines']["LayoutParser"]
  42. # self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
  43. chat_bot_config = config["SubModules"]["LLM_Chat"]
  44. self.chat_bot = self.create_chat_bot(chat_bot_config)
  45. retriever_config = config["SubModules"]["LLM_Retriever"]
  46. self.retriever = self.create_retriever(retriever_config)
  47. text_pe_config = config["SubModules"]["PromptEngneering"]["KIE_CommonText"]
  48. self.text_pe = self.create_prompt_engeering(text_pe_config)
  49. table_pe_config = config["SubModules"]["PromptEngneering"]["KIE_Table"]
  50. self.table_pe = self.create_prompt_engeering(table_pe_config)
  51. return
  52. def decode_visual_result(self, layout_parsing_result):
  53. text_paragraphs_ocr_res = layout_parsing_result["text_paragraphs_ocr_res"]
  54. seal_res_list = layout_parsing_result["seal_res_list"]
  55. normal_text_dict = {}
  56. layout_type = "text"
  57. for text in text_paragraphs_ocr_res["rec_text"]:
  58. if layout_type not in normal_text_dict:
  59. normal_text_dict[layout_type] = text
  60. else:
  61. normal_text_dict[layout_type] += f"\n {text}"
  62. layout_type = "seal"
  63. for seal_res in seal_res_list:
  64. for text in seal_res["rec_text"]:
  65. if layout_type not in normal_text_dict:
  66. normal_text_dict[layout_type] = text
  67. else:
  68. normal_text_dict[layout_type] += f"\n {text}"
  69. table_res_list = layout_parsing_result["table_res_list"]
  70. table_text_list = []
  71. table_html_list = []
  72. for table_res in table_res_list:
  73. table_html_list.append(table_res["pred_html"])
  74. single_table_text = " ".join(table_res["table_ocr_pred"]["rec_text"])
  75. table_text_list.append(single_table_text)
  76. visual_info = {}
  77. visual_info["normal_text_dict"] = normal_text_dict
  78. visual_info["table_text_list"] = table_text_list
  79. visual_info["table_html_list"] = table_html_list
  80. return VisualInfoResult(visual_info)
  81. def visual_predict(
  82. self,
  83. input,
  84. use_doc_orientation_classify=True,
  85. use_doc_unwarping=True,
  86. use_common_ocr=True,
  87. use_seal_recognition=True,
  88. use_table_recognition=True,
  89. **kwargs,
  90. ):
  91. if not isinstance(input, list):
  92. input_list = [input]
  93. else:
  94. input_list = input
  95. img_id = 1
  96. for input in input_list:
  97. if isinstance(input, str):
  98. image_array = next(self.img_reader(input))[0]["img"]
  99. else:
  100. image_array = input
  101. assert len(image_array.shape) == 3
  102. layout_parsing_result = next(
  103. self.layout_parsing_pipeline.predict(
  104. image_array,
  105. use_doc_orientation_classify=use_doc_orientation_classify,
  106. use_doc_unwarping=use_doc_unwarping,
  107. use_common_ocr=use_common_ocr,
  108. use_seal_recognition=use_seal_recognition,
  109. use_table_recognition=use_table_recognition,
  110. )
  111. )
  112. visual_info = self.decode_visual_result(layout_parsing_result)
  113. visual_predict_res = {
  114. "layout_parsing_result": layout_parsing_result,
  115. "visual_info": visual_info,
  116. }
  117. yield visual_predict_res
  118. def save_visual_info_list(self, visual_info, save_path):
  119. if not isinstance(visual_info, list):
  120. visual_info_list = [visual_info]
  121. else:
  122. visual_info_list = visual_info
  123. with open(save_path, "w") as fout:
  124. fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
  125. return
  126. def load_visual_info_list(self, data_path):
  127. with open(data_path, "r") as fin:
  128. data = fin.readline()
  129. visual_info_list = json.loads(data)
  130. return visual_info_list
  131. def merge_visual_info_list(self, visual_info_list):
  132. all_normal_text_list = []
  133. all_table_text_list = []
  134. all_table_html_list = []
  135. for single_visual_info in visual_info_list:
  136. normal_text_dict = single_visual_info["normal_text_dict"]
  137. table_text_list = single_visual_info["table_text_list"]
  138. table_html_list = single_visual_info["table_html_list"]
  139. all_normal_text_list.append(normal_text_dict)
  140. all_table_text_list.extend(table_text_list)
  141. all_table_html_list.extend(table_html_list)
  142. return all_normal_text_list, all_table_text_list, all_table_html_list
  143. def build_vector(self, visual_info, min_characters=3500, llm_request_interval=1.0):
  144. if not isinstance(visual_info, list):
  145. visual_info_list = [visual_info]
  146. else:
  147. visual_info_list = visual_info
  148. all_visual_info = self.merge_visual_info_list(visual_info_list)
  149. all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
  150. all_normal_text_str = "".join(
  151. ["\n".join(e.values()) for e in all_normal_text_list]
  152. )
  153. vector_info = {}
  154. all_items = []
  155. for i, normal_text_dict in enumerate(all_normal_text_list):
  156. for type, text in normal_text_dict.items():
  157. all_items += [f"{type}:{text}"]
  158. if len(all_normal_text_str) > min_characters:
  159. vector_info["flag_too_short_text"] = False
  160. vector_info["vector"] = self.retriever.generate_vector_database(all_items)
  161. else:
  162. vector_info["flag_too_short_text"] = True
  163. vector_info["vector"] = all_items
  164. return vector_info
  165. def format_key(self, key_list):
  166. """format key"""
  167. if key_list == "":
  168. return []
  169. if isinstance(key_list, list):
  170. return key_list
  171. if isinstance(key_list, str):
  172. key_list = re.sub(r"[\t\n\r\f\v]", "", key_list)
  173. key_list = key_list.replace(",", ",").split(",")
  174. return key_list
  175. return []
  176. def fix_llm_result_format(self, llm_result):
  177. if not llm_result:
  178. return {}
  179. if "json" in llm_result or "```" in llm_result:
  180. llm_result = (
  181. llm_result.replace("```", "").replace("json", "").replace("/n", "")
  182. )
  183. llm_result = llm_result.replace("[", "").replace("]", "")
  184. try:
  185. llm_result = json.loads(llm_result)
  186. llm_result_final = {}
  187. for key in llm_result:
  188. value = llm_result[key]
  189. if isinstance(value, list):
  190. if len(value) > 0:
  191. llm_result_final[key] = value[0]
  192. else:
  193. llm_result_final[key] = value
  194. return llm_result_final
  195. except:
  196. results = (
  197. llm_result.replace("\n", "")
  198. .replace(" ", "")
  199. .replace("{", "")
  200. .replace("}", "")
  201. )
  202. if not results.endswith('"'):
  203. results = results + '"'
  204. pattern = r'"(.*?)": "([^"]*)"'
  205. matches = re.findall(pattern, str(results))
  206. if len(matches) > 0:
  207. llm_result = {k: v for k, v in matches}
  208. return llm_result
  209. else:
  210. return {}
  211. def generate_and_merge_chat_results(
  212. self, prompt, key_list, final_results, failed_results
  213. ):
  214. llm_result = self.chat_bot.generate_chat_results(prompt)
  215. llm_result = self.fix_llm_result_format(llm_result)
  216. for key, value in llm_result.items():
  217. if value not in failed_results and key in key_list:
  218. key_list.remove(key)
  219. final_results[key] = value
  220. return
  221. def chat(
  222. self,
  223. visual_info,
  224. key_list,
  225. vector_info,
  226. text_task_description=None,
  227. text_output_format=None,
  228. text_rules_str=None,
  229. text_few_shot_demo_text_content=None,
  230. text_few_shot_demo_key_value_list=None,
  231. table_task_description=None,
  232. table_output_format=None,
  233. table_rules_str=None,
  234. table_few_shot_demo_text_content=None,
  235. table_few_shot_demo_key_value_list=None,
  236. ):
  237. key_list = self.format_key(key_list)
  238. if len(key_list) == 0:
  239. return {"chat_res": "输入的key_list无效!"}
  240. if not isinstance(visual_info, list):
  241. visual_info_list = [visual_info]
  242. else:
  243. visual_info_list = visual_info
  244. all_visual_info = self.merge_visual_info_list(visual_info_list)
  245. all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
  246. final_results = {}
  247. failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
  248. for all_table_info in [all_table_html_list, all_table_text_list]:
  249. for table_info in all_table_info:
  250. if len(key_list) == 0:
  251. continue
  252. prompt = self.table_pe.generate_prompt(
  253. table_info,
  254. key_list,
  255. task_description=table_task_description,
  256. output_format=table_output_format,
  257. rules_str=table_rules_str,
  258. few_shot_demo_text_content=table_few_shot_demo_text_content,
  259. few_shot_demo_key_value_list=table_few_shot_demo_key_value_list,
  260. )
  261. self.generate_and_merge_chat_results(
  262. prompt, key_list, final_results, failed_results
  263. )
  264. if len(key_list) > 0:
  265. question_key_list = [f"抽取关键信息:{key}" for key in key_list]
  266. vector = vector_info["vector"]
  267. if not vector_info["flag_too_short_text"]:
  268. related_text = self.retriever.similarity_retrieval(
  269. question_key_list, vector
  270. )
  271. else:
  272. related_text = " ".join(vector)
  273. prompt = self.text_pe.generate_prompt(
  274. related_text,
  275. key_list,
  276. task_description=text_task_description,
  277. output_format=text_output_format,
  278. rules_str=text_rules_str,
  279. few_shot_demo_text_content=text_few_shot_demo_text_content,
  280. few_shot_demo_key_value_list=text_few_shot_demo_key_value_list,
  281. )
  282. self.generate_and_merge_chat_results(
  283. prompt, key_list, final_results, failed_results
  284. )
  285. return final_results
  286. def predict(self, *args, **kwargs):
  287. logging.error(
  288. "PP-ChatOCRv3-doc Pipeline do not support to call `predict()` directly! Please invoke `visual_predict`, `build_vector`, `chat` sequentially to obtain the result."
  289. )
  290. return