faiss.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import pickle
  16. from pathlib import Path
  17. import faiss
  18. import numpy as np
  19. from ....utils import logging
  20. from ..base import BaseComponent
  21. class IndexData:
  22. def __init__(self, index, id_map):
  23. self._index = index
  24. self._id_map = id_map
  25. @property
  26. def index(self):
  27. return self._index
  28. @property
  29. def index_bytes(self):
  30. return faiss.serialize_index(self._index)
  31. @property
  32. def id_map(self):
  33. return self._id_map
  34. def save(self, save_path):
  35. index_data = {
  36. "index_bytes": self.index_bytes,
  37. "id_map": self.id_map,
  38. }
  39. with open(save_path, "wb") as fd:
  40. pickle.dump(index_data, fd)
  41. @classmethod
  42. def load(self, index):
  43. if isinstance(index, str):
  44. with open(index, "rb") as fd:
  45. index_data = pickle.load(fd)
  46. index_ = faiss.deserialize_index(index_data["index_bytes"])
  47. id_map = index_data["id_map"]
  48. assert index_.ntotal == len(
  49. id_map
  50. ), "data number in index is not equal in in id_map"
  51. return index_, id_map
  52. else:
  53. assert isinstance(index, IndexData)
  54. return index.index, index.id_map
  55. class FaissIndexer(BaseComponent):
  56. INPUT_KEYS = "feature"
  57. OUTPUT_KEYS = ["label", "score"]
  58. DEAULT_INPUTS = {"feature": "feature"}
  59. DEAULT_OUTPUTS = {"label": "label", "score": "score"}
  60. ENABLE_BATCH = True
  61. def __init__(
  62. self,
  63. index,
  64. metric_type="IP",
  65. return_k=1,
  66. score_thres=None,
  67. hamming_radius=None,
  68. ):
  69. super().__init__()
  70. if metric_type == "hamming":
  71. self.hamming_radius = hamming_radius
  72. else:
  73. self.score_thres = score_thres
  74. self._indexer, self.id_map = IndexData.load(index)
  75. self.metric_type = metric_type
  76. self.return_k = return_k
  77. def apply(self, feature):
  78. """apply"""
  79. scores_list, ids_list = self._indexer.search(np.array(feature), self.return_k)
  80. preds = []
  81. for scores, ids in zip(scores_list, ids_list):
  82. labels = []
  83. for id in ids:
  84. if id > 0:
  85. labels.append(self.id_map[id])
  86. preds.append({"score": scores, "label": labels})
  87. if self.metric_type == "hamming":
  88. idxs = np.where(scores_list[:, 0] > self.hamming_radius)[0]
  89. else:
  90. idxs = np.where(scores_list[:, 0] < self.score_thres)[0]
  91. for idx in idxs:
  92. preds[idx] = {"score": None, "label": None}
  93. return preds
  94. class FaissBuilder:
  95. SUPPORT_METRIC_TYPE = ("hamming", "IP", "L2")
  96. SUPPORT_INDEX_TYPE = ("Flat", "IVF", "HNSW32")
  97. BINARY_METRIC_TYPE = ("hamming",)
  98. BINARY_SUPPORT_INDEX_TYPE = ("Flat", "IVF", "BinaryHash")
  99. @classmethod
  100. def _get_index_type(cls, metric_type, index_type, num=None):
  101. # if IVF method, cal ivf number automaticlly
  102. if index_type == "IVF":
  103. index_type = index_type + str(min(int(num // 8), 65536))
  104. if metric_type in cls.BINARY_METRIC_TYPE:
  105. index_type += ",BFlat"
  106. else:
  107. index_type += ",Flat"
  108. # for binary index, add B at head of index_type
  109. if metric_type in cls.BINARY_METRIC_TYPE:
  110. assert (
  111. index_type in cls.BINARY_SUPPORT_INDEX_TYPE
  112. ), f"The metric type({metric_type}) only support {cls.BINARY_SUPPORT_INDEX_TYPE} index types!"
  113. index_type = "B" + index_type
  114. if index_type == "HNSW32":
  115. logging.warning("The HNSW32 method dose not support 'remove' operation")
  116. index_type = "HNSW32"
  117. if index_type == "Flat":
  118. index_type = "Flat"
  119. return index_type
  120. @classmethod
  121. def _get_metric_type(cls, metric_type):
  122. if metric_type == "hamming":
  123. return faiss.METRIC_Hamming
  124. elif metric_type == "jaccard":
  125. return faiss.METRIC_Jaccard
  126. elif metric_type == "IP":
  127. return faiss.METRIC_INNER_PRODUCT
  128. elif metric_type == "L2":
  129. return faiss.METRIC_L2
  130. @classmethod
  131. def build(
  132. cls,
  133. gallery_imgs,
  134. gallery_label,
  135. predict_func,
  136. metric_type="IP",
  137. index_type="HNSW32",
  138. ):
  139. assert (
  140. metric_type in cls.SUPPORT_METRIC_TYPE
  141. ), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!"
  142. if isinstance(gallery_label, str):
  143. gallery_docs, gallery_list = cls.load_gallery(gallery_label, gallery_imgs)
  144. else:
  145. gallery_docs, gallery_list = gallery_label, gallery_imgs
  146. features = [res["feature"] for res in predict_func(gallery_list)]
  147. dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
  148. features = np.array(features).astype(dtype)
  149. vector_num, vector_dim = features.shape
  150. if metric_type in cls.BINARY_METRIC_TYPE:
  151. index = faiss.index_binary_factory(
  152. vector_dim,
  153. cls._get_index_type(metric_type, index_type, vector_num),
  154. cls._get_metric_type(metric_type),
  155. )
  156. else:
  157. index = faiss.index_factory(
  158. vector_dim,
  159. cls._get_index_type(metric_type, index_type, vector_num),
  160. cls._get_metric_type(metric_type),
  161. )
  162. index = faiss.IndexIDMap2(index)
  163. ids = {}
  164. # calculate id for new data
  165. index, ids = cls._add_gallery(
  166. metric_type, index, ids, features, gallery_docs, mode="new"
  167. )
  168. return IndexData(index, ids)
  169. @classmethod
  170. def remove(
  171. cls,
  172. gallery_label,
  173. index,
  174. index_type="HNSW32",
  175. ):
  176. assert (
  177. index_type in cls.SUPPORT_INDEX_TYPE
  178. ), f"Supported index types only: {cls.SUPPORT_INDEX_TYPE}!"
  179. if isinstance(gallery_label, str):
  180. gallery_docs, _ = cls.load_gallery(gallery_label)
  181. else:
  182. gallery_docs = gallery_label
  183. index, ids = IndexData.load(index)
  184. if index_type == "HNSW32":
  185. raise RuntimeError(
  186. "The index_type: HNSW32 dose not support 'remove' operation"
  187. )
  188. # remove ids in id_map, remove index data in faiss index
  189. index, ids = cls._rm_id_in_gallery(index, ids, gallery_docs)
  190. return IndexData(index, ids)
  191. @classmethod
  192. def append(cls, gallery_imgs, gallery_label, predict_func, index, metric_type="IP"):
  193. assert (
  194. metric_type in cls.SUPPORT_METRIC_TYPE
  195. ), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!"
  196. if isinstance(gallery_label, str):
  197. gallery_docs, gallery_list = cls.load_gallery(gallery_label, gallery_imgs)
  198. else:
  199. gallery_docs, gallery_list = gallery_label, gallery_imgs
  200. features = [res["feature"] for res in predict_func(gallery_list)]
  201. dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
  202. features = np.array(features).astype(dtype)
  203. index, ids = IndexData.load(index)
  204. # calculate id for new data
  205. index, ids = cls._add_gallery(
  206. metric_type, index, ids, features, gallery_docs, mode="append"
  207. )
  208. return IndexData(index, ids)
  209. @classmethod
  210. def _add_gallery(
  211. cls, metric_type, index, ids, gallery_features, gallery_docs, mode
  212. ):
  213. start_id = max(ids.keys()) + 1 if ids else 0
  214. ids_now = (np.arange(0, len(gallery_docs)) + start_id).astype(np.int64)
  215. # only train when new index file
  216. if mode == "new":
  217. if metric_type in cls.BINARY_METRIC_TYPE:
  218. index.add(gallery_features)
  219. else:
  220. index.train(gallery_features)
  221. if not metric_type in cls.BINARY_METRIC_TYPE:
  222. index.add_with_ids(gallery_features, ids_now)
  223. for i, d in zip(list(ids_now), gallery_docs):
  224. ids[i] = d
  225. return index, ids
  226. @classmethod
  227. def _rm_id_in_gallery(cls, index, ids, gallery_docs):
  228. remove_ids = list(filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
  229. remove_ids = np.asarray(remove_ids)
  230. index.remove_ids(remove_ids)
  231. for k in remove_ids:
  232. del ids[k]
  233. return index, ids
  234. @classmethod
  235. def load_gallery(cls, gallery_label_path, gallery_imgs_root="", delimiter=" "):
  236. lines = []
  237. files = []
  238. labels = []
  239. root = Path(gallery_imgs_root)
  240. with open(gallery_label_path, "r", encoding="utf-8") as f:
  241. lines = f.readlines()
  242. for line in lines:
  243. path, label = line.strip().split(delimiter)
  244. file_path = root / path
  245. files.append(file_path.as_posix())
  246. labels.append(label)
  247. return labels, files