erniebot.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. import json
  17. import erniebot
  18. from pathlib import Path
  19. from .base import BaseLLM
  20. from ....utils import logging
  21. from ....utils.func_register import FuncRegister
  22. from langchain.docstore.document import Document
  23. from langchain.text_splitter import RecursiveCharacterTextSplitter
  24. from langchain_community.embeddings import QianfanEmbeddingsEndpoint
  25. from langchain_community.vectorstores import FAISS
  26. from langchain_community import vectorstores
  27. from erniebot_agent.extensions.langchain.embeddings import ErnieEmbeddings
  28. __all__ = ["ErnieBot"]
  29. class ErnieBot(BaseLLM):
  30. INPUT_KEYS = ["prompts"]
  31. OUTPUT_KEYS = ["cls_res"]
  32. DEAULT_INPUTS = {"prompts": "prompts"}
  33. DEAULT_OUTPUTS = {"cls_pred": "cls_pred"}
  34. API_TYPE = "aistudio"
  35. entities = [
  36. "ernie-4.0",
  37. "ernie-3.5",
  38. "ernie-3.5-8k",
  39. "ernie-lite",
  40. "ernie-tiny-8k",
  41. "ernie-speed",
  42. "ernie-speed-128k",
  43. "ernie-char-8k",
  44. ]
  45. _FUNC_MAP = {}
  46. register = FuncRegister(_FUNC_MAP)
  47. def __init__(self, model_name="ernie-4.0", params={}):
  48. super().__init__()
  49. access_token = params.get("access_token")
  50. ak = params.get("ak")
  51. sk = params.get("sk")
  52. api_type = params.get("api_type")
  53. max_retries = params.get("max_retries")
  54. assert model_name in self.entities, f"model_name must be in {self.entities}"
  55. assert any([access_token, ak, sk]), "access_token or ak and sk must be set"
  56. self.model_name = model_name
  57. self.config = {
  58. "api_type": api_type,
  59. "max_retries": max_retries,
  60. }
  61. if access_token:
  62. self.config["access_token"] = access_token
  63. else:
  64. self.config["ak"] = ak
  65. self.config["sk"] = sk
  66. def pred(self, prompt, temperature=0.001):
  67. """
  68. llm predict
  69. """
  70. try:
  71. chat_completion = erniebot.ChatCompletion.create(
  72. _config_=self.config,
  73. model=self.model_name,
  74. messages=[{"role": "user", "content": prompt}],
  75. temperature=float(temperature),
  76. )
  77. llm_result = chat_completion.get_result()
  78. return llm_result
  79. except Exception as e:
  80. if len(e.args) < 1:
  81. self.ERROR_MASSAGE = (
  82. "当前选择后端为AI Studio,千帆调用失败,请检查token"
  83. )
  84. elif (
  85. e.args[-1]
  86. == "暂无权限使用,请在 AI Studio 正确获取访问令牌(access token)使用"
  87. ):
  88. self.ERROR_MASSAGE = (
  89. "当前选择后端为AI Studio,请正确获取访问令牌(access token)使用"
  90. )
  91. elif e.args[-1] == "the max length of current question is 4800":
  92. self.ERROR_MASSAGE = "大模型调用失败"
  93. else:
  94. logging.error(e)
  95. self.ERROR_MASSAGE = "大模型调用失败"
  96. return None
  97. def get_vector(
  98. self,
  99. ocr_result,
  100. sleep_time=0.5,
  101. block_size=300,
  102. separators=["\t", "\n", "。", "\n\n", ""],
  103. ):
  104. """get summary prompt"""
  105. all_items = []
  106. for i, ocr_res in enumerate(ocr_result):
  107. for type, text in ocr_res.items():
  108. all_items += [f"第{i}页{type}:{text}"]
  109. text_splitter = RecursiveCharacterTextSplitter(
  110. chunk_size=block_size, chunk_overlap=20, separators=separators
  111. )
  112. texts = text_splitter.split_text("\t".join(all_items))
  113. all_splits = [Document(page_content=text) for text in texts]
  114. api_type = self.config["api_type"]
  115. if api_type == "qianfan":
  116. os.environ["QIANFAN_AK"] = os.environ.get("EB_AK", self.config["ak"])
  117. os.environ["QIANFAN_SK"] = os.environ.get("EB_SK", self.config["sk"])
  118. user_ak = os.environ.get("EB_AK", self.config["ak"])
  119. user_id = hash(user_ak)
  120. vectorstore = FAISS.from_documents(
  121. documents=all_splits, embedding=QianfanEmbeddingsEndpoint()
  122. )
  123. elif api_type == "aistudio":
  124. token = self.config["access_token"]
  125. vectorstore = FAISS.from_documents(
  126. documents=all_splits[0:1],
  127. embedding=ErnieEmbeddings(aistudio_access_token=token),
  128. )
  129. #### ErnieEmbeddings.chunk_size = 16
  130. step = min(16, len(all_splits) - 1)
  131. for shot_splits in [
  132. all_splits[i : i + step] for i in range(1, len(all_splits), step)
  133. ]:
  134. time.sleep(sleep_time)
  135. vectorstore_slice = FAISS.from_documents(
  136. documents=shot_splits,
  137. embedding=ErnieEmbeddings(aistudio_access_token=token),
  138. )
  139. vectorstore.merge_from(vectorstore_slice)
  140. else:
  141. raise ValueError(f"Unsupported api_type: {api_type}")
  142. vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes())
  143. return vectorstore
  144. def caculate_similar(self, vector, key_list, llm_params=None, sleep_time=0.5):
  145. """caculate similar with key and doc"""
  146. if not self.is_vector_store(vector):
  147. logging.warning(
  148. "The retrieved vectorstore is not for PaddleX and will return vectorstore directly"
  149. )
  150. return vector
  151. # XXX: The initialization parameters are hard-coded.
  152. if llm_params:
  153. api_type = llm_params.get("api_type")
  154. access_token = llm_params.get("access_token")
  155. ak = llm_params.get("ak")
  156. sk = llm_params.get("sk")
  157. else:
  158. api_type = self.config["api_type"]
  159. access_token = self.config.get("access_token")
  160. ak = self.config.get("ak")
  161. sk = self.config.get("sk")
  162. if api_type == "aistudio":
  163. embeddings = ErnieEmbeddings(aistudio_access_token=access_token)
  164. elif api_type == "qianfan":
  165. embeddings = QianfanEmbeddingsEndpoint(qianfan_ak=ak, qianfan_sk=sk)
  166. else:
  167. raise ValueError(f"Unsupported api_type: {api_type}")
  168. vectorstore = vectorstores.FAISS.deserialize_from_bytes(
  169. self.decode_vector_store(vector), embeddings
  170. )
  171. # 根据提问匹配上下文
  172. Q = []
  173. C = []
  174. for key in key_list:
  175. QUESTION = f"抽取关键信息:{key}"
  176. # c_str = ""
  177. Q.append(QUESTION)
  178. time.sleep(sleep_time)
  179. docs = vectorstore.similarity_search_with_relevance_scores(QUESTION, k=2)
  180. context = [(document.page_content, score) for document, score in docs]
  181. context = sorted(context, key=lambda x: x[1])
  182. C.extend([x[0] for x in context[::-1]])
  183. C = list(set(C))
  184. all_C = " ".join(C)
  185. summary_prompt = all_C
  186. return summary_prompt