Browse Source

rm unnecessary arg in remove_index

gaotingquan 1 year ago
parent
commit
64e54b3600

+ 27 - 24
paddlex/inference/components/retrieval/faiss.py

@@ -168,7 +168,10 @@ class FaissBuilder:
             metric_type in cls.SUPPORT_METRIC_TYPE
         ), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!"
 
-        gallery_list, gallery_docs = cls.get_gallery(gallery_imgs, gallery_label)
+        if isinstance(gallery_label, str):
+            gallery_docs, gallery_list = cls.load_gallery(gallery_label, gallery_imgs)
+        else:
+            gallery_docs, gallery_list = gallery_label, gallery_imgs
 
         features = [res["feature"] for res in predict_func(gallery_list)]
         dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
@@ -199,7 +202,6 @@ class FaissBuilder:
     @classmethod
     def remove(
         cls,
-        gallery_imgs,
         gallery_label,
         index,
         index_type="HNSW32",
@@ -208,9 +210,12 @@ class FaissBuilder:
             index_type in cls.SUPPORT_INDEX_TYPE
         ), f"Supported index types only: {cls.SUPPORT_INDEX_TYPE}!"
 
-        gallery_list, gallery_docs = cls.get_gallery(gallery_imgs, gallery_label)
-        index, ids = IndexData.load(index)
+        if isinstance(gallery_label, str):
+            gallery_docs, _ = cls.load_gallery(gallery_label)
+        else:
+            gallery_docs = gallery_label
 
+        index, ids = IndexData.load(index)
         if index_type == "HNSW32":
             raise RuntimeError(
                 "The index_type: HNSW32 dose not support 'remove' operation"
@@ -226,7 +231,11 @@ class FaissBuilder:
             metric_type in cls.SUPPORT_METRIC_TYPE
         ), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!"
 
-        gallery_list, gallery_docs = cls.get_gallery(gallery_imgs, gallery_label)
+        if isinstance(gallery_label, str):
+            gallery_docs, gallery_list = cls.load_gallery(gallery_label, gallery_imgs)
+        else:
+            gallery_docs, gallery_list = gallery_label, gallery_imgs
+
         features = [res["feature"] for res in predict_func(gallery_list)]
         dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
         features = np.array(features).astype(dtype)
@@ -271,22 +280,16 @@ class FaissBuilder:
         return index, ids
 
     @classmethod
-    def get_gallery(cls, gallery_imgs, gallery_label, delimiter=" "):
-        if isinstance(gallery_label, str):
-            assert isinstance(gallery_imgs, str)
-            gallery_imgs = Path(gallery_imgs)
-            files = []
-            labels = []
-            lines = []
-            with open(gallery_label, "r", encoding="utf-8") as f:
-                lines = f.readlines()
-            for line in lines:
-                path, label = line.strip().split(delimiter)
-                file_path = gallery_imgs / path
-                files.append(file_path.as_posix())
-                labels.append(label)
-            return files, labels
-        else:
-            assert isinstance(gallery_imgs, list)
-            assert isinstance(gallery_label, list)
-            return gallery_imgs, gallery_label
+    def load_gallery(cls, gallery_label_path, gallery_imgs_root="", delimiter=" "):
+        lines = []
+        files = []
+        labels = []
+        root = Path(gallery_imgs_root)
+        with open(gallery_label_path, "r", encoding="utf-8") as f:
+            lines = f.readlines()
+        for line in lines:
+            path, label = line.strip().split(delimiter)
+            file_path = root / path
+            files.append(file_path.as_posix())
+            labels.append(label)
+        return labels, files

+ 2 - 4
paddlex/inference/pipelines/pp_shitu_v2.py

@@ -139,11 +139,9 @@ class ShiTuV2Pipeline(BasePipeline):
             **kwargs
         )
 
-    def remove_index(
-        self, gallery_imgs, gallery_label, index, index_type="HNSW32", **kwargs
-    ):
+    def remove_index(self, gallery_label, index, index_type="HNSW32", **kwargs):
         return FaissBuilder.remove(
-            gallery_imgs, gallery_label, index, index_type=index_type, **kwargs
+            gallery_label, index, index_type=index_type, **kwargs
         )
 
     def append_index(