openai_bot_retriever.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, List
  15. from .base import BaseRetriever
  16. class OpenAIBotRetriever(BaseRetriever):
  17. """OpenAI Bot Retriever"""
  18. entities = [
  19. "openai",
  20. ]
  21. MODELS = [
  22. "tao-8k",
  23. "embedding-v1",
  24. "bge-large-zh",
  25. "bge-large-en",
  26. ]
  27. def __init__(self, config: Dict) -> None:
  28. """
  29. Initializes the OpenAIBotRetriever instance with the provided configuration.
  30. Args:
  31. config (Dict): A dictionary containing configuration settings.
  32. - model_name (str): The name of the model to use.
  33. - api_type (str): The type of API to use ('qianfan' or 'openai').
  34. - api_key (str): The API key for 'openai' API.
  35. - base_url (str): The base URL for 'openai' API.
  36. Raises:
  37. ValueError: If api_type is not one of ['qianfan','openai'],
  38. base_url is None for api_type is openai,
  39. api_key is None for api_type is openai.
  40. """
  41. super().__init__()
  42. model_name = config.get("model_name", None)
  43. api_key = config.get("api_key", None)
  44. base_url = config.get("base_url", None)
  45. tiktoken_enabled = config.get("tiktoken_enabled", False)
  46. if api_key is None:
  47. raise ValueError("api_key cannot be empty when api_type is openai.")
  48. if base_url is None:
  49. raise ValueError("base_url cannot be empty when api_type is openai.")
  50. try:
  51. from langchain_openai import OpenAIEmbeddings
  52. except:
  53. raise Exception(
  54. "langchain-openai is not installed, please install it first."
  55. )
  56. self.embedding = OpenAIEmbeddings(
  57. model=model_name,
  58. api_key=api_key,
  59. base_url=base_url,
  60. tiktoken_enabled=tiktoken_enabled,
  61. )
  62. self.model_name = model_name
  63. self.config = config