|
|
@@ -168,7 +168,10 @@ class FaissBuilder:
|
|
|
metric_type in cls.SUPPORT_METRIC_TYPE
|
|
|
), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!"
|
|
|
|
|
|
- gallery_list, gallery_docs = cls.get_gallery(gallery_imgs, gallery_label)
|
|
|
+ if isinstance(gallery_label, str):
|
|
|
+ gallery_docs, gallery_list = cls.load_gallery(gallery_label, gallery_imgs)
|
|
|
+ else:
|
|
|
+ gallery_docs, gallery_list = gallery_label, gallery_imgs
|
|
|
|
|
|
features = [res["feature"] for res in predict_func(gallery_list)]
|
|
|
dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
|
|
|
@@ -199,7 +202,6 @@ class FaissBuilder:
|
|
|
@classmethod
|
|
|
def remove(
|
|
|
cls,
|
|
|
- gallery_imgs,
|
|
|
gallery_label,
|
|
|
index,
|
|
|
index_type="HNSW32",
|
|
|
@@ -208,9 +210,12 @@ class FaissBuilder:
|
|
|
index_type in cls.SUPPORT_INDEX_TYPE
|
|
|
), f"Supported index types only: {cls.SUPPORT_INDEX_TYPE}!"
|
|
|
|
|
|
- gallery_list, gallery_docs = cls.get_gallery(gallery_imgs, gallery_label)
|
|
|
- index, ids = IndexData.load(index)
|
|
|
+ if isinstance(gallery_label, str):
|
|
|
+ gallery_docs, _ = cls.load_gallery(gallery_label)
|
|
|
+ else:
|
|
|
+ gallery_docs = gallery_label
|
|
|
|
|
|
+ index, ids = IndexData.load(index)
|
|
|
if index_type == "HNSW32":
|
|
|
raise RuntimeError(
|
|
|
"The index_type: HNSW32 dose not support 'remove' operation"
|
|
|
@@ -226,7 +231,11 @@ class FaissBuilder:
|
|
|
metric_type in cls.SUPPORT_METRIC_TYPE
|
|
|
), f"Supported metric types only: {cls.SUPPORT_METRIC_TYPE}!"
|
|
|
|
|
|
- gallery_list, gallery_docs = cls.get_gallery(gallery_imgs, gallery_label)
|
|
|
+ if isinstance(gallery_label, str):
|
|
|
+ gallery_docs, gallery_list = cls.load_gallery(gallery_label, gallery_imgs)
|
|
|
+ else:
|
|
|
+ gallery_docs, gallery_list = gallery_label, gallery_imgs
|
|
|
+
|
|
|
features = [res["feature"] for res in predict_func(gallery_list)]
|
|
|
dtype = np.uint8 if metric_type in cls.BINARY_METRIC_TYPE else np.float32
|
|
|
features = np.array(features).astype(dtype)
|
|
|
@@ -271,22 +280,16 @@ class FaissBuilder:
|
|
|
return index, ids
|
|
|
|
|
|
@classmethod
|
|
|
- def get_gallery(cls, gallery_imgs, gallery_label, delimiter=" "):
|
|
|
- if isinstance(gallery_label, str):
|
|
|
- assert isinstance(gallery_imgs, str)
|
|
|
- gallery_imgs = Path(gallery_imgs)
|
|
|
- files = []
|
|
|
- labels = []
|
|
|
- lines = []
|
|
|
- with open(gallery_label, "r", encoding="utf-8") as f:
|
|
|
- lines = f.readlines()
|
|
|
- for line in lines:
|
|
|
- path, label = line.strip().split(delimiter)
|
|
|
- file_path = gallery_imgs / path
|
|
|
- files.append(file_path.as_posix())
|
|
|
- labels.append(label)
|
|
|
- return files, labels
|
|
|
- else:
|
|
|
- assert isinstance(gallery_imgs, list)
|
|
|
- assert isinstance(gallery_label, list)
|
|
|
- return gallery_imgs, gallery_label
|
|
|
+ def load_gallery(cls, gallery_label_path, gallery_imgs_root="", delimiter=" "):
|
|
|
+ lines = []
|
|
|
+ files = []
|
|
|
+ labels = []
|
|
|
+ root = Path(gallery_imgs_root)
|
|
|
+ with open(gallery_label_path, "r", encoding="utf-8") as f:
|
|
|
+ lines = f.readlines()
|
|
|
+ for line in lines:
|
|
|
+ path, label = line.strip().split(delimiter)
|
|
|
+ file_path = root / path
|
|
|
+ files.append(file_path.as_posix())
|
|
|
+ labels.append(label)
|
|
|
+ return labels, files
|