config.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. import copy
  13. import argparse
  14. import yaml
  15. from . import logging
  16. from .errors import raise_key_not_found_error
  17. __all__ = ['get_config']
  18. class AttrDict(dict):
  19. """ Attr Dict """
  20. def __getattr__(self, key):
  21. if key in self:
  22. return self[key]
  23. else:
  24. raise raise_key_not_found_error(key, self)
  25. def __setattr__(self, key, value):
  26. if key in self.__dict__:
  27. self.__dict__[key] = value
  28. else:
  29. self[key] = value
  30. def __deepcopy__(self, content):
  31. return copy.deepcopy(dict(self))
  32. def create_attr_dict(yaml_config):
  33. """ create attr dict """
  34. from ast import literal_eval
  35. for key, value in yaml_config.items():
  36. if type(value) is dict:
  37. yaml_config[key] = value = AttrDict(value)
  38. if isinstance(value, str):
  39. try:
  40. value = literal_eval(value)
  41. except BaseException:
  42. pass
  43. if isinstance(value, AttrDict):
  44. create_attr_dict(yaml_config[key])
  45. else:
  46. yaml_config[key] = value
  47. def parse_config(cfg_file):
  48. """Load a config file into AttrDict"""
  49. with open(cfg_file, 'r') as fopen:
  50. yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader))
  51. create_attr_dict(yaml_config)
  52. return yaml_config
  53. def print_dict(d, delimiter=0):
  54. """
  55. Recursively visualize a dict and
  56. indenting acrrording by the relationship of keys.
  57. """
  58. placeholder = "-" * 60
  59. for k, v in sorted(d.items()):
  60. if isinstance(v, dict):
  61. logging.info("{}{} : ".format(delimiter * " ", k))
  62. print_dict(v, delimiter + 4)
  63. elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
  64. logging.info("{}{} : ".format(delimiter * " ", k))
  65. for value in v:
  66. print_dict(value, delimiter + 4)
  67. else:
  68. logging.info("{}{} : {}".format(delimiter * " ", k, v))
  69. if k.isupper():
  70. logging.info(placeholder)
  71. def print_config(config):
  72. """
  73. visualize configs
  74. Arguments:
  75. config: configs
  76. """
  77. logging.advertise()
  78. print_dict(config)
  79. # def check_config(config):
  80. # """
  81. # Check config
  82. # """
  83. # from . import check
  84. # check.check_version()
  85. # use_gpu = config.get('use_gpu', True)
  86. # if use_gpu:
  87. # check.check_gpu()
  88. # architecture = config.get('ARCHITECTURE')
  89. # #check.check_architecture(architecture)
  90. # use_mix = config.get('use_mix', False)
  91. # check.check_mix(architecture, use_mix)
  92. # classes_num = config.get('classes_num')
  93. # check.check_classes_num(classes_num)
  94. # mode = config.get('mode', 'train')
  95. # if mode.lower() == 'train':
  96. # check.check_function_params(config, 'LEARNING_RATE')
  97. # check.check_function_params(config, 'OPTIMIZER')
  98. def override(dl, ks, v):
  99. """
  100. Recursively replace dict of list
  101. Args:
  102. dl(dict or list): dict or list to be replaced
  103. ks(list): list of keys
  104. v(str): value to be replaced
  105. """
  106. def str2num(v):
  107. try:
  108. return eval(v)
  109. except Exception:
  110. return v
  111. assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
  112. assert len(ks) > 0, ('lenght of keys should larger than 0')
  113. if isinstance(dl, list):
  114. k = str2num(ks[0])
  115. if len(ks) == 1:
  116. assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
  117. dl[k] = str2num(v)
  118. else:
  119. override(dl[k], ks[1:], v)
  120. else:
  121. if len(ks) == 1:
  122. # assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
  123. if not ks[0] in dl:
  124. logging.warning(f'A new field ({ks[0]}) detected!')
  125. dl[ks[0]] = str2num(v)
  126. else:
  127. if ks[0] not in dl.keys():
  128. dl[ks[0]] = {}
  129. logging.warning(f"A new Series field ({ks[0]}) detected!")
  130. override(dl[ks[0]], ks[1:], v)
  131. def override_config(config, options=None):
  132. """
  133. Recursively override the config
  134. Args:
  135. config(dict): dict to be replaced
  136. options(list): list of pairs(key0.key1.idx.key2=value)
  137. such as: [
  138. 'topk=2',
  139. 'VALID.transforms.1.ResizeImage.resize_short=300'
  140. ]
  141. Returns:
  142. config(dict): replaced config
  143. """
  144. if options is not None:
  145. for opt in options:
  146. assert isinstance(opt, str), (
  147. "option({}) should be a str".format(opt))
  148. assert "=" in opt, (
  149. "option({}) should contain a ="
  150. "to distinguish between key and value".format(opt))
  151. pair = opt.split('=')
  152. assert len(pair) == 2, ("there can be only a = in the option")
  153. key, value = pair
  154. keys = key.split('.')
  155. override(config, keys, value)
  156. return config
  157. def get_config(fname, overrides=None, show=False):
  158. """
  159. Read config from file
  160. """
  161. assert os.path.exists(fname), ('config file({}) is not exist'.format(fname))
  162. config = parse_config(fname)
  163. override_config(config, overrides)
  164. if show:
  165. print_config(config)
  166. # check_config(config)
  167. return config
  168. def parse_args():
  169. """ parse args """
  170. parser = argparse.ArgumentParser("generic-image-rec train script")
  171. parser.add_argument(
  172. '-c',
  173. '--config',
  174. type=str,
  175. default='configs/config.yaml',
  176. help='config file path')
  177. parser.add_argument(
  178. '-o',
  179. '--override',
  180. action='append',
  181. default=[],
  182. help='config options to be overridden')
  183. parser.add_argument(
  184. '-p',
  185. '--profiler_options',
  186. type=str,
  187. default=None,
  188. help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
  189. )
  190. args = parser.parse_args()
  191. return args