base.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, List
  15. from abc import ABC, abstractmethod
  16. import time
  17. import base64
  18. from langchain.docstore.document import Document
  19. from langchain.text_splitter import RecursiveCharacterTextSplitter
  20. from langchain_community.vectorstores import FAISS
  21. from langchain_community import vectorstores
  22. from paddlex.utils import logging
  23. from .....utils.subclass_register import AutoRegisterABCMetaClass
  24. class BaseRetriever(ABC, metaclass=AutoRegisterABCMetaClass):
  25. """Base Retriever"""
  26. __is_base = True
  27. VECTOR_STORE_PREFIX = "PADDLEX_VECTOR_STORE"
  28. def __init__(self):
  29. """Initializes an instance of base retriever."""
  30. super().__init__()
  31. self.model_name = None
  32. self.embedding = None
  33. @abstractmethod
  34. def generate_vector_database(self):
  35. """
  36. Declaration of an abstract method. Subclasses are expected to
  37. provide a concrete implementation of generate_vector_database.
  38. """
  39. raise NotImplementedError(
  40. "The method `generate_vector_database` has not been implemented yet."
  41. )
  42. @abstractmethod
  43. def similarity_retrieval(self):
  44. """
  45. Declaration of an abstract method. Subclasses are expected to
  46. provide a concrete implementation of similarity_retrieval.
  47. """
  48. raise NotImplementedError(
  49. "The method `similarity_retrieval` has not been implemented yet."
  50. )
  51. def get_model_name(self) -> str:
  52. """
  53. Get the model name used for generating vectors.
  54. Returns:
  55. str: The model name.
  56. """
  57. return self.model_name
  58. def is_vector_store(self, s: str) -> bool:
  59. """
  60. Check if the given string starts with the vector store prefix.
  61. Args:
  62. s (str): The input string to check.
  63. Returns:
  64. bool: True if the string starts with the vector store prefix, False otherwise.
  65. """
  66. return s.startswith(self.VECTOR_STORE_PREFIX)
  67. def encode_vector_store(self, vector_store_bytes: bytes) -> str:
  68. """
  69. Encode the vector store bytes into a base64 string prefixed with a specific prefix.
  70. Args:
  71. vector_store_bytes (bytes): The bytes to encode.
  72. Returns:
  73. str: The encoded string with the prefix.
  74. """
  75. return self.VECTOR_STORE_PREFIX + base64.b64encode(vector_store_bytes).decode(
  76. "ascii"
  77. )
  78. def decode_vector_store(self, vector_store_str: str) -> bytes:
  79. """
  80. Decodes the vector store string by removing the prefix and decoding the base64 encoded string.
  81. Args:
  82. vector_store_str (str): The vector store string with a prefix.
  83. Returns:
  84. bytes: The decoded vector store data.
  85. """
  86. return base64.b64decode(vector_store_str[len(self.VECTOR_STORE_PREFIX) :])
  87. def generate_vector_database(
  88. self,
  89. text_list: List[str],
  90. block_size: int = 300,
  91. separators: List[str] = ["\t", "\n", "。", "\n\n", ""],
  92. ) -> FAISS:
  93. """
  94. Generates a vector database from a list of texts.
  95. Args:
  96. text_list (list[str]): A list of texts to generate the vector database from.
  97. block_size (int): The size of each chunk to split the text into.
  98. separators (list[str]): A list of separators to use when splitting the text.
  99. Returns:
  100. FAISS: The generated vector database.
  101. Raises:
  102. ValueError: If an unsupported API type is configured.
  103. """
  104. text_splitter = RecursiveCharacterTextSplitter(
  105. chunk_size=block_size, chunk_overlap=20, separators=separators
  106. )
  107. texts = text_splitter.split_text("\t".join(text_list))
  108. all_splits = [Document(page_content=text) for text in texts]
  109. try:
  110. vectorstore = FAISS.from_documents(
  111. documents=all_splits, embedding=self.embedding
  112. )
  113. except ValueError as e:
  114. print(e)
  115. vectorstore = None
  116. return vectorstore
  117. def encode_vector_store_to_bytes(self, vectorstore: FAISS) -> str:
  118. """
  119. Encode the vector store serialized to bytes.
  120. Args:
  121. vectorstore (FAISS): The vector store to be serialized and encoded.
  122. Returns:
  123. str: The encoded vector store.
  124. """
  125. if vectorstore is None:
  126. vectorstore = self.VECTOR_STORE_PREFIX
  127. else:
  128. vectorstore = self.encode_vector_store(vectorstore.serialize_to_bytes())
  129. return vectorstore
  130. def decode_vector_store_from_bytes(self, vectorstore: str) -> FAISS:
  131. """
  132. Decode a vector store from bytes according to the specified API type.
  133. Args:
  134. vectorstore (str): The serialized vector store string.
  135. Returns:
  136. FAISS: Deserialized vector store object.
  137. Raises:
  138. ValueError: If the retrieved vector store is not for PaddleX
  139. or if an unsupported API type is specified.
  140. """
  141. if not self.is_vector_store(vectorstore):
  142. raise ValueError("The retrieved vectorstore is not for PaddleX.")
  143. vectorstore = self.decode_vector_store(vectorstore)
  144. if vectorstore == b"":
  145. logging.warning("The retrieved vectorstore is empty,will empty vector.")
  146. return None
  147. print(vectorstore)
  148. vector = vectorstores.FAISS.deserialize_from_bytes(
  149. vectorstore,
  150. embeddings=self.embedding,
  151. allow_dangerous_deserialization=True,
  152. )
  153. return vector
  154. def similarity_retrieval(
  155. self,
  156. query_text_list: List[str],
  157. vectorstore: FAISS,
  158. sleep_time: float = 0.5,
  159. topk: int = 2,
  160. min_characters: int = 3500,
  161. ) -> str:
  162. """
  163. Retrieve similar contexts based on a list of query texts.
  164. Args:
  165. query_text_list (list[str]): A list of query texts to search for similar contexts.
  166. vectorstore (FAISS): The vector store where to perform the similarity search.
  167. sleep_time (float): The time to sleep between each query, in seconds. Default is 0.5.
  168. topk (int): The number of results to retrieve per query. Default is 2.
  169. min_characters (int): The minimum number of characters required for text processing, defaults to 3500.
  170. Returns:
  171. str: A concatenated string of all unique contexts found.
  172. """
  173. C = []
  174. all_C = ""
  175. if vectorstore is None:
  176. return all_C
  177. for query_text in query_text_list:
  178. QUESTION = query_text
  179. time.sleep(sleep_time)
  180. docs = vectorstore.similarity_search_with_relevance_scores(QUESTION, k=topk)
  181. context = [(document.page_content, score) for document, score in docs]
  182. context = sorted(context, key=lambda x: x[1])
  183. for text, score in context[::-1]:
  184. if score >= -0.1:
  185. if len(all_C) + len(text) > min_characters:
  186. break
  187. all_C += text
  188. return all_C