pipeline_v4.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, Optional
  15. import re
  16. import cv2
  17. import json
  18. import base64
  19. import numpy as np
  20. import copy
  21. from .pipeline_base import PP_ChatOCR_Pipeline
  22. from .result import VisualInfoResult
  23. from ...common.reader import ReadImage
  24. from ...common.batch_sampler import ImageBatchSampler
  25. from ....utils import logging
  26. from ...utils.pp_option import PaddlePredictorOption
  27. from ..layout_parsing.result import LayoutParsingResult
  28. class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
  29. """PP-ChatOCRv4 Pipeline"""
  30. entities = ["PP-ChatOCRv4-doc"]
  31. def __init__(
  32. self,
  33. config: Dict,
  34. device: str = None,
  35. pp_option: PaddlePredictorOption = None,
  36. use_hpip: bool = False,
  37. hpi_params: Optional[Dict[str, Any]] = None,
  38. ) -> None:
  39. """Initializes the pp-chatocrv3-doc pipeline.
  40. Args:
  41. config (Dict): Configuration dictionary containing various settings.
  42. device (str, optional): Device to run the predictions on. Defaults to None.
  43. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  44. use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
  45. hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
  46. use_layout_parsing (bool, optional): Whether to use layout parsing. Defaults to True.
  47. """
  48. super().__init__(
  49. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
  50. )
  51. self.pipeline_name = config["pipeline_name"]
  52. self.inintial_predictor(config)
  53. self.batch_sampler = ImageBatchSampler(batch_size=1)
  54. self.img_reader = ReadImage(format="BGR")
  55. self.table_structure_len_max = 500
  56. def inintial_predictor(self, config: dict) -> None:
  57. """
  58. Initializes the predictor with the given configuration.
  59. Args:
  60. config (dict): The configuration dictionary containing the necessary
  61. parameters for initializing the predictor.
  62. Returns:
  63. None
  64. """
  65. self.use_layout_parser = True
  66. if "use_layout_parser" in config:
  67. self.use_layout_parser = config["use_layout_parser"]
  68. if self.use_layout_parser:
  69. layout_parsing_config = config["SubPipelines"]["LayoutParser"]
  70. self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
  71. from .. import create_chat_bot
  72. chat_bot_config = config["SubModules"]["LLM_Chat"]
  73. self.chat_bot = create_chat_bot(chat_bot_config)
  74. from .. import create_retriever
  75. retriever_config = config["SubModules"]["LLM_Retriever"]
  76. self.retriever = create_retriever(retriever_config)
  77. from .. import create_prompt_engeering
  78. text_pe_config = config["SubModules"]["PromptEngneering"]["KIE_CommonText"]
  79. self.text_pe = create_prompt_engeering(text_pe_config)
  80. table_pe_config = config["SubModules"]["PromptEngneering"]["KIE_Table"]
  81. self.table_pe = create_prompt_engeering(table_pe_config)
  82. self.use_mllm_predict = False
  83. if "use_mllm_predict" in config:
  84. self.use_mllm_predict = config["use_mllm_predict"]
  85. if self.use_mllm_predict:
  86. mllm_chat_bot_config = config["SubModules"]["MLLM_Chat"]
  87. self.mllm_chat_bot = create_chat_bot(mllm_chat_bot_config)
  88. ensemble_pe_config = config["SubModules"]["PromptEngneering"]["Ensemble"]
  89. self.ensemble_pe = create_prompt_engeering(ensemble_pe_config)
  90. return
  91. def decode_visual_result(
  92. self, layout_parsing_result: LayoutParsingResult
  93. ) -> VisualInfoResult:
  94. """
  95. Decodes the visual result from the layout parsing result.
  96. Args:
  97. layout_parsing_result (LayoutParsingResult): The result of layout parsing.
  98. Returns:
  99. VisualInfoResult: The decoded visual information.
  100. """
  101. text_paragraphs_ocr_res = layout_parsing_result["text_paragraphs_ocr_res"]
  102. seal_res_list = layout_parsing_result["seal_res_list"]
  103. normal_text_dict = {}
  104. for seal_res in seal_res_list:
  105. for text in seal_res["rec_text"]:
  106. layout_type = "印章"
  107. if layout_type not in normal_text_dict:
  108. normal_text_dict[layout_type] = f"{text}"
  109. else:
  110. normal_text_dict[layout_type] += f"\n {text}"
  111. for text in text_paragraphs_ocr_res["rec_text"]:
  112. layout_type = "words in text block"
  113. if layout_type not in normal_text_dict:
  114. normal_text_dict[layout_type] = text
  115. else:
  116. normal_text_dict[layout_type] += f"\n {text}"
  117. table_res_list = layout_parsing_result["table_res_list"]
  118. table_text_list = []
  119. table_html_list = []
  120. table_nei_text_list = []
  121. for table_res in table_res_list:
  122. table_html_list.append(table_res["pred_html"])
  123. single_table_text = " ".join(table_res["table_ocr_pred"]["rec_text"])
  124. table_text_list.append(single_table_text)
  125. table_nei_text_list.append(table_res["neighbor_text"])
  126. visual_info = {}
  127. visual_info["normal_text_dict"] = normal_text_dict
  128. visual_info["table_text_list"] = table_text_list
  129. visual_info["table_html_list"] = table_html_list
  130. visual_info["table_nei_text_list"] = table_nei_text_list
  131. return VisualInfoResult(visual_info)
  132. # Function to perform visual prediction on input images
  133. def visual_predict(
  134. self,
  135. input: str | list[str] | np.ndarray | list[np.ndarray],
  136. use_doc_orientation_classify: bool = False, # Whether to use document orientation classification
  137. use_doc_unwarping: bool = False, # Whether to use document unwarping
  138. use_general_ocr: bool = True, # Whether to use general OCR
  139. use_seal_recognition: bool = True, # Whether to use seal recognition
  140. use_table_recognition: bool = True, # Whether to use table recognition
  141. **kwargs,
  142. ) -> dict:
  143. """
  144. This function takes an input image or a list of images and performs various visual
  145. prediction tasks such as document orientation classification, document unwarping,
  146. general OCR, seal recognition, and table recognition based on the provided flags.
  147. Args:
  148. input (str | list[str] | np.ndarray | list[np.ndarray]): Input image path, list of image paths,
  149. numpy array of an image, or list of numpy arrays.
  150. use_doc_orientation_classify (bool): Flag to use document orientation classification.
  151. use_doc_unwarping (bool): Flag to use document unwarping.
  152. use_general_ocr (bool): Flag to use general OCR.
  153. use_seal_recognition (bool): Flag to use seal recognition.
  154. use_table_recognition (bool): Flag to use table recognition.
  155. **kwargs: Additional keyword arguments.
  156. Returns:
  157. dict: A dictionary containing the layout parsing result and visual information.
  158. """
  159. if self.use_layout_parser == False:
  160. logging.error("The models for layout parser are not initialized.")
  161. yield None
  162. for layout_parsing_result in self.layout_parsing_pipeline.predict(
  163. input,
  164. use_doc_orientation_classify=use_doc_orientation_classify,
  165. use_doc_unwarping=use_doc_unwarping,
  166. use_general_ocr=use_general_ocr,
  167. use_seal_recognition=use_seal_recognition,
  168. use_table_recognition=use_table_recognition,
  169. ):
  170. visual_info = self.decode_visual_result(layout_parsing_result)
  171. visual_predict_res = {
  172. "layout_parsing_result": layout_parsing_result,
  173. "visual_info": visual_info,
  174. }
  175. yield visual_predict_res
  176. def save_visual_info_list(
  177. self, visual_info: VisualInfoResult, save_path: str
  178. ) -> None:
  179. """
  180. Save the visual info list to the specified file path.
  181. Args:
  182. visual_info (VisualInfoResult): The visual info result, which can be a single object or a list of objects.
  183. save_path (str): The file path to save the visual info list.
  184. Returns:
  185. None
  186. """
  187. if not isinstance(visual_info, list):
  188. visual_info_list = [visual_info]
  189. else:
  190. visual_info_list = visual_info
  191. with open(save_path, "w") as fout:
  192. fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
  193. return
  194. def load_visual_info_list(self, data_path: str) -> list[VisualInfoResult]:
  195. """
  196. Loads visual info list from a JSON file.
  197. Args:
  198. data_path (str): The path to the JSON file containing visual info.
  199. Returns:
  200. list[VisualInfoResult]: A list of VisualInfoResult objects parsed from the JSON file.
  201. """
  202. with open(data_path, "r") as fin:
  203. data = fin.readline()
  204. visual_info_list = json.loads(data)
  205. return visual_info_list
  206. def merge_visual_info_list(
  207. self, visual_info_list: list[VisualInfoResult]
  208. ) -> tuple[list, list, list, list]:
  209. """
  210. Merge visual info lists.
  211. Args:
  212. visual_info_list (list[VisualInfoResult]): A list of visual info results.
  213. Returns:
  214. tuple[list, list, list, list]: A tuple containing four lists, one for normal text dicts,
  215. one for table text lists, one for table HTML lists.
  216. one for table neighbor texts.
  217. """
  218. all_normal_text_list = []
  219. all_table_text_list = []
  220. all_table_html_list = []
  221. all_table_nei_text_list = []
  222. for single_visual_info in visual_info_list:
  223. normal_text_dict = single_visual_info["normal_text_dict"]
  224. for key in normal_text_dict:
  225. normal_text_dict[key] = normal_text_dict[key].replace("\n", "")
  226. table_text_list = single_visual_info["table_text_list"]
  227. table_html_list = single_visual_info["table_html_list"]
  228. table_nei_text_list = single_visual_info["table_nei_text_list"]
  229. all_normal_text_list.append(normal_text_dict)
  230. all_table_text_list.extend(table_text_list)
  231. all_table_html_list.extend(table_html_list)
  232. all_table_nei_text_list.extend(table_nei_text_list)
  233. return (
  234. all_normal_text_list,
  235. all_table_text_list,
  236. all_table_html_list,
  237. all_table_nei_text_list,
  238. )
  239. def build_vector(
  240. self,
  241. visual_info: VisualInfoResult,
  242. min_characters: int = 3500,
  243. llm_request_interval: float = 1.0,
  244. ) -> dict:
  245. """
  246. Build a vector representation from visual information.
  247. Args:
  248. visual_info (VisualInfoResult): The visual information input, can be a single instance or a list of instances.
  249. min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
  250. llm_request_interval (float): The interval between LLM requests, defaults to 1.0.
  251. Returns:
  252. dict: A dictionary containing the vector info and a flag indicating if the text is too short.
  253. """
  254. if not isinstance(visual_info, list):
  255. visual_info_list = [visual_info]
  256. else:
  257. visual_info_list = visual_info
  258. all_visual_info = self.merge_visual_info_list(visual_info_list)
  259. (
  260. all_normal_text_list,
  261. all_table_text_list,
  262. all_table_html_list,
  263. all_table_nei_text_list,
  264. ) = all_visual_info
  265. vector_info = {}
  266. all_items = []
  267. for i, normal_text_dict in enumerate(all_normal_text_list):
  268. for type, text in normal_text_dict.items():
  269. all_items += [f"{type}:{text}\n"]
  270. for table_html, table_text, table_nei_text in zip(
  271. all_table_html_list, all_table_text_list, all_table_nei_text_list
  272. ):
  273. if len(table_html) > min_characters - self.table_structure_len_max:
  274. all_items += [f"table:{table_text}\t{table_nei_text}"]
  275. all_text_str = "".join(all_items)
  276. if len(all_text_str) > min_characters:
  277. vector_info["flag_too_short_text"] = False
  278. vector_info["vector"] = self.retriever.generate_vector_database(all_items)
  279. else:
  280. vector_info["flag_too_short_text"] = True
  281. vector_info["vector"] = all_items
  282. return vector_info
  283. def save_vector(self, vector_info: dict, save_path: str) -> None:
  284. if "flag_too_short_text" not in vector_info or "vector" not in vector_info:
  285. logging.error("Invalid vector info.")
  286. return
  287. save_vector_info = {}
  288. save_vector_info["flag_too_short_text"] = vector_info["flag_too_short_text"]
  289. if not vector_info["flag_too_short_text"]:
  290. save_vector_info["vector"] = self.retriever.encode_vector_store_to_bytes(
  291. vector_info["vector"]
  292. )
  293. else:
  294. save_vector_info["vector"] = vector_info["vector"]
  295. with open(save_path, "w") as fout:
  296. fout.write(json.dumps(save_vector_info, ensure_ascii=False) + "\n")
  297. return
  298. def load_vector(self, data_path: str) -> dict:
  299. vector_info = None
  300. with open(data_path, "r") as fin:
  301. data = fin.readline()
  302. vector_info = json.loads(data)
  303. if "flag_too_short_text" not in vector_info or "vector" not in vector_info:
  304. logging.error("Invalid vector info.")
  305. return
  306. if not vector_info["flag_too_short_text"]:
  307. vector_info["vector"] = self.retriever.decode_vector_store_from_bytes(
  308. vector_info["vector"]
  309. )
  310. return vector_info
  311. def format_key(self, key_list: str | list[str]) -> list[str]:
  312. """
  313. Formats the key list.
  314. Args:
  315. key_list (str|list[str]): A string or a list of strings representing the keys.
  316. Returns:
  317. list[str]: A list of formatted keys.
  318. """
  319. if key_list == "":
  320. return []
  321. if isinstance(key_list, list):
  322. key_list = [key.replace("\xa0", " ") for key in key_list]
  323. return key_list
  324. if isinstance(key_list, str):
  325. key_list = re.sub(r"[\t\n\r\f\v]", "", key_list)
  326. key_list = key_list.replace(",", ",").split(",")
  327. return key_list
  328. return []
  329. def mllm_pred(
  330. self,
  331. input: str | np.ndarray,
  332. key_list,
  333. **kwargs,
  334. ) -> dict:
  335. key_list = self.format_key(key_list)
  336. if len(key_list) == 0:
  337. return {"mllm_res": "Error:输入的key_list无效!"}
  338. if isinstance(input, list):
  339. logging.error("Input is a list, but it's not supported here.")
  340. return {"mllm_res": "Error:Input is a list, but it's not supported here!"}
  341. image_array_list = self.img_reader([input])
  342. if (
  343. isinstance(input, str)
  344. and input.endswith(".pdf")
  345. and len(image_array_list) > 1
  346. ):
  347. logging.error("The input with PDF should have only one page.")
  348. return {"mllm_res": "Error:The input with PDF should have only one page!"}
  349. for image_array in image_array_list:
  350. assert len(image_array.shape) == 3
  351. image_string = cv2.imencode(".jpg", image_array)[1].tostring()
  352. image_base64 = base64.b64encode(image_string).decode("utf-8")
  353. result = {}
  354. for key in key_list:
  355. prompt = (
  356. str(key)
  357. + "\n请用图片中完整出现的内容回答,可以是单词、短语或句子,针对问题回答尽可能详细和完整,并保持格式、单位、符号和标点都与图片中的文字内容完全一致。"
  358. )
  359. mllm_chat_bot_result = self.mllm_chat_bot.generate_chat_results(
  360. prompt=prompt, image=image_base64
  361. )
  362. if mllm_chat_bot_result is None:
  363. return {"mllm_res": "大模型调用失败"}
  364. result[key] = mllm_chat_bot_result
  365. return {"mllm_res": result}
  366. def generate_and_merge_chat_results(
  367. self, prompt: str, key_list: list, final_results: dict, failed_results: list
  368. ) -> None:
  369. """
  370. Generate and merge chat results into the final results dictionary.
  371. Args:
  372. prompt (str): The input prompt for the chat bot.
  373. key_list (list): A list of keys to track which results to merge.
  374. final_results (dict): The dictionary to store the final merged results.
  375. failed_results (list): A list of failed results to avoid merging.
  376. Returns:
  377. None
  378. """
  379. llm_result = self.chat_bot.generate_chat_results(prompt)
  380. if llm_result is None:
  381. logging.error(
  382. "chat bot error: \n [prompt:]\n %s\n [result:] %s\n"
  383. % (prompt, self.chat_bot.ERROR_MASSAGE)
  384. )
  385. return
  386. llm_result = self.chat_bot.fix_llm_result_format(llm_result)
  387. for key, value in llm_result.items():
  388. if value not in failed_results and key in key_list:
  389. key_list.remove(key)
  390. final_results[key] = value
  391. return
  392. def get_related_normal_text(
  393. self,
  394. use_vector_retrieval: bool,
  395. vector_info: dict,
  396. key_list: list[str],
  397. all_normal_text_list: list,
  398. min_characters: int,
  399. ) -> str:
  400. """
  401. Retrieve related normal text based on vector retrieval or all normal text list.
  402. Args:
  403. use_vector_retrieval (bool): Whether to use vector retrieval.
  404. vector_info (dict): Dictionary containing vector information.
  405. key_list (list[str]): List of keys to generate question keys.
  406. all_normal_text_list (list): List of normal text.
  407. min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
  408. Returns:
  409. str: Related normal text.
  410. """
  411. if use_vector_retrieval and vector_info is not None:
  412. question_key_list = [f"{key}" for key in key_list]
  413. vector = vector_info["vector"]
  414. if not vector_info["flag_too_short_text"]:
  415. related_text = self.retriever.similarity_retrieval(
  416. question_key_list, vector, topk=50, min_characters=min_characters
  417. )
  418. else:
  419. if len(vector) > 0:
  420. related_text = "".join(vector)
  421. else:
  422. related_text = ""
  423. else:
  424. all_items = []
  425. for i, normal_text_dict in enumerate(all_normal_text_list):
  426. for type, text in normal_text_dict.items():
  427. all_items += [f"{type}:{text}\n"]
  428. related_text = "".join(all_items)
  429. if len(related_text) > min_characters:
  430. logging.warning(
  431. "The input text content is too long, the large language model may truncate it."
  432. )
  433. return related_text
  434. def ensemble_ocr_llm_mllm(
  435. self, key_list: list[str], ocr_llm_predict_dict: dict, mllm_predict_dict: dict
  436. ) -> dict:
  437. """
  438. Ensemble OCR_LLM and LMM predictions based on given key list.
  439. Args:
  440. key_list (list[str]): List of keys to retrieve predictions.
  441. ocr_llm_predict_dict (dict): Dictionary containing OCR LLM predictions.
  442. mllm_predict_dict (dict): Dictionary containing mLLM predictions.
  443. Returns:
  444. dict: A dictionary with final predictions.
  445. """
  446. final_predict_dict = {}
  447. for key in key_list:
  448. predict = ""
  449. ocr_llm_predict = ""
  450. mllm_predict = ""
  451. if key in ocr_llm_predict_dict:
  452. ocr_llm_predict = ocr_llm_predict_dict[key]
  453. if key in mllm_predict_dict:
  454. mllm_predict = mllm_predict_dict[key]
  455. if ocr_llm_predict != "" and mllm_predict != "":
  456. prompt = self.ensemble_pe.generate_prompt(
  457. key, ocr_llm_predict, mllm_predict
  458. )
  459. llm_result = self.chat_bot.generate_chat_results(prompt)
  460. if llm_result is not None:
  461. llm_result = self.chat_bot.fix_llm_result_format(llm_result)
  462. if key in llm_result:
  463. tmp = llm_result[key]
  464. if "B" in tmp:
  465. predict = mllm_predict
  466. else:
  467. predict = ocr_llm_predict
  468. else:
  469. predict = ocr_llm_predict
  470. elif key in ocr_llm_predict_dict:
  471. predict = ocr_llm_predict_dict[key]
  472. elif key in mllm_predict_dict:
  473. predict = mllm_predict_dict[key]
  474. if predict != "":
  475. final_predict_dict[key] = predict
  476. return final_predict_dict
  477. def chat(
  478. self,
  479. key_list: str | list[str],
  480. visual_info: VisualInfoResult,
  481. use_vector_retrieval: bool = True,
  482. vector_info: dict = None,
  483. min_characters: int = 3500,
  484. text_task_description: str = None,
  485. text_output_format: str = None,
  486. text_rules_str: str = None,
  487. text_few_shot_demo_text_content: str = None,
  488. text_few_shot_demo_key_value_list: str = None,
  489. table_task_description: str = None,
  490. table_output_format: str = None,
  491. table_rules_str: str = None,
  492. table_few_shot_demo_text_content: str = None,
  493. table_few_shot_demo_key_value_list: str = None,
  494. mllm_predict_dict: dict = None,
  495. mllm_integration_strategy: str = "integration",
  496. ) -> dict:
  497. """
  498. Generates chat results based on the provided key list and visual information.
  499. Args:
  500. key_list (str | list[str]): A single key or a list of keys to extract information.
  501. visual_info (VisualInfoResult): The visual information result.
  502. use_vector_retrieval (bool): Whether to use vector retrieval.
  503. vector_info (dict): The vector information for retrieval.
  504. min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
  505. text_task_description (str): The description of the text task.
  506. text_output_format (str): The output format for text results.
  507. text_rules_str (str): The rules for generating text results.
  508. text_few_shot_demo_text_content (str): The text content for few-shot demos.
  509. text_few_shot_demo_key_value_list (str): The key-value list for few-shot demos.
  510. table_task_description (str): The description of the table task.
  511. table_output_format (str): The output format for table results.
  512. table_rules_str (str): The rules for generating table results.
  513. table_few_shot_demo_text_content (str): The text content for table few-shot demos.
  514. table_few_shot_demo_key_value_list (str): The key-value list for table few-shot demos.
  515. mllm_predict_dict (dict): The dictionary of mLLM predicts.
  516. mllm_integration_strategy(str): The integration strategy of mLLM and LLM, defaults to "integration", options are "integration", "llm_only" and "mllm_only".
  517. Returns:
  518. dict: A dictionary containing the chat results.
  519. """
  520. key_list = self.format_key(key_list)
  521. key_list_ori = key_list.copy()
  522. if len(key_list) == 0:
  523. return {"chat_res": "Error:输入的key_list无效!"}
  524. if not isinstance(visual_info, list):
  525. visual_info_list = [visual_info]
  526. else:
  527. visual_info_list = visual_info
  528. all_visual_info = self.merge_visual_info_list(visual_info_list)
  529. (
  530. all_normal_text_list,
  531. all_table_text_list,
  532. all_table_html_list,
  533. all_table_nei_text_list,
  534. ) = all_visual_info
  535. final_results = {}
  536. failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
  537. if len(key_list) > 0:
  538. related_text = self.get_related_normal_text(
  539. use_vector_retrieval,
  540. vector_info,
  541. key_list,
  542. all_normal_text_list,
  543. min_characters,
  544. )
  545. if len(related_text) > 0:
  546. prompt = self.text_pe.generate_prompt(
  547. related_text,
  548. key_list,
  549. task_description=text_task_description,
  550. output_format=text_output_format,
  551. rules_str=text_rules_str,
  552. few_shot_demo_text_content=text_few_shot_demo_text_content,
  553. few_shot_demo_key_value_list=text_few_shot_demo_key_value_list,
  554. )
  555. self.generate_and_merge_chat_results(
  556. prompt, key_list, final_results, failed_results
  557. )
  558. if len(key_list) > 0:
  559. for table_html, table_text, table_nei_text in zip(
  560. all_table_html_list, all_table_text_list, all_table_nei_text_list
  561. ):
  562. if len(table_html) <= min_characters - self.table_structure_len_max:
  563. for table_info in [table_html]:
  564. if len(key_list) > 0:
  565. if len(table_nei_text) > 0:
  566. table_info = (
  567. table_info + "\n 表格周围文字:" + table_nei_text
  568. )
  569. prompt = self.table_pe.generate_prompt(
  570. table_info,
  571. key_list,
  572. task_description=table_task_description,
  573. output_format=table_output_format,
  574. rules_str=table_rules_str,
  575. few_shot_demo_text_content=table_few_shot_demo_text_content,
  576. few_shot_demo_key_value_list=table_few_shot_demo_key_value_list,
  577. )
  578. self.generate_and_merge_chat_results(
  579. prompt, key_list, final_results, failed_results
  580. )
  581. if self.use_mllm_predict and mllm_predict_dict != "llm_only":
  582. if mllm_integration_strategy == "integration":
  583. final_predict_dict = self.ensemble_ocr_llm_mllm(
  584. key_list_ori, final_results, mllm_predict_dict
  585. )
  586. elif mllm_integration_strategy == "mllm_only":
  587. final_predict_dict = mllm_predict_dict
  588. else:
  589. return {
  590. "chat_res": f"Error:Unsupported mllm_integration_strategy {mllm_integration_strategy}, only support 'integration', 'llm_only' and 'mllm_only'!"
  591. }
  592. else:
  593. final_predict_dict = final_results
  594. return {"chat_res": final_predict_dict}
  595. def predict(self, *args, **kwargs) -> None:
  596. logging.error(
  597. "PP-ChatOCRv4-doc Pipeline do not support to call `predict()` directly! Please invoke `visual_predict`, `build_vector`, `chat` sequentially to obtain the result."
  598. )
  599. return