config.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import copy
  16. import argparse
  17. import yaml
  18. from paddlex.ppcls.utils import logger
  19. from paddlex.ppcls.utils import check
  20. __all__ = ['get_config']
  21. class AttrDict(dict):
  22. def __getattr__(self, key):
  23. return self[key]
  24. def __setattr__(self, key, value):
  25. if key in self.__dict__:
  26. self.__dict__[key] = value
  27. else:
  28. self[key] = value
  29. def __deepcopy__(self, content):
  30. return copy.deepcopy(dict(self))
  31. def create_attr_dict(yaml_config):
  32. from ast import literal_eval
  33. for key, value in yaml_config.items():
  34. if type(value) is dict:
  35. yaml_config[key] = value = AttrDict(value)
  36. if isinstance(value, str):
  37. try:
  38. value = literal_eval(value)
  39. except BaseException:
  40. pass
  41. if isinstance(value, AttrDict):
  42. create_attr_dict(yaml_config[key])
  43. else:
  44. yaml_config[key] = value
  45. def parse_config(cfg_file):
  46. """Load a config file into AttrDict"""
  47. with open(cfg_file, 'r') as fopen:
  48. yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader))
  49. create_attr_dict(yaml_config)
  50. return yaml_config
  51. def print_dict(d, delimiter=0):
  52. """
  53. Recursively visualize a dict and
  54. indenting acrrording by the relationship of keys.
  55. """
  56. placeholder = "-" * 60
  57. for k, v in sorted(d.items()):
  58. if isinstance(v, dict):
  59. logger.info("{}{} : ".format(delimiter * " ", k))
  60. print_dict(v, delimiter + 4)
  61. elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
  62. logger.info("{}{} : ".format(delimiter * " ", k))
  63. for value in v:
  64. print_dict(value, delimiter + 4)
  65. else:
  66. logger.info("{}{} : {}".format(delimiter * " ", k, v))
  67. if k.isupper():
  68. logger.info(placeholder)
  69. def print_config(config):
  70. """
  71. visualize configs
  72. Arguments:
  73. config: configs
  74. """
  75. logger.advertise()
  76. print_dict(config)
  77. def check_config(config):
  78. """
  79. Check config
  80. """
  81. check.check_version()
  82. use_gpu = config.get('use_gpu', True)
  83. if use_gpu:
  84. check.check_gpu()
  85. architecture = config.get('ARCHITECTURE')
  86. #check.check_architecture(architecture)
  87. use_mix = config.get('use_mix', False)
  88. check.check_mix(architecture, use_mix)
  89. classes_num = config.get('classes_num')
  90. check.check_classes_num(classes_num)
  91. mode = config.get('mode', 'train')
  92. if mode.lower() == 'train':
  93. check.check_function_params(config, 'LEARNING_RATE')
  94. check.check_function_params(config, 'OPTIMIZER')
  95. def override(dl, ks, v):
  96. """
  97. Recursively replace dict of list
  98. Args:
  99. dl(dict or list): dict or list to be replaced
  100. ks(list): list of keys
  101. v(str): value to be replaced
  102. """
  103. def str2num(v):
  104. try:
  105. return eval(v)
  106. except Exception:
  107. return v
  108. assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
  109. assert len(ks) > 0, ('lenght of keys should larger than 0')
  110. if isinstance(dl, list):
  111. k = str2num(ks[0])
  112. if len(ks) == 1:
  113. assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
  114. dl[k] = str2num(v)
  115. else:
  116. override(dl[k], ks[1:], v)
  117. else:
  118. if len(ks) == 1:
  119. # assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
  120. if not ks[0] in dl:
  121. print('A new filed ({}) detected!'.format(ks[0], dl))
  122. dl[ks[0]] = str2num(v)
  123. else:
  124. override(dl[ks[0]], ks[1:], v)
  125. def override_config(config, options=None):
  126. """
  127. Recursively override the config
  128. Args:
  129. config(dict): dict to be replaced
  130. options(list): list of pairs(key0.key1.idx.key2=value)
  131. such as: [
  132. 'topk=2',
  133. 'VALID.transforms.1.ResizeImage.resize_short=300'
  134. ]
  135. Returns:
  136. config(dict): replaced config
  137. """
  138. if options is not None:
  139. for opt in options:
  140. assert isinstance(opt, str), (
  141. "option({}) should be a str".format(opt))
  142. assert "=" in opt, (
  143. "option({}) should contain a ="
  144. "to distinguish between key and value".format(opt))
  145. pair = opt.split('=')
  146. assert len(pair) == 2, ("there can be only a = in the option")
  147. key, value = pair
  148. keys = key.split('.')
  149. override(config, keys, value)
  150. return config
  151. def get_config(fname, overrides=None, show=False):
  152. """
  153. Read config from file
  154. """
  155. assert os.path.exists(fname), (
  156. 'config file({}) is not exist'.format(fname))
  157. config = parse_config(fname)
  158. override_config(config, overrides)
  159. if show:
  160. print_config(config)
  161. # check_config(config)
  162. return config
  163. def parse_args():
  164. parser = argparse.ArgumentParser("generic-image-rec train script")
  165. parser.add_argument(
  166. '-c',
  167. '--config',
  168. type=str,
  169. default='configs/config.yaml',
  170. help='config file path')
  171. parser.add_argument(
  172. '-o',
  173. '--override',
  174. action='append',
  175. default=[],
  176. help='config options to be overridden')
  177. parser.add_argument(
  178. '-p',
  179. '--profiler_options',
  180. type=str,
  181. default=None,
  182. help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
  183. )
  184. args = parser.parse_args()
  185. return args