|
|
@@ -31,18 +31,20 @@ class OpenAIBotChat(BaseChat):
|
|
|
"""Initializes the OpenAIBotChat with given configuration.
|
|
|
|
|
|
Args:
|
|
|
- config (Dict): Configuration dictionary containing model_name, api_type, base_url, api_key.
|
|
|
+ config (Dict): Configuration dictionary containing model_name, api_type, base_url, api_key, end_point.
|
|
|
|
|
|
Raises:
|
|
|
ValueError: If api_type is not one of ['openai'],
|
|
|
base_url is None for api_type is openai,
|
|
|
api_key is None for api_type is openai.
|
|
|
+ ValueError: If end_point is not one of ['completion', 'chat_completion'].
|
|
|
"""
|
|
|
super().__init__()
|
|
|
model_name = config.get("model_name", None)
|
|
|
api_type = config.get("api_type", None)
|
|
|
api_key = config.get("api_key", None)
|
|
|
base_url = config.get("base_url", None)
|
|
|
+ end_point = config.get("end_point", "chat_completion")
|
|
|
|
|
|
if api_type not in ["openai"]:
|
|
|
raise ValueError("api_type must be one of ['openai']")
|
|
|
@@ -53,6 +55,11 @@ class OpenAIBotChat(BaseChat):
|
|
|
if base_url is None:
|
|
|
raise ValueError("base_url cannot be empty when api_type is openai.")
|
|
|
|
|
|
+ if end_point not in ["completion", "chat_completion"]:
|
|
|
+ raise ValueError(
|
|
|
+ "end_point must be one of ['completion', 'chat_completion']"
|
|
|
+ )
|
|
|
+
|
|
|
try:
|
|
|
from openai import OpenAI
|
|
|
except:
|
|
|
@@ -111,6 +118,21 @@ class OpenAIBotChat(BaseChat):
|
|
|
)
|
|
|
llm_result = chat_completion.choices[0].message.content
|
|
|
return llm_result
|
|
|
+ elif self.config.get("end_point", "chat_completion") == "chat_completion":
|
|
|
+ chat_completion = self.client.chat.completions.create(
|
|
|
+ model=self.model_name,
|
|
|
+ messages=[
|
|
|
+ {
|
|
|
+ "role": "user",
|
|
|
+ "content": prompt,
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ stream=False,
|
|
|
+ temperature=temperature,
|
|
|
+ top_p=0.001,
|
|
|
+ )
|
|
|
+ llm_result = chat_completion.choices[0].message.content
|
|
|
+ return llm_result
|
|
|
else:
|
|
|
chat_completion = self.client.completions.create(
|
|
|
model=self.model_name,
|