瀏覽代碼

support openai API chatbot (#2675)

changdazhou 10 月之前
父節點
當前提交
ee6533aac9

+ 7 - 0
paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml

@@ -18,6 +18,13 @@ SubModules:
     ak: "api_key" # Set this to a real API key
     sk: "secret_key"  # Set this to a real secret key
 
+  MLLM_Chat:
+    module_name: chat_bot
+    model_name: PP-DocBee
+    base_url: "http://127.0.0.1/v1/chat/completions"
+    api_type: openai
+    api_key: "api_key"
+
   PromptEngneering:
     KIE_CommonText:
       module_name: prompt_engneering

+ 1 - 1
paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py

@@ -547,7 +547,7 @@ class PPChatOCRPipeline(_TableRecPipeline):
                 logging.debug(prompt)
                 res = self.get_llm_result(llm_api, prompt)
                 # TODO: why use one html but the whole table_text in next step
-                if list(res.values())[0] in failed_results:
+                if not res or list(res.values())[0] in failed_results:
                     logging.debug(
                         "table html sequence is too much longer, using ocr directly!"
                     )

+ 4 - 4
paddlex/inference/pipelines_new/__init__.py

@@ -135,8 +135,8 @@ def create_chat_bot(config: Dict, *args, **kwargs) -> BaseChat:
     Returns:
         BaseChat: An instance of the chat bot class corresponding to the 'model_name' in the config.
     """
-    model_name = config["model_name"]
-    chat_bot = BaseChat.get(model_name)(config)
+    api_type = config["api_type"]
+    chat_bot = BaseChat.get(api_type)(config)
     return chat_bot
 
 
@@ -156,8 +156,8 @@ def create_retriever(
     Returns:
         BaseRetriever: An instance of a retriever class corresponding to the 'model_name' in the config.
     """
-    model_name = config["model_name"]
-    retriever = BaseRetriever.get(model_name)(config)
+    api_type = config["api_type"]
+    retriever = BaseRetriever.get(api_type)(config)
     return retriever
 
 

+ 1 - 0
paddlex/inference/pipelines_new/components/chat_server/__init__.py

@@ -13,3 +13,4 @@
 # limitations under the License.
 
 from .ernie_bot_chat import ErnieBotChat
+from .openai_bot_chat import OpenAIBotChat

+ 33 - 3
paddlex/inference/pipelines_new/components/chat_server/ernie_bot_chat.py

@@ -12,10 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Dict
 import re
 import json
 import erniebot
+from typing import Dict
 from .....utils import logging
 from .base import BaseChat
 
@@ -24,6 +24,11 @@ class ErnieBotChat(BaseChat):
     """Ernie Bot Chat"""
 
     entities = [
+        "aistudio",
+        "qianfan",
+    ]
+
+    MODELS = [
         "ernie-4.0",
         "ernie-3.5",
         "ernie-3.5-8k",
@@ -53,8 +58,8 @@ class ErnieBotChat(BaseChat):
         sk = config.get("sk", None)
         access_token = config.get("access_token", None)
 
-        if model_name not in self.entities:
-            raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.")
+        if model_name not in self.MODELS:
+            raise ValueError(f"model_name must be in {self.MODELS} of ErnieBotChat.")
 
         if api_type not in ["aistudio", "qianfan"]:
             raise ValueError("api_type must be one of ['aistudio', 'qianfan']")
@@ -127,6 +132,12 @@ class ErnieBotChat(BaseChat):
             return {}
 
         if "json" in llm_result or "```" in llm_result:
+            index = llm_result.find("{")
+            if index != -1:
+                llm_result = llm_result[index:]
+            index = llm_result.rfind("}")
+            if index != -1:
+                llm_result = llm_result[: index + 1]
             llm_result = (
                 llm_result.replace("```", "").replace("json", "").replace("/n", "")
             )
@@ -135,6 +146,15 @@ class ErnieBotChat(BaseChat):
         try:
             llm_result = json.loads(llm_result)
             llm_result_final = {}
+            if "问题" in llm_result.keys() and "答案" in llm_result.keys():
+                key = llm_result["问题"]
+                value = llm_result["答案"]
+                if isinstance(value, list):
+                    if len(value) > 0:
+                        llm_result_final[key] = value[0].strip(f"{key}:").strip(key)
+                else:
+                    llm_result_final[key] = value.strip(f"{key}:").strip(key)
+                return llm_result_final
             for key in llm_result:
                 value = llm_result[key]
                 if isinstance(value, list):
@@ -157,6 +177,16 @@ class ErnieBotChat(BaseChat):
             matches = re.findall(pattern, str(results))
             if len(matches) > 0:
                 llm_result = {k: v for k, v in matches}
+                if "问题" in llm_result.keys() and "答案" in llm_result.keys():
+                    llm_result_final = {}
+                    key = llm_result["问题"]
+                    value = llm_result["答案"]
+                    if isinstance(value, list):
+                        if len(value) > 0:
+                            llm_result_final[key] = value[0].strip(f"{key}:").strip(key)
+                    else:
+                        llm_result_final[key] = value.strip(f"{key}:").strip(key)
+                    return llm_result_final
                 return llm_result
             else:
                 return {}

+ 204 - 0
paddlex/inference/pipelines_new/components/chat_server/openai_bot_chat.py

@@ -0,0 +1,204 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+import json
+import base64
+from typing import Dict
+from .....utils import logging
+from .base import BaseChat
+
+
+class OpenAIBotChat(BaseChat):
+    """OpenAI Bot Chat"""
+
+    entities = [
+        "openai",
+    ]
+
+    def __init__(self, config: Dict) -> None:
+        """Initializes the OpenAIBotChat with given configuration.
+
+        Args:
+            config (Dict): Configuration dictionary containing model_name, api_type, base_url, api_key.
+
+        Raises:
+            ValueError: If api_type is not one of ['openai'],
+            base_url is None for api_type is openai,
+            api_key is None for api_type is openai.
+        """
+        super().__init__()
+        model_name = config.get("model_name", None)
+        api_type = config.get("api_type", None)
+        api_key = config.get("api_key", None)
+        base_url = config.get("base_url", None)
+
+        if api_type not in ["openai"]:
+            raise ValueError("api_type must be one of ['openai']")
+
+        if api_type == "openai" and api_key is None:
+            raise ValueError("api_key cannot be empty when api_type is openai.")
+
+        if base_url is None:
+            raise ValueError("base_url cannot be empty when api_type is openai.")
+
+        try:
+            from openai import OpenAI
+        except:
+            raise Exception("openai is not installed, please install it first.")
+
+        self.client = OpenAI(base_url=base_url, api_key=api_key)
+
+        self.model_name = model_name
+        self.config = config
+
+    def generate_chat_results(
+        self,
+        prompt: str,
+        image: base64 = None,
+        temperature: float = 0.001,
+        max_retries: int = 1,
+    ) -> Dict:
+        """
+        Generate chat results using the specified model and configuration.
+
+        Args:
+            prompt (str): The user's input prompt.
+            image (base64): The user's input image for MLLM, defaults to None.
+            temperature (float, optional): The temperature parameter for llms, defaults to 0.001.
+            max_retries (int, optional): The maximum number of retries for llms API calls, defaults to 1.
+
+        Returns:
+            Dict: The chat completion result from the model.
+        """
+        try:
+            if image:
+                chat_completion = self.client.chat.completions.create(
+                    model=self.model_name,
+                    messages=[
+                        {
+                            "role": "system",
+                            # XXX: give a basic prompt for common
+                            "content": "You are a helpful assistant.",
+                        },
+                        {
+                            "role": "user",
+                            "content": [
+                                {"type": "text", "text": prompt},
+                                {
+                                    "type": "image_url",
+                                    "image_url": {
+                                        "url": f"data:image/jpeg;base64,{image}"
+                                    },
+                                },
+                            ],
+                        },
+                    ],
+                    stream=False,
+                    temperature=temperature,
+                    top_p=0.001,
+                )
+                llm_result = chat_completion.choices[0].message.content
+                return llm_result
+            else:
+                chat_completion = self.client.completions.create(
+                    model=self.model_name,
+                    prompt=prompt,
+                    max_tokens=self.config.get("max_tokens", 1024),
+                    temperature=float(temperature),
+                    stream=False,
+                )
+                if isinstance(chat_completion, str):
+                    chat_completion = json.loads(chat_completion)
+                    llm_result = chat_completion["choices"][0]["text"]
+                else:
+                    llm_result = chat_completion.choices[0].text
+                return llm_result
+        except Exception as e:
+            logging.error(e)
+            self.ERROR_MASSAGE = "大模型调用失败"
+        return None
+
+    def fix_llm_result_format(self, llm_result: str) -> dict:
+        """
+        Fix the format of the LLM result.
+
+        Args:
+            llm_result (str): The result from the LLM (Large Language Model).
+
+        Returns:
+            dict: A fixed format dictionary from the LLM result.
+        """
+        if not llm_result:
+            return {}
+
+        if "json" in llm_result or "```" in llm_result:
+            index = llm_result.find("{")
+            if index != -1:
+                llm_result = llm_result[index:]
+            index = llm_result.rfind("}")
+            if index != -1:
+                llm_result = llm_result[: index + 1]
+            llm_result = (
+                llm_result.replace("```", "").replace("json", "").replace("/n", "")
+            )
+            llm_result = llm_result.replace("[", "").replace("]", "")
+
+        try:
+            llm_result = json.loads(llm_result)
+            llm_result_final = {}
+            if "问题" in llm_result.keys() and "答案" in llm_result.keys():
+                key = llm_result["问题"]
+                value = llm_result["答案"]
+                if isinstance(value, list):
+                    if len(value) > 0:
+                        llm_result_final[key] = value[0].strip(f"{key}:").strip(key)
+                else:
+                    llm_result_final[key] = value.strip(f"{key}:").strip(key)
+                return llm_result_final
+            for key in llm_result:
+                value = llm_result[key]
+                if isinstance(value, list):
+                    if len(value) > 0:
+                        llm_result_final[key] = value[0]
+                else:
+                    llm_result_final[key] = value
+            return llm_result_final
+
+        except:
+            results = (
+                llm_result.replace("\n", "")
+                .replace("    ", "")
+                .replace("{", "")
+                .replace("}", "")
+            )
+            if not results.endswith('"'):
+                results = results + '"'
+            pattern = r'"(.*?)": "([^"]*)"'
+            matches = re.findall(pattern, str(results))
+            if len(matches) > 0:
+                llm_result = {k: v for k, v in matches}
+                if "问题" in llm_result.keys() and "答案" in llm_result.keys():
+                    llm_result_final = {}
+                    key = llm_result["问题"]
+                    value = llm_result["答案"]
+                    if isinstance(value, list):
+                        if len(value) > 0:
+                            llm_result_final[key] = value[0].strip(f"{key}:").strip(key)
+                    else:
+                        llm_result_final[key] = value.strip(f"{key}:").strip(key)
+                    return llm_result_final
+                return llm_result
+            else:
+                return {}

+ 1 - 0
paddlex/inference/pipelines_new/components/retriever/__init__.py

@@ -13,3 +13,4 @@
 # limitations under the License.
 
 from .ernie_bot_retriever import ErnieBotRetriever
+from .openai_bot_retriever import OpenAIBotRetriever

+ 8 - 3
paddlex/inference/pipelines_new/components/retriever/ernie_bot_retriever.py

@@ -28,6 +28,11 @@ class ErnieBotRetriever(BaseRetriever):
     """Ernie Bot Retriever"""
 
     entities = [
+        "aistudio",
+        "qianfan",
+    ]
+
+    MODELS = [
         "ernie-4.0",
         "ernie-3.5",
         "ernie-3.5-8k",
@@ -45,7 +50,7 @@ class ErnieBotRetriever(BaseRetriever):
         Args:
             config (Dict): A dictionary containing configuration settings.
                 - model_name (str): The name of the model to use.
-                - api_type (str): The type of API to use ('aistudio' or 'qianfan').
+                - api_type (str): The type of API to use ('aistudio', 'qianfan' or 'openai').
                 - ak (str, optional): The access key for 'qianfan' API.
                 - sk (str, optional): The secret key for 'qianfan' API.
                 - access_token (str, optional): The access token for 'aistudio' API.
@@ -64,8 +69,8 @@ class ErnieBotRetriever(BaseRetriever):
         sk = config.get("sk", None)
         access_token = config.get("access_token", None)
 
-        if model_name not in self.entities:
-            raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.")
+        if model_name not in self.MODELS:
+            raise ValueError(f"model_name must be in {self.MODELS} of ErnieBotChat.")
 
         if api_type not in ["aistudio", "qianfan"]:
             raise ValueError("api_type must be one of ['aistudio', 'qianfan']")

+ 181 - 0
paddlex/inference/pipelines_new/components/retriever/openai_bot_retriever.py

@@ -0,0 +1,181 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .base import BaseRetriever
+
+from langchain.docstore.document import Document
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain_community.vectorstores import FAISS
+from langchain_community import vectorstores
+
+import time
+
+from typing import Dict
+
+
+class OpenAIBotRetriever(BaseRetriever):
+    """OpenAI Bot Retriever"""
+
+    entities = [
+        "openai",
+    ]
+
+    def __init__(self, config: Dict) -> None:
+        """
+        Initializes the OpenAIBotRetriever instance with the provided configuration.
+
+        Args:
+            config (Dict): A dictionary containing configuration settings.
+                - model_name (str): The name of the model to use.
+                - api_type (str): The type of API to use ('aistudio', 'qianfan' or 'openai').
+                - api_key (str, optional): The API key for 'openai' API.
+                - base_url (str, optional): The base URL for 'openai' API.
+
+        Raises:
+            ValueError: If api_type is not one of ['openai'],
+            base_url is None for api_type is openai,
+            api_key is None for api_type is openai.
+        """
+        super().__init__()
+
+        model_name = config.get("model_name", None)
+        api_type = config.get("api_type", None)
+        api_key = config.get("api_key", None)
+        base_url = config.get("base_url", None)
+        tiktoken_enabled = config.get("tiktoken_enabled", False)
+
+        if api_type not in ["openai"]:
+            raise ValueError("api_type must be one of ['openai']")
+
+        if api_type == "openai" and api_key is None:
+            raise ValueError("api_key cannot be empty when api_type is openai.")
+
+        if base_url is None:
+            raise ValueError("base_url cannot be empty when api_type is openai.")
+
+        try:
+            from langchain_openai import OpenAIEmbeddings
+        except:
+            raise Exception(
+                "langchain-openai is not installed, please install it first."
+            )
+
+        self.embedding = OpenAIEmbeddings(
+            model=model_name,
+            api_key=api_key,
+            base_url=base_url,
+            tiktoken_enabled=tiktoken_enabled,
+        )
+
+        self.model_name = model_name
+        self.config = config
+
+    # Generates a vector database from a list of texts using different embeddings based on the configured API type.
+
+    def generate_vector_database(
+        self,
+        text_list: list[str],
+        block_size: int = 300,
+        separators: list[str] = ["\t", "\n", "。", "\n\n", ""],
+        sleep_time: float = 0.5,
+    ) -> FAISS:
+        """
+        Generates a vector database from a list of texts.
+
+        Args:
+            text_list (list[str]): A list of texts to generate the vector database from.
+            block_size (int): The size of each chunk to split the text into.
+            separators (list[str]): A list of separators to use when splitting the text.
+            sleep_time (float): The time to sleep between embedding generations to avoid rate limiting.
+
+        Returns:
+            FAISS: The generated vector database.
+
+        Raises:
+            ValueError: If an unsupported API type is configured.
+        """
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=block_size, chunk_overlap=20, separators=separators
+        )
+        texts = text_splitter.split_text("\t".join(text_list))
+        all_splits = [Document(page_content=text) for text in texts]
+
+        api_type = self.config["api_type"]
+
+        vectorstore = FAISS.from_documents(
+            documents=all_splits, embedding=self.embedding
+        )
+
+        return vectorstore
+
+    def encode_vector_store_to_bytes(self, vectorstore: FAISS) -> str:
+        """
+        Encode the vector store serialized to bytes.
+
+        Args:
+            vectorstore (FAISS): The vector store to be serialized and encoded.
+
+        Returns:
+            str: The encoded vector store.
+        """
+        vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes())
+        return vectorstore
+
+    def decode_vector_store_from_bytes(self, vectorstore: str) -> FAISS:
+        """
+        Decode a vector store from bytes according to the specified API type.
+
+        Args:
+            vectorstore (str): The serialized vector store string.
+
+        Returns:
+            FAISS: Deserialized vector store object.
+
+        Raises:
+            ValueError: If the retrieved vector store is not for PaddleX
+            or if an unsupported API type is specified.
+        """
+        if not self.is_vector_store(vectorstore):
+            raise ValueError("The retrieved vectorstore is not for PaddleX.")
+
+        vector = vectorstores.FAISS.deserialize_from_bytes(
+            self.decode_vector_store(vectorstore), self.embedding
+        )
+        return vector
+
+    def similarity_retrieval(
+        self, query_text_list: list[str], vectorstore: FAISS, sleep_time: float = 0.5
+    ) -> str:
+        """
+        Retrieve similar contexts based on a list of query texts.
+
+        Args:
+            query_text_list (list[str]): A list of query texts to search for similar contexts.
+            vectorstore (FAISS): The vector store where to perform the similarity search.
+            sleep_time (float): The time to sleep between each query, in seconds. Default is 0.5.
+
+        Returns:
+            str: A concatenated string of all unique contexts found.
+        """
+        C = []
+        for query_text in query_text_list:
+            QUESTION = query_text
+            time.sleep(sleep_time)
+            docs = vectorstore.similarity_search_with_relevance_scores(QUESTION, k=2)
+            context = [(document.page_content, score) for document, score in docs]
+            context = sorted(context, key=lambda x: x[1])
+            C.extend([x[0] for x in context[::-1]])
+        C = list(set(C))
+        all_C = " ".join(C)
+        return all_C

+ 1 - 1
paddlex/inference/pipelines_new/pp_chatocr/pipeline_v3.py

@@ -480,7 +480,7 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
         key_list = self.format_key(key_list)
         key_list_ori = key_list.copy()
         if len(key_list) == 0:
-            return {"error": "输入的key_list无效!"}
+            return {"chat_res": "Error:输入的key_list无效!"}
 
         if not isinstance(visual_info, list):
             visual_info_list = [visual_info]

+ 59 - 5
paddlex/inference/pipelines_new/pp_chatocr/pipeline_v4.py

@@ -14,7 +14,9 @@
 
 from typing import Any, Dict, Optional
 import re
+import cv2
 import json
+import base64
 import numpy as np
 import copy
 from .pipeline_base import PP_ChatOCR_Pipeline
@@ -99,6 +101,8 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
         if "use_mllm_predict" in config:
             self.use_mllm_predict = config["use_mllm_predict"]
         if self.use_mllm_predict:
+            mllm_chat_bot_config = config["SubModules"]["MLLM_Chat"]
+            self.mllm_chat_bot = create_chat_bot(mllm_chat_bot_config)
             ensemble_pe_config = config["SubModules"]["PromptEngneering"]["Ensemble"]
             self.ensemble_pe = create_prompt_engeering(ensemble_pe_config)
         return
@@ -380,6 +384,47 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
 
         return []
 
+    def mllm_pred(
+        self,
+        input: str | np.ndarray,
+        key_list,
+        **kwargs,
+    ) -> dict:
+        key_list = self.format_key(key_list)
+        if len(key_list) == 0:
+            return {"mllm_res": "Error:输入的key_list无效!"}
+
+        if isinstance(input, list):
+            logging.error("Input is a list, but it's not supported here.")
+            return {"mllm_res": "Error:Input is a list, but it's not supported here!"}
+        image_array_list = self.img_reader([input])
+        if (
+            isinstance(input, str)
+            and input.endswith(".pdf")
+            and len(image_array_list) > 1
+        ):
+            logging.error("The input with PDF should have only one page.")
+            return {"mllm_res": "Error:The input with PDF should have only one page!"}
+
+        for image_array in image_array_list:
+
+            assert len(image_array.shape) == 3
+            image_string = cv2.imencode(".jpg", image_array)[1].tostring()
+            image_base64 = base64.b64encode(image_string).decode("utf-8")
+            result = {}
+            for key in key_list:
+                prompt = (
+                    str(key)
+                    + "\n请用图片中完整出现的内容回答,可以是单词、短语或句子,针对问题回答尽可能详细和完整,并保持格式、单位、符号和标点都与图片中的文字内容完全一致。"
+                )
+                mllm_chat_bot_result = self.mllm_chat_bot.generate_chat_results(
+                    prompt=prompt, image=image_base64
+                )
+                if mllm_chat_bot_result is None:
+                    return {"mllm_res": "大模型调用失败"}
+                result[key] = mllm_chat_bot_result
+            return {"mllm_res": result}
+
     def generate_and_merge_chat_results(
         self, prompt: str, key_list: list, final_results: dict, failed_results: list
     ) -> None:
@@ -524,6 +569,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
         table_few_shot_demo_text_content: str = None,
         table_few_shot_demo_key_value_list: str = None,
         mllm_predict_dict: dict = None,
+        mllm_integration_strategy: str = "integration",
     ) -> dict:
         """
         Generates chat results based on the provided key list and visual information.
@@ -545,6 +591,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
             table_few_shot_demo_text_content (str): The text content for table few-shot demos.
             table_few_shot_demo_key_value_list (str): The key-value list for table few-shot demos.
             mllm_predict_dict (dict): The dictionary of mLLM predicts.
+            mllm_integration_strategy(str): The integration strategy of mLLM and LLM, defaults to "integration", options are "integration", "llm_only" and "mllm_only".
         Returns:
             dict: A dictionary containing the chat results.
         """
