|
@@ -20,7 +20,6 @@ import base64
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import copy
|
|
import copy
|
|
|
from .pipeline_base import PP_ChatOCR_Pipeline
|
|
from .pipeline_base import PP_ChatOCR_Pipeline
|
|
|
-from .result import VisualInfoResult
|
|
|
|
|
from ...common.reader import ReadImage
|
|
from ...common.reader import ReadImage
|
|
|
from ...common.batch_sampler import ImageBatchSampler
|
|
from ...common.batch_sampler import ImageBatchSampler
|
|
|
from ....utils import logging
|
|
from ....utils import logging
|
|
@@ -72,45 +71,71 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
None
|
|
None
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
- self.use_layout_parser = True
|
|
|
|
|
- if "use_layout_parser" in config:
|
|
|
|
|
- self.use_layout_parser = config["use_layout_parser"]
|
|
|
|
|
-
|
|
|
|
|
|
|
+ self.use_layout_parser = config.get("use_layout_parser", True)
|
|
|
if self.use_layout_parser:
|
|
if self.use_layout_parser:
|
|
|
- layout_parsing_config = config["SubPipelines"]["LayoutParser"]
|
|
|
|
|
|
|
+ layout_parsing_config = config.get("SubPipelines", {}).get(
|
|
|
|
|
+ "LayoutParser",
|
|
|
|
|
+ {"pipeline_config_error": "config error for layout_parsing_pipeline!"},
|
|
|
|
|
+ )
|
|
|
self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
|
|
self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
|
|
|
|
|
|
|
|
from .. import create_chat_bot
|
|
from .. import create_chat_bot
|
|
|
|
|
|
|
|
- chat_bot_config = config["SubModules"]["LLM_Chat"]
|
|
|
|
|
|
|
+ chat_bot_config = config.get("SubModules", {}).get(
|
|
|
|
|
+ "LLM_Chat",
|
|
|
|
|
+ {"chat_bot_config_error": "config error for llm chat bot!"},
|
|
|
|
|
+ )
|
|
|
self.chat_bot = create_chat_bot(chat_bot_config)
|
|
self.chat_bot = create_chat_bot(chat_bot_config)
|
|
|
|
|
|
|
|
from .. import create_retriever
|
|
from .. import create_retriever
|
|
|
|
|
|
|
|
- retriever_config = config["SubModules"]["LLM_Retriever"]
|
|
|
|
|
|
|
+ retriever_config = config.get("SubModules", {}).get(
|
|
|
|
|
+ "LLM_Retriever",
|
|
|
|
|
+ {"retriever_config_error": "config error for llm retriever!"},
|
|
|
|
|
+ )
|
|
|
self.retriever = create_retriever(retriever_config)
|
|
self.retriever = create_retriever(retriever_config)
|
|
|
|
|
|
|
|
from .. import create_prompt_engeering
|
|
from .. import create_prompt_engeering
|
|
|
|
|
|
|
|
- text_pe_config = config["SubModules"]["PromptEngneering"]["KIE_CommonText"]
|
|
|
|
|
|
|
+ text_pe_config = (
|
|
|
|
|
+ config.get("SubModules", {})
|
|
|
|
|
+ .get("PromptEngneering", {})
|
|
|
|
|
+ .get(
|
|
|
|
|
+ "KIE_CommonText",
|
|
|
|
|
+ {"pe_config_error": "config error for text_pe!"},
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
self.text_pe = create_prompt_engeering(text_pe_config)
|
|
self.text_pe = create_prompt_engeering(text_pe_config)
|
|
|
|
|
|
|
|
- table_pe_config = config["SubModules"]["PromptEngneering"]["KIE_Table"]
|
|
|
|
|
|
|
+ table_pe_config = (
|
|
|
|
|
+ config.get("SubModules", {})
|
|
|
|
|
+ .get("PromptEngneering", {})
|
|
|
|
|
+ .get(
|
|
|
|
|
+ "KIE_Table",
|
|
|
|
|
+ {"pe_config_error": "config error for table_pe!"},
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
self.table_pe = create_prompt_engeering(table_pe_config)
|
|
self.table_pe = create_prompt_engeering(table_pe_config)
|
|
|
|
|
|
|
|
- self.use_mllm_predict = False
|
|
|
|
|
- if "use_mllm_predict" in config:
|
|
|
|
|
- self.use_mllm_predict = config["use_mllm_predict"]
|
|
|
|
|
|
|
+ self.use_mllm_predict = config.get("use_mllm_predict", True)
|
|
|
if self.use_mllm_predict:
|
|
if self.use_mllm_predict:
|
|
|
- mllm_chat_bot_config = config["SubModules"]["MLLM_Chat"]
|
|
|
|
|
|
|
+ mllm_chat_bot_config = config.get("SubModules", {}).get(
|
|
|
|
|
+ "MLLM_Chat",
|
|
|
|
|
+ {"mllm_chat_bot_config": "config error for mllm chat bot!"},
|
|
|
|
|
+ )
|
|
|
self.mllm_chat_bot = create_chat_bot(mllm_chat_bot_config)
|
|
self.mllm_chat_bot = create_chat_bot(mllm_chat_bot_config)
|
|
|
- ensemble_pe_config = config["SubModules"]["PromptEngneering"]["Ensemble"]
|
|
|
|
|
|
|
+ ensemble_pe_config = (
|
|
|
|
|
+ config.get("SubModules", {})
|
|
|
|
|
+ .get("PromptEngneering", {})
|
|
|
|
|
+ .get(
|
|
|
|
|
+ "Ensemble",
|
|
|
|
|
+ {"pe_config_error": "config error for ensemble_pe!"},
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
self.ensemble_pe = create_prompt_engeering(ensemble_pe_config)
|
|
self.ensemble_pe = create_prompt_engeering(ensemble_pe_config)
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
- def decode_visual_result(
|
|
|
|
|
- self, layout_parsing_result: LayoutParsingResult
|
|
|
|
|
- ) -> VisualInfoResult:
|
|
|
|
|
|
|
+ def decode_visual_result(self, layout_parsing_result: LayoutParsingResult) -> dict:
|
|
|
"""
|
|
"""
|
|
|
Decodes the visual result from the layout parsing result.
|
|
Decodes the visual result from the layout parsing result.
|
|
|
|
|
|
|
@@ -118,21 +143,21 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
layout_parsing_result (LayoutParsingResult): The result of layout parsing.
|
|
layout_parsing_result (LayoutParsingResult): The result of layout parsing.
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
|
- VisualInfoResult: The decoded visual information.
|
|
|
|
|
|
|
+ dict: The decoded visual information.
|
|
|
"""
|
|
"""
|
|
|
text_paragraphs_ocr_res = layout_parsing_result["text_paragraphs_ocr_res"]
|
|
text_paragraphs_ocr_res = layout_parsing_result["text_paragraphs_ocr_res"]
|
|
|
seal_res_list = layout_parsing_result["seal_res_list"]
|
|
seal_res_list = layout_parsing_result["seal_res_list"]
|
|
|
normal_text_dict = {}
|
|
normal_text_dict = {}
|
|
|
|
|
|
|
|
for seal_res in seal_res_list:
|
|
for seal_res in seal_res_list:
|
|
|
- for text in seal_res["rec_text"]:
|
|
|
|
|
|
|
+ for text in seal_res["rec_texts"]:
|
|
|
layout_type = "印章"
|
|
layout_type = "印章"
|
|
|
if layout_type not in normal_text_dict:
|
|
if layout_type not in normal_text_dict:
|
|
|
normal_text_dict[layout_type] = f"{text}"
|
|
normal_text_dict[layout_type] = f"{text}"
|
|
|
else:
|
|
else:
|
|
|
normal_text_dict[layout_type] += f"\n {text}"
|
|
normal_text_dict[layout_type] += f"\n {text}"
|
|
|
|
|
|
|
|
- for text in text_paragraphs_ocr_res["rec_text"]:
|
|
|
|
|
|
|
+ for text in text_paragraphs_ocr_res["rec_texts"]:
|
|
|
layout_type = "words in text block"
|
|
layout_type = "words in text block"
|
|
|
if layout_type not in normal_text_dict:
|
|
if layout_type not in normal_text_dict:
|
|
|
normal_text_dict[layout_type] = text
|
|
normal_text_dict[layout_type] = text
|
|
@@ -145,26 +170,38 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
table_nei_text_list = []
|
|
table_nei_text_list = []
|
|
|
for table_res in table_res_list:
|
|
for table_res in table_res_list:
|
|
|
table_html_list.append(table_res["pred_html"])
|
|
table_html_list.append(table_res["pred_html"])
|
|
|
- single_table_text = " ".join(table_res["table_ocr_pred"]["rec_text"])
|
|
|
|
|
|
|
+ single_table_text = " ".join(table_res["table_ocr_pred"]["rec_texts"])
|
|
|
table_text_list.append(single_table_text)
|
|
table_text_list.append(single_table_text)
|
|
|
- table_nei_text_list.append(table_res["neighbor_text"])
|
|
|
|
|
|
|
+ table_nei_text_list.append(table_res["neighbor_texts"])
|
|
|
|
|
|
|
|
visual_info = {}
|
|
visual_info = {}
|
|
|
visual_info["normal_text_dict"] = normal_text_dict
|
|
visual_info["normal_text_dict"] = normal_text_dict
|
|
|
visual_info["table_text_list"] = table_text_list
|
|
visual_info["table_text_list"] = table_text_list
|
|
|
visual_info["table_html_list"] = table_html_list
|
|
visual_info["table_html_list"] = table_html_list
|
|
|
visual_info["table_nei_text_list"] = table_nei_text_list
|
|
visual_info["table_nei_text_list"] = table_nei_text_list
|
|
|
- return VisualInfoResult(visual_info)
|
|
|
|
|
|
|
+ return visual_info
|
|
|
|
|
|
|
|
# Function to perform visual prediction on input images
|
|
# Function to perform visual prediction on input images
|
|
|
def visual_predict(
|
|
def visual_predict(
|
|
|
self,
|
|
self,
|
|
|
input: str | list[str] | np.ndarray | list[np.ndarray],
|
|
input: str | list[str] | np.ndarray | list[np.ndarray],
|
|
|
- use_doc_orientation_classify: bool = False, # Whether to use document orientation classification
|
|
|
|
|
- use_doc_unwarping: bool = False, # Whether to use document unwarping
|
|
|
|
|
- use_general_ocr: bool = True, # Whether to use general OCR
|
|
|
|
|
- use_seal_recognition: bool = True, # Whether to use seal recognition
|
|
|
|
|
- use_table_recognition: bool = True, # Whether to use table recognition
|
|
|
|
|
|
|
+ use_doc_orientation_classify: Optional[bool] = None,
|
|
|
|
|
+ use_doc_unwarping: Optional[bool] = None,
|
|
|
|
|
+ use_general_ocr: Optional[bool] = None,
|
|
|
|
|
+ use_seal_recognition: Optional[bool] = None,
|
|
|
|
|
+ use_table_recognition: Optional[bool] = None,
|
|
|
|
|
+ text_det_limit_side_len: Optional[int] = None,
|
|
|
|
|
+ text_det_limit_type: Optional[str] = None,
|
|
|
|
|
+ text_det_thresh: Optional[float] = None,
|
|
|
|
|
+ text_det_box_thresh: Optional[float] = None,
|
|
|
|
|
+ text_det_unclip_ratio: Optional[float] = None,
|
|
|
|
|
+ text_rec_score_thresh: Optional[float] = None,
|
|
|
|
|
+ seal_det_limit_side_len: Optional[int] = None,
|
|
|
|
|
+ seal_det_limit_type: Optional[str] = None,
|
|
|
|
|
+ seal_det_thresh: Optional[float] = None,
|
|
|
|
|
+ seal_det_box_thresh: Optional[float] = None,
|
|
|
|
|
+ seal_det_unclip_ratio: Optional[float] = None,
|
|
|
|
|
+ seal_rec_score_thresh: Optional[float] = None,
|
|
|
**kwargs,
|
|
**kwargs,
|
|
|
) -> dict:
|
|
) -> dict:
|
|
|
"""
|
|
"""
|
|
@@ -187,7 +224,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
"""
|
|
"""
|
|
|
if self.use_layout_parser == False:
|
|
if self.use_layout_parser == False:
|
|
|
logging.error("The models for layout parser are not initialized.")
|
|
logging.error("The models for layout parser are not initialized.")
|
|
|
- yield None
|
|
|
|
|
|
|
+ yield {"error": "The models for layout parser are not initialized."}
|
|
|
|
|
|
|
|
for layout_parsing_result in self.layout_parsing_pipeline.predict(
|
|
for layout_parsing_result in self.layout_parsing_pipeline.predict(
|
|
|
input,
|
|
input,
|
|
@@ -196,6 +233,18 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
use_general_ocr=use_general_ocr,
|
|
use_general_ocr=use_general_ocr,
|
|
|
use_seal_recognition=use_seal_recognition,
|
|
use_seal_recognition=use_seal_recognition,
|
|
|
use_table_recognition=use_table_recognition,
|
|
use_table_recognition=use_table_recognition,
|
|
|
|
|
+ text_det_limit_side_len=text_det_limit_side_len,
|
|
|
|
|
+ text_det_limit_type=text_det_limit_type,
|
|
|
|
|
+ text_det_thresh=text_det_thresh,
|
|
|
|
|
+ text_det_box_thresh=text_det_box_thresh,
|
|
|
|
|
+ text_det_unclip_ratio=text_det_unclip_ratio,
|
|
|
|
|
+ text_rec_score_thresh=text_rec_score_thresh,
|
|
|
|
|
+ seal_det_box_thresh=seal_det_box_thresh,
|
|
|
|
|
+ seal_det_limit_side_len=seal_det_limit_side_len,
|
|
|
|
|
+ seal_det_limit_type=seal_det_limit_type,
|
|
|
|
|
+ seal_det_thresh=seal_det_thresh,
|
|
|
|
|
+ seal_det_unclip_ratio=seal_det_unclip_ratio,
|
|
|
|
|
+ seal_rec_score_thresh=seal_rec_score_thresh,
|
|
|
):
|
|
):
|
|
|
|
|
|
|
|
visual_info = self.decode_visual_result(layout_parsing_result)
|
|
visual_info = self.decode_visual_result(layout_parsing_result)
|
|
@@ -206,14 +255,12 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
}
|
|
}
|
|
|
yield visual_predict_res
|
|
yield visual_predict_res
|
|
|
|
|
|
|
|
- def save_visual_info_list(
|
|
|
|
|
- self, visual_info: VisualInfoResult, save_path: str
|
|
|
|
|
- ) -> None:
|
|
|
|
|
|
|
+ def save_visual_info_list(self, visual_info: dict, save_path: str) -> None:
|
|
|
"""
|
|
"""
|
|
|
Save the visual info list to the specified file path.
|
|
Save the visual info list to the specified file path.
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
- visual_info (VisualInfoResult): The visual info result, which can be a single object or a list of objects.
|
|
|
|
|
|
|
+ visual_info (dict): The visual info result, which can be a single object or a list of objects.
|
|
|
save_path (str): The file path to save the visual info list.
|
|
save_path (str): The file path to save the visual info list.
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
@@ -228,7 +275,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
|
|
fout.write(json.dumps(visual_info_list, ensure_ascii=False) + "\n")
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
- def load_visual_info_list(self, data_path: str) -> list[VisualInfoResult]:
|
|
|
|
|
|
|
+ def load_visual_info_list(self, data_path: str) -> list[dict]:
|
|
|
"""
|
|
"""
|
|
|
Loads visual info list from a JSON file.
|
|
Loads visual info list from a JSON file.
|
|
|
|
|
|
|
@@ -236,7 +283,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
data_path (str): The path to the JSON file containing visual info.
|
|
data_path (str): The path to the JSON file containing visual info.
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
|
- list[VisualInfoResult]: A list of VisualInfoResult objects parsed from the JSON file.
|
|
|
|
|
|
|
+ list[dict]: A list of dict objects parsed from the JSON file.
|
|
|
"""
|
|
"""
|
|
|
with open(data_path, "r") as fin:
|
|
with open(data_path, "r") as fin:
|
|
|
data = fin.readline()
|
|
data = fin.readline()
|
|
@@ -244,13 +291,13 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
return visual_info_list
|
|
return visual_info_list
|
|
|
|
|
|
|
|
def merge_visual_info_list(
|
|
def merge_visual_info_list(
|
|
|
- self, visual_info_list: list[VisualInfoResult]
|
|
|
|
|
|
|
+ self, visual_info_list: list[dict]
|
|
|
) -> tuple[list, list, list, list]:
|
|
) -> tuple[list, list, list, list]:
|
|
|
"""
|
|
"""
|
|
|
Merge visual info lists.
|
|
Merge visual info lists.
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
- visual_info_list (list[VisualInfoResult]): A list of visual info results.
|
|
|
|
|
|
|
+ visual_info_list (list[dict]): A list of visual info results.
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
|
tuple[list, list, list, list]: A tuple containing four lists, one for normal text dicts,
|
|
tuple[list, list, list, list]: A tuple containing four lists, one for normal text dicts,
|
|
@@ -281,17 +328,19 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
|
|
|
|
|
def build_vector(
|
|
def build_vector(
|
|
|
self,
|
|
self,
|
|
|
- visual_info: VisualInfoResult,
|
|
|
|
|
|
|
+ visual_info: dict,
|
|
|
min_characters: int = 3500,
|
|
min_characters: int = 3500,
|
|
|
llm_request_interval: float = 1.0,
|
|
llm_request_interval: float = 1.0,
|
|
|
|
|
+ flag_save_bytes_vector: bool = False,
|
|
|
) -> dict:
|
|
) -> dict:
|
|
|
"""
|
|
"""
|
|
|
Build a vector representation from visual information.
|
|
Build a vector representation from visual information.
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
- visual_info (VisualInfoResult): The visual information input, can be a single instance or a list of instances.
|
|
|
|
|
|
|
+ visual_info (dict): The visual information input, can be a single instance or a list of instances.
|
|
|
min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
|
|
min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
|
|
|
llm_request_interval (float): The interval between LLM requests, defaults to 1.0.
|
|
llm_request_interval (float): The interval between LLM requests, defaults to 1.0.
|
|
|
|
|
+ flag_save_bytes_vector (bool): Whether to save the vector as bytes, defaults to False.
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
|
dict: A dictionary containing the vector info and a flag indicating if the text is too short.
|
|
dict: A dictionary containing the vector info and a flag indicating if the text is too short.
|
|
@@ -324,30 +373,23 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
all_items += [f"table:{table_text}\t{table_nei_text}"]
|
|
all_items += [f"table:{table_text}\t{table_nei_text}"]
|
|
|
|
|
|
|
|
all_text_str = "".join(all_items)
|
|
all_text_str = "".join(all_items)
|
|
|
-
|
|
|
|
|
|
|
+ vector_info["flag_save_bytes_vector"] = False
|
|
|
if len(all_text_str) > min_characters:
|
|
if len(all_text_str) > min_characters:
|
|
|
vector_info["flag_too_short_text"] = False
|
|
vector_info["flag_too_short_text"] = False
|
|
|
vector_info["vector"] = self.retriever.generate_vector_database(all_items)
|
|
vector_info["vector"] = self.retriever.generate_vector_database(all_items)
|
|
|
|
|
+ if flag_save_bytes_vector:
|
|
|
|
|
+ vector_info["vector"] = self.retriever.encode_vector_store_to_bytes(
|
|
|
|
|
+ vector_info["vector"]
|
|
|
|
|
+ )
|
|
|
|
|
+ vector_info["flag_save_bytes_vector"] = True
|
|
|
else:
|
|
else:
|
|
|
vector_info["flag_too_short_text"] = True
|
|
vector_info["flag_too_short_text"] = True
|
|
|
vector_info["vector"] = all_items
|
|
vector_info["vector"] = all_items
|
|
|
return vector_info
|
|
return vector_info
|
|
|
|
|
|
|
|
def save_vector(self, vector_info: dict, save_path: str) -> None:
|
|
def save_vector(self, vector_info: dict, save_path: str) -> None:
|
|
|
- if "flag_too_short_text" not in vector_info or "vector" not in vector_info:
|
|
|
|
|
- logging.error("Invalid vector info.")
|
|
|
|
|
- return
|
|
|
|
|
- save_vector_info = {}
|
|
|
|
|
- save_vector_info["flag_too_short_text"] = vector_info["flag_too_short_text"]
|
|
|
|
|
- if not vector_info["flag_too_short_text"]:
|
|
|
|
|
- save_vector_info["vector"] = self.retriever.encode_vector_store_to_bytes(
|
|
|
|
|
- vector_info["vector"]
|
|
|
|
|
- )
|
|
|
|
|
- else:
|
|
|
|
|
- save_vector_info["vector"] = vector_info["vector"]
|
|
|
|
|
-
|
|
|
|
|
with open(save_path, "w") as fout:
|
|
with open(save_path, "w") as fout:
|
|
|
- fout.write(json.dumps(save_vector_info, ensure_ascii=False) + "\n")
|
|
|
|
|
|
|
+ fout.write(json.dumps(vector_info, ensure_ascii=False) + "\n")
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
def load_vector(self, data_path: str) -> dict:
|
|
def load_vector(self, data_path: str) -> dict:
|
|
@@ -355,10 +397,15 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
with open(data_path, "r") as fin:
|
|
with open(data_path, "r") as fin:
|
|
|
data = fin.readline()
|
|
data = fin.readline()
|
|
|
vector_info = json.loads(data)
|
|
vector_info = json.loads(data)
|
|
|
- if "flag_too_short_text" not in vector_info or "vector" not in vector_info:
|
|
|
|
|
|
|
+ if (
|
|
|
|
|
+ "flag_too_short_text" not in vector_info
|
|
|
|
|
+ or "flag_save_bytes_vector" not in vector_info
|
|
|
|
|
+ or "vector" not in vector_info
|
|
|
|
|
+ ):
|
|
|
logging.error("Invalid vector info.")
|
|
logging.error("Invalid vector info.")
|
|
|
- return
|
|
|
|
|
- if not vector_info["flag_too_short_text"]:
|
|
|
|
|
|
|
+ return {"error": "Invalid vector info when load vector!"}
|
|
|
|
|
+
|
|
|
|
|
+ if vector_info["flag_save_bytes_vector"]:
|
|
|
vector_info["vector"] = self.retriever.decode_vector_store_from_bytes(
|
|
vector_info["vector"] = self.retriever.decode_vector_store_from_bytes(
|
|
|
vector_info["vector"]
|
|
vector_info["vector"]
|
|
|
)
|
|
)
|
|
@@ -558,7 +605,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
def chat(
|
|
def chat(
|
|
|
self,
|
|
self,
|
|
|
key_list: str | list[str],
|
|
key_list: str | list[str],
|
|
|
- visual_info: VisualInfoResult,
|
|
|
|
|
|
|
+ visual_info: dict,
|
|
|
use_vector_retrieval: bool = True,
|
|
use_vector_retrieval: bool = True,
|
|
|
vector_info: dict = None,
|
|
vector_info: dict = None,
|
|
|
min_characters: int = 3500,
|
|
min_characters: int = 3500,
|
|
@@ -580,7 +627,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
key_list (str | list[str]): A single key or a list of keys to extract information.
|
|
key_list (str | list[str]): A single key or a list of keys to extract information.
|
|
|
- visual_info (VisualInfoResult): The visual information result.
|
|
|
|
|
|
|
+ visual_info (dict): The visual information result.
|
|
|
use_vector_retrieval (bool): Whether to use vector retrieval.
|
|
use_vector_retrieval (bool): Whether to use vector retrieval.
|
|
|
vector_info (dict): The vector information for retrieval.
|
|
vector_info (dict): The vector information for retrieval.
|
|
|
min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
|
|
min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
|