Browse Source

upgrade api to save vector and id_map independently

gaotingquan 1 year ago
parent
commit
095239c625

+ 97 - 52
paddlex/inference/components/retrieval/faiss.py

@@ -19,13 +19,20 @@ import faiss
 import numpy as np
 
 from ....utils import logging
+from ...utils.io import YAMLWriter, YAMLReader
 from ..base import BaseComponent
 
 
 class IndexData:
-    def __init__(self, index, id_map):
+    VECTOR_FN = "vector"
+    VECTOR_SUFFIX = ".index"
+    IDMAP_FN = "id_map"
+    IDMAP_SUFFIX = ".yaml"
+
+    def __init__(self, index, id_map, metric_type):
         self._index = index
         self._id_map = id_map
+        self._metric_type = metric_type
 
     @property
     def index(self):
@@ -39,28 +46,72 @@ class IndexData:
     def id_map(self):
         return self._id_map
 
-    def save(self, save_path):
-        index_data = {
-            "index_bytes": self.index_bytes,
-            "id_map": self.id_map,
+    @property
+    def metric_type(self):
+        return self._metric_type
+
+    def _convert_int(self, id_map):
+        converted = {int(k): str(v) for k, v in id_map.items()}
+        return converted
+
+    def _convert_np(self, id_map):
+        converted = {np.int(k): str(v) for k, v in id_map.items()}
+        return converted
+
+    def save(self, save_dir):
+        save_dir = Path(save_dir)
+        save_dir.mkdir(parents=True, exist_ok=True)
+        vector_path = (save_dir / f"{self.VECTOR_FN}{self.VECTOR_SUFFIX}").as_posix()
+        index_info_path = (save_dir / f"{self.IDMAP_FN}{self.IDMAP_SUFFIX}").as_posix()
+
+        if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE:
+            faiss.write_index_binary(self.index, vector_path)
+        else:
+            faiss.write_index(self.index, vector_path)
+
+        index_info = {
+            "metric_type": self.metric_type,
+            "id_map": self._convert_int(self.id_map),
         }
-        with open(save_path, "wb") as fd:
-            pickle.dump(index_data, fd)
+        yaml_writer = YAMLWriter()
+        yaml_writer.write(
+            index_info_path, index_info, default_flow_style=False, allow_unicode=True
+        )
 
     @classmethod
     def load(self, index):
         if isinstance(index, str):
-            with open(index, "rb") as fd:
-                index_data = pickle.load(fd)
-            index_ = faiss.deserialize_index(index_data["index_bytes"])
-            id_map = index_data["id_map"]
-            assert index_.ntotal == len(
+            index_root = Path(index)
+            vector_path = index_root / f"{self.VECTOR_FN}{self.VECTOR_SUFFIX}"
+            index_info_path = index_root / f"{self.IDMAP_FN}{self.IDMAP_SUFFIX}"
+
+            assert (
+                vector_path.exists()
+            ), f"Not found the {self.VECTOR_FN}{self.VECTOR_SUFFIX} file in {index}!"
+            assert (
+                index_info_path.exists()
+            ), f"Not found the {self.IDMAP_FN}{self.IDMAP_SUFFIX} file in {index}!"
+
+            yaml_reader = YAMLReader()
+            index_info = yaml_reader.read(index_info_path)
+            assert "id_map" in index_info and "metric_type" in index_info, index_info
+            # id_map = self._convert_np(index_info["id_map"])
+            # print(index_info["id_map"])
+            id_map = index_info["id_map"]
+            metric_type = index_info["metric_type"]
+
+            if metric_type in FaissBuilder.BINARY_METRIC_TYPE:
+                index = faiss.read_index_binary(vector_path.as_posix())
+            else:
+                index = faiss.read_index(vector_path.as_posix())
+            assert index.ntotal == len(
                 id_map
             ), "data number in index is not equal in in id_map"
-            return index_, id_map
+
+            return index, id_map, metric_type
         else:
             assert isinstance(index, IndexData)
-            return index.index, index.id_map
+            return index.index, index.id_map, index.metric_type
 
 
 class FaissIndexer(BaseComponent):
@@ -75,22 +126,18 @@ class FaissIndexer(BaseComponent):
     def __init__(
         self,
         index,
-        metric_type="IP",
         return_k=1,
         score_thres=None,
         hamming_radius=None,
     ):
         super().__init__()
-
-        if metric_type == "hamming":
+        self._indexer, self.id_map, self.metric_type = IndexData.load(index)
+        self.return_k = return_k
+        if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE:
             self.hamming_radius = hamming_radius
         else:
             self.score_thres = score_thres
 
-        self._indexer, self.id_map = IndexData.load(index)
-        self.metric_type = metric_type
-        self.return_k = return_k
-
     def apply(self, feature):
         """apply"""
         scores_list, ids_list = self._indexer.search(np.array(feature), self.return_k)
@@ -102,7 +149,7 @@ class FaissIndexer(BaseComponent):
                     labels.append(self.id_map[id])
             preds.append({"score": scores, "label": labels})
 
-        if self.metric_type == "hamming":
+        if self.metric_type in FaissBuilder.BINARY_METRIC_TYPE:
             idxs = np.where(scores_list[:, 0] > self.hamming_radius)[0]
         else:
             idxs = np.where(scores_list[:, 0] < self.score_thres)[0]
@@ -197,33 +244,41 @@ class FaissBuilder:
         index, ids = cls._add_gallery(
             metric_type, index, ids, features, gallery_docs, mode="new"
         )
-        return IndexData(index, ids)
+        return IndexData(index, ids, metric_type)
 
     @classmethod
     def remove(
         cls,
-        gallery_label,
+        remove_ids,
         index,
-        index_type="HNSW32",
+        # index_type="HNSW32",
     ):
-        assert (
-            index_type in cls.SUPPORT_INDEX_TYPE
-        ), f"Supported index types only: {cls.SUPPORT_INDEX_TYPE}!"
-
-        if isinstance(gallery_label, str):
-            gallery_docs, _ = cls.load_gallery(gallery_label)
+        # assert (
+        #     index_type in cls.SUPPORT_INDEX_TYPE
+        # ), f"Supported index types only: {cls.SUPPORT_INDEX_TYPE}!"
+
+        # if index_type == "HNSW32":
+        #     raise RuntimeError(
+        #         "The index_type: HNSW32 dose not support 'remove' operation"
+        #     )
+
+        index, ids, metric_type = IndexData.load(index)
+        if isinstance(remove_ids, str):
+            lines = []
+            with open(remove_ids) as f:
+                lines = f.readlines()
+            remove_ids = []
+            for line in lines:
+                id_ = int(line.strip().split(" ")[0])
+                remove_ids.append(id_)
+            remove_ids = np.asarray(remove_ids)
         else:
-            gallery_docs = gallery_label
-
-        index, ids = IndexData.load(index)
-        if index_type == "HNSW32":
-            raise RuntimeError(
-                "The index_type: HNSW32 dose not support 'remove' operation"
-            )
+            remove_ids = np.asarray(remove_ids)
 
         # remove ids in id_map, remove index data in faiss index
-        index, ids = cls._rm_id_in_gallery(index, ids, gallery_docs)
-        return IndexData(index, ids)
+        index.remove_ids(remove_ids)
+        ids = {k: v for k, v in ids.items() if k not in remove_ids}
+        return IndexData(index, ids, metric_type)
 
     @classmethod
     def append(cls, gallery_imgs, gallery_label, predict_func, index, metric_type="IP"):
@@ -240,13 +295,13 @@ class FaissBuilder:
         dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
         features = np.array(features).astype(dtype)
 
-        index, ids = IndexData.load(index)
+        index, ids, metric_type = IndexData.load(index)
 
         # calculate id for new data
         index, ids = cls._add_gallery(
             metric_type, index, ids, features, gallery_docs, mode="append"
         )
-        return IndexData(index, ids)
+        return IndexData(index, ids, metric_type)
 
     @classmethod
     def _add_gallery(
@@ -270,16 +325,6 @@ class FaissBuilder:
         return index, ids
 
     @classmethod
-    def _rm_id_in_gallery(cls, index, ids, gallery_docs):
-        remove_ids = list(filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
-        remove_ids = np.asarray(remove_ids)
-        index.remove_ids(remove_ids)
-        for k in remove_ids:
-            del ids[k]
-
-        return index, ids
-
-    @classmethod
     def load_gallery(cls, gallery_label_path, gallery_imgs_root="", delimiter=" "):
         lines = []
         files = []

+ 3 - 8
paddlex/inference/pipelines/pp_shitu_v2.py

@@ -35,7 +35,6 @@ class ShiTuV2Pipeline(BasePipeline):
         det_batch_size=1,
         rec_batch_size=1,
         index=None,
-        metric_type="IP",
         score_thres=None,
         hamming_radius=None,
         return_k=5,
@@ -45,8 +44,7 @@ class ShiTuV2Pipeline(BasePipeline):
         super().__init__(device, predictor_kwargs)
         self._build_predictor(det_model, rec_model)
         self.set_predictor(det_batch_size, rec_batch_size, device)
-        self._metric_type, self._return_k, self._score_thres, self._hamming_radius = (
-            metric_type,
+        self._return_k, self._score_thres, self._hamming_radius = (
             return_k,
             score_thres,
             hamming_radius,
@@ -56,7 +54,6 @@ class ShiTuV2Pipeline(BasePipeline):
     def _build_indexer(self, index):
         return FaissIndexer(
             index=index,
-            metric_type=self._metric_type,
             return_k=self._return_k,
             score_thres=self._score_thres,
             hamming_radius=self._hamming_radius,
@@ -139,10 +136,8 @@ class ShiTuV2Pipeline(BasePipeline):
             **kwargs
         )
 
-    def remove_index(self, gallery_label, index, index_type="HNSW32", **kwargs):
-        return FaissBuilder.remove(
-            gallery_label, index, index_type=index_type, **kwargs
-        )
+    def remove_index(self, remove_ids, index, **kwargs):
+        return FaissBuilder.remove(remove_ids, index, **kwargs)
 
     def append_index(
         self,

+ 9 - 1
paddlex/inference/utils/io/__init__.py

@@ -13,7 +13,14 @@
 # limitations under the License.
 
 
-from .readers import ReaderType, ImageReader, VideoReader, CSVReader, PDFReader
+from .readers import (
+    ReaderType,
+    ImageReader,
+    VideoReader,
+    CSVReader,
+    PDFReader,
+    YAMLReader,
+)
 from .writers import (
     WriterType,
     ImageWriter,
@@ -22,4 +29,5 @@ from .writers import (
     CSVWriter,
     HtmlWriter,
     XlsxWriter,
+    YAMLWriter,
 )

+ 36 - 1
paddlex/inference/utils/io/readers.py

@@ -20,8 +20,16 @@ import fitz
 from PIL import Image, ImageOps
 import pandas as pd
 import numpy as np
+import yaml
 
-__all__ = ["ReaderType", "ImageReader", "VideoReader", "CSVReader", "PDFReader"]
+__all__ = [
+    "ReaderType",
+    "ImageReader",
+    "VideoReader",
+    "CSVReader",
+    "PDFReader",
+    "YAMLReader",
+]
 
 
 class ReaderType(enum.Enum):
@@ -33,6 +41,7 @@ class ReaderType(enum.Enum):
     JSON = 4
     TS = 5
     PDF = 6
+    YAML = 8
 
 
 class _BaseReader(object):
@@ -162,6 +171,24 @@ class VideoReader(_GenerativeReader):
             raise ValueError("Unsupported backend type")
 
 
+class YAMLReader(_BaseReader):
+
+    def __init__(self, backend="PyYAML", **bk_args):
+        super().__init__(backend, **bk_args)
+
+    def read(self, in_path):
+        return self._backend.read_file(str(in_path))
+
+    def _init_backend(self, bk_type, bk_args):
+        if bk_type == "PyYAML":
+            return YAMLReaderBackend(**bk_args)
+        else:
+            raise ValueError("Unsupported backend type")
+
+    def get_type(self):
+        return ReaderType.YAML
+
+
 class _BaseReaderBackend(object):
     """_BaseReaderBackend"""
 
@@ -316,3 +343,11 @@ class PandasCSVReaderBackend(_CSVReaderBackend):
     def read_file(self, in_path):
         """read image file from path by OpenCV"""
         return pd.read_csv(in_path)
+
+
+class YAMLReaderBackend(_BaseReaderBackend):
+
+    def read_file(self, in_path, **kwargs):
+        with open(in_path, "r", encoding="utf-8", **kwargs) as yaml_file:
+            data = yaml.safe_load(yaml_file)
+        return data

+ 37 - 3
paddlex/inference/utils/io/writers.py

@@ -22,6 +22,7 @@ import cv2
 import numpy as np
 from PIL import Image
 import pandas as pd
+import yaml
 from .tablepyxl import document_to_xl
 
 
@@ -33,6 +34,7 @@ __all__ = [
     "CSVWriter",
     "HtmlWriter",
     "XlsxWriter",
+    "YAMLWriter",
 ]
 
 
@@ -46,6 +48,7 @@ class WriterType(enum.Enum):
     HTML = 5
     XLSX = 6
     CSV = 7
+    YAML = 8
 
 
 class _BaseWriter(object):
@@ -189,15 +192,33 @@ class XlsxWriter(_BaseWriter):
         return WriterType.XLSX
 
 
+class YAMLWriter(_BaseWriter):
+    def __init__(self, backend="PyYAML", **bk_args):
+        super().__init__(backend=backend, **bk_args)
+
+    def write(self, out_path, obj, **bk_args):
+        return self._backend.write_obj(str(out_path), obj, **bk_args)
+
+    def _init_backend(self, bk_type, bk_args):
+        if bk_type == "PyYAML":
+            return YAMLWriterBackend(**bk_args)
+        else:
+            raise ValueError("Unsupported backend type")
+
+    def get_type(self):
+        """get type"""
+        return WriterType.YAML
+
+
 class _BaseWriterBackend(object):
     """_BaseWriterBackend"""
 
-    def write_obj(self, out_path, obj):
+    def write_obj(self, out_path, obj, **bk_args):
         """write object"""
         Path(out_path).parent.mkdir(parents=True, exist_ok=True)
-        return self._write_obj(out_path, obj)
+        return self._write_obj(out_path, obj, **bk_args)
 
-    def _write_obj(self, out_path, obj):
+    def _write_obj(self, out_path, obj, **bk_args):
         """write object"""
         raise NotImplementedError
 
@@ -299,6 +320,19 @@ class UJsonWriterBackend(_BaseJsonWriterBackend):
         raise NotImplementedError
 
 
+class YAMLWriterBackend(_BaseWriterBackend):
+
+    def __init__(self, mode="w", encoding="utf-8"):
+        super().__init__()
+        self.mode = mode
+        self.encoding = encoding
+
+    def _write_obj(self, out_path, obj, **bk_args):
+        """write text object"""
+        with open(out_path, mode=self.mode, encoding=self.encoding) as f:
+            yaml.dump(obj, f, **bk_args)
+
+
 class CSVWriter(_BaseWriter):
     """CSVWriter"""