Răsfoiți Sursa

update qianfan embedding API to V2

zhouchangda 9 luni în urmă
părinte
comite
58bbc91e68

+ 7 - 7
paddlex/configs/pipelines/PP-ChatOCRv3-doc.yaml

@@ -6,17 +6,17 @@ use_layout_parser: True
 SubModules:
   LLM_Chat:
     module_name: chat_bot
-    model_name: ernie-3.5
-    api_type: qianfan
-    ak: "api_key" # Set this to a real API key
-    sk: "secret_key"  # Set this to a real secret key
+    model_name: ernie-3.5-8k
+    base_url: "https://qianfan.baidubce.com/v2"
+    api_type: openai
+    api_key: "api_key" # Set this to a real API key
 
   LLM_Retriever:
     module_name: retriever
-    model_name: ernie-3.5
+    model_name: embedding-v1
+    base_url: "https://qianfan.baidubce.com/v2"
     api_type: qianfan
-    ak: "api_key" # Set this to a real API key
-    sk: "secret_key"  # Set this to a real secret key
+    api_key: "api_key" # Set this to a real API key
 
 
   PromptEngneering:

+ 8 - 8
paddlex/configs/pipelines/PP-ChatOCRv4-doc.yaml

@@ -8,22 +8,22 @@ use_mllm_predict: True
 SubModules:
   LLM_Chat:
     module_name: chat_bot
-    model_name: ernie-3.5
-    api_type: qianfan
-    ak: "api_key" # Set this to a real API key
-    sk: "secret_key"  # Set this to a real secret key
+    model_name: ernie-3.5-8k
+    base_url: "https://qianfan.baidubce.com/v2"
+    api_type: openai
+    api_key: "api_key" # Set this to a real API key
 
   LLM_Retriever:
     module_name: retriever
-    model_name: ernie-3.5
+    model_name: embedding-v1
+    base_url: "https://qianfan.baidubce.com/v2"
     api_type: qianfan
-    ak: "api_key" # Set this to a real API key
-    sk: "secret_key"  # Set this to a real secret key
+    api_key: "api_key" # Set this to a real API key
 
   MLLM_Chat:
     module_name: chat_bot
     model_name: PP-DocBee
-    base_url: "http://127.0.0.1/v1/chat/completions"
+    base_url: "http://127.0.0.1:8080/v1/chat/completions"
     api_type: openai
     api_key: "api_key"
 

+ 0 - 1
paddlex/inference/pipelines/components/chat_server/__init__.py

@@ -13,5 +13,4 @@
 # limitations under the License.
 
 from .base import BaseChat
-from .ernie_bot_chat import ErnieBotChat
 from .openai_bot_chat import OpenAIBotChat

+ 0 - 192
paddlex/inference/pipelines/components/chat_server/ernie_bot_chat.py

