faisser.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import pickle
  16. from pathlib import Path
  17. import faiss
  18. import numpy as np
  19. from ....utils import logging
  20. from ...utils.io import YAMLWriter, YAMLReader
  21. class IndexData:
  22. VECTOR_FN = "vector"
  23. VECTOR_SUFFIX = ".index"
  24. IDMAP_FN = "id_map"
  25. IDMAP_SUFFIX = ".yaml"
  26. def __init__(self, index, index_info):
  27. self._index = index
  28. self._index_info = index_info
  29. self._id_map = index_info["id_map"]
  30. self._metric_type = index_info["metric_type"]
  31. self._index_type = index_info["index_type"]
  32. @property
  33. def index(self):
  34. return self._index
  35. @property
  36. def index_bytes(self):
  37. return faiss.serialize_index(self._index)
  38. @property
  39. def id_map(self):
  40. return self._id_map
  41. @property
  42. def metric_type(self):
  43. return self._metric_type
  44. @property
  45. def index_type(self):
  46. return self._index_type
  47. @property
  48. def index_info(self):
  49. return {
  50. "index_type": self.index_type,
  51. "metric_type": self.metric_type,
  52. "id_map": self._convert_int(self.id_map),
  53. }
  54. @classmethod
  55. def from_bytes(cls, bytes):
  56. tup = pickle.loads(bytes)
  57. index = faiss.deserialize_index(tup[0])
  58. return cls(index, tup[1])
  59. def to_bytes(self):
  60. tup = (faiss.serialize_index(self._index), self.index_info)
  61. return pickle.dumps(tup)
  62. def _convert_int(self, id_map):
  63. return {int(k): str(v) for k, v in id_map.items()}
  64. @staticmethod
  65. def _convert_int64(id_map):
  66. return {np.int64(k): str(v) for k, v in id_map.items()}
  67. def save(self, save_dir):
  68. save_dir = Path(save_dir)
  69. save_dir.mkdir(parents=True, exist_ok=True)
  70. vector_path = (save_dir / f"{self.VECTOR_FN}{self.VECTOR_SUFFIX}").as_posix()
  71. index_info_path = (save_dir / f"{self.IDMAP_FN}{self.IDMAP_SUFFIX}").as_posix()
  72. if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE:
  73. faiss.write_index_binary(self.index, vector_path)
  74. else:
  75. faiss.write_index(self.index, vector_path)
  76. yaml_writer = YAMLWriter()
  77. yaml_writer.write(
  78. index_info_path,
  79. self.index_info,
  80. default_flow_style=False,
  81. allow_unicode=True,
  82. )
  83. @classmethod
  84. def load(cls, index):
  85. if isinstance(index, str):
  86. index_root = Path(index)
  87. vector_path = index_root / f"{cls.VECTOR_FN}{cls.VECTOR_SUFFIX}"
  88. index_info_path = index_root / f"{cls.IDMAP_FN}{cls.IDMAP_SUFFIX}"
  89. assert (
  90. vector_path.exists()
  91. ), f"Not found the {cls.VECTOR_FN}{cls.VECTOR_SUFFIX} file in {index}!"
  92. assert (
  93. index_info_path.exists()
  94. ), f"Not found the {cls.IDMAP_FN}{cls.IDMAP_SUFFIX} file in {index}!"
  95. yaml_reader = YAMLReader()
  96. index_info = yaml_reader.read(index_info_path)
  97. assert (
  98. "id_map" in index_info
  99. and "metric_type" in index_info
  100. and "index_type" in index_info
  101. ), f"The index_info file({index_info_path}) may have been damaged, `id_map` or `metric_type` or `index_type` not found in `index_info`."
  102. id_map = IndexData._convert_int64(index_info["id_map"])
  103. if index_info["metric_type"] in FaissBuilder.BINARY_METRIC_TYPE:
  104. index = faiss.read_index_binary(vector_path.as_posix())
  105. else:
  106. index = faiss.read_index(vector_path.as_posix())
  107. assert index.ntotal == len(
  108. id_map
  109. ), "data number in index is not equal in in id_map"
  110. return index, id_map, index_info["metric_type"], index_info["index_type"]
  111. else:
  112. assert isinstance(index, IndexData)
  113. return index.index, index.id_map, index.metric_type, index.index_type
  114. class FaissIndexer:
  115. def __init__(
  116. self,
  117. index,
  118. ):
  119. super().__init__()
  120. self._indexer, self.id_map, self.metric_type, index_type = IndexData.load(index)
  121. def __call__(self, feature, score_thres, hamming_radius, topk):
  122. scores_list, ids_list = self._indexer.search(np.array(feature), topk)
  123. preds = []
  124. for scores, ids in zip(scores_list, ids_list):
  125. labels = []
  126. for id in ids:
  127. if id > 0:
  128. labels.append(self.id_map[id])
  129. preds.append({"score": scores, "label": labels})
  130. if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE:
  131. idxs = np.where(scores_list[:, 0] > hamming_radius)[0]
  132. else:
  133. idxs = np.where(scores_list[:, 0] < score_thres)[0]
  134. for idx in idxs:
  135. preds[idx] = {"score": None, "label": None}
  136. return preds
  137. class FaissBuilder:
  138. SUPPORT_METRIC_TYPE = ("hamming", "IP", "L2")
  139. SUPPORT_INDEX_TYPE = ("Flat", "IVF", "HNSW32")
  140. BINARY_METRIC_TYPE = ("hamming",)
  141. BINARY_SUPPORT_INDEX_TYPE = ("Flat", "IVF", "BinaryHash")
  142. @classmethod
  143. def _get_index_type(cls, metric_type, index_type, num=None):
  144. # if IVF method, cal ivf number automaticlly
  145. if index_type == "IVF":
  146. index_type = index_type + str(min(int(num // 8), 65536))
  147. if metric_type in cls.BINARY_METRIC_TYPE:
  148. index_type += ",BFlat"
  149. else:
  150. index_type += ",Flat"
  151. # for binary index, add B at head of index_type
  152. if metric_type in cls.BINARY_METRIC_TYPE:
  153. assert (
  154. index_type in cls.BINARY_SUPPORT_INDEX_TYPE
  155. ), f"The metric type({metric_type}) only support {cls.BINARY_SUPPORT_INDEX_TYPE} index types!"
  156. index_type = "B" + index_type
  157. if index_type == "HNSW32":
  158. logging.warning("The HNSW32 method dose not support 'remove' operation")
  159. index_type = "HNSW32"
  160. if index_type == "Flat":
  161. index_type = "Flat"
  162. return index_type
  163. @classmethod
  164. def _get_metric_type(cls, metric_type):
  165. if metric_type == "hamming":
  166. return faiss.METRIC_Hamming
  167. elif metric_type == "jaccard":
  168. return faiss.METRIC_Jaccard
  169. elif metric_type == "IP":
  170. return faiss.METRIC_INNER_PRODUCT
  171. elif metric_type == "L2":
  172. return faiss.METRIC_L2
  173. @classmethod
  174. def build(
  175. cls,
  176. gallery_imgs,
  177. gallery_label,
  178. predict_func,
  179. metric_type="IP",
  180. index_type="HNSW32",
  181. ):
  182. assert (
  183. index_type in cls.SUPPORT_INDEX_TYPE
  184. ), f"Supported index types only: {cls.SUPPORT_INDEX_TYPE}!"
  185. assert (
  186. metric_type in cls.SUPPORT_METRIC_TYPE
  187. ), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!"
  188. if isinstance(gallery_label, str):
  189. gallery_docs, gallery_list = cls.load_gallery(gallery_label, gallery_imgs)
  190. else:
  191. gallery_docs, gallery_list = gallery_label, gallery_imgs
  192. features = [res["feature"] for res in predict_func(gallery_list)]
  193. dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
  194. features = np.array(features).astype(dtype)
  195. vector_num, vector_dim = features.shape
  196. if metric_type in cls.BINARY_METRIC_TYPE:
  197. index = faiss.index_binary_factory(
  198. vector_dim,
  199. cls._get_index_type(metric_type, index_type, vector_num),
  200. cls._get_metric_type(metric_type),
  201. )
  202. else:
  203. index = faiss.index_factory(
  204. vector_dim,
  205. cls._get_index_type(metric_type, index_type, vector_num),
  206. cls._get_metric_type(metric_type),
  207. )
  208. index = faiss.IndexIDMap2(index)
  209. ids = {}
  210. # calculate id for new data
  211. index, ids = cls._add_gallery(
  212. metric_type, index, ids, features, gallery_docs, mode="new"
  213. )
  214. return IndexData(
  215. index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type}
  216. )
  217. @classmethod
  218. def remove(
  219. cls,
  220. remove_ids,
  221. index,
  222. ):
  223. index, ids, metric_type, index_type = IndexData.load(index)
  224. if index_type == "HNSW32":
  225. raise RuntimeError(
  226. "The index_type: HNSW32 dose not support 'remove' operation"
  227. )
  228. if isinstance(remove_ids, str):
  229. lines = []
  230. with open(remove_ids) as f:
  231. lines = f.readlines()
  232. remove_ids = []
  233. for line in lines:
  234. id_ = int(line.strip().split(" ")[0])
  235. remove_ids.append(id_)
  236. remove_ids = np.asarray(remove_ids)
  237. else:
  238. remove_ids = np.asarray(remove_ids)
  239. # remove ids in id_map, remove index data in faiss index
  240. index.remove_ids(remove_ids)
  241. ids = {k: v for k, v in ids.items() if k not in remove_ids}
  242. return IndexData(
  243. index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type}
  244. )
  245. @classmethod
  246. def append(cls, gallery_imgs, gallery_label, predict_func, index):
  247. index, ids, metric_type, index_type = IndexData.load(index)
  248. assert (
  249. metric_type in cls.SUPPORT_METRIC_TYPE
  250. ), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!"
  251. if isinstance(gallery_label, str):
  252. gallery_docs, gallery_list = cls.load_gallery(gallery_label, gallery_imgs)
  253. else:
  254. gallery_docs, gallery_list = gallery_label, gallery_imgs
  255. features = [res["feature"] for res in predict_func(gallery_list)]
  256. dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
  257. features = np.array(features).astype(dtype)
  258. # calculate id for new data
  259. index, ids = cls._add_gallery(
  260. metric_type, index, ids, features, gallery_docs, mode="append"
  261. )
  262. return IndexData(
  263. index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type}
  264. )
  265. @classmethod
  266. def _add_gallery(
  267. cls, metric_type, index, ids, gallery_features, gallery_docs, mode
  268. ):
  269. start_id = max(ids.keys()) + 1 if ids else 0
  270. ids_now = (np.arange(0, len(gallery_docs)) + start_id).astype(np.int64)
  271. # only train when new index file
  272. if mode == "new":
  273. if metric_type in cls.BINARY_METRIC_TYPE:
  274. index.add(gallery_features)
  275. else:
  276. index.train(gallery_features)
  277. if metric_type not in cls.BINARY_METRIC_TYPE:
  278. index.add_with_ids(gallery_features, ids_now)
  279. # TODO(gaotingquan): how append when using hamming metric type
  280. # else:
  281. # pass
  282. for i, d in zip(list(ids_now), gallery_docs):
  283. ids[i] = d
  284. return index, ids
  285. @classmethod
  286. def load_gallery(cls, gallery_label_path, gallery_imgs_root="", delimiter=" "):
  287. lines = []
  288. files = []
  289. labels = []
  290. root = Path(gallery_imgs_root)
  291. with open(gallery_label_path, "r", encoding="utf-8") as f:
  292. lines = f.readlines()
  293. for line in lines:
  294. path, label = line.strip().split(delimiter)
  295. file_path = root / path
  296. files.append(file_path.as_posix())
  297. labels.append(label)
  298. return labels, files