|
@@ -14,7 +14,9 @@
|
|
|
|
|
|
|
|
from typing import Any, Dict, Optional
|
|
from typing import Any, Dict, Optional
|
|
|
import re
|
|
import re
|
|
|
|
|
+import cv2
|
|
|
import json
|
|
import json
|
|
|
|
|
+import base64
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import copy
|
|
import copy
|
|
|
from .pipeline_base import PP_ChatOCR_Pipeline
|
|
from .pipeline_base import PP_ChatOCR_Pipeline
|
|
@@ -99,6 +101,8 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
if "use_mllm_predict" in config:
|
|
if "use_mllm_predict" in config:
|
|
|
self.use_mllm_predict = config["use_mllm_predict"]
|
|
self.use_mllm_predict = config["use_mllm_predict"]
|
|
|
if self.use_mllm_predict:
|
|
if self.use_mllm_predict:
|
|
|
|
|
+ mllm_chat_bot_config = config["SubModules"]["MLLM_Chat"]
|
|
|
|
|
+ self.mllm_chat_bot = create_chat_bot(mllm_chat_bot_config)
|
|
|
ensemble_pe_config = config["SubModules"]["PromptEngneering"]["Ensemble"]
|
|
ensemble_pe_config = config["SubModules"]["PromptEngneering"]["Ensemble"]
|
|
|
self.ensemble_pe = create_prompt_engeering(ensemble_pe_config)
|
|
self.ensemble_pe = create_prompt_engeering(ensemble_pe_config)
|
|
|
return
|
|
return
|
|
@@ -380,6 +384,47 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
|
|
|
|
|
return []
|
|
return []
|
|
|
|
|
|
|
|
|
|
+ def mllm_pred(
|
|
|
|
|
+ self,
|
|
|
|
|
+ input: str | np.ndarray,
|
|
|
|
|
+ key_list,
|
|
|
|
|
+ **kwargs,
|
|
|
|
|
+ ) -> dict:
|
|
|
|
|
+ key_list = self.format_key(key_list)
|
|
|
|
|
+ if len(key_list) == 0:
|
|
|
|
|
+ return {"mllm_res": "Error:输入的key_list无效!"}
|
|
|
|
|
+
|
|
|
|
|
+ if isinstance(input, list):
|
|
|
|
|
+ logging.error("Input is a list, but it's not supported here.")
|
|
|
|
|
+ return {"mllm_res": "Error:Input is a list, but it's not supported here!"}
|
|
|
|
|
+ image_array_list = self.img_reader([input])
|
|
|
|
|
+ if (
|
|
|
|
|
+ isinstance(input, str)
|
|
|
|
|
+ and input.endswith(".pdf")
|
|
|
|
|
+ and len(image_array_list) > 1
|
|
|
|
|
+ ):
|
|
|
|
|
+ logging.error("The input with PDF should have only one page.")
|
|
|
|
|
+ return {"mllm_res": "Error:The input with PDF should have only one page!"}
|
|
|
|
|
+
|
|
|
|
|
+ for image_array in image_array_list:
|
|
|
|
|
+
|
|
|
|
|
+ assert len(image_array.shape) == 3
|
|
|
|
|
+ image_string = cv2.imencode(".jpg", image_array)[1].tostring()
|
|
|
|
|
+ image_base64 = base64.b64encode(image_string).decode("utf-8")
|
|
|
|
|
+ result = {}
|
|
|
|
|
+ for key in key_list:
|
|
|
|
|
+ prompt = (
|
|
|
|
|
+ str(key)
|
|
|
|
|
+ + "\n请用图片中完整出现的内容回答,可以是单词、短语或句子,针对问题回答尽可能详细和完整,并保持格式、单位、符号和标点都与图片中的文字内容完全一致。"
|
|
|
|
|
+ )
|
|
|
|
|
+ mllm_chat_bot_result = self.mllm_chat_bot.generate_chat_results(
|
|
|
|
|
+ prompt=prompt, image=image_base64
|
|
|
|
|
+ )
|
|
|
|
|
+ if mllm_chat_bot_result is None:
|
|
|
|
|
+ return {"mllm_res": "大模型调用失败"}
|
|
|
|
|
+ result[key] = mllm_chat_bot_result
|
|
|
|
|
+ return {"mllm_res": result}
|
|
|
|
|
+
|
|
|
def generate_and_merge_chat_results(
|
|
def generate_and_merge_chat_results(
|
|
|
self, prompt: str, key_list: list, final_results: dict, failed_results: list
|
|
self, prompt: str, key_list: list, final_results: dict, failed_results: list
|
|
|
) -> None:
|
|
) -> None:
|
|
@@ -524,6 +569,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
table_few_shot_demo_text_content: str = None,
|
|
table_few_shot_demo_text_content: str = None,
|
|
|
table_few_shot_demo_key_value_list: str = None,
|
|
table_few_shot_demo_key_value_list: str = None,
|
|
|
mllm_predict_dict: dict = None,
|
|
mllm_predict_dict: dict = None,
|
|
|
|
|
+ mllm_integration_strategy: str = "integration",
|
|
|
) -> dict:
|
|
) -> dict:
|
|
|
"""
|
|
"""
|
|
|
Generates chat results based on the provided key list and visual information.
|
|
Generates chat results based on the provided key list and visual information.
|
|
@@ -545,6 +591,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
table_few_shot_demo_text_content (str): The text content for table few-shot demos.
|
|
table_few_shot_demo_text_content (str): The text content for table few-shot demos.
|
|
|
table_few_shot_demo_key_value_list (str): The key-value list for table few-shot demos.
|
|
table_few_shot_demo_key_value_list (str): The key-value list for table few-shot demos.
|
|
|
mllm_predict_dict (dict): The dictionary of mLLM predicts.
|
|
mllm_predict_dict (dict): The dictionary of mLLM predicts.
|
|
|
|
|
+ mllm_integration_strategy(str): The integration strategy of mLLM and LLM, defaults to "integration", options are "integration", "llm_only" and "mllm_only".
|
|
|
Returns:
|
|
Returns:
|
|
|
dict: A dictionary containing the chat results.
|
|
dict: A dictionary containing the chat results.
|
|
|
"""
|
|
"""
|
|
@@ -552,7 +599,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
key_list = self.format_key(key_list)
|
|
key_list = self.format_key(key_list)
|
|
|
key_list_ori = key_list.copy()
|
|
key_list_ori = key_list.copy()
|
|
|
if len(key_list) == 0:
|
|
if len(key_list) == 0:
|
|
|
- return {"error": "输入的key_list无效!"}
|
|
|
|
|
|
|
+ return {"chat_res": "Error:输入的key_list无效!"}
|
|
|
|
|
|
|
|
if not isinstance(visual_info, list):
|
|
if not isinstance(visual_info, list):
|
|
|
visual_info_list = [visual_info]
|
|
visual_info_list = [visual_info]
|
|
@@ -620,10 +667,17 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
prompt, key_list, final_results, failed_results
|
|
prompt, key_list, final_results, failed_results
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- if self.use_mllm_predict:
|
|
|
|
|
- final_predict_dict = self.ensemble_ocr_llm_mllm(
|
|
|
|
|
- key_list_ori, final_results, mllm_predict_dict
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ if self.use_mllm_predict and mllm_predict_dict != "llm_only":
|
|
|
|
|
+ if mllm_integration_strategy == "integration":
|
|
|
|
|
+ final_predict_dict = self.ensemble_ocr_llm_mllm(
|
|
|
|
|
+ key_list_ori, final_results, mllm_predict_dict
|
|
|
|
|
+ )
|
|
|
|
|
+ elif mllm_integration_strategy == "mllm_only":
|
|
|
|
|
+ final_predict_dict = mllm_predict_dict
|
|
|
|
|
+ else:
|
|
|
|
|
+ return {
|
|
|
|
|
+ "chat_res": f"Error:Unsupported mllm_integration_strategy {mllm_integration_strategy}, only support 'integration', 'llm_only' and 'mllm_only'!"
|
|
|
|
|
+ }
|
|
|
else:
|
|
else:
|
|
|
final_predict_dict = final_results
|
|
final_predict_dict = final_results
|
|
|
return {"chat_res": final_predict_dict}
|
|
return {"chat_res": final_predict_dict}
|