openai_bot_chat.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. import json
  16. import base64
  17. from typing import Dict
  18. from .....utils import logging
  19. from .base import BaseChat
  20. class OpenAIBotChat(BaseChat):
  21. """OpenAI Bot Chat"""
  22. entities = [
  23. "openai",
  24. ]
  25. def __init__(self, config: Dict) -> None:
  26. """Initializes the OpenAIBotChat with given configuration.
  27. Args:
  28. config (Dict): Configuration dictionary containing model_name, api_type, base_url, api_key, end_point.
  29. Raises:
  30. ValueError: If api_type is not one of ['openai'],
  31. base_url is None for api_type is openai,
  32. api_key is None for api_type is openai.
  33. ValueError: If end_point is not one of ['completion', 'chat_completion'].
  34. """
  35. super().__init__()
  36. model_name = config.get("model_name", None)
  37. # compatible with historical model name
  38. if model_name == "ernie-3.5":
  39. model_name = "ernie-3.5-8k"
  40. api_type = config.get("api_type", None)
  41. api_key = config.get("api_key", None)
  42. base_url = config.get("base_url", None)
  43. end_point = config.get("end_point", "chat_completion")
  44. if api_type not in ["openai"]:
  45. raise ValueError("api_type must be one of ['openai']")
  46. if api_type == "openai" and api_key is None:
  47. raise ValueError("api_key cannot be empty when api_type is openai.")
  48. if base_url is None:
  49. raise ValueError("base_url cannot be empty when api_type is openai.")
  50. if end_point not in ["completion", "chat_completion"]:
  51. raise ValueError(
  52. "end_point must be one of ['completion', 'chat_completion']"
  53. )
  54. try:
  55. from openai import OpenAI
  56. except:
  57. raise Exception("openai is not installed, please install it first.")
  58. self.client = OpenAI(base_url=base_url, api_key=api_key)
  59. self.model_name = model_name
  60. self.config = config
  61. def generate_chat_results(
  62. self,
  63. prompt: str,
  64. image: base64 = None,
  65. temperature: float = 0.001,
  66. max_retries: int = 1,
  67. ) -> Dict:
  68. """
  69. Generate chat results using the specified model and configuration.
  70. Args:
  71. prompt (str): The user's input prompt.
  72. image (base64): The user's input image for MLLM, defaults to None.
  73. temperature (float, optional): The temperature parameter for llms, defaults to 0.001.
  74. max_retries (int, optional): The maximum number of retries for llms API calls, defaults to 1.
  75. Returns:
  76. Dict: The chat completion result from the model.
  77. """
  78. llm_result = {"content": None, "reasoning_content": None}
  79. try:
  80. if image:
  81. chat_completion = self.client.chat.completions.create(
  82. model=self.model_name,
  83. messages=[
  84. {
  85. "role": "system",
  86. # XXX: give a basic prompt for common
  87. "content": "You are a helpful assistant.",
  88. },
  89. {
  90. "role": "user",
  91. "content": [
  92. {"type": "text", "text": prompt},
  93. {
  94. "type": "image_url",
  95. "image_url": {
  96. "url": f"data:image/jpeg;base64,{image}"
  97. },
  98. },
  99. ],
  100. },
  101. ],
  102. stream=False,
  103. temperature=temperature,
  104. top_p=0.001,
  105. )
  106. llm_result["content"] = chat_completion.choices[0].message.content
  107. return llm_result
  108. elif self.config.get("end_point", "chat_completion") == "chat_completion":
  109. chat_completion = self.client.chat.completions.create(
  110. model=self.model_name,
  111. messages=[
  112. {
  113. "role": "user",
  114. "content": prompt,
  115. },
  116. ],
  117. stream=False,
  118. temperature=temperature,
  119. top_p=0.001,
  120. )
  121. llm_result["content"] = chat_completion.choices[0].message.content
  122. try:
  123. llm_result["reasoning_content"] = chat_completion.choices[
  124. 0
  125. ].message.reasoning_content
  126. except:
  127. pass
  128. return llm_result
  129. else:
  130. chat_completion = self.client.completions.create(
  131. model=self.model_name,
  132. prompt=prompt,
  133. max_tokens=self.config.get("max_tokens", 1024),
  134. temperature=float(temperature),
  135. stream=False,
  136. )
  137. if isinstance(chat_completion, str):
  138. chat_completion = json.loads(chat_completion)
  139. llm_result = chat_completion["choices"][0]["text"]
  140. else:
  141. llm_result["content"] = chat_completion.choices[0].text
  142. return llm_result
  143. except Exception as e:
  144. logging.error(e)
  145. self.ERROR_MASSAGE = "大模型调用失败"
  146. return llm_result
  147. def fix_llm_result_format(self, llm_result: str) -> dict:
  148. """
  149. Fix the format of the LLM result.
  150. Args:
  151. llm_result (str): The result from the LLM (Large Language Model).
  152. Returns:
  153. dict: A fixed format dictionary from the LLM result.
  154. """
  155. if not llm_result:
  156. return {}
  157. if "json" in llm_result or "```" in llm_result:
  158. index = llm_result.find("{")
  159. if index != -1:
  160. llm_result = llm_result[index:]
  161. index = llm_result.rfind("}")
  162. if index != -1:
  163. llm_result = llm_result[: index + 1]
  164. llm_result = (
  165. llm_result.replace("```", "").replace("json", "").replace("/n", "")
  166. )
  167. llm_result = llm_result.replace("[", "").replace("]", "")
  168. try:
  169. llm_result = json.loads(llm_result)
  170. llm_result_final = {}
  171. if "问题" in llm_result.keys() and "答案" in llm_result.keys():
  172. key = llm_result["问题"]
  173. value = llm_result["答案"]
  174. if isinstance(value, list):
  175. if len(value) > 0:
  176. llm_result_final[key] = value[0].strip(f"{key}:").strip(key)
  177. else:
  178. llm_result_final[key] = value.strip(f"{key}:").strip(key)
  179. return llm_result_final
  180. for key in llm_result:
  181. value = llm_result[key]
  182. if isinstance(value, list):
  183. if len(value) > 0:
  184. llm_result_final[key] = value[0]
  185. else:
  186. llm_result_final[key] = value
  187. return llm_result_final
  188. except:
  189. results = (
  190. llm_result.replace("\n", "")
  191. .replace(" ", "")
  192. .replace("{", "")
  193. .replace("}", "")
  194. )
  195. if not results.endswith('"'):
  196. results = results + '"'
  197. pattern = r'"(.*?)": "([^"]*)"'
  198. matches = re.findall(pattern, str(results))
  199. if len(matches) > 0:
  200. llm_result = {k: v for k, v in matches}
  201. if "问题" in llm_result.keys() and "答案" in llm_result.keys():
  202. llm_result_final = {}
  203. key = llm_result["问题"]
  204. value = llm_result["答案"]
  205. if isinstance(value, list):
  206. if len(value) > 0:
  207. llm_result_final[key] = value[0].strip(f"{key}:").strip(key)
  208. else:
  209. llm_result_final[key] = value.strip(f"{key}:").strip(key)
  210. return llm_result_final
  211. return llm_result
  212. else:
  213. return {}