@@ -1,192 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import re
-import json
-import erniebot
-from typing import Dict
-from .....utils import logging
-from .base import BaseChat
-
-
-class ErnieBotChat(BaseChat):
-    """Ernie Bot Chat"""
-
-    entities = [
-        "aistudio",
-        "qianfan",
-    ]
-
-    MODELS = [
-        "ernie-4.0",
-        "ernie-3.5",
-        "ernie-3.5-8k",
-        "ernie-lite",
-        "ernie-tiny-8k",
-        "ernie-speed",
-        "ernie-speed-128k",
-        "ernie-char-8k",
-    ]
-
-    def __init__(self, config: Dict) -> None:
-        """Initializes the ErnieBotChat with given configuration.
-
-        Args:
-            config (Dict): Configuration dictionary containing model_name, api_type, ak, sk, and access_token.
-
-        Raises:
-            ValueError: If model_name is not in the predefined entities,
-            api_type is not one of ['aistudio', 'qianfan'],
-            access_token is None for 'aistudio' api_type,
-            or ak and sk are None for 'qianfan' api_type.
-        """
-        super().__init__()
-        model_name = config.get("model_name", None)
-        api_type = config.get("api_type", None)
-        ak = config.get("ak", None)
-        sk = config.get("sk", None)
-        access_token = config.get("access_token", None)
-
-        if model_name not in self.MODELS:
-            raise ValueError(f"model_name must be in {self.MODELS} of ErnieBotChat.")
-
-        if api_type not in ["aistudio", "qianfan"]:
-            raise ValueError("api_type must be one of ['aistudio', 'qianfan']")
-
-        if api_type == "aistudio" and access_token is None:
-            raise ValueError("access_token cannot be empty when api_type is aistudio.")
-
-        if api_type == "qianfan" and (ak is None or sk is None):
-            raise ValueError("ak and sk cannot be empty when api_type is qianfan.")
-
-        self.model_name = model_name
-        self.config = config
-
-    def generate_chat_results(
-        self, prompt: str, temperature: float = 0.001, max_retries: int = 1
-    ) -> Dict:
-        """
-        Generate chat results using the specified model and configuration.
-
-        Args:
-            prompt (str): The user's input prompt.
-            temperature (float, optional): The temperature parameter for llms, defaults to 0.001.
-            max_retries (int, optional): The maximum number of retries for llms API calls, defaults to 1.
-
-        Returns:
-            Dict: The chat completion result from the model.
-        """
-        try:
-            cur_config = {
-                "api_type": self.config["api_type"],
-                "max_retries": max_retries,
-            }
-            if self.config["api_type"] == "aistudio":
-                cur_config["access_token"] = self.config["access_token"]
-            elif self.config["api_type"] == "qianfan":
-                cur_config["ak"] = self.config["ak"]
-                cur_config["sk"] = self.config["sk"]
-            chat_completion = erniebot.ChatCompletion.create(
-                _config_=cur_config,
-                model=self.model_name,
-                messages=[{"role": "user", "content": prompt}],
-                temperature=float(temperature),
-            )
-            llm_result = chat_completion.get_result()
-            return llm_result
-        except Exception as e:
-            if len(e.args) < 1:
-                self.ERROR_MASSAGE = "暂无权限访问ErnieBot服务,请检查访问令牌。"
-            elif (
-                e.args[-1]
-                == "暂无权限使用,请在 AI Studio 正确获取访问令牌(access token)使用"
-            ):
-                self.ERROR_MASSAGE = "暂无权限访问ErnieBot服务,请检查访问令牌。"
-            else:
-                logging.error(e)
-                self.ERROR_MASSAGE = "大模型调用失败"
-        return None
-
-    def fix_llm_result_format(self, llm_result: str) -> dict:
-        """
-        Fix the format of the LLM result.
-
-        Args:
-            llm_result (str): The result from the LLM (Large Language Model).
-
-        Returns:
-            dict: A fixed format dictionary from the LLM result.
-        """
-        if not llm_result:
-            return {}
-
-        if "json" in llm_result or "```" in llm_result:
-            index = llm_result.find("{")
-            if index != -1:
-                llm_result = llm_result[index:]
-            index = llm_result.rfind("}")
-            if index != -1:
-                llm_result = llm_result[: index + 1]
-            llm_result = (
-                llm_result.replace("```", "").replace("json", "").replace("/n", "")
-            )
-            llm_result = llm_result.replace("[", "").replace("]", "")
-
-        try:
-            llm_result = json.loads(llm_result)
-            llm_result_final = {}
-            if "问题" in llm_result.keys() and "答案" in llm_result.keys():
-                key = llm_result["问题"]
-                value = llm_result["答案"]
-                if isinstance(value, list):
-                    if len(value) > 0:
-                        llm_result_final[key] = value[0].strip(f"{key}:").strip(key)
-                else:
-                    llm_result_final[key] = value.strip(f"{key}:").strip(key)
-                return llm_result_final
-            for key in llm_result:
-                value = llm_result[key]
-                if isinstance(value, list):
-                    if len(value) > 0:
-                        llm_result_final[key] = value[0]
-                else:
-                    llm_result_final[key] = value
-            return llm_result_final
-
-        except:
-            results = (
-                llm_result.replace("\n", "")
-                .replace("    ", "")
-                .replace("{", "")
-                .replace("}", "")
-            )
-            if not results.endswith('"'):
-                results = results + '"'
-            pattern = r'"(.*?)": "([^"]*)"'
-            matches = re.findall(pattern, str(results))
-            if len(matches) > 0:
-                llm_result = {k: v for k, v in matches}
-                if "问题" in llm_result.keys() and "答案" in llm_result.keys():
-                    llm_result_final = {}
-                    key = llm_result["问题"]
-                    value = llm_result["答案"]
-                    if isinstance(value, list):
-                        if len(value) > 0:
-                            llm_result_final[key] = value[0].strip(f"{key}:").strip(key)
-                    else:
-                        llm_result_final[key] = value.strip(f"{key}:").strip(key)
-                    return llm_result_final
-                return llm_result
-            else:
-                return {}

+ 3 - 0
paddlex/inference/pipelines/components/chat_server/openai_bot_chat.py

@@ -41,6 +41,9 @@ class OpenAIBotChat(BaseChat):
         """
         super().__init__()
         model_name = config.get("model_name", None)
+        # compatible with historical model name
+        if model_name == "ernie-3.5":
+            model_name = "ernie-3.5-8k"
         api_type = config.get("api_type", None)
         api_key = config.get("api_key", None)
         base_url = config.get("base_url", None)

+ 1 - 1
paddlex/inference/pipelines/components/retriever/__init__.py

@@ -12,5 +12,5 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .ernie_bot_retriever import ErnieBotRetriever
+from .qianfan_bot_retriever import QianFanBotRetriever
 from .openai_bot_retriever import OpenAIBotRetriever

+ 143 - 2
paddlex/inference/pipelines/components/retriever/base.py

@@ -11,10 +11,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+from typing import Dict, List
 from abc import ABC, abstractmethod
-import inspect
+
+import time
 import base64
+
+from langchain.docstore.document import Document
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain_community.vectorstores import FAISS
+from langchain_community import vectorstores
+
+from paddlex.utils import logging
+
 from .....utils.subclass_register import AutoRegisterABCMetaClass
 
 
@@ -28,6 +37,8 @@ class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
     def __init__(self):
         """Initializes an instance of base retriever."""
         super().__init__()
+        self.model_name = None
+        self.embedding = None
 
     @abstractmethod
     def generate_vector_database(self):
@@ -49,6 +60,15 @@ class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
             "The method `similarity_retrieval` has not been implemented yet."
         )
 
+    def get_model_name(self) -> str:
+        """
+        Get the model name used for generating vectors.
+
+        Returns:
+            str: The model name.
+        """
+        return self.model_name
+
     def is_vector_store(self, s: str) -> bool:
         """
         Check if the given string starts with the vector store prefix.
