|
|
@@ -19,13 +19,20 @@ import faiss
|
|
|
import numpy as np
|
|
|
|
|
|
from ....utils import logging
|
|
|
+from ...utils.io import YAMLWriter, YAMLReader
|
|
|
from ..base import BaseComponent
|
|
|
|
|
|
|
|
|
class IndexData:
|
|
|
- def __init__(self, index, id_map):
|
|
|
+ VECTOR_FN = "vector"
|
|
|
+ VECTOR_SUFFIX = ".index"
|
|
|
+ IDMAP_FN = "id_map"
|
|
|
+ IDMAP_SUFFIX = ".yaml"
|
|
|
+
|
|
|
+ def __init__(self, index, id_map, metric_type):
|
|
|
self._index = index
|
|
|
self._id_map = id_map
|
|
|
+ self._metric_type = metric_type
|
|
|
|
|
|
@property
|
|
|
def index(self):
|
|
|
@@ -39,28 +46,72 @@ class IndexData:
|
|
|
def id_map(self):
|
|
|
return self._id_map
|
|
|
|
|
|
- def save(self, save_path):
|
|
|
- index_data = {
|
|
|
- "index_bytes": self.index_bytes,
|
|
|
- "id_map": self.id_map,
|
|
|
+ @property
|
|
|
+ def metric_type(self):
|
|
|
+ return self._metric_type
|
|
|
+
|
|
|
+ def _convert_int(self, id_map):
|
|
|
+ converted = {int(k): str(v) for k, v in id_map.items()}
|
|
|
+ return converted
|
|
|
+
|
|
|
+ def _convert_np(self, id_map):
|
|
|
+ converted = {np.int(k): str(v) for k, v in id_map.items()}
|
|
|
+ return converted
|
|
|
+
|
|
|
+ def save(self, save_dir):
|
|
|
+ save_dir = Path(save_dir)
|
|
|
+ save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+ vector_path = (save_dir / f"{self.VECTOR_FN}{self.VECTOR_SUFFIX}").as_posix()
|
|
|
+ index_info_path = (save_dir / f"{self.IDMAP_FN}{self.IDMAP_SUFFIX}").as_posix()
|
|
|
+
|
|
|
+ if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE:
|
|
|
+ faiss.write_index_binary(self.index, vector_path)
|
|
|
+ else:
|
|
|
+ faiss.write_index(self.index, vector_path)
|
|
|
+
|
|
|
+ index_info = {
|
|
|
+ "metric_type": self.metric_type,
|
|
|
+ "id_map": self._convert_int(self.id_map),
|
|
|
}
|
|
|
- with open(save_path, "wb") as fd:
|
|
|
- pickle.dump(index_data, fd)
|
|
|
+ yaml_writer = YAMLWriter()
|
|
|
+ yaml_writer.write(
|
|
|
+ index_info_path, index_info, default_flow_style=False, allow_unicode=True
|
|
|
+ )
|
|
|
|
|
|
@classmethod
|
|
|
def load(self, index):
|
|
|
if isinstance(index, str):
|
|
|
- with open(index, "rb") as fd:
|
|
|
- index_data = pickle.load(fd)
|
|
|
- index_ = faiss.deserialize_index(index_data["index_bytes"])
|
|
|
- id_map = index_data["id_map"]
|
|
|
- assert index_.ntotal == len(
|
|
|
+ index_root = Path(index)
|
|
|
+ vector_path = index_root / f"{self.VECTOR_FN}{self.VECTOR_SUFFIX}"
|
|
|
+ index_info_path = index_root / f"{self.IDMAP_FN}{self.IDMAP_SUFFIX}"
|
|
|
+
|
|
|
+ assert (
|
|
|
+ vector_path.exists()
|
|
|
+ ), f"Not found the {self.VECTOR_FN}{self.VECTOR_SUFFIX} file in {index}!"
|
|
|
+ assert (
|
|
|
+ index_info_path.exists()
|
|
|
+ ), f"Not found the {self.IDMAP_FN}{self.IDMAP_SUFFIX} file in {index}!"
|
|
|
+
|
|
|
+ yaml_reader = YAMLReader()
|
|
|
+ index_info = yaml_reader.read(index_info_path)
|
|
|
+ assert "id_map" in index_info and "metric_type" in index_info, index_info
|
|
|
+ # id_map = self._convert_np(index_info["id_map"])
|
|
|
+ # print(index_info["id_map"])
|
|
|
+ id_map = index_info["id_map"]
|
|
|
+ metric_type = index_info["metric_type"]
|
|
|
+
|
|
|
+ if metric_type in FaissBuilder.BINARY_METRIC_TYPE:
|
|
|
+ index = faiss.read_index_binary(vector_path.as_posix())
|
|
|
+ else:
|
|
|
+ index = faiss.read_index(vector_path.as_posix())
|
|
|
+ assert index.ntotal == len(
|
|
|
id_map
|
|
|
), "data number in index is not equal in in id_map"
|
|
|
- return index_, id_map
|
|
|
+
|
|
|
+ return index, id_map, metric_type
|
|
|
else:
|
|
|
assert isinstance(index, IndexData)
|
|
|
- return index.index, index.id_map
|
|
|
+ return index.index, index.id_map, index.metric_type
|
|
|
|
|
|
|
|
|
class FaissIndexer(BaseComponent):
|
|
|
@@ -75,22 +126,18 @@ class FaissIndexer(BaseComponent):
|
|
|
def __init__(
|
|
|
self,
|
|
|
index,
|
|
|
- metric_type="IP",
|
|
|
return_k=1,
|
|
|
score_thres=None,
|
|
|
hamming_radius=None,
|
|
|
):
|
|
|
super().__init__()
|
|
|
-
|
|
|
- if metric_type == "hamming":
|
|
|
+ self._indexer, self.id_map, self.metric_type = IndexData.load(index)
|
|
|
+ self.return_k = return_k
|
|
|
+ if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE:
|
|
|
self.hamming_radius = hamming_radius
|
|
|
else:
|
|
|
self.score_thres = score_thres
|
|
|
|
|
|
- self._indexer, self.id_map = IndexData.load(index)
|
|
|
- self.metric_type = metric_type
|
|
|
- self.return_k = return_k
|
|
|
-
|
|
|
def apply(self, feature):
|
|
|
"""apply"""
|
|
|
scores_list, ids_list = self._indexer.search(np.array(feature), self.return_k)
|
|
|
@@ -102,7 +149,7 @@ class FaissIndexer(BaseComponent):
|
|
|
labels.append(self.id_map[id])
|
|
|
preds.append({"score": scores, "label": labels})
|
|
|
|
|
|
- if self.metric_type == "hamming":
|
|
|
+ if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE:
|
|
|
idxs = np.where(scores_list[:, 0] > self.hamming_radius)[0]
|
|
|
else:
|
|
|
idxs = np.where(scores_list[:, 0] < self.score_thres)[0]
|
|
|
@@ -197,33 +244,41 @@ class FaissBuilder:
|
|
|
index, ids = cls._add_gallery(
|
|
|
metric_type, index, ids, features, gallery_docs, mode="new"
|
|
|
)
|
|
|
- return IndexData(index, ids)
|
|
|
+ return IndexData(index, ids, metric_type)
|
|
|
|
|
|
@classmethod
|
|
|
def remove(
|
|
|
cls,
|
|
|
- gallery_label,
|
|
|
+ remove_ids,
|
|
|
index,
|
|
|
- index_type="HNSW32",
|
|
|
+ # index_type="HNSW32",
|
|
|
):
|
|
|
- assert (
|
|
|
- index_type in cls.SUPPORT_INDEX_TYPE
|
|
|
- ), f"Supported index types only: {cls.SUPPORT_INDEX_TYPE}!"
|
|
|
-
|
|
|
- if isinstance(gallery_label, str):
|
|
|
- gallery_docs, _ = cls.load_gallery(gallery_label)
|
|
|
+ # assert (
|
|
|
+ # index_type in cls.SUPPORT_INDEX_TYPE
|
|
|
+ # ), f"Supported index types only: {cls.SUPPORT_INDEX_TYPE}!"
|
|
|
+
|
|
|
+ # if index_type == "HNSW32":
|
|
|
+ # raise RuntimeError(
|
|
|
+ # "The index_type: HNSW32 dose not support 'remove' operation"
|
|
|
+ # )
|
|
|
+
|
|
|
+ index, ids, metric_type = IndexData.load(index)
|
|
|
+ if isinstance(remove_ids, str):
|
|
|
+ lines = []
|
|
|
+ with open(remove_ids) as f:
|
|
|
+ lines = f.readlines()
|
|
|
+ remove_ids = []
|
|
|
+ for line in lines:
|
|
|
+ id_ = int(line.strip().split(" ")[0])
|
|
|
+ remove_ids.append(id_)
|
|
|
+ remove_ids = np.asarray(remove_ids)
|
|
|
else:
|
|
|
- gallery_docs = gallery_label
|
|
|
-
|
|
|
- index, ids = IndexData.load(index)
|
|
|
- if index_type == "HNSW32":
|
|
|
- raise RuntimeError(
|
|
|
- "The index_type: HNSW32 dose not support 'remove' operation"
|
|
|
- )
|
|
|
+ remove_ids = np.asarray(remove_ids)
|
|
|
|
|
|
# remove ids in id_map, remove index data in faiss index
|
|
|
- index, ids = cls._rm_id_in_gallery(index, ids, gallery_docs)
|
|
|
- return IndexData(index, ids)
|
|
|
+ index.remove_ids(remove_ids)
|
|
|
+ ids = {k: v for k, v in ids.items() if k not in remove_ids}
|
|
|
+ return IndexData(index, ids, metric_type)
|
|
|
|
|
|
@classmethod
|
|
|
def append(cls, gallery_imgs, gallery_label, predict_func, index, metric_type="IP"):
|
|
|
@@ -240,13 +295,13 @@ class FaissBuilder:
|
|
|
dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
|
|
|
features = np.array(features).astype(dtype)
|
|
|
|
|
|
- index, ids = IndexData.load(index)
|
|
|
+ index, ids, metric_type = IndexData.load(index)
|
|
|
|
|
|
# calculate id for new data
|
|
|
index, ids = cls._add_gallery(
|
|
|
metric_type, index, ids, features, gallery_docs, mode="append"
|
|
|
)
|
|
|
- return IndexData(index, ids)
|
|
|
+ return IndexData(index, ids, metric_type)
|
|
|
|
|
|
@classmethod
|
|
|
def _add_gallery(
|
|
|
@@ -270,16 +325,6 @@ class FaissBuilder:
|
|
|
return index, ids
|
|
|
|
|
|
@classmethod
|
|
|
- def _rm_id_in_gallery(cls, index, ids, gallery_docs):
|
|
|
- remove_ids = list(filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
|
|
|
- remove_ids = np.asarray(remove_ids)
|
|
|
- index.remove_ids(remove_ids)
|
|
|
- for k in remove_ids:
|
|
|
- del ids[k]
|
|
|
-
|
|
|
- return index, ids
|
|
|
-
|
|
|
- @classmethod
|
|
|
def load_gallery(cls, gallery_label_path, gallery_imgs_root="", delimiter=" "):
|
|
|
lines = []
|
|
|
files = []
|