ernie_bot_chat.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict
  15. import re
  16. import json
  17. import erniebot
  18. from .....utils import logging
  19. from .base import BaseChat
  20. class ErnieBotChat(BaseChat):
  21. """Ernie Bot Chat"""
  22. entities = [
  23. "ernie-4.0",
  24. "ernie-3.5",
  25. "ernie-3.5-8k",
  26. "ernie-lite",
  27. "ernie-tiny-8k",
  28. "ernie-speed",
  29. "ernie-speed-128k",
  30. "ernie-char-8k",
  31. ]
  32. def __init__(self, config: Dict) -> None:
  33. """Initializes the ErnieBotChat with given configuration.
  34. Args:
  35. config (Dict): Configuration dictionary containing model_name, api_type, ak, sk, and access_token.
  36. Raises:
  37. ValueError: If model_name is not in the predefined entities,
  38. api_type is not one of ['aistudio', 'qianfan'],
  39. access_token is None for 'aistudio' api_type,
  40. or ak and sk are None for 'qianfan' api_type.
  41. """
  42. super().__init__()
  43. model_name = config.get("model_name", None)
  44. api_type = config.get("api_type", None)
  45. ak = config.get("ak", None)
  46. sk = config.get("sk", None)
  47. access_token = config.get("access_token", None)
  48. if model_name not in self.entities:
  49. raise ValueError(f"model_name must be in {self.entities} of ErnieBotChat.")
  50. if api_type not in ["aistudio", "qianfan"]:
  51. raise ValueError("api_type must be one of ['aistudio', 'qianfan']")
  52. if api_type == "aistudio" and access_token is None:
  53. raise ValueError("access_token cannot be empty when api_type is aistudio.")
  54. if api_type == "qianfan" and (ak is None or sk is None):
  55. raise ValueError("ak and sk cannot be empty when api_type is qianfan.")
  56. self.model_name = model_name
  57. self.config = config
  58. def generate_chat_results(
  59. self, prompt: str, temperature: float = 0.001, max_retries: int = 1
  60. ) -> Dict:
  61. """
  62. Generate chat results using the specified model and configuration.
  63. Args:
  64. prompt (str): The user's input prompt.
  65. temperature (float, optional): The temperature parameter for llms, defaults to 0.001.
  66. max_retries (int, optional): The maximum number of retries for llms API calls, defaults to 1.
  67. Returns:
  68. Dict: The chat completion result from the model.
  69. """
  70. try:
  71. cur_config = {
  72. "api_type": self.config["api_type"],
  73. "max_retries": max_retries,
  74. }
  75. if self.config["api_type"] == "aistudio":
  76. cur_config["access_token"] = self.config["access_token"]
  77. elif self.config["api_type"] == "qianfan":
  78. cur_config["ak"] = self.config["ak"]
  79. cur_config["sk"] = self.config["sk"]
  80. chat_completion = erniebot.ChatCompletion.create(
  81. _config_=cur_config,
  82. model=self.model_name,
  83. messages=[{"role": "user", "content": prompt}],
  84. temperature=float(temperature),
  85. )
  86. llm_result = chat_completion.get_result()
  87. return llm_result
  88. except Exception as e:
  89. if len(e.args) < 1:
  90. self.ERROR_MASSAGE = "暂无权限访问ErnieBot服务,请检查访问令牌。"
  91. elif (
  92. e.args[-1]
  93. == "暂无权限使用,请在 AI Studio 正确获取访问令牌(access token)使用"
  94. ):
  95. self.ERROR_MASSAGE = "暂无权限访问ErnieBot服务,请检查访问令牌。"
  96. else:
  97. logging.error(e)
  98. self.ERROR_MASSAGE = "大模型调用失败"
  99. return None
  100. def fix_llm_result_format(self, llm_result: str) -> dict:
  101. """
  102. Fix the format of the LLM result.
  103. Args:
  104. llm_result (str): The result from the LLM (Large Language Model).
  105. Returns:
  106. dict: A fixed format dictionary from the LLM result.
  107. """
  108. if not llm_result:
  109. return {}
  110. if "json" in llm_result or "```" in llm_result:
  111. llm_result = (
  112. llm_result.replace("```", "").replace("json", "").replace("/n", "")
  113. )
  114. llm_result = llm_result.replace("[", "").replace("]", "")
  115. try:
  116. llm_result = json.loads(llm_result)
  117. llm_result_final = {}
  118. for key in llm_result:
  119. value = llm_result[key]
  120. if isinstance(value, list):
  121. if len(value) > 0:
  122. llm_result_final[key] = value[0]
  123. else:
  124. llm_result_final[key] = value
  125. return llm_result_final
  126. except:
  127. results = (
  128. llm_result.replace("\n", "")
  129. .replace(" ", "")
  130. .replace("{", "")
  131. .replace("}", "")
  132. )
  133. if not results.endswith('"'):
  134. results = results + '"'
  135. pattern = r'"(.*?)": "([^"]*)"'
  136. matches = re.findall(pattern, str(results))
  137. if len(matches) > 0:
  138. llm_result = {k: v for k, v in matches}
  139. return llm_result
  140. else:
  141. return {}