pipeline_v3.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, Optional, Union, List, Tuple
  15. import os
  16. import re
  17. import copy
  18. import json
  19. import numpy as np
  20. from .pipeline_base import PP_ChatOCR_Pipeline
  21. from ...common.reader import ReadImage
  22. from ...common.batch_sampler import ImageBatchSampler
  23. from ....utils import logging
  24. from ....utils.file_interface import custom_open
  25. from ...utils.pp_option import PaddlePredictorOption
  26. from ..layout_parsing.result import LayoutParsingResult
  27. from ..components.chat_server import BaseChat
  28. class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
  29. """PP-ChatOCR Pipeline"""
  30. entities = ["PP-ChatOCRv3-doc"]
  31. def __init__(
  32. self,
  33. config: Dict,
  34. device: str = None,
  35. pp_option: PaddlePredictorOption = None,
  36. use_hpip: bool = False,
  37. initial_predictor: bool = True,
  38. ) -> None:
  39. """Initializes the pp-chatocrv3-doc pipeline.
  40. Args:
  41. config (Dict): Configuration dictionary containing various settings.
  42. device (str, optional): Device to run the predictions on. Defaults to None.
  43. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  44. use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
  45. use_layout_parsing (bool, optional): Whether to use layout parsing. Defaults to True.
  46. initial_predictor (bool, optional): Whether to initialize the predictor. Defaults to True.
  47. """
  48. super().__init__(device=device, pp_option=pp_option, use_hpip=use_hpip)
  49. self.pipeline_name = config["pipeline_name"]
  50. self.config = config
  51. self.use_layout_parser = config.get("use_layout_parser", True)
  52. self.layout_parsing_pipeline = None
  53. self.chat_bot = None
  54. self.retriever = None
  55. if initial_predictor:
  56. self.inintial_visual_predictor(config)
  57. self.inintial_chat_predictor(config)
  58. self.inintial_retriever_predictor(config)
  59. self.batch_sampler = ImageBatchSampler(batch_size=1)
  60. self.img_reader = ReadImage(format="BGR")
  61. self.table_structure_len_max = 500
  62. def inintial_visual_predictor(self, config: dict) -> None:
  63. """
  64. Initializes the visual predictor with the given configuration.
  65. Args:
  66. config (dict): The configuration dictionary containing the necessary
  67. parameters for initializing the predictor.
  68. Returns:
  69. None
  70. """
  71. self.use_layout_parser = config.get("use_layout_parser", True)
  72. if self.use_layout_parser:
  73. layout_parsing_config = config.get("SubPipelines", {}).get(
  74. "LayoutParser",
  75. {"pipeline_config_error": "config error for layout_parsing_pipeline!"},
  76. )
  77. self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
  78. return
  79. def inintial_retriever_predictor(self, config: dict) -> None:
  80. """
  81. Initializes the retriever predictor with the given configuration.
  82. Args:
  83. config (dict): The configuration dictionary containing the necessary
  84. parameters for initializing the predictor.
  85. Returns:
  86. None
  87. """
  88. from .. import create_retriever
  89. retriever_config = config.get("SubModules", {}).get(
  90. "LLM_Retriever",
  91. {"retriever_config_error": "config error for llm retriever!"},
  92. )
  93. self.retriever = create_retriever(retriever_config)
  94. def inintial_chat_predictor(self, config: dict) -> None:
  95. """
  96. Initializes the chat predictor with the given configuration.
  97. Args:
  98. config (dict): The configuration dictionary containing the necessary
  99. parameters for initializing the predictor.
  100. Returns:
  101. None
  102. """
  103. from .. import create_chat_bot
  104. chat_bot_config = config.get("SubModules", {}).get(
  105. "LLM_Chat",
  106. {"chat_bot_config_error": "config error for llm chat bot!"},
  107. )
  108. self.chat_bot = create_chat_bot(chat_bot_config)
  109. from .. import create_prompt_engineering
  110. text_pe_config = (
  111. config.get("SubModules", {})
  112. .get("PromptEngneering", {})
  113. .get(
  114. "KIE_CommonText",
  115. {"pe_config_error": "config error for text_pe!"},
  116. )
  117. )
  118. self.text_pe = create_prompt_engineering(text_pe_config)
  119. table_pe_config = (
  120. config.get("SubModules", {})
  121. .get("PromptEngneering", {})
  122. .get(
  123. "KIE_Table",
  124. {"pe_config_error": "config error for table_pe!"},
  125. )
  126. )
  127. self.table_pe = create_prompt_engineering(table_pe_config)
  128. return
  129. def decode_visual_result(self, layout_parsing_result: LayoutParsingResult) -> dict:
  130. """
  131. Decodes the visual result from the layout parsing result.
  132. Args:
  133. layout_parsing_result (LayoutParsingResult): The result of layout parsing.
  134. Returns:
  135. dict: The decoded visual information.
  136. """
  137. text_paragraphs_ocr_res = layout_parsing_result["text_paragraphs_ocr_res"]
  138. seal_res_list = layout_parsing_result["seal_res_list"]
  139. normal_text_dict = {}
  140. for seal_res in seal_res_list:
  141. for text in seal_res["rec_texts"]:
  142. layout_type = "印章"
  143. if layout_type not in normal_text_dict:
  144. normal_text_dict[layout_type] = f"{text}"
  145. else:
  146. normal_text_dict[layout_type] += f"\n {text}"
  147. for text in text_paragraphs_ocr_res["rec_texts"]:
  148. layout_type = "words in text block"
  149. if layout_type not in normal_text_dict:
  150. normal_text_dict[layout_type] = text
  151. else:
  152. normal_text_dict[layout_type] += f"\n {text}"
  153. table_res_list = layout_parsing_result["table_res_list"]
  154. table_text_list = []
  155. table_html_list = []
  156. for table_res in table_res_list:
  157. table_html_list.append(table_res["pred_html"])
  158. single_table_text = " ".join(table_res["table_ocr_pred"]["rec_texts"])
  159. table_text_list.append(single_table_text)
  160. visual_info = {}
  161. visual_info["normal_text_dict"] = normal_text_dict
  162. visual_info["table_text_list"] = table_text_list
  163. visual_info["table_html_list"] = table_html_list
  164. return visual_info
  165. # Function to perform visual prediction on input images
  166. def visual_predict(
  167. self,
  168. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  169. use_doc_orientation_classify: Optional[bool] = None,
  170. use_doc_unwarping: Optional[bool] = None,
  171. use_general_ocr: Optional[bool] = None,
  172. use_seal_recognition: Optional[bool] = None,
  173. use_table_recognition: Optional[bool] = None,
  174. text_det_limit_side_len: Optional[int] = None,
  175. text_det_limit_type: Optional[str] = None,
  176. text_det_thresh: Optional[float] = None,
  177. text_det_box_thresh: Optional[float] = None,
  178. text_det_unclip_ratio: Optional[float] = None,
  179. text_rec_score_thresh: Optional[float] = None,
  180. seal_det_limit_side_len: Optional[int] = None,
  181. seal_det_limit_type: Optional[str] = None,
  182. seal_det_thresh: Optional[float] = None,
  183. seal_det_box_thresh: Optional[float] = None,
  184. seal_det_unclip_ratio: Optional[float] = None,
  185. seal_rec_score_thresh: Optional[float] = None,
  186. **kwargs,
  187. ) -> dict:
  188. """
  189. This function takes an input image or a list of images and performs various visual
  190. prediction tasks such as document orientation classification, document unwarping,
  191. general OCR, seal recognition, and table recognition based on the provided flags.
  192. Args:
  193. input (Union[str, list[str], np.ndarray, list[np.ndarray]]): Input image path, list of image paths,
  194. numpy array of an image, or list of numpy arrays.
  195. use_doc_orientation_classify (bool): Flag to use document orientation classification.
  196. use_doc_unwarping (bool): Flag to use document unwarping.
  197. use_general_ocr (bool): Flag to use general OCR.
  198. use_seal_recognition (bool): Flag to use seal recognition.
  199. use_table_recognition (bool): Flag to use table recognition.
  200. **kwargs: Additional keyword arguments.
  201. Returns:
  202. dict: A dictionary containing the layout parsing result and visual information.
  203. """
  204. if self.use_layout_parser == False:
  205. logging.error("The models for layout parser are not initialized.")
  206. yield {"error": "The models for layout parser are not initialized."}
  207. if self.layout_parsing_pipeline is None:
  208. logging.warning(
  209. "The layout parsing pipeline is not initialized, will initialize it now."
  210. )
  211. self.inintial_visual_predictor(self.config)
  212. for layout_parsing_result in self.layout_parsing_pipeline.predict(
  213. input,
  214. use_doc_orientation_classify=use_doc_orientation_classify,
  215. use_doc_unwarping=use_doc_unwarping,
  216. use_general_ocr=use_general_ocr,
  217. use_seal_recognition=use_seal_recognition,
  218. use_table_recognition=use_table_recognition,
  219. text_det_limit_side_len=text_det_limit_side_len,
  220. text_det_limit_type=text_det_limit_type,
  221. text_det_thresh=text_det_thresh,
  222. text_det_box_thresh=text_det_box_thresh,
  223. text_det_unclip_ratio=text_det_unclip_ratio,
  224. text_rec_score_thresh=text_rec_score_thresh,
  225. seal_det_box_thresh=seal_det_box_thresh,
  226. seal_det_limit_side_len=seal_det_limit_side_len,
  227. seal_det_limit_type=seal_det_limit_type,
  228. seal_det_thresh=seal_det_thresh,
  229. seal_det_unclip_ratio=seal_det_unclip_ratio,
  230. seal_rec_score_thresh=seal_rec_score_thresh,
  231. ):
  232. visual_info = self.decode_visual_result(layout_parsing_result)
  233. visual_predict_res = {
  234. "layout_parsing_result": layout_parsing_result,
  235. "visual_info": visual_info,
  236. }
  237. yield visual_predict_res
  238. def save_visual_info_list(self, visual_info: dict, save_path: str) -> None:
  239. """
  240. Save the visual info list to the specified file path.
  241. Args:
  242. visual_info (dict): The visual info result, which can be a single object or a list of objects.
  243. save_path (str): The file path to save the visual info list.
  244. Returns:
  245. None
  246. """
  247. if not isinstance(visual_info, list):
  248. visual_info_list = [visual_info]
  249. else:
  250. visual_info_list = visual_info
  251. directory = os.path.dirname(save_path)
  252. if not os.path.exists(directory):
  253. os.makedirs(directory)
  254. with custom_open(save_path, "w") as fout:
  255. fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
  256. return
  257. def load_visual_info_list(self, data_path: str) -> List[dict]:
  258. """
  259. Loads visual info list from a JSON file.
  260. Args:
  261. data_path (str): The path to the JSON file containing visual info.
  262. Returns:
  263. list[dict]: A list of dict objects parsed from the JSON file.
  264. """
  265. with custom_open(data_path, "r") as fin:
  266. data = fin.readline()
  267. visual_info_list = json.loads(data)
  268. return visual_info_list
  269. def merge_visual_info_list(
  270. self, visual_info_list: List[dict]
  271. ) -> Tuple[list, list, list]:
  272. """
  273. Merge visual info lists.
  274. Args:
  275. visual_info_list (list[dict]): A list of visual info results.
  276. Returns:
  277. tuple[list, list, list]: A tuple containing four lists, one for normal text dicts,
  278. one for table text lists, one for table HTML lists.
  279. """
  280. all_normal_text_list = []
  281. all_table_text_list = []
  282. all_table_html_list = []
  283. for single_visual_info in visual_info_list:
  284. normal_text_dict = single_visual_info["normal_text_dict"]
  285. for key in normal_text_dict:
  286. normal_text_dict[key] = normal_text_dict[key].replace("\n", "")
  287. table_text_list = single_visual_info["table_text_list"]
  288. table_html_list = single_visual_info["table_html_list"]
  289. all_normal_text_list.append(normal_text_dict)
  290. all_table_text_list.extend(table_text_list)
  291. all_table_html_list.extend(table_html_list)
  292. return (all_normal_text_list, all_table_text_list, all_table_html_list)
  293. def build_vector(
  294. self,
  295. visual_info: dict,
  296. min_characters: int = 3500,
  297. block_size: int = 300,
  298. flag_save_bytes_vector: bool = False,
  299. retriever_config: dict = None,
  300. ) -> dict:
  301. """
  302. Build a vector representation from visual information.
  303. Args:
  304. visual_info (dict): The visual information input, can be a single instance or a list of instances.
  305. min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
  306. block_size (int): The size of each chunk to split the text into.
  307. flag_save_bytes_vector (bool): Whether to save the vector as bytes, defaults to False.
  308. retriever_config (dict): The configuration for the retriever, defaults to None.
  309. Returns:
  310. dict: A dictionary containing the vector info and a flag indicating if the text is too short.
  311. """
  312. if not isinstance(visual_info, list):
  313. visual_info_list = [visual_info]
  314. else:
  315. visual_info_list = visual_info
  316. if retriever_config is not None:
  317. from .. import create_retriever
  318. retriever = create_retriever(retriever_config)
  319. else:
  320. if self.retriever is None:
  321. logging.warning(
  322. "The retriever is not initialized,will initialize it now."
  323. )
  324. self.inintial_retriever_predictor(self.config)
  325. retriever = self.retriever
  326. all_visual_info = self.merge_visual_info_list(visual_info_list)
  327. (
  328. all_normal_text_list,
  329. all_table_text_list,
  330. all_table_html_list,
  331. ) = all_visual_info
  332. vector_info = {}
  333. all_items = []
  334. for i, normal_text_dict in enumerate(all_normal_text_list):
  335. for type, text in normal_text_dict.items():
  336. all_items += [f"{type}:{text}\n"]
  337. for table_html, table_text in zip(all_table_html_list, all_table_text_list):
  338. if len(table_html) > min_characters - self.table_structure_len_max:
  339. all_items += [f"table:{table_text}"]
  340. all_text_str = "".join(all_items)
  341. vector_info["flag_save_bytes_vector"] = False
  342. if len(all_text_str) > min_characters:
  343. vector_info["flag_too_short_text"] = False
  344. vector_info["model_name"] = retriever.model_name
  345. vector_info["block_size"] = block_size
  346. vector_info["vector"] = retriever.generate_vector_database(
  347. all_items, block_size=block_size
  348. )
  349. if flag_save_bytes_vector:
  350. vector_info["vector"] = retriever.encode_vector_store_to_bytes(
  351. vector_info["vector"]
  352. )
  353. vector_info["flag_save_bytes_vector"] = True
  354. else:
  355. vector_info["flag_too_short_text"] = True
  356. vector_info["vector"] = all_items
  357. return vector_info
  358. def save_vector(self, vector_info: dict, save_path: str) -> None:
  359. directory = os.path.dirname(save_path)
  360. if not os.path.exists(directory):
  361. os.makedirs(directory)
  362. if self.retriever is None:
  363. logging.warning("The retriever is not initialized,will initialize it now.")
  364. self.inintial_retriever_predictor(self.config)
  365. vector_info_data = copy.deepcopy(vector_info)
  366. if (
  367. not vector_info["flag_too_short_text"]
  368. and not vector_info["flag_save_bytes_vector"]
  369. ):
  370. vector_info_data["vector"] = self.retriever.encode_vector_store_to_bytes(
  371. vector_info_data["vector"]
  372. )
  373. vector_info_data["flag_save_bytes_vector"] = True
  374. with custom_open(save_path, "w") as fout:
  375. fout.write(json.dumps(vector_info_data, ensure_ascii=False) + "\n")
  376. return
  377. def load_vector(self, data_path: str) -> dict:
  378. vector_info = None
  379. if self.retriever is None:
  380. logging.warning("The retriever is not initialized,will initialize it now.")
  381. self.inintial_retriever_predictor(self.config)
  382. with open(data_path, "r") as fin:
  383. data = fin.readline()
  384. vector_info = json.loads(data)
  385. if (
  386. "flag_too_short_text" not in vector_info
  387. or "flag_save_bytes_vector" not in vector_info
  388. or "vector" not in vector_info
  389. ):
  390. logging.error("Invalid vector info.")
  391. return {"error": "Invalid vector info when load vector!"}
  392. if vector_info["flag_save_bytes_vector"]:
  393. vector_info["vector"] = self.retriever.decode_vector_store_from_bytes(
  394. vector_info["vector"]
  395. )
  396. vector_info["flag_save_bytes_vector"] = False
  397. return vector_info
  398. def format_key(self, key_list: Union[str, List[str]]) -> List[str]:
  399. """
  400. Formats the key list.
  401. Args:
  402. key_list (str|list[str]): A string or a list of strings representing the keys.
  403. Returns:
  404. list[str]: A list of formatted keys.
  405. """
  406. if key_list == "":
  407. return []
  408. if isinstance(key_list, list):
  409. return key_list
  410. if isinstance(key_list, str):
  411. key_list = re.sub(r"[\t\n\r\f\v]", "", key_list)
  412. key_list = key_list.replace(",", ",").split(",")
  413. return key_list
  414. return []
  415. def generate_and_merge_chat_results(
  416. self,
  417. chat_bot: BaseChat,
  418. prompt: str,
  419. key_list: list,
  420. final_results: dict,
  421. failed_results: list,
  422. ) -> None:
  423. """
  424. Generate and merge chat results into the final results dictionary.
  425. Args:
  426. prompt (str): The input prompt for the chat bot.
  427. key_list (list): A list of keys to track which results to merge.
  428. final_results (dict): The dictionary to store the final merged results.
  429. failed_results (list): A list of failed results to avoid merging.
  430. Returns:
  431. None
  432. """
  433. llm_result = chat_bot.generate_chat_results(prompt)
  434. if llm_result is None:
  435. logging.error(
  436. "chat bot error: \n [prompt:]\n %s\n [result:] %s\n"
  437. % (prompt, self.chat_bot.ERROR_MASSAGE)
  438. )
  439. return
  440. llm_result = self.chat_bot.fix_llm_result_format(llm_result)
  441. for key, value in llm_result.items():
  442. if value not in failed_results and key in key_list:
  443. key_list.remove(key)
  444. final_results[key] = value
  445. return
  446. def get_related_normal_text(
  447. self,
  448. retriever_config: dict,
  449. use_vector_retrieval: bool,
  450. vector_info: dict,
  451. key_list: List[str],
  452. all_normal_text_list: list,
  453. min_characters: int,
  454. ) -> str:
  455. """
  456. Retrieve related normal text based on vector retrieval or all normal text list.
  457. Args:
  458. retriever_config (dict): Configuration for the retriever.
  459. use_vector_retrieval (bool): Whether to use vector retrieval.
  460. vector_info (dict): Dictionary containing vector information.
  461. key_list (list[str]): List of keys to generate question keys.
  462. all_normal_text_list (list): List of normal text.
  463. min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
  464. Returns:
  465. str: Related normal text.
  466. """
  467. if use_vector_retrieval and vector_info is not None:
  468. if retriever_config is not None:
  469. from .. import create_retriever
  470. retriever = create_retriever(retriever_config)
  471. else:
  472. if self.retriever is None:
  473. logging.warning(
  474. "The retriever is not initialized,will initialize it now."
  475. )
  476. self.inintial_retriever_predictor(self.config)
  477. retriever = self.retriever
  478. question_key_list = [f"{key}" for key in key_list]
  479. vector = vector_info["vector"]
  480. if not vector_info["flag_too_short_text"]:
  481. assert (
  482. vector_info["model_name"] == retriever.model_name
  483. ), f"The vector model name ({vector_info['model_name']}) does not match the retriever model name ({retriever.model_name}). Please check your retriever config."
  484. if vector_info["flag_save_bytes_vector"]:
  485. vector = retriever.decode_vector_store_from_bytes(vector)
  486. related_text = retriever.similarity_retrieval(
  487. question_key_list, vector, topk=50, min_characters=min_characters
  488. )
  489. else:
  490. if len(vector) > 0:
  491. related_text = "".join(vector)
  492. else:
  493. related_text = ""
  494. else:
  495. all_items = []
  496. for i, normal_text_dict in enumerate(all_normal_text_list):
  497. for type, text in normal_text_dict.items():
  498. all_items += [f"{type}:{text}\n"]
  499. related_text = "".join(all_items)
  500. if len(related_text) > min_characters:
  501. logging.warning(
  502. "The input text content is too long, the large language model may truncate it."
  503. )
  504. return related_text
  505. def chat(
  506. self,
  507. key_list: Union[str, List[str]],
  508. visual_info: List[dict],
  509. use_vector_retrieval: bool = True,
  510. vector_info: dict = None,
  511. min_characters: int = 3500,
  512. text_task_description: str = None,
  513. text_output_format: str = None,
  514. text_rules_str: str = None,
  515. text_few_shot_demo_text_content: str = None,
  516. text_few_shot_demo_key_value_list: str = None,
  517. table_task_description: str = None,
  518. table_output_format: str = None,
  519. table_rules_str: str = None,
  520. table_few_shot_demo_text_content: str = None,
  521. table_few_shot_demo_key_value_list: str = None,
  522. chat_bot_config: dict = None,
  523. retriever_config: dict = None,
  524. ) -> dict:
  525. """
  526. Generates chat results based on the provided key list and visual information.
  527. Args:
  528. key_list (Union[str, list[str]]): A single key or a list of keys to extract information.
  529. visual_info (dict): The visual information result.
  530. use_vector_retrieval (bool): Whether to use vector retrieval.
  531. vector_info (dict): The vector information for retrieval.
  532. min_characters (int): The minimum number of characters required.
  533. text_task_description (str): The description of the text task.
  534. text_output_format (str): The output format for text results.
  535. text_rules_str (str): The rules for generating text results.
  536. text_few_shot_demo_text_content (str): The text content for few-shot demos.
  537. text_few_shot_demo_key_value_list (str): The key-value list for few-shot demos.
  538. table_task_description (str): The description of the table task.
  539. table_output_format (str): The output format for table results.
  540. table_rules_str (str): The rules for generating table results.
  541. table_few_shot_demo_text_content (str): The text content for table few-shot demos.
  542. table_few_shot_demo_key_value_list (str): The key-value list for table few-shot demos.
  543. chat_bot_config(dict): The parameters for LLM chatbot, including api_type, api_key... refer to config file for more details.
  544. retriever_config (dict): The parameters for LLM retriever, including api_type, api_key... refer to config file for more details.
  545. Returns:
  546. dict: A dictionary containing the chat results.
  547. """
  548. key_list = self.format_key(key_list)
  549. key_list_ori = key_list.copy()
  550. if len(key_list) == 0:
  551. return {"chat_res": "Error:输入的key_list无效!"}
  552. if not isinstance(visual_info, list):
  553. visual_info_list = [visual_info]
  554. else:
  555. visual_info_list = visual_info
  556. if self.chat_bot is None:
  557. logging.warning(
  558. "The LLM chat bot is not initialized,will initialize it now."
  559. )
  560. self.inintial_chat_predictor(self.config)
  561. if chat_bot_config is not None:
  562. from .. import create_chat_bot
  563. chat_bot = create_chat_bot(chat_bot_config)
  564. else:
  565. chat_bot = self.chat_bot
  566. all_visual_info = self.merge_visual_info_list(visual_info_list)
  567. (
  568. all_normal_text_list,
  569. all_table_text_list,
  570. all_table_html_list,
  571. ) = all_visual_info
  572. final_results = {}
  573. failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
  574. if len(key_list) > 0:
  575. for table_html, table_text in zip(all_table_html_list, all_table_text_list):
  576. if len(table_html) <= min_characters - self.table_structure_len_max:
  577. for table_info in [table_html]:
  578. if len(key_list) > 0:
  579. prompt = self.table_pe.generate_prompt(
  580. table_info,
  581. key_list,
  582. task_description=table_task_description,
  583. output_format=table_output_format,
  584. rules_str=table_rules_str,
  585. few_shot_demo_text_content=table_few_shot_demo_text_content,
  586. few_shot_demo_key_value_list=table_few_shot_demo_key_value_list,
  587. )
  588. self.generate_and_merge_chat_results(
  589. chat_bot,
  590. prompt,
  591. key_list,
  592. final_results,
  593. failed_results,
  594. )
  595. if len(key_list) > 0:
  596. related_text = self.get_related_normal_text(
  597. retriever_config,
  598. use_vector_retrieval,
  599. vector_info,
  600. key_list,
  601. all_normal_text_list,
  602. min_characters,
  603. )
  604. if len(related_text) > 0:
  605. prompt = self.text_pe.generate_prompt(
  606. related_text,
  607. key_list,
  608. task_description=text_task_description,
  609. output_format=text_output_format,
  610. rules_str=text_rules_str,
  611. few_shot_demo_text_content=text_few_shot_demo_text_content,
  612. few_shot_demo_key_value_list=text_few_shot_demo_key_value_list,
  613. )
  614. self.generate_and_merge_chat_results(
  615. chat_bot, prompt, key_list, final_results, failed_results
  616. )
  617. return {"chat_res": final_results}
  618. def predict(self, *args, **kwargs) -> None:
  619. logging.error(
  620. "PP-ChatOCRv3-doc Pipeline do not support to call `predict()` directly! Please invoke `visual_predict`, `build_vector`, `chat` sequentially to obtain the result."
  621. )
  622. return