faiss.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import pickle
  16. from pathlib import Path
  17. import faiss
  18. import numpy as np
  19. from ....utils import logging
  20. from ..base import BaseComponent
  21. class FaissIndexer(BaseComponent):
  22. INPUT_KEYS = "feature"
  23. OUTPUT_KEYS = ["label", "score"]
  24. DEAULT_INPUTS = {"feature": "feature"}
  25. DEAULT_OUTPUTS = {"label": "label", "score": "score"}
  26. ENABLE_BATCH = True
  27. def __init__(
  28. self,
  29. index_dir,
  30. metric_type="IP",
  31. return_k=1,
  32. score_thres=None,
  33. hamming_radius=None,
  34. ):
  35. super().__init__()
  36. index_dir = Path(index_dir)
  37. vector_path = (index_dir / "vector.index").as_posix()
  38. id_map_path = (index_dir / "id_map.pkl").as_posix()
  39. if metric_type == "hamming":
  40. self._indexer = faiss.read_index_binary(vector_path)
  41. self.hamming_radius = hamming_radius
  42. else:
  43. self._indexer = faiss.read_index(vector_path)
  44. self.score_thres = score_thres
  45. with open(id_map_path, "rb") as fd:
  46. self.id_map = pickle.load(fd)
  47. self.metric_type = metric_type
  48. self.return_k = return_k
  49. def apply(self, feature):
  50. """apply"""
  51. scores_list, ids_list = self._indexer.search(np.array(feature), self.return_k)
  52. preds = []
  53. for scores, ids in zip(scores_list, ids_list):
  54. labels = []
  55. for id in ids:
  56. if id > 0:
  57. labels.append(self.id_map[id])
  58. preds.append({"score": scores, "label": labels})
  59. if self.metric_type == "hamming":
  60. idxs = np.where(scores_list[:, 0] > self.hamming_radius)[0]
  61. else:
  62. idxs = np.where(scores_list[:, 0] < self.score_thres)[0]
  63. for idx in idxs:
  64. preds[idx] = {"score": None, "label": None}
  65. return preds
  66. class FaissBuilder:
  67. SUPPORT_MODE = ("new", "remove", "append")
  68. SUPPORT_METRIC_TYPE = ("hamming", "IP", "L2")
  69. SUPPORT_INDEX_TYPE = ("Flat", "IVF", "HNSW32")
  70. BINARY_METRIC_TYPE = ("hamming",)
  71. BINARY_SUPPORT_INDEX_TYPE = ("Flat", "IVF", "BinaryHash")
  72. def __init__(self, predict, mode="new", index_type="HNSW32", metric_type="IP"):
  73. super().__init__()
  74. assert (
  75. mode in self.SUPPORT_MODE
  76. ), f"Supported modes only: {self.SUPPORT_MODE}. But received {mode}!"
  77. assert (
  78. metric_type in self.SUPPORT_METRIC_TYPE
  79. ), f"Supported metric types only: {self.SUPPORT_METRIC_TYPE}!"
  80. assert (
  81. index_type in self.SUPPORT_INDEX_TYPE
  82. ), f"Supported index types only: {self.SUPPORT_INDEX_TYPE}!"
  83. self._predict = predict
  84. self._mode = mode
  85. self._metric_type = metric_type
  86. self._index_type = index_type
  87. def _get_index_type(self, num=None):
  88. # if IVF method, cal ivf number automaticlly
  89. if self._index_type == "IVF":
  90. index_type = self._index_type + str(min(int(num // 8), 65536))
  91. if self._metric_type in self.BINARY_METRIC_TYPE:
  92. index_type += ",BFlat"
  93. else:
  94. index_type += ",Flat"
  95. # for binary index, add B at head of index_type
  96. if self._metric_type in self.BINARY_METRIC_TYPE:
  97. assert (
  98. self._index_type in self.BINARY_SUPPORT_INDEX_TYPE
  99. ), f"The metric type({self._metric_type}) only support {self.BINARY_SUPPORT_INDEX_TYPE} index types!"
  100. index_type = "B" + index_type
  101. if self._index_type == "HNSW32":
  102. logging.warning("The HNSW32 method dose not support 'remove' operation")
  103. index_type = "HNSW32"
  104. if self._index_type == "Flat":
  105. index_type = "Flat"
  106. return index_type
  107. def _get_metric_type(self):
  108. if self._metric_type == "hamming":
  109. return faiss.METRIC_Hamming
  110. elif self._metric_type == "jaccard":
  111. return faiss.METRIC_Jaccard
  112. elif self._metric_type == "IP":
  113. return faiss.METRIC_INNER_PRODUCT
  114. elif self._metric_type == "L2":
  115. return faiss.METRIC_L2
  116. def build(
  117. self,
  118. label_file,
  119. image_root,
  120. index_dir,
  121. ):
  122. file_list, gallery_docs = get_file_list(label_file, image_root)
  123. features = [res["feature"] for res in self._predict(file_list)]
  124. dtype = np.uint8 if self._metric_type in self.BINARY_METRIC_TYPE else np.float32
  125. features = np.array(features).astype(dtype)
  126. vector_num, vector_dim = features.shape
  127. if self._metric_type in self.BINARY_METRIC_TYPE:
  128. index = faiss.index_binary_factory(
  129. vector_dim,
  130. self._get_index_type(vector_num),
  131. self._get_metric_type(),
  132. )
  133. else:
  134. index = faiss.index_factory(
  135. vector_dim,
  136. self._get_index_type(vector_num),
  137. self._get_metric_type(),
  138. )
  139. index = faiss.IndexIDMap2(index)
  140. ids = {}
  141. # calculate id for new data
  142. index, ids = self._add_gallery(index, ids, features, gallery_docs)
  143. self._save_gallery(index, ids, index_dir)
  144. def remove(
  145. self,
  146. label_file,
  147. image_root,
  148. index_dir,
  149. ):
  150. file_list, gallery_docs = get_file_list(label_file, image_root)
  151. # load vector.index and id_map.pkl
  152. index, ids = self._load_index(index_dir)
  153. if self._index_type == "HNSW32":
  154. raise RuntimeError(
  155. "The index_type: HNSW32 dose not support 'remove' operation"
  156. )
  157. # remove ids in id_map, remove index data in faiss index
  158. index, ids = self._rm_id_in_galllery(index, ids, gallery_docs)
  159. self._save_gallery(index, ids, index_dir)
  160. def append(
  161. self,
  162. label_file,
  163. image_root,
  164. index_dir,
  165. ):
  166. file_list, gallery_docs = get_file_list(label_file, image_root)
  167. features = [res["feature"] for res in self._predict(file_list)]
  168. dtype = np.uint8 if self._metric_type in self.BINARY_METRIC_TYPE else np.float32
  169. features = np.array(features).astype(dtype)
  170. # load vector.index and id_map.pkl
  171. index, ids = self._load_index(index_dir)
  172. # calculate id for new data
  173. index, ids = self._add_gallery(index, ids, features, gallery_docs)
  174. self._save_gallery(index, ids, index_dir)
  175. def _load_index(self, index_dir):
  176. assert os.path.join(
  177. index_dir, "vector.index"
  178. ), "The vector.index dose not exist in {} when 'index_operation' is not None".format(
  179. index_dir
  180. )
  181. assert os.path.join(
  182. index_dir, "id_map.pkl"
  183. ), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format(
  184. index_dir
  185. )
  186. index = faiss.read_index(os.path.join(index_dir, "vector.index"))
  187. with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd:
  188. ids = pickle.load(fd)
  189. assert index.ntotal == len(
  190. ids.keys()
  191. ), "data number in index is not equal in in id_map"
  192. return index, ids
  193. def _add_gallery(self, index, ids, gallery_features, gallery_docs):
  194. start_id = max(ids.keys()) + 1 if ids else 0
  195. ids_now = (np.arange(0, len(gallery_docs)) + start_id).astype(np.int64)
  196. # only train when new index file
  197. if self._mode == "new":
  198. if self._metric_type in self.BINARY_METRIC_TYPE:
  199. index.add(gallery_features)
  200. else:
  201. index.train(gallery_features)
  202. if not self._metric_type in self.BINARY_METRIC_TYPE:
  203. index.add_with_ids(gallery_features, ids_now)
  204. for i, d in zip(list(ids_now), gallery_docs):
  205. ids[i] = d
  206. return index, ids
  207. def _rm_id_in_galllery(self, index, ids, gallery_docs):
  208. remove_ids = list(filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
  209. remove_ids = np.asarray(remove_ids)
  210. index.remove_ids(remove_ids)
  211. for k in remove_ids:
  212. del ids[k]
  213. return index, ids
  214. def _save_gallery(self, index, ids, index_dir):
  215. Path(index_dir).mkdir(parents=True, exist_ok=True)
  216. if self._metric_type in self.BINARY_METRIC_TYPE:
  217. faiss.write_index_binary(index, os.path.join(index_dir, "vector.index"))
  218. else:
  219. faiss.write_index(index, os.path.join(index_dir, "vector.index"))
  220. with open(os.path.join(index_dir, "id_map.pkl"), "wb") as fd:
  221. pickle.dump(ids, fd)
  222. def get_file_list(data_file, root_dir, delimiter="\t"):
  223. root_dir = Path(root_dir)
  224. files = []
  225. labels = []
  226. lines = []
  227. with open(data_file, "r", encoding="utf-8") as f:
  228. lines = f.readlines()
  229. for line in lines:
  230. path, label = line.strip().split(delimiter)
  231. file_path = root_dir / path
  232. files.append(file_path.as_posix())
  233. labels.append(label)
  234. return files, labels