qianfan_bot_retriever.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. from typing import Dict, List
  16. import requests
  17. from paddlex.utils import logging
  18. from .....utils.deps import is_dep_available
  19. from .base import BaseRetriever
  20. class QianFanBotRetriever(BaseRetriever):
  21. """QianFan Bot Retriever"""
  22. entities = [
  23. "qianfan",
  24. ]
  25. MODELS = [
  26. "tao-8k",
  27. "embedding-v1",
  28. "bge-large-zh",
  29. "bge-large-en",
  30. ]
  31. def __init__(self, config: Dict) -> None:
  32. """
  33. Initializes the ErnieBotRetriever instance with the provided configuration.
  34. Args:
  35. config (Dict): A dictionary containing configuration settings.
  36. - model_name (str): The name of the model to use.
  37. - api_type (str): The type of API to use ('qianfan' or 'openai').
  38. - api_key (str): The API key for 'qianfan' API.
  39. - base_url (str): The base URL for 'qianfan' API.
  40. Raises:
  41. ValueError: If api_type is not one of ['qianfan','openai'],
  42. base_url is None for api_type is qianfan,
  43. api_key is None for api_type is qianfan.
  44. """
  45. super().__init__()
  46. model_name = config.get("model_name", None)
  47. api_key = config.get("api_key", None)
  48. base_url = config.get("base_url", None)
  49. if model_name not in self.MODELS:
  50. raise ValueError(
  51. f"model_name must be in {self.MODELS} of QianFanBotRetriever."
  52. )
  53. if api_key is None:
  54. raise ValueError("api_key cannot be empty when api_type is qianfan.")
  55. if base_url is None:
  56. raise ValueError("base_url cannot be empty when api_type is qianfan.")
  57. self.embedding = QianfanEmbeddings(
  58. model=model_name,
  59. base_url=base_url,
  60. api_key=api_key,
  61. )
  62. self.model_name = model_name
  63. self.config = config
  64. if is_dep_available("langchain-core"):
  65. from langchain_core.embeddings import Embeddings
  66. class QianfanEmbeddings(Embeddings):
  67. """`Baidu Qianfan Embeddings` embedding models."""
  68. def __init__(
  69. self,
  70. api_key: str,
  71. base_url: str = "https://qianfan.baidubce.com/v2",
  72. model: str = "embedding-v1",
  73. **kwargs,
  74. ):
  75. """
  76. Initialize the Baidu Qianfan Embeddings class.
  77. Args:
  78. api_key (str): The Qianfan API key.
  79. base_url (str): The base URL for 'qianfan' API.
  80. model (str): Model name. Default is "embedding-v1",select in ["tao-8k","embedding-v1","bge-large-en","bge-large-zh"].
  81. kwargs (dict): Additional keyword arguments passed to the base Embeddings class.
  82. """
  83. super().__init__(**kwargs)
  84. chunk_size_map = {
  85. "tao-8k": 1,
  86. "embedding-v1": 16,
  87. "bge-large-en": 16,
  88. "bge-large-zh": 16,
  89. }
  90. self.api_key = api_key
  91. self.base_url = base_url
  92. self.model = model
  93. self.chunk_size = chunk_size_map.get(model, 1)
  94. def embed(self, texts: str, **kwargs) -> List[float]:
  95. url = f"{self.base_url}/embeddings"
  96. payload = json.dumps(
  97. {"model": kwargs.get("model", self.model), "input": [f"{texts}"]}
  98. )
  99. headers = {
  100. "Content-Type": "application/json",
  101. "Authorization": f"Bearer {self.api_key}",
  102. }
  103. response = requests.request("POST", url, headers=headers, data=payload)
  104. if response.status_code != 200:
  105. logging.error(
  106. f"Failed to call Qianfan API. Status code: {response.status_code}, Response content: {response}"
  107. )
  108. return response.json()
  109. def embed_query(self, text: str) -> List[float]:
  110. resp = self.embed_documents([text])
  111. return resp[0]
  112. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  113. """
  114. Embeds a list of text documents using the AutoVOT algorithm.
  115. Args:
  116. texts (List[str]): A list of text documents to embed.
  117. Returns:
  118. List[List[float]]: A list of embeddings for each document in the input list.
  119. Each embedding is represented as a list of float values.
  120. """
  121. lst = []
  122. for chunk in texts:
  123. resp = self.embed(texts=chunk)
  124. lst.extend([res["embedding"] for res in resp["data"]])
  125. return lst
  126. async def aembed_query(self, text: str) -> List[float]:
  127. embeddings = await self.aembed_documents([text])
  128. return embeddings[0]
  129. async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
  130. lst = []
  131. for chunk in texts:
  132. resp = await self.embed(texts=chunk)
  133. for res in resp["data"]:
  134. lst.extend([res["embedding"]])
  135. return lst