Ver Fonte

use index_info to cover other data except index

gaotingquan há 1 ano atrás
pai
commit
752a814daf
1 ficheiros alterados com 14 adições e 7 exclusões
  1. 14 7
      paddlex/inference/components/retrieval/faiss.py

+ 14 - 7
paddlex/inference/components/retrieval/faiss.py

@@ -29,11 +29,12 @@ class IndexData:
     IDMAP_FN = "id_map"
     IDMAP_SUFFIX = ".yaml"
 
-    def __init__(self, index, id_map, metric_type, index_type):
+    def __init__(self, index, index_info):
         self._index = index
-        self._id_map = id_map
-        self._metric_type = metric_type
-        self._index_type = index_type
+        self._index_info = index_info
+        self._id_map = index_info["id_map"]
+        self._metric_type = index_info["metric_type"]
+        self._index_type = index_info["index_type"]
 
     @property
     def index(self):
@@ -260,7 +261,9 @@ class FaissBuilder:
         index, ids = cls._add_gallery(
             metric_type, index, ids, features, gallery_docs, mode="new"
         )
-        return IndexData(index, ids, metric_type, index_type)
+        return IndexData(
+            index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type}
+        )
 
     @classmethod
     def remove(
@@ -288,7 +291,9 @@ class FaissBuilder:
         # remove ids in id_map, remove index data in faiss index
         index.remove_ids(remove_ids)
         ids = {k: v for k, v in ids.items() if k not in remove_ids}
-        return IndexData(index, ids, metric_type, index_type)
+        return IndexData(
+            index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type}
+        )
 
     @classmethod
     def append(cls, gallery_imgs, gallery_label, predict_func, index):
@@ -310,7 +315,9 @@ class FaissBuilder:
         index, ids = cls._add_gallery(
             metric_type, index, ids, features, gallery_docs, mode="append"
         )
-        return IndexData(index, ids, metric_type, index_type)
+        return IndexData(
+            index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type}
+        )
 
     @classmethod
     def _add_gallery(