config.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..cls import ClsConfig
  15. from ....utils.misc import abspath
  16. class ShiTuRecConfig(ClsConfig):
  17. """ShiTu Recognition Config"""
  18. def update_dataset(
  19. self,
  20. dataset_path: str,
  21. dataset_type: str = None,
  22. *,
  23. train_list_path: str = None,
  24. ):
  25. """update dataset settings
  26. Args:
  27. dataset_path (str): the root path of dataset.
  28. dataset_type (str, optional): dataset type. Defaults to None.
  29. train_list_path (str, optional): the path of train dataset annotation file . Defaults to None.
  30. Raises:
  31. ValueError: the dataset_type error.
  32. """
  33. dataset_path = abspath(dataset_path)
  34. dataset_type = "ShiTuRecDataset"
  35. if train_list_path:
  36. train_list_path = f"{train_list_path}"
  37. else:
  38. train_list_path = f"{dataset_path}/train.txt"
  39. ds_cfg = [
  40. f"DataLoader.Train.dataset.name={dataset_type}",
  41. f"DataLoader.Train.dataset.image_root={dataset_path}",
  42. f"DataLoader.Train.dataset.cls_label_path={train_list_path}",
  43. f"DataLoader.Eval.Query.dataset.name={dataset_type}",
  44. f"DataLoader.Eval.Query.dataset.image_root={dataset_path}",
  45. f"DataLoader.Eval.Query.dataset.cls_label_path={dataset_path}/query.txt",
  46. f"DataLoader.Eval.Gallery.dataset.name={dataset_type}",
  47. f"DataLoader.Eval.Gallery.dataset.image_root={dataset_path}",
  48. f"DataLoader.Eval.Gallery.dataset.cls_label_path={dataset_path}/gallery.txt",
  49. ]
  50. self.update(ds_cfg)
  51. def update_batch_size(self, batch_size: int, mode: str = "train"):
  52. """update batch size setting
  53. Args:
  54. batch_size (int): the batch size number to set.
  55. mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval'.
  56. Defaults to 'train'.
  57. Raises:
  58. ValueError: `mode` error.
  59. """
  60. if mode == "train":
  61. if self.DataLoader["Train"]["sampler"].get("batch_size", False):
  62. _cfg = [f"DataLoader.Train.sampler.batch_size={batch_size}"]
  63. else:
  64. _cfg = [f"DataLoader.Train.sampler.first_bs={batch_size}"]
  65. _cfg = [f"DataLoader.Train.dataset.name=MultiScaleDataset"]
  66. elif mode == "eval":
  67. _cfg = [f"DataLoader.Eval.Query.sampler.batch_size={batch_size}"]
  68. _cfg = [f"DataLoader.Eval.Gallery.sampler.batch_size={batch_size}"]
  69. else:
  70. raise ValueError("The input `mode` should be train or eval")
  71. self.update(_cfg)
  72. def update_num_classes(self, num_classes: int):
  73. """update classes number
  74. Args:
  75. num_classes (int): the classes number value to set.
  76. """
  77. update_str_list = [f"Arch.Head.class_num={num_classes}"]
  78. self.update(update_str_list)
  79. def update_num_workers(self, num_workers: int):
  80. """update workers number of train and eval dataloader
  81. Args:
  82. num_workers (int): the value of train and eval dataloader workers number to set.
  83. """
  84. _cfg = [
  85. f"DataLoader.Train.loader.num_workers={num_workers}",
  86. f"DataLoader.Eval.Query.loader.num_workers={num_workers}",
  87. f"DataLoader.Eval.Gallery.loader.num_workers={num_workers}",
  88. ]
  89. self.update(_cfg)
  90. def update_shared_memory(self, shared_memeory: bool):
  91. """update shared memory setting of train and eval dataloader
  92. Args:
  93. shared_memeory (bool): whether or not to use shared memory
  94. """
  95. assert isinstance(shared_memeory, bool), "shared_memeory should be a bool"
  96. _cfg = [
  97. f"DataLoader.Train.loader.use_shared_memory={shared_memeory}",
  98. f"DataLoader.Eval.Query.loader.use_shared_memory={shared_memeory}",
  99. f"DataLoader.Eval.Gallery.loader.use_shared_memory={shared_memeory}",
  100. ]
  101. self.update(_cfg)
  102. def update_shuffle(self, shuffle: bool):
  103. """update shuffle setting of train and eval dataloader
  104. Args:
  105. shuffle (bool): whether or not to shuffle the data
  106. """
  107. assert isinstance(shuffle, bool), "shuffle should be a bool"
  108. _cfg = [
  109. f"DataLoader.Train.loader.shuffle={shuffle}",
  110. f"DataLoader.Eval.Query.loader.shuffle={shuffle}",
  111. f"DataLoader.Eval.Gallery.loader.shuffle={shuffle}",
  112. ]
  113. self.update(_cfg)
  114. def _get_backbone_name(self) -> str:
  115. """get backbone name of rec model
  116. Returns:
  117. str: the model backbone name, i.e., `Arch.Backbone.name` in config.
  118. """
  119. return self.dict["Arch"]["Backbone"]["name"]