openai_bot_retriever.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, List
  15. from .base import BaseRetriever
  16. class OpenAIBotRetriever(BaseRetriever):
  17. """OpenAI Bot Retriever"""
  18. entities = [
  19. "openai",
  20. ]
  21. def __init__(self, config: Dict) -> None:
  22. """
  23. Initializes the OpenAIBotRetriever instance with the provided configuration.
  24. Args:
  25. config (Dict): A dictionary containing configuration settings.
  26. - model_name (str): The name of the model to use.
  27. - api_type (str): The type of API to use ('qianfan' or 'openai').
  28. - api_key (str): The API key for 'openai' API.
  29. - base_url (str): The base URL for 'openai' API.
  30. Raises:
  31. ValueError: If api_type is not one of ['qianfan','openai'],
  32. base_url is None for api_type is openai,
  33. api_key is None for api_type is openai.
  34. """
  35. super().__init__()
  36. model_name = config.get("model_name", None)
  37. api_key = config.get("api_key", None)
  38. base_url = config.get("base_url", None)
  39. tiktoken_enabled = config.get("tiktoken_enabled", False)
  40. if api_key is None:
  41. raise ValueError("api_key cannot be empty when api_type is openai.")
  42. if base_url is None:
  43. raise ValueError("base_url cannot be empty when api_type is openai.")
  44. try:
  45. from langchain_openai import OpenAIEmbeddings
  46. except:
  47. raise Exception(
  48. "langchain-openai is not installed, please install it first."
  49. )
  50. self.embedding = OpenAIEmbeddings(
  51. model=model_name,
  52. api_key=api_key,
  53. base_url=base_url,
  54. tiktoken_enabled=tiktoken_enabled,
  55. )
  56. self.model_name = model_name
  57. self.config = config