config.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import abc
  15. import collections.abc
  16. from collections import OrderedDict
  17. from .register import get_registered_model_info, get_registered_suite_info
  18. from ...utils.errors import UnsupportedParamError
  19. __all__ = ['Config', 'BaseConfig']
  20. def _create_config(model_name, config_path=None):
  21. """ _create_config """
  22. # Build config from model name
  23. try:
  24. model_info = get_registered_model_info(model_name)
  25. except KeyError as e:
  26. raise UnsupportedParamError(
  27. f"{repr(model_name)} is not a registered model name.") from e
  28. suite_name = model_info['suite']
  29. suite_info = get_registered_suite_info(suite_name)
  30. config_cls = suite_info['config']
  31. config_obj = config_cls(model_name=model_name, config_path=config_path)
  32. return config_obj
  33. Config = _create_config
  34. class _Config(object):
  35. """ _Config """
  36. _DICT_TYPE_ = OrderedDict
  37. def __init__(self, cfg=None):
  38. super().__init__()
  39. self._dict = self._DICT_TYPE_()
  40. if cfg is not None:
  41. # Manipulate the internal `_dict` such that we avoid an extra copy
  42. self.reset_from_dict(cfg._dict)
  43. @property
  44. def dict(self):
  45. """ dict """
  46. return dict(self._dict)
  47. def __getattr__(self, key):
  48. try:
  49. val = self._dict[key]
  50. return val
  51. except KeyError:
  52. raise AttributeError
  53. def set_val(self, key, val):
  54. """ set_val """
  55. self._dict[key] = val
  56. def __getitem__(self, key):
  57. return self._dict[key]
  58. def __setitem__(self, key, val):
  59. self._dict[key] = val
  60. def __contains__(self, key):
  61. return key in self._dict
  62. def new_config(self, **kwargs):
  63. """ new_config """
  64. cfg = self.copy()
  65. cfg.update(kwargs)
  66. def copy(self):
  67. """ copy """
  68. return type(self)(cfg=self)
  69. def pop(self, key):
  70. """ pop """
  71. self._dict.pop(key)
  72. def __repr__(self):
  73. return format_cfg(self, indent=0)
  74. def reset_from_dict(self, dict_like_obj):
  75. """ reset_from_dict """
  76. self._dict.clear()
  77. self._dict.update(dict_like_obj)
  78. class BaseConfig(_Config, metaclass=abc.ABCMeta):
  79. """
  80. Abstract base class of Config.
  81. Config provides the funtionality to load, parse, or dump to a configuration
  82. file with a specific format. Also, it provides APIs to update configurations
  83. of several important hyperparameters and model components.
  84. """
  85. def __init__(self, model_name, config_path=None, cfg=None):
  86. """
  87. Initialize the instance.
  88. Args:
  89. model_name (str): A registered model name.
  90. config_path (str|None): Path of a configuration file. Default: None.
  91. cfg (BaseConfig|None): `BaseConfig` object to initialize from.
  92. Default: None.
  93. """
  94. super().__init__(cfg=cfg)
  95. self.model_name = model_name
  96. if cfg is None:
  97. # Initialize from file if no `cfg` is specified to initialize from
  98. if config_path is None:
  99. model_info = get_registered_model_info(self.model_name)
  100. config_path = model_info['config_path']
  101. self.load(config_path)
  102. def update_device(self, device):
  103. """Update the device"""
  104. pass
  105. @abc.abstractmethod
  106. def load(self, config_path):
  107. """Load configurations from a file."""
  108. raise NotImplementedError
  109. @abc.abstractmethod
  110. def dump(self, config_path):
  111. """Dump configurations to a file."""
  112. raise NotImplementedError
  113. @abc.abstractmethod
  114. def update(self, dict_like_obj):
  115. """Update configurations from a dict-like object."""
  116. raise NotImplementedError
  117. @abc.abstractmethod
  118. def update_dataset(self, dataset_dir, dataset_type=None):
  119. """Update configurations of dataset."""
  120. raise NotImplementedError
  121. @abc.abstractmethod
  122. def update_learning_rate(self, learning_rate):
  123. """Update learning rate."""
  124. raise NotImplementedError
  125. @abc.abstractmethod
  126. def update_batch_size(self, batch_size, mode='train'):
  127. """
  128. Update batch size.
  129. By default this method modifies the training batch size.
  130. """
  131. raise NotImplementedError
  132. @abc.abstractmethod
  133. def update_pretrained_weights(self, weight_path, is_backbone=False):
  134. """
  135. Update path to pretrained weights.
  136. By default this method modifies the weight path for the entire model.
  137. """
  138. raise NotImplementedError
  139. def get_epochs_iters(self):
  140. """Get total number of epochs or iterations in training."""
  141. raise NotImplementedError
  142. def get_learning_rate(self):
  143. """Get learning rate used in training."""
  144. raise NotImplementedError
  145. def get_batch_size(self, mode='train'):
  146. """
  147. Get batch size.
  148. By default this method returns the training batch size.
  149. """
  150. raise NotImplementedError
  151. def get_qat_epochs_iters(self):
  152. """Get total number of epochs or iterations in QAT."""
  153. raise NotImplementedError
  154. def get_qat_learning_rate(self):
  155. """Get learning rate used in QAT."""
  156. raise NotImplementedError
  157. def copy(self):
  158. """ copy """
  159. return type(self)(model_name=self.model_name, cfg=self)
  160. def format_cfg(cfg, indent=0):
  161. """ format_cfg """
  162. MAP_TYPES = (collections.abc.Mapping, )
  163. SEQ_TYPES = (list, tuple)
  164. NESTED_TYPES = (*MAP_TYPES, *SEQ_TYPES)
  165. s = ' ' * indent
  166. if isinstance(cfg, _Config):
  167. cfg = cfg.dict
  168. if isinstance(cfg, MAP_TYPES):
  169. for i, (k, v) in enumerate(sorted(cfg.items())):
  170. s += str(k) + ': '
  171. if isinstance(v, NESTED_TYPES):
  172. s += '\n' + format_cfg(v, indent=indent + 1)
  173. else:
  174. s += str(v)
  175. if i != len(cfg) - 1:
  176. s += '\n'
  177. elif isinstance(cfg, SEQ_TYPES):
  178. for i, v in enumerate(cfg):
  179. s += '- '
  180. if isinstance(v, NESTED_TYPES):
  181. s += '\n' + format_cfg(v, indent=indent + 1)
  182. else:
  183. s += str(v)
  184. if i != len(cfg) - 1:
  185. s += '\n'
  186. else:
  187. s += str(cfg)
  188. return s