# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import pickle from pathlib import Path import faiss import numpy as np from ....utils import logging from ...utils.io import YAMLWriter, YAMLReader from ..base import BaseComponent class IndexData: VECTOR_FN = "vector" VECTOR_SUFFIX = ".index" IDMAP_FN = "id_map" IDMAP_SUFFIX = ".yaml" def __init__(self, index, index_info): self._index = index self._index_info = index_info self._id_map = index_info["id_map"] self._metric_type = index_info["metric_type"] self._index_type = index_info["index_type"] @property def index(self): return self._index @property def index_bytes(self): return faiss.serialize_index(self._index) @property def id_map(self): return self._id_map @property def metric_type(self): return self._metric_type @property def index_type(self): return self._index_type @property def index_info(self): return { "index_type": self.index_type, "metric_type": self.metric_type, "id_map": self._convert_int(self.id_map), } def _convert_int(self, id_map): return {int(k): str(v) for k, v in id_map.items()} @staticmethod def _convert_int64(id_map): return {np.int64(k): str(v) for k, v in id_map.items()} def save(self, save_dir): save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) vector_path = (save_dir / f"{self.VECTOR_FN}{self.VECTOR_SUFFIX}").as_posix() index_info_path = (save_dir / f"{self.IDMAP_FN}{self.IDMAP_SUFFIX}").as_posix() if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE: faiss.write_index_binary(self.index, vector_path) else: faiss.write_index(self.index, vector_path) yaml_writer = YAMLWriter() yaml_writer.write( index_info_path, self.index_info, default_flow_style=False, allow_unicode=True, ) @classmethod def load(cls, index): if isinstance(index, str): index_root = Path(index) vector_path = index_root / f"{cls.VECTOR_FN}{cls.VECTOR_SUFFIX}" index_info_path = index_root / f"{cls.IDMAP_FN}{cls.IDMAP_SUFFIX}" assert ( vector_path.exists() ), f"Not found the {cls.VECTOR_FN}{cls.VECTOR_SUFFIX} file in {index}!" assert ( index_info_path.exists() ), f"Not found the {cls.IDMAP_FN}{cls.IDMAP_SUFFIX} file in {index}!" yaml_reader = YAMLReader() index_info = yaml_reader.read(index_info_path) assert ( "id_map" in index_info and "metric_type" in index_info and "index_type" in index_info ), f"The index_info file({index_info_path}) may have been damaged, `id_map` or `metric_type` or `index_type` not found in `index_info`." id_map = IndexData._convert_int64(index_info["id_map"]) if index_info["metric_type"] in FaissBuilder.BINARY_METRIC_TYPE: index = faiss.read_index_binary(vector_path.as_posix()) else: index = faiss.read_index(vector_path.as_posix()) assert index.ntotal == len( id_map ), "data number in index is not equal in in id_map" return index, id_map, index_info["metric_type"], index_info["index_type"] else: assert isinstance(index, IndexData) return index.index, index.id_map, index.metric_type, index.index_type class FaissIndexer(BaseComponent): INPUT_KEYS = "feature" OUTPUT_KEYS = ["label", "score"] DEAULT_INPUTS = {"feature": "feature"} DEAULT_OUTPUTS = {"label": "label", "score": "score"} ENABLE_BATCH = True def __init__( self, index, return_k=1, score_thres=None, hamming_radius=None, ): super().__init__() self._indexer, self.id_map, self.metric_type, index_type = IndexData.load(index) self.return_k = return_k if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE: self.hamming_radius = hamming_radius else: self.score_thres = score_thres def apply(self, feature): """apply""" scores_list, ids_list = self._indexer.search(np.array(feature), self.return_k) preds = [] for scores, ids in zip(scores_list, ids_list): labels = [] for id in ids: if id > 0: labels.append(self.id_map[id]) preds.append({"score": scores, "label": labels}) if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE: idxs = np.where(scores_list[:, 0] > self.hamming_radius)[0] else: idxs = np.where(scores_list[:, 0] < self.score_thres)[0] for idx in idxs: preds[idx] = {"score": None, "label": None} return preds class FaissBuilder: SUPPORT_METRIC_TYPE = ("hamming", "IP", "L2") SUPPORT_INDEX_TYPE = ("Flat", "IVF", "HNSW32") BINARY_METRIC_TYPE = ("hamming",) BINARY_SUPPORT_INDEX_TYPE = ("Flat", "IVF", "BinaryHash") @classmethod def _get_index_type(cls, metric_type, index_type, num=None): # if IVF method, cal ivf number automaticlly if index_type == "IVF": index_type = index_type + str(min(int(num // 8), 65536)) if metric_type in cls.BINARY_METRIC_TYPE: index_type += ",BFlat" else: index_type += ",Flat" # for binary index, add B at head of index_type if metric_type in cls.BINARY_METRIC_TYPE: assert ( index_type in cls.BINARY_SUPPORT_INDEX_TYPE ), f"The metric type({metric_type}) only support {cls.BINARY_SUPPORT_INDEX_TYPE} index types!" index_type = "B" + index_type if index_type == "HNSW32": logging.warning("The HNSW32 method dose not support 'remove' operation") index_type = "HNSW32" if index_type == "Flat": index_type = "Flat" return index_type @classmethod def _get_metric_type(cls, metric_type): if metric_type == "hamming": return faiss.METRIC_Hamming elif metric_type == "jaccard": return faiss.METRIC_Jaccard elif metric_type == "IP": return faiss.METRIC_INNER_PRODUCT elif metric_type == "L2": return faiss.METRIC_L2 @classmethod def build( cls, gallery_imgs, gallery_label, predict_func, metric_type="IP", index_type="HNSW32", ): assert ( index_type in cls.SUPPORT_INDEX_TYPE ), f"Supported index types only: {cls.SUPPORT_INDEX_TYPE}!" assert ( metric_type in cls.SUPPORT_METRIC_TYPE ), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!" if isinstance(gallery_label, str): gallery_docs, gallery_list = cls.load_gallery(gallery_label, gallery_imgs) else: gallery_docs, gallery_list = gallery_label, gallery_imgs features = [res["feature"] for res in predict_func(gallery_list)] dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32 features = np.array(features).astype(dtype) vector_num, vector_dim = features.shape if metric_type in cls.BINARY_METRIC_TYPE: index = faiss.index_binary_factory( vector_dim, cls._get_index_type(metric_type, index_type, vector_num), cls._get_metric_type(metric_type), ) else: index = faiss.index_factory( vector_dim, cls._get_index_type(metric_type, index_type, vector_num), cls._get_metric_type(metric_type), ) index = faiss.IndexIDMap2(index) ids = {} # calculate id for new data index, ids = cls._add_gallery( metric_type, index, ids, features, gallery_docs, mode="new" ) return IndexData( index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type} ) @classmethod def remove( cls, remove_ids, index, ): index, ids, metric_type, index_type = IndexData.load(index) if index_type == "HNSW32": raise RuntimeError( "The index_type: HNSW32 dose not support 'remove' operation" ) if isinstance(remove_ids, str): lines = [] with open(remove_ids) as f: lines = f.readlines() remove_ids = [] for line in lines: id_ = int(line.strip().split(" ")[0]) remove_ids.append(id_) remove_ids = np.asarray(remove_ids) else: remove_ids = np.asarray(remove_ids) # remove ids in id_map, remove index data in faiss index index.remove_ids(remove_ids) ids = {k: v for k, v in ids.items() if k not in remove_ids} return IndexData( index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type} ) @classmethod def append(cls, gallery_imgs, gallery_label, predict_func, index): index, ids, metric_type, index_type = IndexData.load(index) assert ( metric_type in cls.SUPPORT_METRIC_TYPE ), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!" if isinstance(gallery_label, str): gallery_docs, gallery_list = cls.load_gallery(gallery_label, gallery_imgs) else: gallery_docs, gallery_list = gallery_label, gallery_imgs features = [res["feature"] for res in predict_func(gallery_list)] dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32 features = np.array(features).astype(dtype) # calculate id for new data index, ids = cls._add_gallery( metric_type, index, ids, features, gallery_docs, mode="append" ) return IndexData( index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type} ) @classmethod def _add_gallery( cls, metric_type, index, ids, gallery_features, gallery_docs, mode ): start_id = max(ids.keys()) + 1 if ids else 0 ids_now = (np.arange(0, len(gallery_docs)) + start_id).astype(np.int64) # only train when new index file if mode == "new": if metric_type in cls.BINARY_METRIC_TYPE: index.add(gallery_features) else: index.train(gallery_features) if not metric_type in cls.BINARY_METRIC_TYPE: index.add_with_ids(gallery_features, ids_now) for i, d in zip(list(ids_now), gallery_docs): ids[i] = d return index, ids @classmethod def load_gallery(cls, gallery_label_path, gallery_imgs_root="", delimiter=" "): lines = [] files = [] labels = [] root = Path(gallery_imgs_root) with open(gallery_label_path, "r", encoding="utf-8") as f: lines = f.readlines() for line in lines: path, label = line.strip().split(delimiter) file_path = root / path files.append(file_path.as_posix()) labels.append(label) return labels, files