base.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import base64
  15. import time
  16. from abc import ABC, abstractmethod
  17. from typing import List
  18. from paddlex.utils import logging
  19. from .....utils.deps import class_requires_deps, is_dep_available
  20. from .....utils.subclass_register import AutoRegisterABCMetaClass
  21. if is_dep_available("langchain"):
  22. from langchain.docstore.document import Document
  23. from langchain.text_splitter import RecursiveCharacterTextSplitter
  24. if is_dep_available("langchain-community"):
  25. from langchain_community import vectorstores
  26. from langchain_community.vectorstores import FAISS
  27. @class_requires_deps("langchain", "langchain-community")
  28. class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
  29. """Base Retriever"""
  30. __is_base = True
  31. VECTOR_STORE_PREFIX = "PADDLEX_VECTOR_STORE"
  32. def __init__(self):
  33. """Initializes an instance of base retriever."""
  34. super().__init__()
  35. self.model_name = None
  36. self.embedding = None
  37. @abstractmethod
  38. def generate_vector_database(self):
  39. """
  40. Declaration of an abstract method. Subclasses are expected to
  41. provide a concrete implementation of generate_vector_database.
  42. """
  43. raise NotImplementedError(
  44. "The method `generate_vector_database` has not been implemented yet."
  45. )
  46. @abstractmethod
  47. def similarity_retrieval(self):
  48. """
  49. Declaration of an abstract method. Subclasses are expected to
  50. provide a concrete implementation of similarity_retrieval.
  51. """
  52. raise NotImplementedError(
  53. "The method `similarity_retrieval` has not been implemented yet."
  54. )
  55. def get_model_name(self) -> str:
  56. """
  57. Get the model name used for generating vectors.
  58. Returns:
  59. str: The model name.
  60. """
  61. return self.model_name
  62. def is_vector_store(self, s: str) -> bool:
  63. """
  64. Check if the given string starts with the vector store prefix.
  65. Args:
  66. s (str): The input string to check.
  67. Returns:
  68. bool: True if the string starts with the vector store prefix, False otherwise.
  69. """
  70. return s.startswith(self.VECTOR_STORE_PREFIX)
  71. def encode_vector_store(self, vector_store_bytes: bytes) -> str:
  72. """
  73. Encode the vector store bytes into a base64 string prefixed with a specific prefix.
  74. Args:
  75. vector_store_bytes (bytes): The bytes to encode.
  76. Returns:
  77. str: The encoded string with the prefix.
  78. """
  79. return self.VECTOR_STORE_PREFIX + base64.b64encode(vector_store_bytes).decode(
  80. "ascii"
  81. )
  82. def decode_vector_store(self, vector_store_str: str) -> bytes:
  83. """
  84. Decodes the vector store string by removing the prefix and decoding the base64 encoded string.
  85. Args:
  86. vector_store_str (str): The vector store string with a prefix.
  87. Returns:
  88. bytes: The decoded vector store data.
  89. """
  90. return base64.b64decode(vector_store_str[len(self.VECTOR_STORE_PREFIX) :])
  91. def generate_vector_database(
  92. self,
  93. text_list: List[str],
  94. block_size: int = 300,
  95. separators: List[str] = ["\t", "\n", "。", "\n\n", ""],
  96. ) -> "FAISS":
  97. """
  98. Generates a vector database from a list of texts.
  99. Args:
  100. text_list (list[str]): A list of texts to generate the vector database from.
  101. block_size (int): The size of each chunk to split the text into.
  102. separators (list[str]): A list of separators to use when splitting the text.
  103. Returns:
  104. FAISS: The generated vector database.
  105. Raises:
  106. ValueError: If an unsupported API type is configured.
  107. """
  108. text_splitter = RecursiveCharacterTextSplitter(
  109. chunk_size=block_size, chunk_overlap=20, separators=separators
  110. )
  111. texts = text_splitter.split_text("\t".join(text_list))
  112. all_splits = [Document(page_content=text) for text in texts]
  113. try:
  114. vectorstore = FAISS.from_documents(
  115. documents=all_splits, embedding=self.embedding
  116. )
  117. except ValueError:
  118. vectorstore = None
  119. return vectorstore
  120. def encode_vector_store_to_bytes(self, vectorstore: "FAISS") -> str:
  121. """
  122. Encode the vector store serialized to bytes.
  123. Args:
  124. vectorstore (FAISS): The vector store to be serialized and encoded.
  125. Returns:
  126. str: The encoded vector store.
  127. """
  128. if vectorstore is None:
  129. vectorstore = self.VECTOR_STORE_PREFIX
  130. else:
  131. vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes())
  132. return vectorstore
  133. def decode_vector_store_from_bytes(self, vectorstore: str) -> "FAISS":
  134. """
  135. Decode a vector store from bytes according to the specified API type.
  136. Args:
  137. vectorstore (str): The serialized vector store string.
  138. Returns:
  139. FAISS: Deserialized vector store object.
  140. Raises:
  141. ValueError: If the retrieved vector store is not for PaddleX
  142. or if an unsupported API type is specified.
  143. """
  144. if not self.is_vector_store(vectorstore):
  145. raise ValueError("The retrieved vectorstore is not for PaddleX.")
  146. vectorstore = self.decode_vector_store(vectorstore)
  147. if vectorstore == b"":
  148. logging.warning("The retrieved vectorstore is empty,will empty vector.")
  149. return None
  150. vector = vectorstores.FAISS.deserialize_from_bytes(
  151. vectorstore,
  152. embeddings=self.embedding,
  153. allow_dangerous_deserialization=True,
  154. )
  155. return vector
  156. def similarity_retrieval(
  157. self,
  158. query_text_list: List[str],
  159. vectorstore: "FAISS",
  160. sleep_time: float = 0.5,
  161. topk: int = 2,
  162. min_characters: int = 3500,
  163. ) -> str:
  164. """
  165. Retrieve similar contexts based on a list of query texts.
  166. Args:
  167. query_text_list (list[str]): A list of query texts to search for similar contexts.
  168. vectorstore (FAISS): The vector store where to perform the similarity search.
  169. sleep_time (float): The time to sleep between each query, in seconds. Default is 0.5.
  170. topk (int): The number of results to retrieve per query. Default is 2.
  171. min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
  172. Returns:
  173. str: A concatenated string of all unique contexts found.
  174. """
  175. all_C = ""
  176. if vectorstore is None:
  177. return all_C
  178. for query_text in query_text_list:
  179. QUESTION = query_text
  180. time.sleep(sleep_time)
  181. docs = vectorstore.similarity_search_with_relevance_scores(QUESTION, k=topk)
  182. context = [(document.page_content, score) for document, score in docs]
  183. context = sorted(context, key=lambda x: x[1])
  184. for text, score in context[::-1]:
  185. if score >= -0.1:
  186. if len(all_C) + len(text) > min_characters:
  187. break
  188. all_C += text
  189. return all_C