| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import base64
- import copy
- import json
- import os
- import re
- from typing import Any, Dict, List, Optional, Tuple, Union
- import numpy as np
- from ....utils import logging
- from ....utils.deps import (
- function_requires_deps,
- is_dep_available,
- pipeline_requires_extra,
- )
- from ....utils.file_interface import custom_open
- from ...common.batch_sampler import ImageBatchSampler
- from ...common.reader import ReadImage
- from ...utils.hpi import HPIConfig
- from ...utils.pp_option import PaddlePredictorOption
- from ..components.chat_server import BaseChat
- from ..layout_parsing.result import LayoutParsingResult
- from .pipeline_base import PP_ChatOCR_Pipeline
- if is_dep_available("opencv-contrib-python"):
- import cv2
- @pipeline_requires_extra("ie")
- class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
- """PP-ChatOCRv4 Pipeline"""
- entities = ["PP-ChatOCRv4-doc"]
- def __init__(
- self,
- config: Dict,
- device: str = None,
- pp_option: PaddlePredictorOption = None,
- use_hpip: bool = False,
- hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
- initial_predictor: bool = True,
- ) -> None:
- """Initializes the pp-chatocrv3-doc pipeline.
- Args:
- config (Dict): Configuration dictionary containing various settings.
- device (str, optional): Device to run the predictions on. Defaults to None.
- pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
- use_hpip (bool, optional): Whether to use the high-performance
- inference plugin (HPIP) by default. Defaults to False.
- hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
- The default high-performance inference configuration dictionary.
- Defaults to None.
- initial_predictor (bool, optional): Whether to initialize the predictor. Defaults to True.
- """
- super().__init__(
- device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_config=hpi_config
- )
- self.pipeline_name = config["pipeline_name"]
- self.config = config
- self.use_layout_parser = config.get("use_layout_parser", True)
- self.use_mllm_predict = config.get("use_mllm_predict", True)
- self.layout_parsing_pipeline = None
- self.chat_bot = None
- self.retriever = None
- self.mllm_chat_bot = None
- if initial_predictor:
- self.inintial_visual_predictor(config)
- self.inintial_chat_predictor(config)
- self.inintial_retriever_predictor(config)
- self.inintial_mllm_predictor(config)
- self.batch_sampler = ImageBatchSampler(batch_size=1)
- self.img_reader = ReadImage(format="BGR")
- self.table_structure_len_max = 500
- def inintial_visual_predictor(self, config: dict) -> None:
- """
- Initializes the visual predictor with the given configuration.
- Args:
- config (dict): The configuration dictionary containing the necessary
- parameters for initializing the predictor.
- Returns:
- None
- """
- self.use_layout_parser = config.get("use_layout_parser", True)
- if self.use_layout_parser:
- layout_parsing_config = config.get("SubPipelines", {}).get(
- "LayoutParser",
- {"pipeline_config_error": "config error for layout_parsing_pipeline!"},
- )
- self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
- return
- def inintial_retriever_predictor(self, config: dict) -> None:
- """
- Initializes the retriever predictor with the given configuration.
- Args:
- config (dict): The configuration dictionary containing the necessary
- parameters for initializing the predictor.
- Returns:
- None
- """
- from .. import create_retriever
- retriever_config = config.get("SubModules", {}).get(
- "LLM_Retriever",
- {"retriever_config_error": "config error for llm retriever!"},
- )
- self.retriever = create_retriever(retriever_config)
- def inintial_chat_predictor(self, config: dict) -> None:
- """
- Initializes the chat predictor with the given configuration.
- Args:
- config (dict): The configuration dictionary containing the necessary
- parameters for initializing the predictor.
- Returns:
- None
- """
- from .. import create_chat_bot
- chat_bot_config = config.get("SubModules", {}).get(
- "LLM_Chat",
- {"chat_bot_config_error": "config error for llm chat bot!"},
- )
- self.chat_bot = create_chat_bot(chat_bot_config)
- from .. import create_prompt_engineering
- text_pe_config = (
- config.get("SubModules", {})
- .get("PromptEngneering", {})
- .get(
- "KIE_CommonText",
- {"pe_config_error": "config error for text_pe!"},
- )
- )
- self.text_pe = create_prompt_engineering(text_pe_config)
- table_pe_config = (
- config.get("SubModules", {})
- .get("PromptEngneering", {})
- .get(
- "KIE_Table",
- {"pe_config_error": "config error for table_pe!"},
- )
- )
- self.table_pe = create_prompt_engineering(table_pe_config)
- return
- def inintial_mllm_predictor(self, config: dict) -> None:
- """
- Initializes the predictor with the given configuration.
- Args:
- config (dict): The configuration dictionary containing the necessary
- parameters for initializing the predictor.
- Returns:
- None
- """
- from .. import create_chat_bot, create_prompt_engineering
- self.use_mllm_predict = config.get("use_mllm_predict", True)
- if self.use_mllm_predict:
- mllm_chat_bot_config = config.get("SubModules", {}).get(
- "MLLM_Chat",
- {"mllm_chat_bot_config": "config error for mllm chat bot!"},
- )
- self.mllm_chat_bot = create_chat_bot(mllm_chat_bot_config)
- ensemble_pe_config = (
- config.get("SubModules", {})
- .get("PromptEngneering", {})
- .get(
- "Ensemble",
- {"pe_config_error": "config error for ensemble_pe!"},
- )
- )
- self.ensemble_pe = create_prompt_engineering(ensemble_pe_config)
- return
- def decode_visual_result(self, layout_parsing_result: LayoutParsingResult) -> dict:
- """
- Decodes the visual result from the layout parsing result.
- Args:
- layout_parsing_result (LayoutParsingResult): The result of layout parsing.
- Returns:
- dict: The decoded visual information.
- """
- normal_text_dict = {}
- parsing_res_list = layout_parsing_result["parsing_res_list"]
- for pno in range(len(parsing_res_list)):
- label = parsing_res_list[pno]["block_label"]
- content = parsing_res_list[pno]["block_content"]
- if label in ["table", "formula"]:
- continue
- key = f"words in {label}"
- if key not in normal_text_dict:
- normal_text_dict[key] = content
- else:
- normal_text_dict[key] += f"\n {content}"
- table_res_list = layout_parsing_result["table_res_list"]
- table_text_list = []
- table_html_list = []
- table_nei_text_list = []
- for table_res in table_res_list:
- table_html_list.append(table_res["pred_html"])
- single_table_text = " ".join(table_res["table_ocr_pred"]["rec_texts"])
- table_text_list.append(single_table_text)
- table_nei_text_list.append(table_res["neighbor_texts"])
- visual_info = {}
- visual_info["normal_text_dict"] = normal_text_dict
- visual_info["table_text_list"] = table_text_list
- visual_info["table_html_list"] = table_html_list
- visual_info["table_nei_text_list"] = table_nei_text_list
- return visual_info
- # Function to perform visual prediction on input images
- def visual_predict(
- self,
- input: Union[str, List[str], np.ndarray, List[np.ndarray]],
- use_doc_orientation_classify: Optional[bool] = None,
- use_doc_unwarping: Optional[bool] = None,
- use_textline_orientation: Optional[bool] = None,
- use_seal_recognition: Optional[bool] = None,
- use_table_recognition: Optional[bool] = None,
- layout_threshold: Optional[Union[float, dict]] = None,
- layout_nms: Optional[bool] = None,
- layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None,
- layout_merge_bboxes_mode: Optional[str] = None,
- text_det_limit_side_len: Optional[int] = None,
- text_det_limit_type: Optional[str] = None,
- text_det_thresh: Optional[float] = None,
- text_det_box_thresh: Optional[float] = None,
- text_det_unclip_ratio: Optional[float] = None,
- text_rec_score_thresh: Optional[float] = None,
- seal_det_limit_side_len: Optional[int] = None,
- seal_det_limit_type: Optional[str] = None,
- seal_det_thresh: Optional[float] = None,
- seal_det_box_thresh: Optional[float] = None,
- seal_det_unclip_ratio: Optional[float] = None,
- seal_rec_score_thresh: Optional[float] = None,
- **kwargs,
- ) -> dict:
- """
- This function takes an input image or a list of images and performs various visual
- prediction tasks such as document orientation classification, document unwarping,
- general OCR, seal recognition, and table recognition based on the provided flags.
- Args:
- input (Union[str, list[str], np.ndarray, list[np.ndarray]]): Input image path, list of image paths,
- numpy array of an image, or list of numpy arrays.
- use_doc_orientation_classify (bool): Flag to use document orientation classification.
- use_doc_unwarping (bool): Flag to use document unwarping.
- use_textline_orientation (Optional[bool]): Whether to use textline orientation prediction.
- use_seal_recognition (bool): Flag to use seal recognition.
- use_table_recognition (bool): Flag to use table recognition.
- layout_threshold (Optional[float]): The threshold value to filter out low-confidence predictions. Default is None.
- layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
- layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
- Defaults to None.
- If it's a single number, then both width and height are used.
- If it's a tuple of two numbers, then they are used separately for width and height respectively.
- If it's None, then no unclipping will be performed.
- layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
- text_det_limit_side_len (Optional[int]): Maximum side length for text detection.
- text_det_limit_type (Optional[str]): Type of limit to apply for text detection.
- text_det_thresh (Optional[float]): Threshold for text detection.
- text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
- text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
- text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
- seal_det_limit_side_len (Optional[int]): Maximum side length for seal detection.
- seal_det_limit_type (Optional[str]): Type of limit to apply for seal detection.
- seal_det_thresh (Optional[float]): Threshold for seal detection.
- seal_det_box_thresh (Optional[float]): Threshold for seal detection boxes.
- seal_det_unclip_ratio (Optional[float]): Ratio for unclipping seal detection boxes.
- seal_rec_score_thresh (Optional[float]): Score threshold for seal recognition.
- **kwargs: Additional keyword arguments.
- Returns:
- dict: A dictionary containing the layout parsing result and visual information.
- """
- if self.use_layout_parser == False:
- logging.error("The models for layout parser are not initialized.")
- yield {"error": "The models for layout parser are not initialized."}
- if self.layout_parsing_pipeline is None:
- logging.warning(
- "The layout parsing pipeline is not initialized, will initialize it now."
- )
- self.inintial_visual_predictor(self.config)
- for layout_parsing_result in self.layout_parsing_pipeline.predict(
- input,
- use_doc_orientation_classify=use_doc_orientation_classify,
- use_doc_unwarping=use_doc_unwarping,
- use_textline_orientation=use_textline_orientation,
- use_seal_recognition=use_seal_recognition,
- use_table_recognition=use_table_recognition,
- layout_threshold=layout_threshold,
- layout_nms=layout_nms,
- layout_unclip_ratio=layout_unclip_ratio,
- layout_merge_bboxes_mode=layout_merge_bboxes_mode,
- text_det_limit_side_len=text_det_limit_side_len,
- text_det_limit_type=text_det_limit_type,
- text_det_thresh=text_det_thresh,
- text_det_box_thresh=text_det_box_thresh,
- text_det_unclip_ratio=text_det_unclip_ratio,
- text_rec_score_thresh=text_rec_score_thresh,
- seal_det_box_thresh=seal_det_box_thresh,
- seal_det_limit_side_len=seal_det_limit_side_len,
- seal_det_limit_type=seal_det_limit_type,
- seal_det_thresh=seal_det_thresh,
- seal_det_unclip_ratio=seal_det_unclip_ratio,
- seal_rec_score_thresh=seal_rec_score_thresh,
- ):
- visual_info = self.decode_visual_result(layout_parsing_result)
- visual_predict_res = {
- "layout_parsing_result": layout_parsing_result,
- "visual_info": visual_info,
- }
- yield visual_predict_res
- def save_visual_info_list(self, visual_info: dict, save_path: str) -> None:
- """
- Save the visual info list to the specified file path.
- Args:
- visual_info (dict): The visual info result, which can be a single object or a list of objects.
- save_path (str): The file path to save the visual info list.
- Returns:
- None
- """
- if not isinstance(visual_info, list):
- visual_info_list = [visual_info]
- else:
- visual_info_list = visual_info
- with open(save_path, "w") as fout:
- fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
- return
- def load_visual_info_list(self, data_path: str) -> List[dict]:
- """
- Loads visual info list from a JSON file.
- Args:
- data_path (str): The path to the JSON file containing visual info.
- Returns:
- list[dict]: A list of dict objects parsed from the JSON file.
- """
- with open(data_path, "r") as fin:
- data = fin.readline()
- visual_info_list = json.loads(data)
- return visual_info_list
- def merge_visual_info_list(
- self, visual_info_list: List[dict]
- ) -> Tuple[list, list, list, list]:
- """
- Merge visual info lists.
- Args:
- visual_info_list (list[dict]): A list of visual info results.
- Returns:
- tuple[list, list, list, list]: A tuple containing four lists, one for normal text dicts,
- one for table text lists, one for table HTML lists.
- one for table neighbor texts.
- """
- all_normal_text_list = []
- all_table_text_list = []
- all_table_html_list = []
- all_table_nei_text_list = []
- for single_visual_info in visual_info_list:
- normal_text_dict = single_visual_info["normal_text_dict"]
- for key in normal_text_dict:
- normal_text_dict[key] = normal_text_dict[key].replace("\n", "")
- table_text_list = single_visual_info["table_text_list"]
- table_html_list = single_visual_info["table_html_list"]
- table_nei_text_list = single_visual_info["table_nei_text_list"]
- all_normal_text_list.append(normal_text_dict)
- all_table_text_list.extend(table_text_list)
- all_table_html_list.extend(table_html_list)
- all_table_nei_text_list.extend(table_nei_text_list)
- return (
- all_normal_text_list,
- all_table_text_list,
- all_table_html_list,
- all_table_nei_text_list,
- )
- def build_vector(
- self,
- visual_info: dict,
- min_characters: int = 3500,
- block_size: int = 300,
- flag_save_bytes_vector: bool = False,
- retriever_config: dict = None,
- ) -> dict:
- """
- Build a vector representation from visual information.
- Args:
- visual_info (dict): The visual information input, can be a single instance or a list of instances.
- min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
- block_size (int): The size of each chunk to split the text into.
- flag_save_bytes_vector (bool): Whether to save the vector as bytes, defaults to False.
- retriever_config (dict): The configuration for the retriever, defaults to None.
- Returns:
- dict: A dictionary containing the vector info and a flag indicating if the text is too short.
- """
- if not isinstance(visual_info, list):
- visual_info_list = [visual_info]
- else:
- visual_info_list = visual_info
- if retriever_config is not None:
- from .. import create_retriever
- retriever = create_retriever(retriever_config)
- else:
- if self.retriever is None:
- logging.warning(
- "The retriever is not initialized,will initialize it now."
- )
- self.inintial_retriever_predictor(self.config)
- retriever = self.retriever
- all_visual_info = self.merge_visual_info_list(visual_info_list)
- (
- all_normal_text_list,
- all_table_text_list,
- all_table_html_list,
- all_table_nei_text_list,
- ) = all_visual_info
- vector_info = {}
- all_items = []
- for i, normal_text_dict in enumerate(all_normal_text_list):
- for type, text in normal_text_dict.items():
- all_items += [f"{type}:{text}\n"]
- for table_html, table_text, table_nei_text in zip(
- all_table_html_list, all_table_text_list, all_table_nei_text_list
- ):
- if len(table_html) > min_characters - self.table_structure_len_max:
- all_items += [f"table:{table_text}\t{table_nei_text}"]
- all_text_str = "".join(all_items)
- vector_info["flag_save_bytes_vector"] = False
- if len(all_text_str) > min_characters:
- vector_info["flag_too_short_text"] = False
- vector_info["model_name"] = retriever.model_name
- vector_info["block_size"] = block_size
- vector_info["vector"] = retriever.generate_vector_database(
- all_items, block_size=block_size
- )
- if flag_save_bytes_vector:
- vector_info["vector"] = retriever.encode_vector_store_to_bytes(
- vector_info["vector"]
- )
- vector_info["flag_save_bytes_vector"] = True
- else:
- vector_info["flag_too_short_text"] = True
- vector_info["vector"] = all_items
- return vector_info
- def save_vector(
- self, vector_info: dict, save_path: str, retriever_config: dict = None
- ) -> None:
- directory = os.path.dirname(save_path)
- if not os.path.exists(directory):
- os.makedirs(directory)
- if retriever_config is not None:
- from .. import create_retriever
- retriever = create_retriever(retriever_config)
- else:
- if self.retriever is None:
- logging.warning(
- "The retriever is not initialized,will initialize it now."
- )
- self.inintial_retriever_predictor(self.config)
- retriever = self.retriever
- vector_info_data = copy.deepcopy(vector_info)
- if (
- not vector_info["flag_too_short_text"]
- and not vector_info["flag_save_bytes_vector"]
- ):
- vector_info_data["vector"] = retriever.encode_vector_store_to_bytes(
- vector_info_data["vector"]
- )
- vector_info_data["flag_save_bytes_vector"] = True
- with custom_open(save_path, "w") as fout:
- fout.write(json.dumps(vector_info_data, ensure_ascii=False) + "\n")
- return
- def load_vector(self, data_path: str, retriever_config: dict = None) -> dict:
- vector_info = None
- if retriever_config is not None:
- from .. import create_retriever
- retriever = create_retriever(retriever_config)
- else:
- if self.retriever is None:
- logging.warning(
- "The retriever is not initialized,will initialize it now."
- )
- self.inintial_retriever_predictor(self.config)
- retriever = self.retriever
- with open(data_path, "r") as fin:
- data = fin.readline()
- vector_info = json.loads(data)
- if (
- "flag_too_short_text" not in vector_info
- or "flag_save_bytes_vector" not in vector_info
- or "vector" not in vector_info
- ):
- logging.error("Invalid vector info.")
- return {"error": "Invalid vector info when load vector!"}
- if vector_info["flag_save_bytes_vector"]:
- vector_info["vector"] = retriever.decode_vector_store_from_bytes(
- vector_info["vector"]
- )
- vector_info["flag_save_bytes_vector"] = False
- return vector_info
- def format_key(self, key_list: Union[str, List[str]]) -> List[str]:
- """
- Formats the key list.
- Args:
- key_list (str|list[str]): A string or a list of strings representing the keys.
- Returns:
- list[str]: A list of formatted keys.
- """
- if key_list == "":
- return []
- if isinstance(key_list, list):
- key_list = [key.replace("\xa0", " ") for key in key_list]
- return key_list
- if isinstance(key_list, str):
- key_list = re.sub(r"[\t\n\r\f\v]", "", key_list)
- key_list = key_list.replace(",", ",").split(",")
- return key_list
- return []
- @function_requires_deps("opencv-contrib-python")
- def mllm_pred(
- self,
- input: Union[str, np.ndarray],
- key_list: Union[str, List[str]],
- mllm_chat_bot_config=None,
- ) -> dict:
- """
- Generates MLLM results based on the provided key list and input image.
- Args:
- input (Union[str, np.ndarray]): Input image path, or numpy array of an image.
- key_list (Union[str, list[str]]): A single key or a list of keys to extract information.
- chat_bot_config (dict): The parameters for LLM chatbot, including api_type, api_key... refer to config file for more details.
- Returns:
- dict: A dictionary containing the chat results.
- """
- if self.use_mllm_predict == False:
- logging.error("MLLM prediction is disabled.")
- return {"mllm_res": "Error:MLLM prediction is disabled!"}
- key_list = self.format_key(key_list)
- if len(key_list) == 0:
- return {"mllm_res": "Error:输入的key_list无效!"}
- if isinstance(input, list):
- logging.error("Input is a list, but it's not supported here.")
- return {"mllm_res": "Error:Input is a list, but it's not supported here!"}
- if isinstance(input, str) and input.endswith(".pdf"):
- logging.error("MLMM prediction does not support PDF currently!")
- return {"mllm_res": "Error:MLMM prediction does not support PDF currently!"}
- if self.mllm_chat_bot is None:
- logging.warning(
- "The MLLM chat bot is not initialized,will initialize it now."
- )
- self.inintial_mllm_predictor(self.config)
- if mllm_chat_bot_config is not None:
- from .. import create_chat_bot
- mllm_chat_bot = create_chat_bot(mllm_chat_bot_config)
- else:
- mllm_chat_bot = self.mllm_chat_bot
- for image_array in self.img_reader([input]):
- image_string = cv2.imencode(".jpg", image_array)[1].tobytes()
- image_base64 = base64.b64encode(image_string).decode("utf-8")
- result = {}
- for key in key_list:
- prompt = (
- str(key)
- + "\n请用图片中完整出现的内容回答,可以是单词、短语或句子,针对问题回答尽可能详细和完整,并保持格式、单位、符号和标点都与图片中的文字内容完全一致。"
- )
- mllm_chat_bot_result = mllm_chat_bot.generate_chat_results(
- prompt=prompt, image=image_base64
- )["content"]
- if mllm_chat_bot_result is None:
- return {"mllm_res": "大模型调用失败"}
- result[key] = mllm_chat_bot_result
- return {"mllm_res": result}
- def generate_and_merge_chat_results(
- self,
- chat_bot: BaseChat,
- prompt: str,
- key_list: list,
- final_results: dict,
- failed_results: list,
- ) -> None:
- """
- Generate and merge chat results into the final results dictionary.
- Args:
- prompt (str): The input prompt for the chat bot.
- key_list (list): A list of keys to track which results to merge.
- final_results (dict): The dictionary to store the final merged results.
- failed_results (list): A list of failed results to avoid merging.
- Returns:
- None
- """
- llm_result = chat_bot.generate_chat_results(prompt)
- llm_result_content = llm_result["content"]
- llm_result_reasoning_content = llm_result["reasoning_content"]
- if llm_result_reasoning_content is not None:
- if "reasoning_content" not in final_results:
- final_results["reasoning_content"] = [llm_result_reasoning_content]
- else:
- final_results["reasoning_content"].append(llm_result_reasoning_content)
- if llm_result_content is None:
- logging.error(
- "chat bot error: \n [prompt:]\n %s\n [result:] %s\n"
- % (prompt, chat_bot.ERROR_MASSAGE)
- )
- return
- llm_result_content = chat_bot.fix_llm_result_format(llm_result_content)
- for key, value in llm_result_content.items():
- if value not in failed_results and key in key_list:
- key_list.remove(key)
- final_results[key] = value
- return
- def get_related_normal_text(
- self,
- retriever_config: dict,
- use_vector_retrieval: bool,
- vector_info: dict,
- key_list: List[str],
- all_normal_text_list: list,
- min_characters: int,
- ) -> str:
- """
- Retrieve related normal text based on vector retrieval or all normal text list.
- Args:
- retriever_config (dict): Configuration for the retriever.
- use_vector_retrieval (bool): Whether to use vector retrieval.
- vector_info (dict): Dictionary containing vector information.
- key_list (list[str]): List of keys to generate question keys.
- all_normal_text_list (list): List of normal text.
- min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
- Returns:
- str: Related normal text.
- """
- if use_vector_retrieval and vector_info is not None:
- if retriever_config is not None:
- from .. import create_retriever
- retriever = create_retriever(retriever_config)
- else:
- if self.retriever is None:
- logging.warning(
- "The retriever is not initialized,will initialize it now."
- )
- self.inintial_retriever_predictor(self.config)
- retriever = self.retriever
- question_key_list = [f"{key}" for key in key_list]
- vector = vector_info["vector"]
- if not vector_info["flag_too_short_text"]:
- assert (
- vector_info["model_name"] == retriever.model_name
- ), f"The vector model name ({vector_info['model_name']}) does not match the retriever model name ({retriever.model_name}). Please check your retriever config."
- if vector_info["flag_save_bytes_vector"]:
- vector = retriever.decode_vector_store_from_bytes(vector)
- related_text = retriever.similarity_retrieval(
- question_key_list, vector, topk=50, min_characters=min_characters
- )
- else:
- if len(vector) > 0:
- related_text = "".join(vector)
- else:
- related_text = ""
- else:
- all_items = []
- for i, normal_text_dict in enumerate(all_normal_text_list):
- for type, text in normal_text_dict.items():
- all_items += [f"{type}:{text}\n"]
- related_text = "".join(all_items)
- if len(related_text) > min_characters:
- logging.warning(
- "The input text content is too long, the large language model may truncate it."
- )
- return related_text
- def ensemble_ocr_llm_mllm(
- self,
- chat_bot: BaseChat,
- key_list: List[str],
- ocr_llm_predict_dict: dict,
- mllm_predict_dict: dict,
- ) -> dict:
- """
- Ensemble OCR_LLM and LMM predictions based on given key list.
- Args:
- key_list (list[str]): List of keys to retrieve predictions.
- ocr_llm_predict_dict (dict): Dictionary containing OCR LLM predictions.
- mllm_predict_dict (dict): Dictionary containing mLLM predictions.
- Returns:
- dict: A dictionary with final predictions.
- """
- final_predict_dict = {}
- for key in key_list:
- predict = ""
- ocr_llm_predict = ""
- mllm_predict = ""
- if key in ocr_llm_predict_dict:
- ocr_llm_predict = ocr_llm_predict_dict[key]
- if key in mllm_predict_dict:
- mllm_predict = mllm_predict_dict[key]
- if ocr_llm_predict != "" and mllm_predict != "":
- prompt = self.ensemble_pe.generate_prompt(
- key, ocr_llm_predict, mllm_predict
- )
- llm_result = chat_bot.generate_chat_results(prompt)
- llm_result_content = llm_result["content"]
- llm_result_reasoning_content = llm_result["reasoning_content"]
- if llm_result_reasoning_content is not None:
- if "reasoning_content" not in final_predict_dict:
- final_predict_dict["reasoning_content"] = [
- llm_result_reasoning_content
- ]
- else:
- final_predict_dict["reasoning_content"].append(
- llm_result_reasoning_content
- )
- if llm_result_content is not None:
- llm_result_content = chat_bot.fix_llm_result_format(
- llm_result_content
- )
- if key in llm_result_content:
- tmp = llm_result_content[key]
- if "B" in tmp:
- predict = mllm_predict
- else:
- predict = ocr_llm_predict
- else:
- predict = ocr_llm_predict
- elif key in ocr_llm_predict_dict:
- predict = ocr_llm_predict_dict[key]
- elif key in mllm_predict_dict:
- predict = mllm_predict_dict[key]
- if predict != "":
- final_predict_dict[key] = predict
- return final_predict_dict
- def chat(
- self,
- key_list: Union[str, List[str]],
- visual_info: dict,
- use_vector_retrieval: bool = True,
- vector_info: dict = None,
- min_characters: int = 3500,
- text_task_description: str = None,
- text_output_format: str = None,
- text_rules_str: str = None,
- text_few_shot_demo_text_content: str = None,
- text_few_shot_demo_key_value_list: str = None,
- table_task_description: str = None,
- table_output_format: str = None,
- table_rules_str: str = None,
- table_few_shot_demo_text_content: str = None,
- table_few_shot_demo_key_value_list: str = None,
- mllm_predict_info: dict = None,
- mllm_integration_strategy: str = "integration",
- chat_bot_config: dict = None,
- retriever_config: dict = None,
- ) -> dict:
- """
- Generates chat results based on the provided key list and visual information.
- Args:
- key_list (Union[str, list[str]]): A single key or a list of keys to extract information.
- visual_info (dict): The visual information result.
- use_vector_retrieval (bool): Whether to use vector retrieval.
- vector_info (dict): The vector information for retrieval.
- min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
- text_task_description (str): The description of the text task.
- text_output_format (str): The output format for text results.
- text_rules_str (str): The rules for generating text results.
- text_few_shot_demo_text_content (str): The text content for few-shot demos.
- text_few_shot_demo_key_value_list (str): The key-value list for few-shot demos.
- table_task_description (str): The description of the table task.
- table_output_format (str): The output format for table results.
- table_rules_str (str): The rules for generating table results.
- table_few_shot_demo_text_content (str): The text content for table few-shot demos.
- table_few_shot_demo_key_value_list (str): The key-value list for table few-shot demos.
- mllm_predict_dict (dict): The dictionary of mLLM predicts.
- mllm_integration_strategy (str): The integration strategy of mLLM and LLM, defaults to "integration", options are "integration", "llm_only" and "mllm_only".
- chat_bot_config (dict): The parameters for LLM chatbot, including api_type, api_key... refer to config file for more details.
- retriever_config (dict): The parameters for LLM retriever, including api_type, api_key... refer to config file for more details.
- Returns:
- dict: A dictionary containing the chat results.
- """
- key_list = self.format_key(key_list)
- key_list_ori = key_list.copy()
- if len(key_list) == 0:
- return {"chat_res": "Error:输入的key_list无效!"}
- if not isinstance(visual_info, list):
- visual_info_list = [visual_info]
- else:
- visual_info_list = visual_info
- if self.chat_bot is None:
- logging.warning(
- "The LLM chat bot is not initialized,will initialize it now."
- )
- self.inintial_chat_predictor(self.config)
- if chat_bot_config is not None:
- from .. import create_chat_bot
- chat_bot = create_chat_bot(chat_bot_config)
- else:
- chat_bot = self.chat_bot
- all_visual_info = self.merge_visual_info_list(visual_info_list)
- (
- all_normal_text_list,
- all_table_text_list,
- all_table_html_list,
- all_table_nei_text_list,
- ) = all_visual_info
- final_results = {}
- failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
- if len(key_list) > 0:
- related_text = self.get_related_normal_text(
- retriever_config,
- use_vector_retrieval,
- vector_info,
- key_list,
- all_normal_text_list,
- min_characters,
- )
- if len(related_text) > 0:
- prompt = self.text_pe.generate_prompt(
- related_text,
- key_list,
- task_description=text_task_description,
- output_format=text_output_format,
- rules_str=text_rules_str,
- few_shot_demo_text_content=text_few_shot_demo_text_content,
- few_shot_demo_key_value_list=text_few_shot_demo_key_value_list,
- )
- self.generate_and_merge_chat_results(
- chat_bot, prompt, key_list, final_results, failed_results
- )
- if len(key_list) > 0:
- for table_html, table_text, table_nei_text in zip(
- all_table_html_list, all_table_text_list, all_table_nei_text_list
- ):
- if len(table_html) <= min_characters - self.table_structure_len_max:
- for table_info in [table_html]:
- if len(key_list) > 0:
- if len(table_nei_text) > 0:
- table_info = (
- table_info + "\n 表格周围文字:" + table_nei_text
- )
- prompt = self.table_pe.generate_prompt(
- table_info,
- key_list,
- task_description=table_task_description,
- output_format=table_output_format,
- rules_str=table_rules_str,
- few_shot_demo_text_content=table_few_shot_demo_text_content,
- few_shot_demo_key_value_list=table_few_shot_demo_key_value_list,
- )
- self.generate_and_merge_chat_results(
- chat_bot,
- prompt,
- key_list,
- final_results,
- failed_results,
- )
- if (
- self.use_mllm_predict
- and mllm_integration_strategy != "llm_only"
- and mllm_predict_info is not None
- ):
- if mllm_integration_strategy == "integration":
- final_predict_dict = self.ensemble_ocr_llm_mllm(
- chat_bot, key_list_ori, final_results, mllm_predict_info
- )
- elif mllm_integration_strategy == "mllm_only":
- final_predict_dict = mllm_predict_info
- else:
- return {
- "chat_res": f"Error:Unsupported mllm_integration_strategy {mllm_integration_strategy}, only support 'integration', 'llm_only' and 'mllm_only'!"
- }
- else:
- final_predict_dict = final_results
- return {"chat_res": final_predict_dict}
- def predict(self, *args, **kwargs) -> None:
- logging.error(
- "PP-ChatOCRv4-doc Pipeline do not support to call `predict()` directly! Please invoke `visual_predict`, `build_vector`, `chat` sequentially to obtain the result."
- )
- return
|