@@ -552,7 +599,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
         key_list = self.format_key(key_list)
         key_list_ori = key_list.copy()
         if len(key_list) == 0:
-            return {"error": "输入的key_list无效!"}
+            return {"chat_res": "Error:输入的key_list无效!"}
 
         if not isinstance(visual_info, list):
             visual_info_list = [visual_info]
@@ -620,10 +667,17 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
                                 prompt, key_list, final_results, failed_results
                             )
 
-        if self.use_mllm_predict:
-            final_predict_dict = self.ensemble_ocr_llm_mllm(
-                key_list_ori, final_results, mllm_predict_dict
-            )
+        if self.use_mllm_predict and mllm_predict_dict != "llm_only":
+            if mllm_integration_strategy == "integration":
+                final_predict_dict = self.ensemble_ocr_llm_mllm(
+                    key_list_ori, final_results, mllm_predict_dict
+                )
+            elif mllm_integration_strategy == "mllm_only":
+                final_predict_dict = mllm_predict_dict
+            else:
+                return {
+                    "chat_res": f"Error:Unsupported mllm_integration_strategy {mllm_integration_strategy}, only support 'integration', 'llm_only' and 'mllm_only'!"
+                }
         else:
             final_predict_dict = final_results
         return {"chat_res": final_predict_dict}