@@ -86,3 +106,124 @@ class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
             bytes: The decoded vector store data.
         """
         return base64.b64decode(vector_store_str[len(self.VECTOR_STORE_PREFIX) :])
+
+    def generate_vector_database(
+        self,
+        text_list: List[str],
+        block_size: int = 300,
+        separators: List[str] = ["\t", "\n", "。", "\n\n", ""],
+    ) -> FAISS:
+        """
+        Generates a vector database from a list of texts.
+
+        Args:
+            text_list (list[str]): A list of texts to generate the vector database from.
+            block_size (int): The size of each chunk to split the text into.
+            separators (list[str]): A list of separators to use when splitting the text.
+
+        Returns:
+            FAISS: The generated vector database.
+
+        Raises:
+            ValueError: If an unsupported API type is configured.
+        """
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=block_size, chunk_overlap=20, separators=separators
+        )
+        texts = text_splitter.split_text("\t".join(text_list))
+        all_splits = [Document(page_content=text) for text in texts]
+
+        try:
+            vectorstore = FAISS.from_documents(
+                documents=all_splits, embedding=self.embedding
+            )
+        except ValueError as e:
+            print(e)
+            vectorstore = None
+
+        return vectorstore
+
+    def encode_vector_store_to_bytes(self, vectorstore: FAISS) -> str:
+        """
+        Encode the vector store serialized to bytes.
+
+        Args:
+            vectorstore (FAISS): The vector store to be serialized and encoded.
+
+        Returns:
+            str: The encoded vector store.
+        """
+        if vectorstore is None:
+            vectorstore = self.VECTOR_STORE_PREFIX
+        else:
+            vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes())
+        return vectorstore
+
+    def decode_vector_store_from_bytes(self, vectorstore: str) -> FAISS:
+        """
+        Decode a vector store from bytes according to the specified API type.
+
+        Args:
+            vectorstore (str): The serialized vector store string.
+
+        Returns:
+            FAISS: Deserialized vector store object.
+
+        Raises:
+            ValueError: If the retrieved vector store is not for PaddleX
+            or if an unsupported API type is specified.
+        """
+        if not self.is_vector_store(vectorstore):
+            raise ValueError("The retrieved vectorstore is not for PaddleX.")
+
+        vectorstore = self.decode_vector_store(vectorstore)
+
+        if vectorstore == b"":
+            logging.warning("The retrieved vectorstore is empty,will empty vector.")
+            return None
+
+        print(vectorstore)
+
+        vector = vectorstores.FAISS.deserialize_from_bytes(
+            vectorstore,
+            embeddings=self.embedding,
+            allow_dangerous_deserialization=True,
+        )
+        return vector
+
+    def similarity_retrieval(
+        self,
+        query_text_list: List[str],
+        vectorstore: FAISS,
+        sleep_time: float = 0.5,
+        topk: int = 2,
+        min_characters: int = 3500,
+    ) -> str:
+        """
+        Retrieve similar contexts based on a list of query texts.
+
+        Args:
+            query_text_list (list[str]): A list of query texts to search for similar contexts.
+            vectorstore (FAISS): The vector store where to perform the similarity search.
+            sleep_time (float): The time to sleep between each query, in seconds. Default is 0.5.
+            topk (int): The number of results to retrieve per query. Default is 2.
+            min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
+        Returns:
+            str: A concatenated string of all unique contexts found.
+        """
+        C = []
+        all_C = ""
+        if vectorstore is None:
+            return all_C
+        for query_text in query_text_list:
+            QUESTION = query_text
+            time.sleep(sleep_time)
+            docs = vectorstore.similarity_search_with_relevance_scores(QUESTION, k=topk)
+            context = [(document.page_content, score) for document, score in docs]
+            context = sorted(context, key=lambda x: x[1])
+            for text, score in context[::-1]:
+                if score >= -0.1:
+                    if len(all_C) + len(text) > min_characters:
+                        break
+                    all_C += text
+        return all_C

+ 0 - 227
paddlex/inference/pipelines/components/retriever/ernie_bot_retriever.py

@@ -1,227 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Dict, List
-import time
-import os
-from langchain.docstore.document import Document
-from langchain.text_splitter import RecursiveCharacterTextSplitter
-from langchain_community.embeddings import QianfanEmbeddingsEndpoint
-from langchain_community.vectorstores import FAISS
-from langchain_community import vectorstores
-from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings
-from .base import BaseRetriever
-
-
-class ErnieBotRetriever(BaseRetriever):
-    """Ernie Bot Retriever"""
-
-    entities = [
-        "aistudio",
-        "qianfan",
-    ]
-
-    MODELS = [
-        "ernie-4.0",
-        "ernie-3.5",
-        "ernie-3.5-8k",
-        "ernie-lite",
-        "ernie-tiny-8k",
-        "ernie-speed",
-        "ernie-speed-128k",
-        "ernie-char-8k",
-    ]
-
-    def __init__(self, config: Dict) -> None:
-        """
-        Initializes the ErnieBotRetriever instance with the provided configuration.
-
-        Args:
-            config (Dict): A dictionary containing configuration settings.
-                - model_name (str): The name of the model to use.
-                - api_type (str): The type of API to use ('aistudio', 'qianfan' or 'openai').
-                - ak (str, optional): The access key for 'qianfan' API.
-                - sk (str, optional): The secret key for 'qianfan' API.
-                - access_token (str, optional): The access token for 'aistudio' API.
-
-        Raises:
-            ValueError: If model_name is not in self.entities,
-                api_type is not 'aistudio' or 'qianfan',
-                access_token is missing for 'aistudio' API,
-                or ak and sk are missing for 'qianfan' API.
-        """
-        super().__init__()
-
-        model_name = config.get("model_name", None)
-        api_type = config.get("api_type", None)
-        ak = config.get("ak", None)
-        sk = config.get("sk", None)
-        access_token = config.get("access_token", None)
-
-        if model_name not in self.MODELS:
-            raise ValueError(f"model_name must be in {self.MODELS} of ErnieBotChat.")
-
-        if api_type not in ["aistudio", "qianfan"]:
-            raise ValueError("api_type must be one of ['aistudio', 'qianfan']")
-
-        if api_type == "aistudio" and access_token is None:
-            raise ValueError("access_token cannot be empty when api_type is aistudio.")
-
-        if api_type == "qianfan" and (ak is None or sk is None):
-            raise ValueError("ak and sk cannot be empty when api_type is qianfan.")
-
-        self.model_name = model_name
-        self.config = config
-
-    # Generates a vector database from a list of texts using different embeddings based on the configured API type.
-
-    def generate_vector_database(
-        self,
-        text_list: List[str],
-        block_size: int = 300,
-        separators: List[str] = ["\t", "\n", "。", "\n\n", ""],
-        sleep_time: float = 0.5,
-    ) -> FAISS:
-        """
-        Generates a vector database from a list of texts.
-
-        Args:
-            text_list (list[str]): A list of texts to generate the vector database from.
-            block_size (int): The size of each chunk to split the text into.
-            separators (list[str]): A list of separators to use when splitting the text.
-            sleep_time (float): The time to sleep between embedding generations to avoid rate limiting.
-
-        Returns:
-            FAISS: The generated vector database.
-
-        Raises:
-            ValueError: If an unsupported API type is configured.
-        """
-        text_splitter = RecursiveCharacterTextSplitter(
-            chunk_size=block_size, chunk_overlap=20, separators=separators
-        )
-        texts = text_splitter.split_text("\t".join(text_list))
-        all_splits = [Document(page_content=text) for text in texts]
-        api_type = self.config["api_type"]
-        if api_type == "qianfan":
-            os.environ["QIANFAN_AK"] = os.environ.get("EB_AK", self.config["ak"])
-            os.environ["QIANFAN_SK"] = os.environ.get("EB_SK", self.config["sk"])
-            user_ak = os.environ.get("EB_AK", self.config["ak"])
-            user_id = hash(user_ak)
-            vectorstore = FAISS.from_documents(
-                documents=all_splits, embedding=QianfanEmbeddingsEndpoint()
-            )
-        elif api_type == "aistudio":
-            token = self.config["access_token"]
-            vectorstore = FAISS.from_documents(
-                documents=all_splits[0:1],
-                embedding=ErnieEmbeddings(aistudio_access_token=token),
-            )
-            #### ErnieEmbeddings.chunk_size = 16
-            step = min(16, len(all_splits) - 1)
-            for shot_splits in [
-                all_splits[i : i + step] for i in range(1, len(all_splits), step)
-            ]:
-                time.sleep(sleep_time)
-                vectorstore_slice = FAISS.from_documents(
-                    documents=shot_splits,
-                    embedding=ErnieEmbeddings(aistudio_access_token=token),
-                )
-                vectorstore.merge_from(vectorstore_slice)
-        else:
-            raise ValueError(f"Unsupported api_type: {api_type}")
-
-        return vectorstore
-
-    def encode_vector_store_to_bytes(self, vectorstore: FAISS) -> str:
-        """
-        Encode the vector store serialized to bytes.
-
-        Args:
-            vectorstore (FAISS): The vector store to be serialized and encoded.
-
-        Returns:
-            str: The encoded vector store.
-        """
-        vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes())
-        return vectorstore
-
-    def decode_vector_store_from_bytes(self, vectorstore: str) -> FAISS:
-        """
-        Decode a vector store from bytes according to the specified API type.
-
-        Args:
-            vectorstore (str): The serialized vector store string.
-
-        Returns:
-            FAISS: Deserialized vector store object.
-
-        Raises:
-            ValueError: If the retrieved vector store is not for PaddleX
-            or if an unsupported API type is specified.
-        """
-        if not self.is_vector_store(vectorstore):
-            raise ValueError("The retrieved vectorstore is not for PaddleX.")
-
-        api_type = self.config["api_type"]
-
-        if api_type == "aistudio":
-            access_token = self.config["access_token"]
-            embeddings = ErnieEmbeddings(aistudio_access_token=access_token)
-        elif api_type == "qianfan":
-            ak = self.config["ak"]
-            sk = self.config["sk"]
-            embeddings = QianfanEmbeddingsEndpoint(qianfan_ak=ak, qianfan_sk=sk)
-        else:
-            raise ValueError(f"Unsupported api_type: {api_type}")
-
-        vector = vectorstores.FAISS.deserialize_from_bytes(
-            self.decode_vector_store(vectorstore), embeddings
-        )
-        return vector
-
-    def similarity_retrieval(
-        self,
-        query_text_list: List[str],
-        vectorstore: FAISS,
-        sleep_time: float = 0.5,
-        topk: int = 2,
-        min_characters: int = 3500,
-    ) -> str:
-        """
-        Retrieve similar contexts based on a list of query texts.
-
-        Args:
-            query_text_list (list[str]): A list of query texts to search for similar contexts.
-            vectorstore (FAISS): The vector store where to perform the similarity search.
-            sleep_time (float): The time to sleep between each query, in seconds. Default is 0.5.
-            topk (int): The number of results to retrieve per query. Default is 2.
-            min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
-        Returns:
-            str: A concatenated string of all unique contexts found.
-        """
-        C = []
-        all_C = ""
-        for query_text in query_text_list:
-            QUESTION = query_text
-            time.sleep(sleep_time)
-            docs = vectorstore.similarity_search_with_relevance_scores(QUESTION, k=topk)
-            context = [(document.page_content, score) for document, score in docs]
-            context = sorted(context, key=lambda x: x[1])
-            for text, score in context[::-1]:
-                if score >= -0.1:
-                    if len(all_C) + len(text) > min_characters:
-                        break
-                    all_C += text
-        return all_C

+ 13 - 117
paddlex/inference/pipelines/components/retriever/openai_bot_retriever.py

@@ -11,18 +11,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Dict, List
 
 from .base import BaseRetriever
 
-from langchain.docstore.document import Document
-from langchain.text_splitter import RecursiveCharacterTextSplitter
-from langchain_community.vectorstores import FAISS
-from langchain_community import vectorstores
-
-import time
-
-from typing import Dict, List
-
 
 class OpenAIBotRetriever(BaseRetriever):
     """OpenAI Bot Retriever"""
@@ -31,6 +23,13 @@ class OpenAIBotRetriever(BaseRetriever):
         "openai",
     ]
 
+    MODELS = [
+        "tao-8k",
+        "embedding-v1",
+        "bge-large-zh",
+        "bge-large-en",
+    ]
+
     def __init__(self, config: Dict) -> None:
         """
         Initializes the OpenAIBotRetriever instance with the provided configuration.
