faiss.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import pickle
  16. from pathlib import Path
  17. import faiss
  18. import numpy as np
  19. from ....utils import logging
  20. from ...utils.io import YAMLWriter, YAMLReader
  21. from ..base import BaseComponent
  22. class IndexData:
  23. VECTOR_FN = "vector"
  24. VECTOR_SUFFIX = ".index"
  25. IDMAP_FN = "id_map"
  26. IDMAP_SUFFIX = ".yaml"
  27. def __init__(self, index, id_map, metric_type):
  28. self._index = index
  29. self._id_map = id_map
  30. self._metric_type = metric_type
  31. @property
  32. def index(self):
  33. return self._index
  34. @property
  35. def index_bytes(self):
  36. return faiss.serialize_index(self._index)
  37. @property
  38. def id_map(self):
  39. return self._id_map
  40. @property
  41. def metric_type(self):
  42. return self._metric_type
  43. def _convert_int(self, id_map):
  44. converted = {int(k): str(v) for k, v in id_map.items()}
  45. return converted
  46. def _convert_np(self, id_map):
  47. converted = {np.int(k): str(v) for k, v in id_map.items()}
  48. return converted
  49. def save(self, save_dir):
  50. save_dir = Path(save_dir)
  51. save_dir.mkdir(parents=True, exist_ok=True)
  52. vector_path = (save_dir / f"{self.VECTOR_FN}{self.VECTOR_SUFFIX}").as_posix()
  53. index_info_path = (save_dir / f"{self.IDMAP_FN}{self.IDMAP_SUFFIX}").as_posix()
  54. if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE:
  55. faiss.write_index_binary(self.index, vector_path)
  56. else:
  57. faiss.write_index(self.index, vector_path)
  58. index_info = {
  59. "metric_type": self.metric_type,
  60. "id_map": self._convert_int(self.id_map),
  61. }
  62. yaml_writer = YAMLWriter()
  63. yaml_writer.write(
  64. index_info_path, index_info, default_flow_style=False, allow_unicode=True
  65. )
  66. @classmethod
  67. def load(self, index):
  68. if isinstance(index, str):
  69. index_root = Path(index)
  70. vector_path = index_root / f"{self.VECTOR_FN}{self.VECTOR_SUFFIX}"
  71. index_info_path = index_root / f"{self.IDMAP_FN}{self.IDMAP_SUFFIX}"
  72. assert (
  73. vector_path.exists()
  74. ), f"Not found the {self.VECTOR_FN}{self.VECTOR_SUFFIX} file in {index}!"
  75. assert (
  76. index_info_path.exists()
  77. ), f"Not found the {self.IDMAP_FN}{self.IDMAP_SUFFIX} file in {index}!"
  78. yaml_reader = YAMLReader()
  79. index_info = yaml_reader.read(index_info_path)
  80. assert "id_map" in index_info and "metric_type" in index_info, index_info
  81. # id_map = self._convert_np(index_info["id_map"])
  82. # print(index_info["id_map"])
  83. id_map = index_info["id_map"]
  84. metric_type = index_info["metric_type"]
  85. if metric_type in FaissBuilder.BINARY_METRIC_TYPE:
  86. index = faiss.read_index_binary(vector_path.as_posix())
  87. else:
  88. index = faiss.read_index(vector_path.as_posix())
  89. assert index.ntotal == len(
  90. id_map
  91. ), "data number in index is not equal in in id_map"
  92. return index, id_map, metric_type
  93. else:
  94. assert isinstance(index, IndexData)
  95. return index.index, index.id_map, index.metric_type
  96. class FaissIndexer(BaseComponent):
  97. INPUT_KEYS = "feature"
  98. OUTPUT_KEYS = ["label", "score"]
  99. DEAULT_INPUTS = {"feature": "feature"}
  100. DEAULT_OUTPUTS = {"label": "label", "score": "score"}
  101. ENABLE_BATCH = True
  102. def __init__(
  103. self,
  104. index,
  105. return_k=1,
  106. score_thres=None,
  107. hamming_radius=None,
  108. ):
  109. super().__init__()
  110. self._indexer, self.id_map, self.metric_type = IndexData.load(index)
  111. self.return_k = return_k
  112. if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE:
  113. self.hamming_radius = hamming_radius
  114. else:
  115. self.score_thres = score_thres
  116. def apply(self, feature):
  117. """apply"""
  118. scores_list, ids_list = self._indexer.search(np.array(feature), self.return_k)
  119. preds = []
  120. for scores, ids in zip(scores_list, ids_list):
  121. labels = []
  122. for id in ids:
  123. if id > 0:
  124. labels.append(self.id_map[id])
  125. preds.append({"score": scores, "label": labels})
  126. if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE:
  127. idxs = np.where(scores_list[:, 0] > self.hamming_radius)[0]
  128. else:
  129. idxs = np.where(scores_list[:, 0] < self.score_thres)[0]
  130. for idx in idxs:
  131. preds[idx] = {"score": None, "label": None}
  132. return preds
  133. class FaissBuilder:
  134. SUPPORT_METRIC_TYPE = ("hamming", "IP", "L2")
  135. SUPPORT_INDEX_TYPE = ("Flat", "IVF", "HNSW32")
  136. BINARY_METRIC_TYPE = ("hamming",)
  137. BINARY_SUPPORT_INDEX_TYPE = ("Flat", "IVF", "BinaryHash")
  138. @classmethod
  139. def _get_index_type(cls, metric_type, index_type, num=None):
  140. # if IVF method, cal ivf number automaticlly
  141. if index_type == "IVF":
  142. index_type = index_type + str(min(int(num // 8), 65536))
  143. if metric_type in cls.BINARY_METRIC_TYPE:
  144. index_type += ",BFlat"
  145. else:
  146. index_type += ",Flat"
  147. # for binary index, add B at head of index_type
  148. if metric_type in cls.BINARY_METRIC_TYPE:
  149. assert (
  150. index_type in cls.BINARY_SUPPORT_INDEX_TYPE
  151. ), f"The metric type({metric_type}) only support {cls.BINARY_SUPPORT_INDEX_TYPE} index types!"
  152. index_type = "B" + index_type
  153. if index_type == "HNSW32":
  154. logging.warning("The HNSW32 method dose not support 'remove' operation")
  155. index_type = "HNSW32"
  156. if index_type == "Flat":
  157. index_type = "Flat"
  158. return index_type
  159. @classmethod
  160. def _get_metric_type(cls, metric_type):
  161. if metric_type == "hamming":
  162. return faiss.METRIC_Hamming
  163. elif metric_type == "jaccard":
  164. return faiss.METRIC_Jaccard
  165. elif metric_type == "IP":
  166. return faiss.METRIC_INNER_PRODUCT
  167. elif metric_type == "L2":
  168. return faiss.METRIC_L2
  169. @classmethod
  170. def build(
  171. cls,
  172. gallery_imgs,
  173. gallery_label,
  174. predict_func,
  175. metric_type="IP",
  176. index_type="HNSW32",
  177. ):
  178. assert (
  179. metric_type in cls.SUPPORT_METRIC_TYPE
  180. ), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!"
  181. if isinstance(gallery_label, str):
  182. gallery_docs, gallery_list = cls.load_gallery(gallery_label, gallery_imgs)
  183. else:
  184. gallery_docs, gallery_list = gallery_label, gallery_imgs
  185. features = [res["feature"] for res in predict_func(gallery_list)]
  186. dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
  187. features = np.array(features).astype(dtype)
  188. vector_num, vector_dim = features.shape
  189. if metric_type in cls.BINARY_METRIC_TYPE:
  190. index = faiss.index_binary_factory(
  191. vector_dim,
  192. cls._get_index_type(metric_type, index_type, vector_num),
  193. cls._get_metric_type(metric_type),
  194. )
  195. else:
  196. index = faiss.index_factory(
  197. vector_dim,
  198. cls._get_index_type(metric_type, index_type, vector_num),
  199. cls._get_metric_type(metric_type),
  200. )
  201. index = faiss.IndexIDMap2(index)
  202. ids = {}
  203. # calculate id for new data
  204. index, ids = cls._add_gallery(
  205. metric_type, index, ids, features, gallery_docs, mode="new"
  206. )
  207. return IndexData(index, ids, metric_type)
  208. @classmethod
  209. def remove(
  210. cls,
  211. remove_ids,
  212. index,
  213. # index_type="HNSW32",
  214. ):
  215. # assert (
  216. # index_type in cls.SUPPORT_INDEX_TYPE
  217. # ), f"Supported index types only: {cls.SUPPORT_INDEX_TYPE}!"
  218. # if index_type == "HNSW32":
  219. # raise RuntimeError(
  220. # "The index_type: HNSW32 dose not support 'remove' operation"
  221. # )
  222. index, ids, metric_type = IndexData.load(index)
  223. if isinstance(remove_ids, str):
  224. lines = []
  225. with open(remove_ids) as f:
  226. lines = f.readlines()
  227. remove_ids = []
  228. for line in lines:
  229. id_ = int(line.strip().split(" ")[0])
  230. remove_ids.append(id_)
  231. remove_ids = np.asarray(remove_ids)
  232. else:
  233. remove_ids = np.asarray(remove_ids)
  234. # remove ids in id_map, remove index data in faiss index
  235. index.remove_ids(remove_ids)
  236. ids = {k: v for k, v in ids.items() if k not in remove_ids}
  237. return IndexData(index, ids, metric_type)
  238. @classmethod
  239. def append(cls, gallery_imgs, gallery_label, predict_func, index, metric_type="IP"):
  240. assert (
  241. metric_type in cls.SUPPORT_METRIC_TYPE
  242. ), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!"
  243. if isinstance(gallery_label, str):
  244. gallery_docs, gallery_list = cls.load_gallery(gallery_label, gallery_imgs)
  245. else:
  246. gallery_docs, gallery_list = gallery_label, gallery_imgs
  247. features = [res["feature"] for res in predict_func(gallery_list)]
  248. dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
  249. features = np.array(features).astype(dtype)
  250. index, ids, metric_type = IndexData.load(index)
  251. # calculate id for new data
  252. index, ids = cls._add_gallery(
  253. metric_type, index, ids, features, gallery_docs, mode="append"
  254. )
  255. return IndexData(index, ids, metric_type)
  256. @classmethod
  257. def _add_gallery(
  258. cls, metric_type, index, ids, gallery_features, gallery_docs, mode
  259. ):
  260. start_id = max(ids.keys()) + 1 if ids else 0
  261. ids_now = (np.arange(0, len(gallery_docs)) + start_id).astype(np.int64)
  262. # only train when new index file
  263. if mode == "new":
  264. if metric_type in cls.BINARY_METRIC_TYPE:
  265. index.add(gallery_features)
  266. else:
  267. index.train(gallery_features)
  268. if not metric_type in cls.BINARY_METRIC_TYPE:
  269. index.add_with_ids(gallery_features, ids_now)
  270. for i, d in zip(list(ids_now), gallery_docs):
  271. ids[i] = d
  272. return index, ids
  273. @classmethod
  274. def load_gallery(cls, gallery_label_path, gallery_imgs_root="", delimiter=" "):
  275. lines = []
  276. files = []
  277. labels = []
  278. root = Path(gallery_imgs_root)
  279. with open(gallery_label_path, "r", encoding="utf-8") as f:
  280. lines = f.readlines()
  281. for line in lines:
  282. path, label = line.strip().split(delimiter)
  283. file_path = root / path
  284. files.append(file_path.as_posix())
  285. labels.append(label)
  286. return labels, files