config.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import abc
  15. import collections.abc
  16. from collections import OrderedDict
  17. from ...utils.errors import UnsupportedParamError
  18. from .register import get_registered_model_info, get_registered_suite_info
  19. __all__ = ["Config", "BaseConfig"]
  20. def _create_config(model_name, config_path=None):
  21. """_create_config"""
  22. # Build config from model name
  23. try:
  24. model_info = get_registered_model_info(model_name)
  25. except KeyError as e:
  26. raise UnsupportedParamError(
  27. f"{repr(model_name)} is not a registered model name."
  28. ) from e
  29. suite_name = model_info["suite"]
  30. suite_info = get_registered_suite_info(suite_name)
  31. config_cls = suite_info["config"]
  32. config_obj = config_cls(model_name=model_name, config_path=config_path)
  33. return config_obj
  34. Config = _create_config
  35. class _Config(object):
  36. """_Config"""
  37. _DICT_TYPE_ = OrderedDict
  38. def __init__(self, cfg=None):
  39. super().__init__()
  40. self._dict = self._DICT_TYPE_()
  41. if cfg is not None:
  42. # Manipulate the internal `_dict` such that we avoid an extra copy
  43. self.reset_from_dict(cfg._dict)
  44. @property
  45. def dict(self):
  46. """dict"""
  47. return dict(self._dict)
  48. def __getattr__(self, key):
  49. try:
  50. val = self._dict[key]
  51. return val
  52. except KeyError:
  53. raise AttributeError
  54. def set_val(self, key, val):
  55. """set_val"""
  56. self._dict[key] = val
  57. def __getitem__(self, key):
  58. return self._dict[key]
  59. def __setitem__(self, key, val):
  60. self._dict[key] = val
  61. def __contains__(self, key):
  62. return key in self._dict
  63. def new_config(self, **kwargs):
  64. """new_config"""
  65. cfg = self.copy()
  66. cfg.update(kwargs)
  67. def copy(self):
  68. """copy"""
  69. return type(self)(cfg=self)
  70. def pop(self, key):
  71. """pop"""
  72. self._dict.pop(key)
  73. def __repr__(self):
  74. return format_cfg(self, indent=0)
  75. def reset_from_dict(self, dict_like_obj):
  76. """reset_from_dict"""
  77. self._dict.clear()
  78. self._dict.update(dict_like_obj)
  79. class BaseConfig(_Config, metaclass=abc.ABCMeta):
  80. """
  81. Abstract base class of Config.
  82. Config provides the functionality to load, parse, or dump to a configuration
  83. file with a specific format. Also, it provides APIs to update configurations
  84. of several important hyperparameters and model components.
  85. """
  86. def __init__(self, model_name, config_path=None, cfg=None):
  87. """
  88. Initialize the instance.
  89. Args:
  90. model_name (str): A registered model name.
  91. config_path (str|None): Path of a configuration file. Default: None.
  92. cfg (BaseConfig|None): `BaseConfig` object to initialize from.
  93. Default: None.
  94. """
  95. super().__init__(cfg=cfg)
  96. self.model_name = model_name
  97. if cfg is None:
  98. # Initialize from file if no `cfg` is specified to initialize from
  99. if config_path is None:
  100. model_info = get_registered_model_info(self.model_name)
  101. config_path = model_info["config_path"]
  102. self.load(config_path)
  103. def update_device(self, device):
  104. """Update the device"""
  105. @abc.abstractmethod
  106. def load(self, config_path):
  107. """Load configurations from a file."""
  108. raise NotImplementedError
  109. @abc.abstractmethod
  110. def dump(self, config_path):
  111. """Dump configurations to a file."""
  112. raise NotImplementedError
  113. @abc.abstractmethod
  114. def update(self, dict_like_obj):
  115. """Update configurations from a dict-like object."""
  116. raise NotImplementedError
  117. @abc.abstractmethod
  118. def update_dataset(self, dataset_dir, dataset_type=None):
  119. """Update configurations of dataset."""
  120. raise NotImplementedError
  121. @abc.abstractmethod
  122. def update_learning_rate(self, learning_rate):
  123. """Update learning rate."""
  124. raise NotImplementedError
  125. @abc.abstractmethod
  126. def update_batch_size(self, batch_size, mode="train"):
  127. """
  128. Update batch size.
  129. By default this method modifies the training batch size.
  130. """
  131. raise NotImplementedError
  132. @abc.abstractmethod
  133. def update_pretrained_weights(self, weight_path, is_backbone=False):
  134. """
  135. Update path to pretrained weights.
  136. By default this method modifies the weight path for the entire model.
  137. """
  138. raise NotImplementedError
  139. def get_epochs_iters(self):
  140. """Get total number of epochs or iterations in training."""
  141. raise NotImplementedError
  142. def get_learning_rate(self):
  143. """Get learning rate used in training."""
  144. raise NotImplementedError
  145. def get_batch_size(self, mode="train"):
  146. """
  147. Get batch size.
  148. By default this method returns the training batch size.
  149. """
  150. raise NotImplementedError
  151. def get_qat_epochs_iters(self):
  152. """Get total number of epochs or iterations in QAT."""
  153. raise NotImplementedError
  154. def get_qat_learning_rate(self):
  155. """Get learning rate used in QAT."""
  156. raise NotImplementedError
  157. def copy(self):
  158. """copy"""
  159. return type(self)(model_name=self.model_name, cfg=self)
  160. def format_cfg(cfg, indent=0):
  161. """format_cfg"""
  162. MAP_TYPES = (collections.abc.Mapping,)
  163. SEQ_TYPES = (list, tuple)
  164. NESTED_TYPES = (*MAP_TYPES, *SEQ_TYPES)
  165. s = " " * indent
  166. if isinstance(cfg, _Config):
  167. cfg = cfg.dict
  168. if isinstance(cfg, MAP_TYPES):
  169. for i, (k, v) in enumerate(sorted(cfg.items())):
  170. s += str(k) + ": "
  171. if isinstance(v, NESTED_TYPES):
  172. s += "\n" + format_cfg(v, indent=indent + 1)
  173. else:
  174. s += str(v)
  175. if i != len(cfg) - 1:
  176. s += "\n"
  177. elif isinstance(cfg, SEQ_TYPES):
  178. for i, v in enumerate(cfg):
  179. s += "- "
  180. if isinstance(v, NESTED_TYPES):
  181. s += "\n" + format_cfg(v, indent=indent + 1)
  182. else:
  183. s += str(v)
  184. if i != len(cfg) - 1:
  185. s += "\n"
  186. else:
  187. s += str(cfg)
  188. return s