@@ -38,27 +37,23 @@ class OpenAIBotRetriever(BaseRetriever):
         Args:
             config (Dict): A dictionary containing configuration settings.
                 - model_name (str): The name of the model to use.
-                - api_type (str): The type of API to use ('aistudio', 'qianfan' or 'openai').
-                - api_key (str, optional): The API key for 'openai' API.
-                - base_url (str, optional): The base URL for 'openai' API.
+                - api_type (str): The type of API to use ('qianfan' or 'openai').
+                - api_key (str): The API key for 'openai' API.
+                - base_url (str): The base URL for 'openai' API.
 
         Raises:
-            ValueError: If api_type is not one of ['openai'],
+            ValueError: If api_type is not one of ['qianfan','openai'],
             base_url is None for api_type is openai,
             api_key is None for api_type is openai.
         """
         super().__init__()
 
         model_name = config.get("model_name", None)
-        api_type = config.get("api_type", None)
         api_key = config.get("api_key", None)
         base_url = config.get("base_url", None)
         tiktoken_enabled = config.get("tiktoken_enabled", False)
 
-        if api_type not in ["openai"]:
-            raise ValueError("api_type must be one of ['openai']")
-
-        if api_type == "openai" and api_key is None:
+        if api_key is None:
             raise ValueError("api_key cannot be empty when api_type is openai.")
 
         if base_url is None:
@@ -80,102 +75,3 @@ class OpenAIBotRetriever(BaseRetriever):
 
         self.model_name = model_name
         self.config = config
-
-    # Generates a vector database from a list of texts using different embeddings based on the configured API type.
-
-    def generate_vector_database(
-        self,
-        text_list: List[str],
-        block_size: int = 300,
-        separators: List[str] = ["\t", "\n", "。", "\n\n", ""],
-        sleep_time: float = 0.5,
-    ) -> FAISS:
-        """
-        Generates a vector database from a list of texts.
-
-        Args:
-            text_list (list[str]): A list of texts to generate the vector database from.
-            block_size (int): The size of each chunk to split the text into.
-            separators (list[str]): A list of separators to use when splitting the text.
-            sleep_time (float): The time to sleep between embedding generations to avoid rate limiting.
-
-        Returns:
-            FAISS: The generated vector database.
-
-        Raises:
-            ValueError: If an unsupported API type is configured.
-        """
-        text_splitter = RecursiveCharacterTextSplitter(
-            chunk_size=block_size, chunk_overlap=20, separators=separators
-        )
-        texts = text_splitter.split_text("\t".join(text_list))
-        all_splits = [Document(page_content=text) for text in texts]
-
-        api_type = self.config["api_type"]
-
-        vectorstore = FAISS.from_documents(
-            documents=all_splits, embedding=self.embedding
-        )
-
-        return vectorstore
-
-    def encode_vector_store_to_bytes(self, vectorstore: FAISS) -> str:
-        """
-        Encode the vector store serialized to bytes.
-
-        Args:
-            vectorstore (FAISS): The vector store to be serialized and encoded.
-
-        Returns:
-            str: The encoded vector store.
-        """
-        vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes())
-        return vectorstore
-
-    def decode_vector_store_from_bytes(self, vectorstore: str) -> FAISS:
-        """
-        Decode a vector store from bytes according to the specified API type.
-
-        Args:
-            vectorstore (str): The serialized vector store string.
-
-        Returns:
-            FAISS: Deserialized vector store object.
-
-        Raises:
-            ValueError: If the retrieved vector store is not for PaddleX
-            or if an unsupported API type is specified.
-        """
-        if not self.is_vector_store(vectorstore):
-            raise ValueError("The retrieved vectorstore is not for PaddleX.")
-
-        vector = vectorstores.FAISS.deserialize_from_bytes(
-            self.decode_vector_store(vectorstore), self.embedding
-        )
-        return vector
-
-    def similarity_retrieval(
-        self, query_text_list: List[str], vectorstore: FAISS, sleep_time: float = 0.5
-    ) -> str:
-        """
-        Retrieve similar contexts based on a list of query texts.
-
-        Args:
-            query_text_list (list[str]): A list of query texts to search for similar contexts.
-            vectorstore (FAISS): The vector store where to perform the similarity search.
-            sleep_time (float): The time to sleep between each query, in seconds. Default is 0.5.
-
-        Returns:
-            str: A concatenated string of all unique contexts found.
-        """
-        C = []
-        for query_text in query_text_list:
-            QUESTION = query_text
-            time.sleep(sleep_time)
-            docs = vectorstore.similarity_search_with_relevance_scores(QUESTION, k=2)
-            context = [(document.page_content, score) for document, score in docs]
-            context = sorted(context, key=lambda x: x[1])
-            C.extend([x[0] for x in context[::-1]])
-        C = list(set(C))
-        all_C = " ".join(C)
-        return all_C

+ 163 - 0
paddlex/inference/pipelines/components/retriever/qianfan_bot_retriever.py

@@ -0,0 +1,163 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from operator import le
+from typing import Dict, List
+
+import json
+import requests
+from langchain_core.embeddings import Embeddings
+
+from paddlex.utils import logging
+from .base import BaseRetriever
+
+
+class QianFanBotRetriever(BaseRetriever):
+    """QianFan Bot Retriever"""
+
+    entities = [
+        "qianfan",
+    ]
+
+    MODELS = [
+        "tao-8k",
+        "embedding-v1",
+        "bge-large-zh",
+        "bge-large-en",
+    ]
+
+    def __init__(self, config: Dict) -> None:
+        """
+        Initializes the ErnieBotRetriever instance with the provided configuration.
+
+        Args:
+            config (Dict): A dictionary containing configuration settings.
+                - model_name (str): The name of the model to use.
+                - api_type (str): The type of API to use ('qianfan' or 'openai').
+                - api_key (str): The API key for 'qianfan' API.
+                - base_url (str): The base URL for 'qianfan' API.
+
+        Raises:
+            ValueError: If api_type is not one of ['qianfan','openai'],
+                base_url is None for api_type is qianfan,
+                api_key is None for api_type is qianfan.
+        """
+        super().__init__()
+
+        model_name = config.get("model_name", None)
+        api_key = config.get("api_key", None)
+        base_url = config.get("base_url", None)
+
+        if model_name not in self.MODELS:
+            raise ValueError(
+                f"model_name must be in {self.MODELS} of QianFanBotRetriever."
+            )
+
+        if api_key is None:
+            raise ValueError("api_key cannot be empty when api_type is qianfan.")
+
+        if base_url is None:
+            raise ValueError("base_url cannot be empty when api_type is qianfan.")
+
+        self.embedding = QianfanEmbeddings(
+            model=model_name,
+            base_url=base_url,
+            api_key=api_key,
+        )
+
+        self.model_name = model_name
+        self.config = config
+
+
+class QianfanEmbeddings(Embeddings):
+    """`Baidu Qianfan Embeddings` embedding models."""
+
+    def __init__(
+        self,
+        api_key: str,
+        base_url: str = "https://qianfan.baidubce.com/v2",
+        model: str = "embedding-v1",
+        **kwargs,
+    ):
+        """
+        Initialize the Baidu Qianfan Embeddings class.
+
+        Args:
+            api_key (str): The Qianfan API key.
+            base_url (str): The base URL for 'qianfan' API.
+            model (str): Model name. Default is "embedding-v1",select in ["tao-8k","embedding-v1","bge-large-en","bge-large-zh"].
+            kwargs (dict): Additional keyword arguments passed to the base Embeddings class.
+        """
+        super().__init__(**kwargs)
+        chunk_size_map = {
+            "tao-8k": 1,
+            "embedding-v1": 16,
+            "bge-large-en": 16,
+            "bge-large-zh": 16,
+        }
+        self.api_key = api_key
+        self.base_url = base_url
+        self.model = model
+        self.chunk_size = chunk_size_map.get(model, 1)
+
+    def embed(self, texts: str, **kwargs) -> List[float]:
+        url = f"{self.base_url}/embeddings"
+        payload = json.dumps(
+            {"model": kwargs.get("model", self.model), "input": [f"{texts}"]}
+        )
+        headers = {
+            "Content-Type": "application/json",
+            "Authorization": f"Bearer {self.api_key}",
+        }
+
+        response = requests.request("POST", url, headers=headers, data=payload)
+        if response.status_code != 200:
+            logging.error(
+                f"Failed to call Qianfan API. Status code: {response.status_code}, Response content: {response}"
+            )
+
+        return response.json()
+
+    def embed_query(self, text: str) -> List[float]:
+        resp = self.embed_documents([text])
+        return resp[0]
+
+    def embed_documents(self, texts: List[str]) -> List[List[float]]:
+        """
+        Embeds a list of text documents using the AutoVOT algorithm.
+
+        Args:
+            texts (List[str]): A list of text documents to embed.
+
+        Returns:
+            List[List[float]]: A list of embeddings for each document in the input list.
+                            Each embedding is represented as a list of float values.
+        """
+        lst = []
+        for chunk in texts:
+            resp = self.embed(texts=chunk)
+            lst.extend([res["embedding"] for res in resp["data"]])
+        return lst
+
+    async def aembed_query(self, text: str) -> List[float]:
+        embeddings = await self.aembed_documents([text])
+        return embeddings[0]
+
+    async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
+        lst = []
+        for chunk in texts:
+            resp = await self.embed(texts=chunk)
+            for res in resp["data"]:
+                lst.extend([res["embedding"]])
+        return lst

+ 31 - 7
paddlex/inference/pipelines/pp_chatocr/pipeline_v3.py

@@ -15,9 +15,9 @@
 from typing import Any, Dict, Optional, Union, List, Tuple
 import os
 import re
+import copy
 import json
 import numpy as np
-import copy
 from .pipeline_base import PP_ChatOCR_Pipeline
 from ...common.reader import ReadImage
 from ...common.batch_sampler import ImageBatchSampler
@@ -65,6 +65,7 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
         if initial_predictor:
             self.inintial_visual_predictor(config)
             self.inintial_chat_predictor(config)
+            self.inintial_retriever_predictor(config)
 
         self.batch_sampler = ImageBatchSampler(batch_size=1)
         self.img_reader = ReadImage(format="BGR")
@@ -344,7 +345,7 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
         self,
         visual_info: dict,
         min_characters: int = 3500,
-        llm_request_interval: float = 1.0,
+        block_size: int = 300,
         flag_save_bytes_vector: bool = False,
         retriever_config: dict = None,
     ) -> dict:
@@ -354,7 +355,7 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
         Args:
             visual_info (dict): The visual information input, can be a single instance or a list of instances.
             min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
-            llm_request_interval (float): The interval between LLM requests, defaults to 1.0.
+            block_size (int): The size of each chunk to split the text into.
             flag_save_bytes_vector (bool): Whether to save the vector as bytes, defaults to False.
             retriever_config (dict): The configuration for the retriever, defaults to None.
 
@@ -401,9 +402,13 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
         vector_info["flag_save_bytes_vector"] = False
         if len(all_text_str) > min_characters:
             vector_info["flag_too_short_text"] = False
-            vector_info["vector"] = retriever.generate_vector_database(all_items)
+            vector_info["model_name"] = retriever.model_name
+            vector_info["block_size"] = block_size
+            vector_info["vector"] = retriever.generate_vector_database(
+                all_items, block_size=block_size
+            )
             if flag_save_bytes_vector:
-                vector_info["vector"] = self.retriever.encode_vector_store_to_bytes(
+                vector_info["vector"] = retriever.encode_vector_store_to_bytes(
                     vector_info["vector"]
                 )
                 vector_info["flag_save_bytes_vector"] = True
@@ -416,9 +421,22 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
         directory = os.path.dirname(save_path)
         if not os.path.exists(directory):
             os.makedirs(directory)
+        if self.retriever is None:
+            logging.warning("The retriever is not initialized,will initialize it now.")
+            self.inintial_retriever_predictor(self.config)
+
+        vector_info_data = copy.deepcopy(vector_info)
+        if (
+            not vector_info["flag_too_short_text"]
+            and not vector_info["flag_save_bytes_vector"]
+        ):
+            vector_info_data["vector"] = self.retriever.encode_vector_store_to_bytes(
+                vector_info_data["vector"]
+            )
+            vector_info_data["flag_save_bytes_vector"] = True
 
         with custom_open(save_path, "w") as fout:
-            fout.write(json.dumps(vector_info, ensure_ascii=False) + "\n")
+            fout.write(json.dumps(vector_info_data, ensure_ascii=False) + "\n")
         return
 
     def load_vector(self, data_path: str) -> dict:
@@ -437,11 +455,12 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
             ):
                 logging.error("Invalid vector info.")
                 return {"error": "Invalid vector info when load vector!"}
-
             if vector_info["flag_save_bytes_vector"]:
                 vector_info["vector"] = self.retriever.decode_vector_store_from_bytes(
                     vector_info["vector"]
                 )
+                vector_info["flag_save_bytes_vector"] = False
+
         return vector_info
 
     def format_key(self, key_list: Union[str, List[str]]) -> List[str]:
@@ -545,6 +564,11 @@ class PP_ChatOCRv3_Pipeline(PP_ChatOCR_Pipeline):
             question_key_list = [f"{key}" for key in key_list]
             vector = vector_info["vector"]
             if not vector_info["flag_too_short_text"]:
+                assert (
+                    vector_info["model_name"] == retriever.model_name
+                ), f"The vector model name ({vector_info['model_name']}) does not match the retriever model name ({retriever.model_name}). Please check your retriever config."
+                if vector_info["flag_save_bytes_vector"]:
+                    vector = retriever.decode_vector_store_from_bytes(vector)
                 related_text = retriever.similarity_retrieval(
                     question_key_list, vector, topk=50, min_characters=min_characters
                 )

+ 57 - 10
paddlex/inference/pipelines/pp_chatocr/pipeline_v4.py

@@ -13,16 +13,18 @@
 # limitations under the License.
 
 from typing import Any, Dict, Optional, Union, List, Tuple
+import os
 import re
 import cv2
+import copy
 import json
 import base64
 import numpy as np
-import copy
 from .pipeline_base import PP_ChatOCR_Pipeline
 from ...common.reader import ReadImage
 from ...common.batch_sampler import ImageBatchSampler
 from ....utils import logging
+from ....utils.file_interface import custom_open
 from ...utils.pp_option import PaddlePredictorOption
 from ..layout_parsing.result import LayoutParsingResult
 from ..components.chat_server import BaseChat
@@ -67,6 +69,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
         if initial_predictor:
             self.inintial_visual_predictor(config)
             self.inintial_chat_predictor(config)
+            self.inintial_retriever_predictor(config)
             self.inintial_mllm_predictor(config)
 
         self.batch_sampler = ImageBatchSampler(batch_size=1)
@@ -384,7 +387,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
         self,
         visual_info: dict,
         min_characters: int = 3500,
-        llm_request_interval: float = 1.0,
+        block_size: int = 300,
         flag_save_bytes_vector: bool = False,
         retriever_config: dict = None,
     ) -> dict:
@@ -394,7 +397,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
         Args:
             visual_info (dict): The visual information input, can be a single instance or a list of instances.
             min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
-            llm_request_interval (float): The interval between LLM requests, defaults to 1.0.
+            block_size (int): The size of each chunk to split the text into.
             flag_save_bytes_vector (bool): Whether to save the vector as bytes, defaults to False.
             retriever_config (dict): The configuration for the retriever, defaults to None.
 
@@ -444,7 +447,11 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
         vector_info["flag_save_bytes_vector"] = False
         if len(all_text_str) > min_characters:
             vector_info["flag_too_short_text"] = False
-            vector_info["vector"] = retriever.generate_vector_database(all_items)
+            vector_info["model_name"] = retriever.model_name
+            vector_info["block_size"] = block_size
+            vector_info["vector"] = retriever.generate_vector_database(
+                all_items, block_size=block_size
+            )
             if flag_save_bytes_vector:
                 vector_info["vector"] = retriever.encode_vector_store_to_bytes(
                     vector_info["vector"]
@@ -456,8 +463,25 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
         return vector_info
 
     def save_vector(self, vector_info: dict, save_path: str) -> None:
-        with open(save_path, "w") as fout:
-            fout.write(json.dumps(vector_info, ensure_ascii=False) + "\n")
+        directory = os.path.dirname(save_path)
+        if not os.path.exists(directory):
+            os.makedirs(directory)
+        if self.retriever is None:
+            logging.warning("The retriever is not initialized,will initialize it now.")
+            self.inintial_retriever_predictor(self.config)
+
+        vector_info_data = copy.deepcopy(vector_info)
+        if (
+            not vector_info["flag_too_short_text"]
+            and not vector_info["flag_save_bytes_vector"]
+        ):
+            vector_info_data["vector"] = self.retriever.encode_vector_store_to_bytes(
+                vector_info_data["vector"]
+            )
+            vector_info_data["flag_save_bytes_vector"] = True
+
+        with custom_open(save_path, "w") as fout:
+            fout.write(json.dumps(vector_info_data, ensure_ascii=False) + "\n")
         return
 
     def load_vector(self, data_path: str) -> dict:
@@ -476,11 +500,12 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
             ):
                 logging.error("Invalid vector info.")
                 return {"error": "Invalid vector info when load vector!"}
-
             if vector_info["flag_save_bytes_vector"]:
                 vector_info["vector"] = self.retriever.decode_vector_store_from_bytes(
                     vector_info["vector"]
                 )
+                vector_info["flag_save_bytes_vector"] = False
+
         return vector_info
 
     def format_key(self, key_list: Union[str, List[str]]) -> List[str]:
@@ -510,9 +535,19 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
     def mllm_pred(
         self,
         input: Union[str, np.ndarray],
-        key_list,
-        **kwargs,
+        key_list: Union[str, List[str]],
+        mllm_chat_bot_config=None,
     ) -> dict:
+        """
+        Generates MLLM results based on the provided key list and input image.
+
+        Args:
+            input (Union[str, np.ndarray]): Input image path, or numpy array of an image.
+            key_list (Union[str, list[str]]): A single key or a list of keys to extract information.
+            chat_bot_config (dict): The parameters for LLM chatbot, including api_type, api_key... refer to config file for more details.
+        Returns:
+            dict: A dictionary containing the chat results.
+        """
         if self.use_mllm_predict == False:
             logging.error("MLLM prediction is disabled.")
             return {"mllm_res": "Error:MLLM prediction is disabled!"}
@@ -539,6 +574,13 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
             )
             self.inintial_mllm_predictor(self.config)
 
+        if mllm_chat_bot_config is not None:
+            from .. import create_chat_bot
+
+            mllm_chat_bot = create_chat_bot(mllm_chat_bot_config)
+        else:
+            mllm_chat_bot = self.mllm_chat_bot
+
         for image_array in image_array_list:
 
             assert len(image_array.shape) == 3
@@ -550,7 +592,7 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
                     str(key)
                     + "\n请用图片中完整出现的内容回答,可以是单词、短语或句子,针对问题回答尽可能详细和完整,并保持格式、单位、符号和标点都与图片中的文字内容完全一致。"
                 )
-                mllm_chat_bot_result = self.mllm_chat_bot.generate_chat_results(
+                mllm_chat_bot_result = mllm_chat_bot.generate_chat_results(
                     prompt=prompt, image=image_base64
                 )
                 if mllm_chat_bot_result is None:
@@ -636,6 +678,11 @@ class PP_ChatOCRv4_Pipeline(PP_ChatOCR_Pipeline):
             question_key_list = [f"{key}" for key in key_list]
             vector = vector_info["vector"]
             if not vector_info["flag_too_short_text"]:
+                assert (
+                    vector_info["model_name"] == retriever.model_name
+                ), f"The vector model name ({vector_info['model_name']}) does not match the retriever model name ({retriever.model_name}). Please check your retriever config."
+                if vector_info["flag_save_bytes_vector"]:
+                    vector = retriever.decode_vector_store_from_bytes(vector)
                 related_text = retriever.similarity_retrieval(
                     question_key_list, vector, topk=50, min_characters=min_characters
                 )

+ 4 - 5
paddlex/repo_manager/requirements.txt

@@ -9,11 +9,10 @@ openpyxl
 premailer
 python-docx
 ######## For Chatocrv3 #######
-qianfan==0.0.3
-langchain==0.1.5
-langchain-community==0.0.17
-erniebot == 0.5.9
-erniebot-agent == 0.5.2
+langchain==0.2.17
+langchain-community==0.2.17
+langchain-text-splitters==0.2.4
+transformers==4.40.0
 unstructured
 networkx
 numpy==1.24.4; python_version<"3.12"

+ 4 - 5
requirements.txt

@@ -34,11 +34,10 @@ ujson
 Pillow
 importlib_resources>=6.4
 ######## For Chatocrv3 #######
-qianfan==0.0.3
-langchain==0.1.5
-langchain-community==0.0.17
-erniebot == 0.5.9
-erniebot-agent == 0.5.2
+langchain==0.2.17
+langchain-community==0.2.17
+langchain-text-splitters==0.2.4
+transformers==4.40.0
 unstructured
 networkx
 faiss-cpu