|
|
@@ -47,6 +47,7 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
|
|
|
pp_option: PaddlePredictorOption = None,
|
|
|
use_hpip: bool = False,
|
|
|
hpi_params: Optional[Dict[str, Any]] = None,
|
|
|
+ use_layout_parsing: bool = True,
|
|
|
) -> None:
|
|
|
"""Initializes the pp-chatocrv3-doc pipeline.
|
|
|
|
|
|
@@ -62,6 +63,8 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
|
|
|
device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
|
|
|
)
|
|
|
|
|
|
+ self.use_layout_parsing = use_layout_parsing
|
|
|
+
|
|
|
self.inintial_predictor(config)
|
|
|
|
|
|
self.img_reader = ReadImage(format="BGR")
|
|
|
@@ -78,8 +81,10 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
|
|
|
Returns:
|
|
|
None
|
|
|
"""
|
|
|
- layout_parsing_config = config["SubPipelines"]["LayoutParser"]
|
|
|
- self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
|
|
|
+
|
|
|
+ if self.use_layout_parsing:
|
|
|
+ layout_parsing_config = config["SubPipelines"]["LayoutParser"]
|
|
|
+ self.layout_parsing_pipeline = self.create_pipeline(layout_parsing_config)
|
|
|
|
|
|
from .. import create_chat_bot
|
|
|
|
|
|
@@ -152,7 +157,7 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
|
|
|
input: str | list[str] | np.ndarray | list[np.ndarray],
|
|
|
use_doc_orientation_classify: bool = False, # Whether to use document orientation classification
|
|
|
use_doc_unwarping: bool = False, # Whether to use document unwarping
|
|
|
- use_common_ocr: bool = True, # Whether to use common OCR
|
|
|
+ use_general_ocr: bool = True, # Whether to use general OCR
|
|
|
use_seal_recognition: bool = True, # Whether to use seal recognition
|
|
|
use_table_recognition: bool = True, # Whether to use table recognition
|
|
|
**kwargs,
|
|
|
@@ -160,14 +165,14 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
|
|
|
"""
|
|
|
This function takes an input image or a list of images and performs various visual
|
|
|
prediction tasks such as document orientation classification, document unwarping,
|
|
|
- common OCR, seal recognition, and table recognition based on the provided flags.
|
|
|
+ general OCR, seal recognition, and table recognition based on the provided flags.
|
|
|
|
|
|
Args:
|
|
|
input (str | list[str] | np.ndarray | list[np.ndarray]): Input image path, list of image paths,
|
|
|
numpy array of an image, or list of numpy arrays.
|
|
|
use_doc_orientation_classify (bool): Flag to use document orientation classification.
|
|
|
use_doc_unwarping (bool): Flag to use document unwarping.
|
|
|
- use_common_ocr (bool): Flag to use common OCR.
|
|
|
+ use_general_ocr (bool): Flag to use general OCR.
|
|
|
use_seal_recognition (bool): Flag to use seal recognition.
|
|
|
use_table_recognition (bool): Flag to use table recognition.
|
|
|
**kwargs: Additional keyword arguments.
|
|
|
@@ -176,6 +181,9 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
|
|
|
dict: A dictionary containing the layout parsing result and visual information.
|
|
|
"""
|
|
|
|
|
|
+ if not self.use_layout_parsing:
|
|
|
+ raise ValueError("The models for layout parsing are not initialized.")
|
|
|
+
|
|
|
if not isinstance(input, list):
|
|
|
input_list = [input]
|
|
|
else:
|
|
|
@@ -195,7 +203,7 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
|
|
|
image_array,
|
|
|
use_doc_orientation_classify=use_doc_orientation_classify,
|
|
|
use_doc_unwarping=use_doc_unwarping,
|
|
|
- use_common_ocr=use_common_ocr,
|
|
|
+ use_general_ocr=use_general_ocr,
|
|
|
use_seal_recognition=use_seal_recognition,
|
|
|
use_table_recognition=use_table_recognition,
|
|
|
)
|
|
|
@@ -264,6 +272,8 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
|
|
|
all_table_html_list = []
|
|
|
for single_visual_info in visual_info_list:
|
|
|
normal_text_dict = single_visual_info["normal_text_dict"]
|
|
|
+ for key in normal_text_dict:
|
|
|
+ normal_text_dict[key] = normal_text_dict[key].replace("\n", "")
|
|
|
table_text_list = single_visual_info["table_text_list"]
|
|
|
table_html_list = single_visual_info["table_html_list"]
|
|
|
all_normal_text_list.append(normal_text_dict)
|
|
|
@@ -308,6 +318,9 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
|
|
|
if len(table_html) > min_characters - self.table_structure_len_max:
|
|
|
all_items += [f"table:{table_text}\n"]
|
|
|
|
|
|
+ # if len(table_html) > min_characters - self.table_structure_len_max:
|
|
|
+ # all_items += [f"table:{table_text}\n"]
|
|
|
+
|
|
|
all_text_str = "".join(all_items)
|
|
|
|
|
|
if len(all_text_str) > min_characters:
|
|
|
@@ -413,7 +426,10 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
|
|
|
)
|
|
|
return
|
|
|
|
|
|
+ # print(prompt, llm_result)
|
|
|
+
|
|
|
llm_result = self.fix_llm_result_format(llm_result)
|
|
|
+
|
|
|
for key, value in llm_result.items():
|
|
|
if value not in failed_results and key in key_list:
|
|
|
key_list.remove(key)
|
|
|
@@ -477,27 +493,10 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
|
|
|
final_results = {}
|
|
|
failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
|
|
|
|
|
|
- for table_html, table_text in zip(all_table_html_list, all_table_text_list):
|
|
|
- if len(table_html) <= min_characters - self.table_structure_len_max:
|
|
|
- for table_info in [table_html, table_text]:
|
|
|
- if len(key_list) > 0:
|
|
|
- prompt = self.table_pe.generate_prompt(
|
|
|
- table_info,
|
|
|
- key_list,
|
|
|
- task_description=table_task_description,
|
|
|
- output_format=table_output_format,
|
|
|
- rules_str=table_rules_str,
|
|
|
- few_shot_demo_text_content=table_few_shot_demo_text_content,
|
|
|
- few_shot_demo_key_value_list=table_few_shot_demo_key_value_list,
|
|
|
- )
|
|
|
-
|
|
|
- self.generate_and_merge_chat_results(
|
|
|
- prompt, key_list, final_results, failed_results
|
|
|
- )
|
|
|
-
|
|
|
if len(key_list) > 0:
|
|
|
if use_vector_retrieval and vector_info is not None:
|
|
|
- question_key_list = [f"抽取关键信息:{key}" for key in key_list]
|
|
|
+ # question_key_list = [f"抽取关键信息:{key}" for key in key_list]
|
|
|
+ question_key_list = [f"待回答问题:{key}" for key in key_list]
|
|
|
vector = vector_info["vector"]
|
|
|
if not vector_info["flag_too_short_text"]:
|
|
|
related_text = self.retriever.similarity_retrieval(
|
|
|
@@ -530,11 +529,29 @@ class PP_ChatOCRv3_doc_Pipeline(BasePipeline):
|
|
|
few_shot_demo_text_content=text_few_shot_demo_text_content,
|
|
|
few_shot_demo_key_value_list=text_few_shot_demo_key_value_list,
|
|
|
)
|
|
|
- # print(prompt)
|
|
|
self.generate_and_merge_chat_results(
|
|
|
prompt, key_list, final_results, failed_results
|
|
|
)
|
|
|
|
|
|
+ if len(key_list) > 0:
|
|
|
+ for table_html, table_text in zip(all_table_html_list, all_table_text_list):
|
|
|
+ if len(table_html) <= min_characters - self.table_structure_len_max:
|
|
|
+ for table_info in [table_html]:
|
|
|
+ if len(key_list) > 0:
|
|
|
+ prompt = self.table_pe.generate_prompt(
|
|
|
+ table_info,
|
|
|
+ key_list,
|
|
|
+ task_description=table_task_description,
|
|
|
+ output_format=table_output_format,
|
|
|
+ rules_str=table_rules_str,
|
|
|
+ few_shot_demo_text_content=table_few_shot_demo_text_content,
|
|
|
+ few_shot_demo_key_value_list=table_few_shot_demo_key_value_list,
|
|
|
+ )
|
|
|
+
|
|
|
+ self.generate_and_merge_chat_results(
|
|
|
+ prompt, key_list, final_results, failed_results
|
|
|
+ )
|
|
|
+
|
|
|
return {"chat_res": final_results}
|
|
|
|
|
|
def predict(self, *args, **kwargs) -> None:
|