| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544 |
- # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from ..base import BasePipeline
- from typing import Any, Dict, Optional
- # import numpy as np
- # import cv2
- from .result import VisualInfoResult
- import re
- ########## [TODO]后续需要更新路径
- from ...components.transforms import ReadImage
- import json
- from ....utils import logging
- from ...utils.pp_option import PaddlePredictorOption
- from ..layout_parsing.result import LayoutParsingResult
- import numpy as np
- class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
- """PP-ChatOCRv3-doc Pipeline"""
- entities = "PP-ChatOCRv3-doc"
- def __init__(
- self,
- config: Dict,
- device: str = None,
- pp_option: PaddlePredictorOption = None,
- use_hpip: bool = False,
- hpi_params: Optional[Dict[str, Any]] = None,
- ) -> None:
- """Initializes the pp-chatocrv3-doc pipeline.
- Args:
- config (Dict): Configuration dictionary containing various settings.
- device (str, optional): Device to run the predictions on. Defaults to None.
- pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
- use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
- hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
- """
- super().__init__(
- device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
- )
- self.inintial_predictor(config)
- self.img_reader = ReadImage(format="BGR")
- self.table_structure_len_max = 500
- def inintial_predictor(self, config: dict) -> None:
- """
- Initializes the predictor with the given configuration.
- Args:
- config (dict): The configuration dictionary containing the necessary
- parameters for initializing the predictor.
- Returns:
- None
- """
- layout_parsing_config = config["SubPipelines"]["LayoutParser"]
- self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
- from .. import create_chat_bot
- chat_bot_config = config["SubModules"]["LLM_Chat"]
- self.chat_bot = create_chat_bot(chat_bot_config)
- from .. import create_retriever
- retriever_config = config["SubModules"]["LLM_Retriever"]
- self.retriever = create_retriever(retriever_config)
- from .. import create_prompt_engeering
- text_pe_config = config["SubModules"]["PromptEngneering"]["KIE_CommonText"]
- self.text_pe = create_prompt_engeering(text_pe_config)
- table_pe_config = config["SubModules"]["PromptEngneering"]["KIE_Table"]
- self.table_pe = create_prompt_engeering(table_pe_config)
- return
- def decode_visual_result(
- self, layout_parsing_result: LayoutParsingResult
- ) -> VisualInfoResult:
- """
- Decodes the visual result from the layout parsing result.
- Args:
- layout_parsing_result (LayoutParsingResult): The result of layout parsing.
- Returns:
- VisualInfoResult: The decoded visual information.
- """
- text_paragraphs_ocr_res = layout_parsing_result["text_paragraphs_ocr_res"]
- seal_res_list = layout_parsing_result["seal_res_list"]
- normal_text_dict = {}
- for seal_res in seal_res_list:
- for text in seal_res["rec_text"]:
- layout_type = "印章"
- if layout_type not in normal_text_dict:
- normal_text_dict[layout_type] = f"{text}"
- else:
- normal_text_dict[layout_type] += f"\n {text}"
- for text in text_paragraphs_ocr_res["rec_text"]:
- layout_type = "words in text block"
- if layout_type not in normal_text_dict:
- normal_text_dict[layout_type] = text
- else:
- normal_text_dict[layout_type] += f"\n {text}"
- table_res_list = layout_parsing_result["table_res_list"]
- table_text_list = []
- table_html_list = []
- for table_res in table_res_list:
- table_html_list.append(table_res["pred_html"])
- single_table_text = " ".join(table_res["table_ocr_pred"]["rec_text"])
- table_text_list.append(single_table_text)
- visual_info = {}
- visual_info["normal_text_dict"] = normal_text_dict
- visual_info["table_text_list"] = table_text_list
- visual_info["table_html_list"] = table_html_list
- return VisualInfoResult(visual_info)
- # Function to perform visual prediction on input images
- def visual_predict(
- self,
- input: str | list[str] | np.ndarray | list[np.ndarray],
- use_doc_orientation_classify: bool = False, # Whether to use document orientation classification
- use_doc_unwarping: bool = False, # Whether to use document unwarping
- use_common_ocr: bool = True, # Whether to use common OCR
- use_seal_recognition: bool = True, # Whether to use seal recognition
- use_table_recognition: bool = True, # Whether to use table recognition
- **kwargs,
- ) -> dict:
- """
- This function takes an input image or a list of images and performs various visual
- prediction tasks such as document orientation classification, document unwarping,
- common OCR, seal recognition, and table recognition based on the provided flags.
- Args:
- input (str | list[str] | np.ndarray | list[np.ndarray]): Input image path, list of image paths,
- numpy array of an image, or list of numpy arrays.
- use_doc_orientation_classify (bool): Flag to use document orientation classification.
- use_doc_unwarping (bool): Flag to use document unwarping.
- use_common_ocr (bool): Flag to use common OCR.
- use_seal_recognition (bool): Flag to use seal recognition.
- use_table_recognition (bool): Flag to use table recognition.
- **kwargs: Additional keyword arguments.
- Returns:
- dict: A dictionary containing the layout parsing result and visual information.
- """
- if not isinstance(input, list):
- input_list = [input]
- else:
- input_list = input
- img_id = 1
- for input in input_list:
- if isinstance(input, str):
- image_array = next(self.img_reader(input))[0]["img"]
- else:
- image_array = input
- assert len(image_array.shape) == 3
- layout_parsing_result = next(
- self.layout_parsing_pipeline.predict(
- image_array,
- use_doc_orientation_classify=use_doc_orientation_classify,
- use_doc_unwarping=use_doc_unwarping,
- use_common_ocr=use_common_ocr,
- use_seal_recognition=use_seal_recognition,
- use_table_recognition=use_table_recognition,
- )
- )
- visual_info = self.decode_visual_result(layout_parsing_result)
- visual_predict_res = {
- "layout_parsing_result": layout_parsing_result,
- "visual_info": visual_info,
- }
- yield visual_predict_res
- def save_visual_info_list(
- self, visual_info: VisualInfoResult, save_path: str
- ) -> None:
- """
- Save the visual info list to the specified file path.
- Args:
- visual_info (VisualInfoResult): The visual info result, which can be a single object or a list of objects.
- save_path (str): The file path to save the visual info list.
- Returns:
- None
- """
- if not isinstance(visual_info, list):
- visual_info_list = [visual_info]
- else:
- visual_info_list = visual_info
- with open(save_path, "w") as fout:
- fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
- return
- def load_visual_info_list(self, data_path: str) -> list[VisualInfoResult]:
- """
- Loads visual info list from a JSON file.
- Args:
- data_path (str): The path to the JSON file containing visual info.
- Returns:
- list[VisualInfoResult]: A list of VisualInfoResult objects parsed from the JSON file.
- """
- with open(data_path, "r") as fin:
- data = fin.readline()
- visual_info_list = json.loads(data)
- return visual_info_list
- def merge_visual_info_list(
- self, visual_info_list: list[VisualInfoResult]
- ) -> tuple[list, list, list]:
- """
- Merge visual info lists.
- Args:
- visual_info_list (list[VisualInfoResult]): A list of visual info results.
- Returns:
- tuple[list, list, list]: A tuple containing three lists, one for normal text dicts,
- one for table text lists, and one for table HTML lists.
- """
- all_normal_text_list = []
- all_table_text_list = []
- all_table_html_list = []
- for single_visual_info in visual_info_list:
- normal_text_dict = single_visual_info["normal_text_dict"]
- table_text_list = single_visual_info["table_text_list"]
- table_html_list = single_visual_info["table_html_list"]
- all_normal_text_list.append(normal_text_dict)
- all_table_text_list.extend(table_text_list)
- all_table_html_list.extend(table_html_list)
- return all_normal_text_list, all_table_text_list, all_table_html_list
- def build_vector(
- self,
- visual_info: VisualInfoResult,
- min_characters: int = 3500,
- llm_request_interval: float = 1.0,
- ) -> dict:
- """
- Build a vector representation from visual information.
- Args:
- visual_info (VisualInfoResult): The visual information input, can be a single instance or a list of instances.
- min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
- llm_request_interval (float): The interval between LLM requests, defaults to 1.0.
- Returns:
- dict: A dictionary containing the vector info and a flag indicating if the text is too short.
- """
- if not isinstance(visual_info, list):
- visual_info_list = [visual_info]
- else:
- visual_info_list = visual_info
- all_visual_info = self.merge_visual_info_list(visual_info_list)
- all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
- vector_info = {}
- all_items = []
- for i, normal_text_dict in enumerate(all_normal_text_list):
- for type, text in normal_text_dict.items():
- all_items += [f"{type}:{text}\n"]
- for table_html, table_text in zip(all_table_html_list, all_table_text_list):
- if len(table_html) > min_characters - self.table_structure_len_max:
- all_items += [f"table:{table_text}\n"]
- all_text_str = "".join(all_items)
- if len(all_text_str) > min_characters:
- vector_info["flag_too_short_text"] = False
- vector_info["vector"] = self.retriever.generate_vector_database(all_items)
- else:
- vector_info["flag_too_short_text"] = True
- vector_info["vector"] = all_items
- return vector_info
- def format_key(self, key_list: str | list[str]) -> list[str]:
- """
- Formats the key list.
- Args:
- key_list (str|list[str]): A string or a list of strings representing the keys.
- Returns:
- list[str]: A list of formatted keys.
- """
- if key_list == "":
- return []
- if isinstance(key_list, list):
- return key_list
- if isinstance(key_list, str):
- key_list = re.sub(r"[\t\n\r\f\v]", "", key_list)
- key_list = key_list.replace(",", ",").split(",")
- return key_list
- return []
- def fix_llm_result_format(self, llm_result: str) -> dict:
- """
- Fix the format of the LLM result.
- Args:
- llm_result (str): The result from the LLM (Large Language Model).
- Returns:
- dict: A fixed format dictionary from the LLM result.
- """
- if not llm_result:
- return {}
- if "json" in llm_result or "```" in llm_result:
- llm_result = (
- llm_result.replace("```", "").replace("json", "").replace("/n", "")
- )
- llm_result = llm_result.replace("[", "").replace("]", "")
- try:
- llm_result = json.loads(llm_result)
- llm_result_final = {}
- for key in llm_result:
- value = llm_result[key]
- if isinstance(value, list):
- if len(value) > 0:
- llm_result_final[key] = value[0]
- else:
- llm_result_final[key] = value
- return llm_result_final
- except:
- results = (
- llm_result.replace("\n", "")
- .replace(" ", "")
- .replace("{", "")
- .replace("}", "")
- )
- if not results.endswith('"'):
- results = results + '"'
- pattern = r'"(.*?)": "([^"]*)"'
- matches = re.findall(pattern, str(results))
- if len(matches) > 0:
- llm_result = {k: v for k, v in matches}
- return llm_result
- else:
- return {}
- def generate_and_merge_chat_results(
- self, prompt: str, key_list: list, final_results: dict, failed_results: dict
- ) -> None:
- """
- Generate and merge chat results into the final results dictionary.
- Args:
- prompt (str): The input prompt for the chat bot.
- key_list (list): A list of keys to track which results to merge.
- final_results (dict): The dictionary to store the final merged results.
- failed_results (dict): A dictionary of failed results to avoid merging.
- Returns:
- None
- """
- llm_result = self.chat_bot.generate_chat_results(prompt)
- if llm_result is None:
- logging.warning(
- "chat bot error: \n [prompt:]\n %s\n [result:] %s\n"
- % (prompt, self.chat_bot.ERROR_MASSAGE)
- )
- return
- llm_result = self.fix_llm_result_format(llm_result)
- for key, value in llm_result.items():
- if value not in failed_results and key in key_list:
- key_list.remove(key)
- final_results[key] = value
- return
- def chat(
- self,
- key_list: str | list[str],
- visual_info: VisualInfoResult,
- use_vector_retrieval: bool = True,
- vector_info: dict = None,
- min_characters: int = 3500,
- text_task_description: str = None,
- text_output_format: str = None,
- text_rules_str: str = None,
- text_few_shot_demo_text_content: str = None,
- text_few_shot_demo_key_value_list: str = None,
- table_task_description: str = None,
- table_output_format: str = None,
- table_rules_str: str = None,
- table_few_shot_demo_text_content: str = None,
- table_few_shot_demo_key_value_list: str = None,
- ) -> dict:
- """
- Generates chat results based on the provided key list and visual information.
- Args:
- key_list (str | list[str]): A single key or a list of keys to extract information.
- visual_info (VisualInfoResult): The visual information result.
- use_vector_retrieval (bool): Whether to use vector retrieval.
- vector_info (dict): The vector information for retrieval.
- min_characters (int): The minimum number of characters required.
- text_task_description (str): The description of the text task.
- text_output_format (str): The output format for text results.
- text_rules_str (str): The rules for generating text results.
- text_few_shot_demo_text_content (str): The text content for few-shot demos.
- text_few_shot_demo_key_value_list (str): The key-value list for few-shot demos.
- table_task_description (str): The description of the table task.
- table_output_format (str): The output format for table results.
- table_rules_str (str): The rules for generating table results.
- table_few_shot_demo_text_content (str): The text content for table few-shot demos.
- table_few_shot_demo_key_value_list (str): The key-value list for table few-shot demos.
- Returns:
- dict: A dictionary containing the chat results.
- """
- key_list = self.format_key(key_list)
- if len(key_list) == 0:
- return {"error": "输入的key_list无效!"}
- if not isinstance(visual_info, list):
- visual_info_list = [visual_info]
- else:
- visual_info_list = visual_info
- all_visual_info = self.merge_visual_info_list(visual_info_list)
- all_normal_text_list, all_table_text_list, all_table_html_list = all_visual_info
- final_results = {}
- failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
- for table_html, table_text in zip(all_table_html_list, all_table_text_list):
- if len(table_html) <= min_characters - self.table_structure_len_max:
- for table_info in [table_html, table_text]:
- if len(key_list) > 0:
- prompt = self.table_pe.generate_prompt(
- table_info,
- key_list,
- task_description=table_task_description,
- output_format=table_output_format,
- rules_str=table_rules_str,
- few_shot_demo_text_content=table_few_shot_demo_text_content,
- few_shot_demo_key_value_list=table_few_shot_demo_key_value_list,
- )
- self.generate_and_merge_chat_results(
- prompt, key_list, final_results, failed_results
- )
- if len(key_list) > 0:
- if use_vector_retrieval and vector_info is not None:
- question_key_list = [f"抽取关键信息:{key}" for key in key_list]
- vector = vector_info["vector"]
- if not vector_info["flag_too_short_text"]:
- related_text = self.retriever.similarity_retrieval(
- question_key_list, vector
- )
- # print(question_key_list, related_text)
- else:
- if len(vector) > 0:
- related_text = "".join(vector)
- else:
- related_text = ""
- else:
- all_items = []
- for i, normal_text_dict in enumerate(all_normal_text_list):
- for type, text in normal_text_dict.items():
- all_items += [f"{type}:{text}\n"]
- related_text = "".join(all_items)
- if len(related_text) > min_characters:
- logging.warning(
- "The input text content is too long, the large language model may truncate it."
- )
- if len(related_text) > 0:
- prompt = self.text_pe.generate_prompt(
- related_text,
- key_list,
- task_description=text_task_description,
- output_format=text_output_format,
- rules_str=text_rules_str,
- few_shot_demo_text_content=text_few_shot_demo_text_content,
- few_shot_demo_key_value_list=text_few_shot_demo_key_value_list,
- )
- # print(prompt)
- self.generate_and_merge_chat_results(
- prompt, key_list, final_results, failed_results
- )
- return {"chat_res": final_results}
- def predict(self, *args, **kwargs) -> None:
- logging.error(
- "PP-ChatOCRv3-doc Pipeline do not support to call `predict()` directly! Please invoke `visual_predict`, `build_vector`, `chat` sequentially to obtain the result."
- )
- return
|