Просмотр исходного кода

add reasoning_content to chat results

zhouchangda 9 месяцев назад
Родитель
Сommit
e751b04ecf

+ 11 - 4
paddlex/inference/pipelines/components/chat_server/openai_bot_chat.py

@@ -92,6 +92,7 @@ class OpenAIBotChat(BaseChat):
         Returns:
             Dict: The chat completion result from the model.
         """
+        llm_result = {"content": None, "reasoning_content": None}
         try:
             if image:
                 chat_completion = self.client.chat.completions.create(
@@ -119,7 +120,7 @@ class OpenAIBotChat(BaseChat):
                     temperature=temperature,
                     top_p=0.001,
                 )
-                llm_result = chat_completion.choices[0].message.content
+                llm_result["content"] = chat_completion.choices[0].message.content
                 return llm_result
             elif self.config.get("end_point", "chat_completion") == "chat_completion":
                 chat_completion = self.client.chat.completions.create(
@@ -134,7 +135,13 @@ class OpenAIBotChat(BaseChat):
                     temperature=temperature,
                     top_p=0.001,
                 )
-                llm_result = chat_completion.choices[0].message.content
+                llm_result["content"] = chat_completion.choices[0].message.content
+                try:
+                    llm_result["reasoning_content"] = chat_completion.choices[
+                        0
+                    ].message.reasoning_content
+                except:
+                    pass
                 return llm_result
             else:
                 chat_completion = self.client.completions.create(
@@ -148,12 +155,12 @@ class OpenAIBotChat(BaseChat):
                     chat_completion = json.loads(chat_completion)
                     llm_result = chat_completion["choices"][0]["text"]
                 else:
-                    llm_result = chat_completion.choices[0].text
+                    llm_result["content"] = chat_completion.choices[0].text
                 return llm_result
         except Exception as e:
             logging.error(e)
             self.ERROR_MASSAGE = "大模型调用失败"
-        return None
+        return llm_result
 
     def fix_llm_result_format(self, llm_result: str) -> dict:
         """

+ 13 - 5
paddlex/inference/pipelines/pp_chatocr/pipeline_v3.py

@@ -502,16 +502,25 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
         """
 
         llm_result = chat_bot.generate_chat_results(prompt)
-        if llm_result is None:
+        llm_result_content = llm_result["content"]
+        llm_result_reasoning_content = llm_result["reasoning_content"]
+
+        if llm_result_reasoning_content is not None:
+            if "reasoning_content" not in final_results:
+                final_results["reasoning_content"] = [llm_result_reasoning_content]
+            else:
+                final_results["reasoning_content"].append(llm_result_reasoning_content)
+
+        if llm_result_content is None:
             logging.error(
                 "chat bot error: \n [prompt:]\n %s\n [result:] %s\n"
-                % (prompt, self.chat_bot.ERROR_MASSAGE)
+                % (prompt, chat_bot.ERROR_MASSAGE)
             )
             return
 
-        llm_result = self.chat_bot.fix_llm_result_format(llm_result)
+        llm_result_content = chat_bot.fix_llm_result_format(llm_result_content)
 
-        for key, value in llm_result.items():
+        for key, value in llm_result_content.items():
             if value not in failed_results and key in key_list:
                 key_list.remove(key)
                 final_results[key] = value
@@ -629,7 +638,6 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
         """
 
         key_list = self.format_key(key_list)
-        key_list_ori = key_list.copy()
         if len(key_list) == 0:
             return {"chat_res": "Error:输入的key_list无效!"}
 

+ 38 - 12
paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py

@@ -582,7 +582,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
                 )
                 mllm_chat_bot_result = mllm_chat_bot.generate_chat_results(
                     prompt=prompt, image=image_base64
-                )
+                )["content"]
                 if mllm_chat_bot_result is None:
                     return {"mllm_res": "大模型调用失败"}
                 result[key] = mllm_chat_bot_result
@@ -610,16 +610,25 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
         """
 
         llm_result = chat_bot.generate_chat_results(prompt)
-        if llm_result is None:
+        llm_result_content = llm_result["content"]
+        llm_result_reasoning_content = llm_result["reasoning_content"]
+
+        if llm_result_reasoning_content is not None:
+            if "reasoning_content" not in final_results:
+                final_results["reasoning_content"] = [llm_result_reasoning_content]
+            else:
+                final_results["reasoning_content"].append(llm_result_reasoning_content)
+
+        if llm_result_content is None:
             logging.error(
                 "chat bot error: \n [prompt:]\n %s\n [result:] %s\n"
-                % (prompt, self.chat_bot.ERROR_MASSAGE)
+                % (prompt, chat_bot.ERROR_MASSAGE)
             )
             return
 
-        llm_result = self.chat_bot.fix_llm_result_format(llm_result)
+        llm_result_content = chat_bot.fix_llm_result_format(llm_result_content)
 
-        for key, value in llm_result.items():
+        for key, value in llm_result_content.items():
             if value not in failed_results and key in key_list:
                 key_list.remove(key)
                 final_results[key] = value
@@ -692,7 +701,11 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
         return related_text
 
     def ensemble_ocr_llm_mllm(
-        self, key_list: List[str], ocr_llm_predict_dict: dict, mllm_predict_dict: dict
+        self,
+        chat_bot: BaseChat,
+        key_list: List[str],
+        ocr_llm_predict_dict: dict,
+        mllm_predict_dict: dict,
     ) -> dict:
         """
         Ensemble OCR_LLM and LMM predictions based on given key list.
@@ -719,11 +732,24 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
                 prompt = self.ensemble_pe.generate_prompt(
                     key, ocr_llm_predict, mllm_predict
                 )
-                llm_result = self.chat_bot.generate_chat_results(prompt)
-                if llm_result is not None:
-                    llm_result = self.chat_bot.fix_llm_result_format(llm_result)
-                if key in llm_result:
-                    tmp = llm_result[key]
+                llm_result = chat_bot.generate_chat_results(prompt)
+                llm_result_content = llm_result["content"]
+                llm_result_reasoning_content = llm_result["reasoning_content"]
+                if llm_result_reasoning_content is not None:
+                    if "reasoning_content" not in final_predict_dict:
+                        final_predict_dict["reasoning_content"] = [
+                            llm_result_reasoning_content
+                        ]
+                    else:
+                        final_predict_dict["reasoning_content"].append(
+                            llm_result_reasoning_content
+                        )
+                if llm_result_content is not None:
+                    llm_result_content = chat_bot.fix_llm_result_format(
+                        llm_result_content
+                    )
+                if key in llm_result_content:
+                    tmp = llm_result_content[key]
                     if "B" in tmp:
                         predict = mllm_predict
                     else:
@@ -884,7 +910,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
         ):
             if mllm_integration_strategy == "integration":
                 final_predict_dict = self.ensemble_ocr_llm_mllm(
-                    key_list_ori, final_results, mllm_predict_info
+                    chat_bot, key_list_ori, final_results, mllm_predict_info
                 )
             elif mllm_integration_strategy == "mllm_only":
                 final_predict_dict = mllm_predict_info