|
|
@@ -0,0 +1,145 @@
|
|
|
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
|
|
+#
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
+# You may obtain a copy of the License at
|
|
|
+#
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+#
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
+# limitations under the License.
|
|
|
+
|
|
|
+from ..cls import ClsConfig
|
|
|
+from ....utils.misc import abspath
|
|
|
+
|
|
|
+
|
|
|
+class ShiTuRecConfig(ClsConfig):
|
|
|
+ """ShiTu Recognition Config"""
|
|
|
+
|
|
|
+ def update_dataset(
|
|
|
+ self,
|
|
|
+ dataset_path: str,
|
|
|
+ dataset_type: str = None,
|
|
|
+ *,
|
|
|
+ train_list_path: str = None,
|
|
|
+ ):
|
|
|
+ """update dataset settings
|
|
|
+
|
|
|
+ Args:
|
|
|
+ dataset_path (str): the root path of dataset.
|
|
|
+ dataset_type (str, optional): dataset type. Defaults to None.
|
|
|
+ train_list_path (str, optional): the path of train dataset annotation file . Defaults to None.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ ValueError: the dataset_type error.
|
|
|
+ """
|
|
|
+ dataset_path = abspath(dataset_path)
|
|
|
+
|
|
|
+ dataset_type = "ShiTuRecDataset"
|
|
|
+ if train_list_path:
|
|
|
+ train_list_path = f"{train_list_path}"
|
|
|
+ else:
|
|
|
+ train_list_path = f"{dataset_path}/train.txt"
|
|
|
+
|
|
|
+
|
|
|
+ ds_cfg = [
|
|
|
+ f"DataLoader.Train.dataset.name={dataset_type}",
|
|
|
+ f"DataLoader.Train.dataset.image_root={dataset_path}",
|
|
|
+ f"DataLoader.Train.dataset.cls_label_path={train_list_path}",
|
|
|
+ f"DataLoader.Eval.Query.dataset.name={dataset_type}",
|
|
|
+ f"DataLoader.Eval.Query.dataset.image_root={dataset_path}",
|
|
|
+ f"DataLoader.Eval.Query.dataset.cls_label_path={dataset_path}/query.txt",
|
|
|
+ f"DataLoader.Eval.Gallery.dataset.name={dataset_type}",
|
|
|
+ f"DataLoader.Eval.Gallery.dataset.image_root={dataset_path}",
|
|
|
+ f"DataLoader.Eval.Gallery.dataset.cls_label_path={dataset_path}/gallery.txt",
|
|
|
+ ]
|
|
|
+
|
|
|
+ self.update(ds_cfg)
|
|
|
+
|
|
|
+ def update_batch_size(self, batch_size: int, mode: str = "train"):
|
|
|
+ """update batch size setting
|
|
|
+
|
|
|
+ Args:
|
|
|
+ batch_size (int): the batch size number to set.
|
|
|
+ mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval'.
|
|
|
+ Defaults to 'train'.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ ValueError: `mode` error.
|
|
|
+ """
|
|
|
+ if mode == "train":
|
|
|
+ if self.DataLoader["Train"]["sampler"].get("batch_size", False):
|
|
|
+ _cfg = [f"DataLoader.Train.sampler.batch_size={batch_size}"]
|
|
|
+ else:
|
|
|
+ _cfg = [f"DataLoader.Train.sampler.first_bs={batch_size}"]
|
|
|
+ _cfg = [f"DataLoader.Train.dataset.name=MultiScaleDataset"]
|
|
|
+ elif mode == "eval":
|
|
|
+ _cfg = [f"DataLoader.Eval.Query.sampler.batch_size={batch_size}"]
|
|
|
+ _cfg = [f"DataLoader.Eval.Gallery.sampler.batch_size={batch_size}"]
|
|
|
+ else:
|
|
|
+ raise ValueError("The input `mode` should be train or eval")
|
|
|
+ self.update(_cfg)
|
|
|
+
|
|
|
+
|
|
|
+ def update_num_classes(self, num_classes: int):
|
|
|
+ """update classes number
|
|
|
+
|
|
|
+ Args:
|
|
|
+ num_classes (int): the classes number value to set.
|
|
|
+ """
|
|
|
+ update_str_list = [f"Arch.Head.class_num={num_classes}"]
|
|
|
+ self.update(update_str_list)
|
|
|
+
|
|
|
+
|
|
|
+ def update_num_workers(self, num_workers: int):
|
|
|
+ """update workers number of train and eval dataloader
|
|
|
+
|
|
|
+ Args:
|
|
|
+ num_workers (int): the value of train and eval dataloader workers number to set.
|
|
|
+ """
|
|
|
+ _cfg = [
|
|
|
+ f"DataLoader.Train.loader.num_workers={num_workers}",
|
|
|
+ f"DataLoader.Eval.Query.loader.num_workers={num_workers}",
|
|
|
+ f"DataLoader.Eval.Gallery.loader.num_workers={num_workers}",
|
|
|
+ ]
|
|
|
+ self.update(_cfg)
|
|
|
+
|
|
|
+ def update_shared_memory(self, shared_memeory: bool):
|
|
|
+ """update shared memory setting of train and eval dataloader
|
|
|
+
|
|
|
+ Args:
|
|
|
+ shared_memeory (bool): whether or not to use shared memory
|
|
|
+ """
|
|
|
+ assert isinstance(shared_memeory, bool), "shared_memeory should be a bool"
|
|
|
+ _cfg = [
|
|
|
+ f"DataLoader.Train.loader.use_shared_memory={shared_memeory}",
|
|
|
+ f"DataLoader.Eval.Query.loader.use_shared_memory={shared_memeory}",
|
|
|
+ f"DataLoader.Eval.Gallery.loader.use_shared_memory={shared_memeory}",
|
|
|
+ ]
|
|
|
+ self.update(_cfg)
|
|
|
+
|
|
|
+ def update_shuffle(self, shuffle: bool):
|
|
|
+ """update shuffle setting of train and eval dataloader
|
|
|
+
|
|
|
+ Args:
|
|
|
+ shuffle (bool): whether or not to shuffle the data
|
|
|
+ """
|
|
|
+ assert isinstance(shuffle, bool), "shuffle should be a bool"
|
|
|
+ _cfg = [
|
|
|
+ f"DataLoader.Train.loader.shuffle={shuffle}",
|
|
|
+ f"DataLoader.Eval.Query.loader.shuffle={shuffle}",
|
|
|
+ f"DataLoader.Eval.Gallery.loader.shuffle={shuffle}",
|
|
|
+ ]
|
|
|
+ self.update(_cfg)
|
|
|
+
|
|
|
+
|
|
|
+ def _get_backbone_name(self) -> str:
|
|
|
+ """get backbone name of rec model
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: the model backbone name, i.e., `Arch.Backbone.name` in config.
|
|
|
+ """
|
|
|
+ return self.dict["Arch"]["Backbone"]["name"]
|