openai_bot_chat.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. import json
  16. import base64
  17. from typing import Dict
  18. from .....utils import logging
  19. from .base import BaseChat
  20. class OpenAIBotChat(BaseChat):
  21. """OpenAI Bot Chat"""
  22. entities = [
  23. "openai",
  24. ]
  25. def __init__(self, config: Dict) -> None:
  26. """Initializes the OpenAIBotChat with given configuration.
  27. Args:
  28. config (Dict): Configuration dictionary containing model_name, api_type, base_url, api_key, end_point.
  29. Raises:
  30. ValueError: If api_type is not one of ['openai'],
  31. base_url is None for api_type is openai,
  32. api_key is None for api_type is openai.
  33. ValueError: If end_point is not one of ['completion', 'chat_completion'].
  34. """
  35. super().__init__()
  36. model_name = config.get("model_name", None)
  37. api_type = config.get("api_type", None)
  38. api_key = config.get("api_key", None)
  39. base_url = config.get("base_url", None)
  40. end_point = config.get("end_point", "chat_completion")
  41. if api_type not in ["openai"]:
  42. raise ValueError("api_type must be one of ['openai']")
  43. if api_type == "openai" and api_key is None:
  44. raise ValueError("api_key cannot be empty when api_type is openai.")
  45. if base_url is None:
  46. raise ValueError("base_url cannot be empty when api_type is openai.")
  47. if end_point not in ["completion", "chat_completion"]:
  48. raise ValueError(
  49. "end_point must be one of ['completion', 'chat_completion']"
  50. )
  51. try:
  52. from openai import OpenAI
  53. except:
  54. raise Exception("openai is not installed, please install it first.")
  55. self.client = OpenAI(base_url=base_url, api_key=api_key)
  56. self.model_name = model_name
  57. self.config = config
  58. def generate_chat_results(
  59. self,
  60. prompt: str,
  61. image: base64 = None,
  62. temperature: float = 0.001,
  63. max_retries: int = 1,
  64. ) -> Dict:
  65. """
  66. Generate chat results using the specified model and configuration.
  67. Args:
  68. prompt (str): The user's input prompt.
  69. image (base64): The user's input image for MLLM, defaults to None.
  70. temperature (float, optional): The temperature parameter for llms, defaults to 0.001.
  71. max_retries (int, optional): The maximum number of retries for llms API calls, defaults to 1.
  72. Returns:
  73. Dict: The chat completion result from the model.
  74. """
  75. try:
  76. if image:
  77. chat_completion = self.client.chat.completions.create(
  78. model=self.model_name,
  79. messages=[
  80. {
  81. "role": "system",
  82. # XXX: give a basic prompt for common
  83. "content": "You are a helpful assistant.",
  84. },
  85. {
  86. "role": "user",
  87. "content": [
  88. {"type": "text", "text": prompt},
  89. {
  90. "type": "image_url",
  91. "image_url": {
  92. "url": f"data:image/jpeg;base64,{image}"
  93. },
  94. },
  95. ],
  96. },
  97. ],
  98. stream=False,
  99. temperature=temperature,
  100. top_p=0.001,
  101. )
  102. llm_result = chat_completion.choices[0].message.content
  103. return llm_result
  104. elif self.config.get("end_point", "chat_completion") == "chat_completion":
  105. chat_completion = self.client.chat.completions.create(
  106. model=self.model_name,
  107. messages=[
  108. {
  109. "role": "user",
  110. "content": prompt,
  111. },
  112. ],
  113. stream=False,
  114. temperature=temperature,
  115. top_p=0.001,
  116. )
  117. llm_result = chat_completion.choices[0].message.content
  118. return llm_result
  119. else:
  120. chat_completion = self.client.completions.create(
  121. model=self.model_name,
  122. prompt=prompt,
  123. max_tokens=self.config.get("max_tokens", 1024),
  124. temperature=float(temperature),
  125. stream=False,
  126. )
  127. if isinstance(chat_completion, str):
  128. chat_completion = json.loads(chat_completion)
  129. llm_result = chat_completion["choices"][0]["text"]
  130. else:
  131. llm_result = chat_completion.choices[0].text
  132. return llm_result
  133. except Exception as e:
  134. logging.error(e)
  135. self.ERROR_MASSAGE = "大模型调用失败"
  136. return None
  137. def fix_llm_result_format(self, llm_result: str) -> dict:
  138. """
  139. Fix the format of the LLM result.
  140. Args:
  141. llm_result (str): The result from the LLM (Large Language Model).
  142. Returns:
  143. dict: A fixed format dictionary from the LLM result.
  144. """
  145. if not llm_result:
  146. return {}
  147. if "json" in llm_result or "```" in llm_result:
  148. index = llm_result.find("{")
  149. if index != -1:
  150. llm_result = llm_result[index:]
  151. index = llm_result.rfind("}")
  152. if index != -1:
  153. llm_result = llm_result[: index + 1]
  154. llm_result = (
  155. llm_result.replace("```", "").replace("json", "").replace("/n", "")
  156. )
  157. llm_result = llm_result.replace("[", "").replace("]", "")
  158. try:
  159. llm_result = json.loads(llm_result)
  160. llm_result_final = {}
  161. if "问题" in llm_result.keys() and "答案" in llm_result.keys():
  162. key = llm_result["问题"]
  163. value = llm_result["答案"]
  164. if isinstance(value, list):
  165. if len(value) > 0:
  166. llm_result_final[key] = value[0].strip(f"{key}:").strip(key)
  167. else:
  168. llm_result_final[key] = value.strip(f"{key}:").strip(key)
  169. return llm_result_final
  170. for key in llm_result:
  171. value = llm_result[key]
  172. if isinstance(value, list):
  173. if len(value) > 0:
  174. llm_result_final[key] = value[0]
  175. else:
  176. llm_result_final[key] = value
  177. return llm_result_final
  178. except:
  179. results = (
  180. llm_result.replace("\n", "")
  181. .replace(" ", "")
  182. .replace("{", "")
  183. .replace("}", "")
  184. )
  185. if not results.endswith('"'):
  186. results = results + '"'
  187. pattern = r'"(.*?)": "([^"]*)"'
  188. matches = re.findall(pattern, str(results))
  189. if len(matches) > 0:
  190. llm_result = {k: v for k, v in matches}
  191. if "问题" in llm_result.keys() and "答案" in llm_result.keys():
  192. llm_result_final = {}
  193. key = llm_result["问题"]
  194. value = llm_result["答案"]
  195. if isinstance(value, list):
  196. if len(value) > 0:
  197. llm_result_final[key] = value[0].strip(f"{key}:").strip(key)
  198. else:
  199. llm_result_final[key] = value.strip(f"{key}:").strip(key)
  200. return llm_result_final
  201. return llm_result
  202. else:
  203. return {}