|
|
@@ -582,7 +582,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
)
|
|
|
mllm_chat_bot_result = mllm_chat_bot.generate_chat_results(
|
|
|
prompt=prompt, image=image_base64
|
|
|
- )
|
|
|
+ )["content"]
|
|
|
if mllm_chat_bot_result is None:
|
|
|
return {"mllm_res": "大模型调用失败"}
|
|
|
result[key] = mllm_chat_bot_result
|
|
|
@@ -610,16 +610,25 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
"""
|
|
|
|
|
|
llm_result = chat_bot.generate_chat_results(prompt)
|
|
|
- if llm_result is None:
|
|
|
+ llm_result_content = llm_result["content"]
|
|
|
+ llm_result_reasoning_content = llm_result["reasoning_content"]
|
|
|
+
|
|
|
+ if llm_result_reasoning_content is not None:
|
|
|
+ if "reasoning_content" not in final_results:
|
|
|
+ final_results["reasoning_content"] = [llm_result_reasoning_content]
|
|
|
+ else:
|
|
|
+ final_results["reasoning_content"].append(llm_result_reasoning_content)
|
|
|
+
|
|
|
+ if llm_result_content is None:
|
|
|
logging.error(
|
|
|
"chat bot error: \n [prompt:]\n %s\n [result:] %s\n"
|
|
|
- % (prompt, self.chat_bot.ERROR_MASSAGE)
|
|
|
+ % (prompt, chat_bot.ERROR_MASSAGE)
|
|
|
)
|
|
|
return
|
|
|
|
|
|
- llm_result = self.chat_bot.fix_llm_result_format(llm_result)
|
|
|
+ llm_result_content = chat_bot.fix_llm_result_format(llm_result_content)
|
|
|
|
|
|
- for key, value in llm_result.items():
|
|
|
+ for key, value in llm_result_content.items():
|
|
|
if value not in failed_results and key in key_list:
|
|
|
key_list.remove(key)
|
|
|
final_results[key] = value
|
|
|
@@ -692,7 +701,11 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
return related_text
|
|
|
|
|
|
def ensemble_ocr_llm_mllm(
|
|
|
- self, key_list: List[str], ocr_llm_predict_dict: dict, mllm_predict_dict: dict
|
|
|
+ self,
|
|
|
+ chat_bot: BaseChat,
|
|
|
+ key_list: List[str],
|
|
|
+ ocr_llm_predict_dict: dict,
|
|
|
+ mllm_predict_dict: dict,
|
|
|
) -> dict:
|
|
|
"""
|
|
|
Ensemble OCR_LLM and LMM predictions based on given key list.
|
|
|
@@ -719,11 +732,24 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
prompt = self.ensemble_pe.generate_prompt(
|
|
|
key, ocr_llm_predict, mllm_predict
|
|
|
)
|
|
|
- llm_result = self.chat_bot.generate_chat_results(prompt)
|
|
|
- if llm_result is not None:
|
|
|
- llm_result = self.chat_bot.fix_llm_result_format(llm_result)
|
|
|
- if key in llm_result:
|
|
|
- tmp = llm_result[key]
|
|
|
+ llm_result = chat_bot.generate_chat_results(prompt)
|
|
|
+ llm_result_content = llm_result["content"]
|
|
|
+ llm_result_reasoning_content = llm_result["reasoning_content"]
|
|
|
+ if llm_result_reasoning_content is not None:
|
|
|
+ if "reasoning_content" not in final_predict_dict:
|
|
|
+ final_predict_dict["reasoning_content"] = [
|
|
|
+ llm_result_reasoning_content
|
|
|
+ ]
|
|
|
+ else:
|
|
|
+ final_predict_dict["reasoning_content"].append(
|
|
|
+ llm_result_reasoning_content
|
|
|
+ )
|
|
|
+ if llm_result_content is not None:
|
|
|
+ llm_result_content = chat_bot.fix_llm_result_format(
|
|
|
+ llm_result_content
|
|
|
+ )
|
|
|
+ if key in llm_result_content:
|
|
|
+ tmp = llm_result_content[key]
|
|
|
if "B" in tmp:
|
|
|
predict = mllm_predict
|
|
|
else:
|
|
|
@@ -884,7 +910,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
|
|
|
):
|
|
|
if mllm_integration_strategy == "integration":
|
|
|
final_predict_dict = self.ensemble_ocr_llm_mllm(
|
|
|
- key_list_ori, final_results, mllm_predict_info
|
|
|
+ chat_bot, key_list_ori, final_results, mllm_predict_info
|
|
|
)
|
|
|
elif mllm_integration_strategy == "mllm_only":
|
|
|
final_predict_dict = mllm_predict_info
|