瀏覽代碼

support Ctrl chatbot end_point

zhouchangda 9 月之前
父節點
當前提交
a30e03be98
共有 1 個文件被更改,包括 23 次插入1 次删除
  1. 23 1
      paddlex/inference/pipelines/components/chat_server/openai_bot_chat.py

+ 23 - 1
paddlex/inference/pipelines/components/chat_server/openai_bot_chat.py

@@ -31,18 +31,20 @@ class OpenAIBotChat(BaseChat):
         """Initializes the OpenAIBotChat with given configuration.
 
         Args:
-            config (Dict): Configuration dictionary containing model_name, api_type, base_url, api_key.
+            config (Dict): Configuration dictionary containing model_name, api_type, base_url, api_key, end_point.
 
         Raises:
             ValueError: If api_type is not one of ['openai'],
             base_url is None for api_type is openai,
             api_key is None for api_type is openai.
+            ValueError: If end_point is not one of ['completion', 'chat_completion'].
         """
         super().__init__()
         model_name = config.get("model_name", None)
         api_type = config.get("api_type", None)
         api_key = config.get("api_key", None)
         base_url = config.get("base_url", None)
+        end_point = config.get("end_point", "chat_completion")
 
         if api_type not in ["openai"]:
             raise ValueError("api_type must be one of ['openai']")
@@ -53,6 +55,11 @@ class OpenAIBotChat(BaseChat):
         if base_url is None:
             raise ValueError("base_url cannot be empty when api_type is openai.")
 
+        if end_point not in ["completion", "chat_completion"]:
+            raise ValueError(
+                "end_point must be one of ['completion', 'chat_completion']"
+            )
+
         try:
             from openai import OpenAI
         except:
@@ -111,6 +118,21 @@ class OpenAIBotChat(BaseChat):
                 )
                 llm_result = chat_completion.choices[0].message.content
                 return llm_result
+            elif self.config.get("end_point", "chat_completion") == "chat_completion":
+                chat_completion = self.client.chat.completions.create(
+                    model=self.model_name,
+                    messages=[
+                        {
+                            "role": "user",
+                            "content": prompt,
+                        },
+                    ],
+                    stream=False,
+                    temperature=temperature,
+                    top_p=0.001,
+                )
+                llm_result = chat_completion.choices[0].message.content
+                return llm_result
             else:
                 chat_completion = self.client.completions.create(
                     model=self.model_name,