config.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from functools import lru_cache
  16. import yaml
  17. from typing import Union
  18. from ..base_seg_config import BaseSegConfig
  19. from ....utils.misc import abspath
  20. from ....utils import logging
  21. class SegConfig(BaseSegConfig):
  22. """Semantic Segmentation Config"""
  23. def update_dataset(self, dataset_path: str, dataset_type: str = None):
  24. """update dataset settings
  25. Args:
  26. dataset_path (str): the root path of dataset.
  27. dataset_type (str, optional): dataset type. Defaults to None.
  28. Raises:
  29. ValueError: the dataset_type error.
  30. """
  31. dataset_dir = abspath(dataset_path)
  32. if dataset_type is None:
  33. dataset_type = "SegDataset"
  34. if dataset_type == "SegDataset":
  35. # TODO: Prune extra keys
  36. ds_cfg = self._make_custom_dataset_config(dataset_dir)
  37. self.update(ds_cfg)
  38. elif dataset_type == "_dummy":
  39. # XXX: A special dataset type to tease PaddleSeg val dataset checkers
  40. self.update(
  41. {
  42. "val_dataset": {
  43. "type": "SegDataset",
  44. "dataset_root": dataset_dir,
  45. "val_path": os.path.join(dataset_dir, "val.txt"),
  46. "mode": "val",
  47. },
  48. }
  49. )
  50. else:
  51. raise ValueError(f"{repr(dataset_type)} is not supported.")
  52. def update_num_classes(self, num_classes: int):
  53. """update classes number
  54. Args:
  55. num_classes (int): the classes number value to set.
  56. """
  57. if "train_dataset" in self:
  58. self.train_dataset["num_classes"] = num_classes
  59. if "val_dataset" in self:
  60. self.val_dataset["num_classes"] = num_classes
  61. if "model" in self:
  62. self.model["num_classes"] = num_classes
  63. if self.model_name in ["MaskFormer_tiny", "MaskFormer_small"]:
  64. for tf_cfg in self.train_dataset["transforms"]:
  65. if tf_cfg["type"] == "GenerateInstanceTargets":
  66. tf_cfg["num_classes"] = num_classes
  67. losses = self.loss["types"]
  68. for loss_cfg in losses:
  69. loss_cfg["num_classes"] = num_classes
  70. def update_train_crop_size(self, crop_size: Union[int, list]):
  71. """update the image cropping size of training preprocessing
  72. Args:
  73. crop_size (int | list): the size of image to be cropped.
  74. Raises:
  75. ValueError: the `crop_size` error.
  76. """
  77. # XXX: This method is highly coupled to PaddleSeg's internal functions
  78. if isinstance(crop_size, int):
  79. crop_size = [crop_size, crop_size]
  80. else:
  81. crop_size = list(crop_size)
  82. if len(crop_size) != 2:
  83. raise ValueError
  84. crop_size = [int(crop_size[0]), int(crop_size[1])]
  85. tf_cfg_list = self.train_dataset["transforms"]
  86. modified = False
  87. for tf_cfg in tf_cfg_list:
  88. if tf_cfg["type"] == "RandomPaddingCrop":
  89. tf_cfg["crop_size"] = crop_size
  90. modified = True
  91. if not modified:
  92. logging.warning(
  93. "Could not find configuration item of image cropping transformation operator. "
  94. "Therefore, the crop size was not updated."
  95. )
  96. def get_epochs_iters(self) -> int:
  97. """get epochs
  98. Returns:
  99. int: the epochs value, i.e., `Global.epochs` in config.
  100. """
  101. if "iters" in self:
  102. return self.iters
  103. else:
  104. # Default iters
  105. return 1000
  106. def get_learning_rate(self) -> float:
  107. """get learning rate
  108. Returns:
  109. float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
  110. """
  111. if "lr_scheduler" not in self or "learning_rate" not in self.lr_scheduler:
  112. # Default lr
  113. return 0.0001
  114. else:
  115. return self.lr_scheduler["learning_rate"]
  116. def get_batch_size(self, mode="train") -> int:
  117. """get batch size
  118. Args:
  119. mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
  120. Defaults to 'train'.
  121. Raises:
  122. ValueError: the `mode` error. `train` is supported only.
  123. Returns:
  124. int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
  125. """
  126. if mode == "train":
  127. if "batch_size" in self:
  128. return self.batch_size
  129. else:
  130. # Default batch size
  131. return 4
  132. else:
  133. raise ValueError(
  134. f"Getting `batch_size` in {repr(mode)} mode is not supported."
  135. )
  136. def _make_custom_dataset_config(self, dataset_root_path: str) -> dict:
  137. """construct the dataset config that meets the format requirements
  138. Args:
  139. dataset_root_path (str): the root directory of dataset.
  140. Returns:
  141. dict: the dataset config.
  142. """
  143. ds_cfg = {
  144. "train_dataset": {
  145. "type": "SegDataset",
  146. "dataset_root": dataset_root_path,
  147. "train_path": os.path.join(dataset_root_path, "train.txt"),
  148. "mode": "train",
  149. },
  150. "val_dataset": {
  151. "type": "SegDataset",
  152. "dataset_root": dataset_root_path,
  153. "val_path": os.path.join(dataset_root_path, "val.txt"),
  154. "mode": "val",
  155. },
  156. }
  157. return ds_cfg