config.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import yaml
  16. from paddlex.ppcls.utils import check
  17. from paddlex.ppcls.utils import logger
  18. __all__ = ['get_config']
  19. class AttrDict(dict):
  20. def __getattr__(self, key):
  21. return self[key]
  22. def __setattr__(self, key, value):
  23. if key in self.__dict__:
  24. self.__dict__[key] = value
  25. else:
  26. self[key] = value
  27. def create_attr_dict(yaml_config):
  28. from ast import literal_eval
  29. for key, value in yaml_config.items():
  30. if type(value) is dict:
  31. yaml_config[key] = value = AttrDict(value)
  32. if isinstance(value, str):
  33. try:
  34. value = literal_eval(value)
  35. except BaseException:
  36. pass
  37. if isinstance(value, AttrDict):
  38. create_attr_dict(yaml_config[key])
  39. else:
  40. yaml_config[key] = value
  41. def parse_config(cfg_file):
  42. """Load a config file into AttrDict"""
  43. with open(cfg_file, 'r') as fopen:
  44. yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader))
  45. create_attr_dict(yaml_config)
  46. return yaml_config
  47. def print_dict(d, delimiter=0):
  48. """
  49. Recursively visualize a dict and
  50. indenting acrrording by the relationship of keys.
  51. """
  52. placeholder = "-" * 60
  53. for k, v in sorted(d.items()):
  54. if isinstance(v, dict):
  55. logger.info("{}{} : ".format(delimiter * " ",
  56. logger.coloring(k, "HEADER")))
  57. print_dict(v, delimiter + 4)
  58. elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
  59. logger.info("{}{} : ".format(delimiter * " ",
  60. logger.coloring(str(k), "HEADER")))
  61. for value in v:
  62. print_dict(value, delimiter + 4)
  63. else:
  64. logger.info("{}{} : {}".format(delimiter * " ",
  65. logger.coloring(k, "HEADER"),
  66. logger.coloring(v, "OKGREEN")))
  67. if k.isupper():
  68. logger.info(placeholder)
  69. def print_config(config):
  70. """
  71. visualize configs
  72. Arguments:
  73. config: configs
  74. """
  75. logger.advertise()
  76. print_dict(config)
  77. def check_config(config):
  78. """
  79. Check config
  80. """
  81. check.check_version()
  82. use_gpu = config.get('use_gpu', True)
  83. if use_gpu:
  84. check.check_gpu()
  85. architecture = config.get('ARCHITECTURE')
  86. check.check_architecture(architecture)
  87. check.check_model_with_running_mode(architecture)
  88. use_mix = config.get('use_mix', False)
  89. check.check_mix(architecture, use_mix)
  90. classes_num = config.get('classes_num')
  91. check.check_classes_num(classes_num)
  92. mode = config.get('mode', 'train')
  93. if mode.lower() == 'train':
  94. check.check_function_params(config, 'LEARNING_RATE')
  95. check.check_function_params(config, 'OPTIMIZER')
  96. def override(dl, ks, v):
  97. """
  98. Recursively replace dict of list
  99. Args:
  100. dl(dict or list): dict or list to be replaced
  101. ks(list): list of keys
  102. v(str): value to be replaced
  103. """
  104. def str2num(v):
  105. try:
  106. return eval(v)
  107. except Exception:
  108. return v
  109. assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
  110. assert len(ks) > 0, ('lenght of keys should larger than 0')
  111. if isinstance(dl, list):
  112. k = str2num(ks[0])
  113. if len(ks) == 1:
  114. assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
  115. dl[k] = str2num(v)
  116. else:
  117. override(dl[k], ks[1:], v)
  118. else:
  119. if len(ks) == 1:
  120. # assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
  121. if not ks[0] in dl:
  122. logger.warning('A new filed ({}) detected!'.format(ks[0]))
  123. dl[ks[0]] = str2num(v)
  124. else:
  125. if not ks[0] in dl:
  126. logger.warning('A new filed ({}) detected!'.format(ks[0]))
  127. dl[ks[0]] = {}
  128. override(dl[ks[0]], ks[1:], v)
  129. def override_config(config, options=None):
  130. """
  131. Recursively override the config
  132. Args:
  133. config(dict): dict to be replaced
  134. options(list): list of pairs(key0.key1.idx.key2=value)
  135. such as: [
  136. 'topk=2',
  137. 'VALID.transforms.1.ResizeImage.resize_short=300'
  138. ]
  139. Returns:
  140. config(dict): replaced config
  141. """
  142. if options is not None:
  143. for opt in options:
  144. assert isinstance(opt, str), (
  145. "option({}) should be a str".format(opt))
  146. assert "=" in opt, (
  147. "option({}) should contain a ="
  148. "to distinguish between key and value".format(opt))
  149. pair = opt.split('=')
  150. assert len(pair) == 2, ("there can be only a = in the option")
  151. key, value = pair
  152. keys = key.split('.')
  153. override(config, keys, value)
  154. return config
  155. def get_config(fname, overrides=None, show=True):
  156. """
  157. Read config from file
  158. """
  159. assert os.path.exists(fname), (
  160. 'config file({}) is not exist'.format(fname))
  161. config = parse_config(fname)
  162. override_config(config, overrides)
  163. if show:
  164. print_config(config)
  165. check_config(config)
  166. return config