faisser.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import pickle
  15. from pathlib import Path
  16. import numpy as np
  17. from ....utils import logging
  18. from ....utils.deps import class_requires_deps, is_dep_available
  19. from ...utils.io import YAMLReader, YAMLWriter
  20. if is_dep_available("faiss-cpu"):
  21. import faiss
  22. @class_requires_deps("faiss-cpu")
  23. class IndexData:
  24. VECTOR_FN = "vector"
  25. VECTOR_SUFFIX = ".index"
  26. IDMAP_FN = "id_map"
  27. IDMAP_SUFFIX = ".yaml"
  28. def __init__(self, index, index_info):
  29. self._index = index
  30. self._index_info = index_info
  31. self._id_map = index_info["id_map"]
  32. self._metric_type = index_info["metric_type"]
  33. self._index_type = index_info["index_type"]
  34. @property
  35. def index(self):
  36. return self._index
  37. @property
  38. def index_bytes(self):
  39. return faiss.serialize_index(self._index)
  40. @property
  41. def id_map(self):
  42. return self._id_map
  43. @property
  44. def metric_type(self):
  45. return self._metric_type
  46. @property
  47. def index_type(self):
  48. return self._index_type
  49. @property
  50. def index_info(self):
  51. return {
  52. "index_type": self.index_type,
  53. "metric_type": self.metric_type,
  54. "id_map": self._convert_int(self.id_map),
  55. }
  56. @classmethod
  57. def from_bytes(cls, bytes):
  58. tup = pickle.loads(bytes)
  59. index = faiss.deserialize_index(tup[0])
  60. return cls(index, tup[1])
  61. def to_bytes(self):
  62. tup = (faiss.serialize_index(self._index), self.index_info)
  63. return pickle.dumps(tup)
  64. def _convert_int(self, id_map):
  65. return {int(k): str(v) for k, v in id_map.items()}
  66. @staticmethod
  67. def _convert_int64(id_map):
  68. return {np.int64(k): str(v) for k, v in id_map.items()}
  69. def save(self, save_dir):
  70. save_dir = Path(save_dir)
  71. save_dir.mkdir(parents=True, exist_ok=True)
  72. vector_path = (save_dir / f"{self.VECTOR_FN}{self.VECTOR_SUFFIX}").as_posix()
  73. index_info_path = (save_dir / f"{self.IDMAP_FN}{self.IDMAP_SUFFIX}").as_posix()
  74. if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE:
  75. faiss.write_index_binary(self.index, vector_path)
  76. else:
  77. faiss.write_index(self.index, vector_path)
  78. yaml_writer = YAMLWriter()
  79. yaml_writer.write(
  80. index_info_path,
  81. self.index_info,
  82. default_flow_style=False,
  83. allow_unicode=True,
  84. )
  85. @classmethod
  86. def load(cls, index):
  87. if isinstance(index, str):
  88. index_root = Path(index)
  89. vector_path = index_root / f"{cls.VECTOR_FN}{cls.VECTOR_SUFFIX}"
  90. index_info_path = index_root / f"{cls.IDMAP_FN}{cls.IDMAP_SUFFIX}"
  91. assert (
  92. vector_path.exists()
  93. ), f"Not found the {cls.VECTOR_FN}{cls.VECTOR_SUFFIX} file in {index}!"
  94. assert (
  95. index_info_path.exists()
  96. ), f"Not found the {cls.IDMAP_FN}{cls.IDMAP_SUFFIX} file in {index}!"
  97. yaml_reader = YAMLReader()
  98. index_info = yaml_reader.read(index_info_path)
  99. assert (
  100. "id_map" in index_info
  101. and "metric_type" in index_info
  102. and "index_type" in index_info
  103. ), f"The index_info file({index_info_path}) may have been damaged, `id_map` or `metric_type` or `index_type` not found in `index_info`."
  104. id_map = IndexData._convert_int64(index_info["id_map"])
  105. if index_info["metric_type"] in FaissBuilder.BINARY_METRIC_TYPE:
  106. index = faiss.read_index_binary(vector_path.as_posix())
  107. else:
  108. index = faiss.read_index(vector_path.as_posix())
  109. assert index.ntotal == len(
  110. id_map
  111. ), "data number in index is not equal in in id_map"
  112. return index, id_map, index_info["metric_type"], index_info["index_type"]
  113. else:
  114. assert isinstance(index, IndexData)
  115. return index.index, index.id_map, index.metric_type, index.index_type
  116. class FaissIndexer:
  117. def __init__(
  118. self,
  119. index,
  120. ):
  121. super().__init__()
  122. self._indexer, self.id_map, self.metric_type, index_type = IndexData.load(index)
  123. def __call__(self, feature, score_thres, hamming_radius, topk):
  124. scores_list, ids_list = self._indexer.search(np.array(feature), topk)
  125. preds = []
  126. for scores, ids in zip(scores_list, ids_list):
  127. preds.append({"score": [], "label": []})
  128. for score, id in zip(scores, ids):
  129. if id >= 0:
  130. preds[-1]["score"].append(score)
  131. preds[-1]["label"].append(self.id_map[id])
  132. if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE:
  133. idxs = np.where(scores_list[:, 0] > hamming_radius)[0]
  134. else:
  135. idxs = np.where(scores_list[:, 0] < score_thres)[0]
  136. for idx in idxs:
  137. preds[idx] = {"score": None, "label": None}
  138. return preds
  139. @class_requires_deps("faiss-cpu")
  140. class FaissBuilder:
  141. SUPPORT_METRIC_TYPE = ("hamming", "IP", "L2")
  142. SUPPORT_INDEX_TYPE = ("Flat", "IVF", "HNSW32")
  143. BINARY_METRIC_TYPE = ("hamming",)
  144. BINARY_SUPPORT_INDEX_TYPE = ("Flat", "IVF", "BinaryHash")
  145. @classmethod
  146. def _get_index_type(cls, metric_type, index_type, num=None):
  147. # if IVF method, cal ivf number automatically
  148. if index_type == "IVF":
  149. index_type = index_type + str(min(int(num // 8), 65536))
  150. if metric_type in cls.BINARY_METRIC_TYPE:
  151. index_type += ",BFlat"
  152. else:
  153. index_type += ",Flat"
  154. # for binary index, add B at head of index_type
  155. if metric_type in cls.BINARY_METRIC_TYPE:
  156. assert (
  157. index_type in cls.BINARY_SUPPORT_INDEX_TYPE
  158. ), f"The metric type({metric_type}) only support {cls.BINARY_SUPPORT_INDEX_TYPE} index types!"
  159. index_type = "B" + index_type
  160. if index_type == "HNSW32":
  161. logging.warning("The HNSW32 method dose not support 'remove' operation")
  162. index_type = "HNSW32"
  163. if index_type == "Flat":
  164. index_type = "Flat"
  165. return index_type
  166. @classmethod
  167. def _get_metric_type(cls, metric_type):
  168. if metric_type == "hamming":
  169. return faiss.METRIC_Hamming
  170. elif metric_type == "jaccard":
  171. return faiss.METRIC_Jaccard
  172. elif metric_type == "IP":
  173. return faiss.METRIC_INNER_PRODUCT
  174. elif metric_type == "L2":
  175. return faiss.METRIC_L2
  176. @classmethod
  177. def build(
  178. cls,
  179. gallery_imgs,
  180. gallery_label,
  181. predict_func,
  182. metric_type="IP",
  183. index_type="HNSW32",
  184. ):
  185. assert (
  186. index_type in cls.SUPPORT_INDEX_TYPE
  187. ), f"Supported index types only: {cls.SUPPORT_INDEX_TYPE}!"
  188. assert (
  189. metric_type in cls.SUPPORT_METRIC_TYPE
  190. ), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!"
  191. if isinstance(gallery_label, str):
  192. gallery_docs, gallery_list = cls.load_gallery(gallery_label, gallery_imgs)
  193. else:
  194. gallery_docs, gallery_list = gallery_label, gallery_imgs
  195. features = [res["feature"] for res in predict_func(gallery_list)]
  196. dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
  197. features = np.array(features).astype(dtype)
  198. vector_num, vector_dim = features.shape
  199. if metric_type in cls.BINARY_METRIC_TYPE:
  200. index = faiss.index_binary_factory(
  201. vector_dim,
  202. cls._get_index_type(metric_type, index_type, vector_num),
  203. cls._get_metric_type(metric_type),
  204. )
  205. else:
  206. index = faiss.index_factory(
  207. vector_dim,
  208. cls._get_index_type(metric_type, index_type, vector_num),
  209. cls._get_metric_type(metric_type),
  210. )
  211. index = faiss.IndexIDMap2(index)
  212. ids = {}
  213. # calculate id for new data
  214. index, ids = cls._add_gallery(
  215. metric_type, index, ids, features, gallery_docs, mode="new"
  216. )
  217. return IndexData(
  218. index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type}
  219. )
  220. @classmethod
  221. def remove(
  222. cls,
  223. remove_ids,
  224. index,
  225. ):
  226. index, ids, metric_type, index_type = IndexData.load(index)
  227. if index_type == "HNSW32":
  228. raise RuntimeError(
  229. "The index_type: HNSW32 dose not support 'remove' operation"
  230. )
  231. if isinstance(remove_ids, str):
  232. lines = []
  233. with open(remove_ids) as f:
  234. lines = f.readlines()
  235. remove_ids = []
  236. for line in lines:
  237. id_ = int(line.strip().split(" ")[0])
  238. remove_ids.append(id_)
  239. remove_ids = np.asarray(remove_ids)
  240. else:
  241. remove_ids = np.asarray(remove_ids)
  242. # remove ids in id_map, remove index data in faiss index
  243. index.remove_ids(remove_ids)
  244. ids = {k: v for k, v in ids.items() if k not in remove_ids}
  245. return IndexData(
  246. index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type}
  247. )
  248. @classmethod
  249. def append(cls, gallery_imgs, gallery_label, predict_func, index):
  250. index, ids, metric_type, index_type = IndexData.load(index)
  251. assert (
  252. metric_type in cls.SUPPORT_METRIC_TYPE
  253. ), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!"
  254. if isinstance(gallery_label, str):
  255. gallery_docs, gallery_list = cls.load_gallery(gallery_label, gallery_imgs)
  256. else:
  257. gallery_docs, gallery_list = gallery_label, gallery_imgs
  258. features = [res["feature"] for res in predict_func(gallery_list)]
  259. dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
  260. features = np.array(features).astype(dtype)
  261. # calculate id for new data
  262. index, ids = cls._add_gallery(
  263. metric_type, index, ids, features, gallery_docs, mode="append"
  264. )
  265. return IndexData(
  266. index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type}
  267. )
  268. @classmethod
  269. def _add_gallery(
  270. cls, metric_type, index, ids, gallery_features, gallery_docs, mode
  271. ):
  272. start_id = max(ids.keys()) + 1 if ids else 0
  273. ids_now = (np.arange(0, len(gallery_docs)) + start_id).astype(np.int64)
  274. # only train when new index file
  275. if mode == "new":
  276. if metric_type in cls.BINARY_METRIC_TYPE:
  277. index.add(gallery_features)
  278. else:
  279. index.train(gallery_features)
  280. if metric_type not in cls.BINARY_METRIC_TYPE:
  281. index.add_with_ids(gallery_features, ids_now)
  282. # TODO(gaotingquan): how append when using hamming metric type
  283. # else:
  284. # pass
  285. for i, d in zip(list(ids_now), gallery_docs):
  286. ids[i] = d
  287. return index, ids
  288. @classmethod
  289. def load_gallery(cls, gallery_label_path, gallery_imgs_root="", delimiter=" "):
  290. lines = []
  291. files = []
  292. labels = []
  293. root = Path(gallery_imgs_root)
  294. with open(gallery_label_path, "r", encoding="utf-8") as f:
  295. lines = f.readlines()
  296. for line in lines:
  297. path, label = line.strip().split(delimiter)
  298. file_path = root / path
  299. files.append(file_path.as_posix())
  300. labels.append(label)
  301. return labels, files