config.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from functools import lru_cache
  16. import yaml
  17. from typing import Union
  18. from paddleseg.utils import NoAliasDumper
  19. from ..base_seg_config import BaseSegConfig
  20. from ....utils.misc import abspath
  21. from ....utils import logging
  22. class SegConfig(BaseSegConfig):
  23. """ Semantic Segmentation Config """
  24. def update_dataset(self, dataset_path: str, dataset_type: str=None):
  25. """update dataset settings
  26. Args:
  27. dataset_path (str): the root path of dataset.
  28. dataset_type (str, optional): dataset type. Defaults to None.
  29. Raises:
  30. ValueError: the dataset_type error.
  31. """
  32. dataset_dir = abspath(dataset_path)
  33. if dataset_type is None:
  34. dataset_type = 'SegDataset'
  35. if dataset_type == 'SegDataset':
  36. # TODO: Prune extra keys
  37. ds_cfg = self._make_custom_dataset_config(dataset_dir)
  38. self.update(ds_cfg)
  39. elif dataset_type == '_dummy':
  40. # XXX: A special dataset type to tease PaddleSeg val dataset checkers
  41. self.update({
  42. 'val_dataset': {
  43. 'type': 'SegDataset',
  44. 'dataset_root': dataset_dir,
  45. 'val_path': os.path.join(dataset_dir, 'val.txt'),
  46. 'mode': 'val'
  47. },
  48. })
  49. else:
  50. raise ValueError(f"{repr(dataset_type)} is not supported.")
  51. def update_num_classes(self, num_classes: int):
  52. """update classes number
  53. Args:
  54. num_classes (int): the classes number value to set.
  55. """
  56. if 'train_dataset' in self:
  57. self.train_dataset['num_classes'] = num_classes
  58. if 'val_dataset' in self:
  59. self.val_dataset['num_classes'] = num_classes
  60. if 'model' in self:
  61. self.model['num_classes'] = num_classes
  62. def update_train_crop_size(self, crop_size: Union[int, list]):
  63. """update the image cropping size of training preprocessing
  64. Args:
  65. crop_size (int | list): the size of image to be cropped.
  66. Raises:
  67. ValueError: the `crop_size` error.
  68. """
  69. # XXX: This method is highly coupled to PaddleSeg's internal functions
  70. if isinstance(crop_size, int):
  71. crop_size = [crop_size, crop_size]
  72. else:
  73. crop_size = list(crop_size)
  74. if len(crop_size) != 2:
  75. raise ValueError
  76. crop_size = [int(crop_size[0]), int(crop_size[1])]
  77. tf_cfg_list = self.train_dataset['transforms']
  78. modified = False
  79. for tf_cfg in tf_cfg_list:
  80. if tf_cfg['type'] == 'RandomPaddingCrop':
  81. tf_cfg['crop_size'] = crop_size
  82. modified = True
  83. if not modified:
  84. logging.warning(
  85. "Could not find configuration item of image cropping transformation operator. "
  86. "Therefore, the crop size was not updated.")
  87. def get_epochs_iters(self) -> int:
  88. """get epochs
  89. Returns:
  90. int: the epochs value, i.e., `Global.epochs` in config.
  91. """
  92. if 'iters' in self:
  93. return self.iters
  94. else:
  95. # Default iters
  96. return 1000
  97. def get_learning_rate(self) -> float:
  98. """get learning rate
  99. Returns:
  100. float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
  101. """
  102. if 'lr_scheduler' not in self or 'learning_rate' not in self.lr_scheduler:
  103. # Default lr
  104. return 0.0001
  105. else:
  106. return self.lr_scheduler['learning_rate']
  107. def get_batch_size(self, mode='train') -> int:
  108. """get batch size
  109. Args:
  110. mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
  111. Defaults to 'train'.
  112. Raises:
  113. ValueError: the `mode` error. `train` is supported only.
  114. Returns:
  115. int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
  116. """
  117. if mode == 'train':
  118. if 'batch_size' in self:
  119. return self.batch_size
  120. else:
  121. # Default batch size
  122. return 4
  123. else:
  124. raise ValueError(
  125. f"Getting `batch_size` in {repr(mode)} mode is not supported.")
  126. def _make_custom_dataset_config(self, dataset_root_path: str) -> dict:
  127. """construct the dataset config that meets the format requirements
  128. Args:
  129. dataset_root_path (str): the root directory of dataset.
  130. Returns:
  131. dict: the dataset config.
  132. """
  133. ds_cfg = {
  134. 'train_dataset': {
  135. 'type': 'SegDataset',
  136. 'dataset_root': dataset_root_path,
  137. 'train_path': os.path.join(dataset_root_path, 'train.txt'),
  138. 'mode': 'train'
  139. },
  140. 'val_dataset': {
  141. 'type': 'SegDataset',
  142. 'dataset_root': dataset_root_path,
  143. 'val_path': os.path.join(dataset_root_path, 'val.txt'),
  144. 'mode': 'val'
  145. },
  146. }
  147. return ds_cfg