|
|
@@ -0,0 +1,204 @@
|
|
|
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
|
|
+#
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
+# You may obtain a copy of the License at
|
|
|
+#
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+#
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
+# limitations under the License.
|
|
|
+
|
|
|
+import re
|
|
|
+import json
|
|
|
+import base64
|
|
|
+from typing import Dict
|
|
|
+from .....utils import logging
|
|
|
+from .base import BaseChat
|
|
|
+
|
|
|
+
|
|
|
+class OpenAIBotChat(BaseChat):
|
|
|
+ """OpenAI Bot Chat"""
|
|
|
+
|
|
|
+ entities = [
|
|
|
+ "openai",
|
|
|
+ ]
|
|
|
+
|
|
|
+ def __init__(self, config: Dict) -> None:
|
|
|
+ """Initializes the OpenAIBotChat with given configuration.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config (Dict): Configuration dictionary containing model_name, api_type, base_url, api_key.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ ValueError: If api_type is not one of ['openai'],
|
|
|
+ base_url is None for api_type is openai,
|
|
|
+ api_key is None for api_type is openai.
|
|
|
+ """
|
|
|
+ super().__init__()
|
|
|
+ model_name = config.get("model_name", None)
|
|
|
+ api_type = config.get("api_type", None)
|
|
|
+ api_key = config.get("api_key", None)
|
|
|
+ base_url = config.get("base_url", None)
|
|
|
+
|
|
|
+ if api_type not in ["openai"]:
|
|
|
+ raise ValueError("api_type must be one of ['openai']")
|
|
|
+
|
|
|
+ if api_type == "openai" and api_key is None:
|
|
|
+ raise ValueError("api_key cannot be empty when api_type is openai.")
|
|
|
+
|
|
|
+ if base_url is None:
|
|
|
+ raise ValueError("base_url cannot be empty when api_type is openai.")
|
|
|
+
|
|
|
+ try:
|
|
|
+ from openai import OpenAI
|
|
|
+ except:
|
|
|
+ raise Exception("openai is not installed, please install it first.")
|
|
|
+
|
|
|
+ self.client = OpenAI(base_url=base_url, api_key=api_key)
|
|
|
+
|
|
|
+ self.model_name = model_name
|
|
|
+ self.config = config
|
|
|
+
|
|
|
+ def generate_chat_results(
|
|
|
+ self,
|
|
|
+ prompt: str,
|
|
|
+ image: base64 = None,
|
|
|
+ temperature: float = 0.001,
|
|
|
+ max_retries: int = 1,
|
|
|
+ ) -> Dict:
|
|
|
+ """
|
|
|
+ Generate chat results using the specified model and configuration.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ prompt (str): The user's input prompt.
|
|
|
+ image (base64): The user's input image for MLLM, defaults to None.
|
|
|
+ temperature (float, optional): The temperature parameter for llms, defaults to 0.001.
|
|
|
+ max_retries (int, optional): The maximum number of retries for llms API calls, defaults to 1.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Dict: The chat completion result from the model.
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ if image:
|
|
|
+ chat_completion = self.client.chat.completions.create(
|
|
|
+ model=self.model_name,
|
|
|
+ messages=[
|
|
|
+ {
|
|
|
+ "role": "system",
|
|
|
+ # XXX: give a basic prompt for common
|
|
|
+ "content": "You are a helpful assistant.",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "role": "user",
|
|
|
+ "content": [
|
|
|
+ {"type": "text", "text": prompt},
|
|
|
+ {
|
|
|
+ "type": "image_url",
|
|
|
+ "image_url": {
|
|
|
+ "url": f"data:image/jpeg;base64,{image}"
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ stream=False,
|
|
|
+ temperature=temperature,
|
|
|
+ top_p=0.001,
|
|
|
+ )
|
|
|
+ llm_result = chat_completion.choices[0].message.content
|
|
|
+ return llm_result
|
|
|
+ else:
|
|
|
+ chat_completion = self.client.completions.create(
|
|
|
+ model=self.model_name,
|
|
|
+ prompt=prompt,
|
|
|
+ max_tokens=self.config.get("max_tokens", 1024),
|
|
|
+ temperature=float(temperature),
|
|
|
+ stream=False,
|
|
|
+ )
|
|
|
+ if isinstance(chat_completion, str):
|
|
|
+ chat_completion = json.loads(chat_completion)
|
|
|
+ llm_result = chat_completion["choices"][0]["text"]
|
|
|
+ else:
|
|
|
+ llm_result = chat_completion.choices[0].text
|
|
|
+ return llm_result
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(e)
|
|
|
+ self.ERROR_MASSAGE = "大模型调用失败"
|
|
|
+ return None
|
|
|
+
|
|
|
+ def fix_llm_result_format(self, llm_result: str) -> dict:
|
|
|
+ """
|
|
|
+ Fix the format of the LLM result.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ llm_result (str): The result from the LLM (Large Language Model).
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ dict: A fixed format dictionary from the LLM result.
|
|
|
+ """
|
|
|
+ if not llm_result:
|
|
|
+ return {}
|
|
|
+
|
|
|
+ if "json" in llm_result or "```" in llm_result:
|
|
|
+ index = llm_result.find("{")
|
|
|
+ if index != -1:
|
|
|
+ llm_result = llm_result[index:]
|
|
|
+ index = llm_result.rfind("}")
|
|
|
+ if index != -1:
|
|
|
+ llm_result = llm_result[: index + 1]
|
|
|
+ llm_result = (
|
|
|
+ llm_result.replace("```", "").replace("json", "").replace("/n", "")
|
|
|
+ )
|
|
|
+ llm_result = llm_result.replace("[", "").replace("]", "")
|
|
|
+
|
|
|
+ try:
|
|
|
+ llm_result = json.loads(llm_result)
|
|
|
+ llm_result_final = {}
|
|
|
+ if "问题" in llm_result.keys() and "答案" in llm_result.keys():
|
|
|
+ key = llm_result["问题"]
|
|
|
+ value = llm_result["答案"]
|
|
|
+ if isinstance(value, list):
|
|
|
+ if len(value) > 0:
|
|
|
+ llm_result_final[key] = value[0].strip(f"{key}:").strip(key)
|
|
|
+ else:
|
|
|
+ llm_result_final[key] = value.strip(f"{key}:").strip(key)
|
|
|
+ return llm_result_final
|
|
|
+ for key in llm_result:
|
|
|
+ value = llm_result[key]
|
|
|
+ if isinstance(value, list):
|
|
|
+ if len(value) > 0:
|
|
|
+ llm_result_final[key] = value[0]
|
|
|
+ else:
|
|
|
+ llm_result_final[key] = value
|
|
|
+ return llm_result_final
|
|
|
+
|
|
|
+ except:
|
|
|
+ results = (
|
|
|
+ llm_result.replace("\n", "")
|
|
|
+ .replace(" ", "")
|
|
|
+ .replace("{", "")
|
|
|
+ .replace("}", "")
|
|
|
+ )
|
|
|
+ if not results.endswith('"'):
|
|
|
+ results = results + '"'
|
|
|
+ pattern = r'"(.*?)": "([^"]*)"'
|
|
|
+ matches = re.findall(pattern, str(results))
|
|
|
+ if len(matches) > 0:
|
|
|
+ llm_result = {k: v for k, v in matches}
|
|
|
+ if "问题" in llm_result.keys() and "答案" in llm_result.keys():
|
|
|
+ llm_result_final = {}
|
|
|
+ key = llm_result["问题"]
|
|
|
+ value = llm_result["答案"]
|
|
|
+ if isinstance(value, list):
|
|
|
+ if len(value) > 0:
|
|
|
+ llm_result_final[key] = value[0].strip(f"{key}:").strip(key)
|
|
|
+ else:
|
|
|
+ llm_result_final[key] = value.strip(f"{key}:").strip(key)
|
|
|
+ return llm_result_final
|
|
|
+ return llm_result
|
|
|
+ else:
|
|
|
+ return {}
|