config.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import copy
  16. import argparse
  17. import yaml
  18. from . import logging
  19. from .errors import raise_key_not_found_error
  20. __all__ = ['get_config']
  21. class AttrDict(dict):
  22. """ Attr Dict """
  23. def __getattr__(self, key):
  24. if key in self:
  25. return self[key]
  26. else:
  27. raise raise_key_not_found_error(key, self)
  28. def __setattr__(self, key, value):
  29. if key in self.__dict__:
  30. self.__dict__[key] = value
  31. else:
  32. self[key] = value
  33. def __deepcopy__(self, content):
  34. return copy.deepcopy(dict(self))
  35. def create_attr_dict(yaml_config):
  36. """ create attr dict """
  37. from ast import literal_eval
  38. for key, value in yaml_config.items():
  39. if type(value) is dict:
  40. yaml_config[key] = value = AttrDict(value)
  41. if isinstance(value, str):
  42. try:
  43. value = literal_eval(value)
  44. except BaseException:
  45. pass
  46. if isinstance(value, AttrDict):
  47. create_attr_dict(yaml_config[key])
  48. else:
  49. yaml_config[key] = value
  50. def parse_config(cfg_file):
  51. """Load a config file into AttrDict"""
  52. with open(cfg_file, 'r') as fopen:
  53. yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader))
  54. create_attr_dict(yaml_config)
  55. return yaml_config
  56. def print_dict(d, delimiter=0):
  57. """
  58. Recursively visualize a dict and
  59. indenting acrrording by the relationship of keys.
  60. """
  61. placeholder = "-" * 60
  62. for k, v in sorted(d.items()):
  63. if isinstance(v, dict):
  64. logging.info("{}{} : ".format(delimiter * " ", k))
  65. print_dict(v, delimiter + 4)
  66. elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
  67. logging.info("{}{} : ".format(delimiter * " ", k))
  68. for value in v:
  69. print_dict(value, delimiter + 4)
  70. else:
  71. logging.info("{}{} : {}".format(delimiter * " ", k, v))
  72. if k.isupper():
  73. logging.info(placeholder)
  74. def print_config(config):
  75. """
  76. visualize configs
  77. Arguments:
  78. config: configs
  79. """
  80. logging.advertise()
  81. print_dict(config)
  82. def override(dl, ks, v):
  83. """
  84. Recursively replace dict of list
  85. Args:
  86. dl(dict or list): dict or list to be replaced
  87. ks(list): list of keys
  88. v(str): value to be replaced
  89. """
  90. def str2num(v):
  91. try:
  92. return eval(v)
  93. except Exception:
  94. return v
  95. assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
  96. assert len(ks) > 0, ('lenght of keys should larger than 0')
  97. if isinstance(dl, list):
  98. k = str2num(ks[0])
  99. if len(ks) == 1:
  100. assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
  101. dl[k] = str2num(v)
  102. else:
  103. override(dl[k], ks[1:], v)
  104. else:
  105. if len(ks) == 1:
  106. # assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
  107. if not ks[0] in dl:
  108. logging.warning(f'A new field ({ks[0]}) detected!')
  109. dl[ks[0]] = str2num(v)
  110. else:
  111. if ks[0] not in dl.keys():
  112. dl[ks[0]] = {}
  113. logging.warning(f"A new Series field ({ks[0]}) detected!")
  114. override(dl[ks[0]], ks[1:], v)
  115. def override_config(config, options=None):
  116. """
  117. Recursively override the config
  118. Args:
  119. config(dict): dict to be replaced
  120. options(list): list of pairs(key0.key1.idx.key2=value)
  121. such as: [
  122. 'topk=2',
  123. 'VALID.transforms.1.ResizeImage.resize_short=300'
  124. ]
  125. Returns:
  126. config(dict): replaced config
  127. """
  128. if options is not None:
  129. for opt in options:
  130. assert isinstance(opt, str), (
  131. "option({}) should be a str".format(opt))
  132. assert "=" in opt, (
  133. "option({}) should contain a ="
  134. "to distinguish between key and value".format(opt))
  135. pair = opt.split('=')
  136. assert len(pair) == 2, ("there can be only a = in the option")
  137. key, value = pair
  138. keys = key.split('.')
  139. override(config, keys, value)
  140. return config
  141. def get_config(fname, overrides=None, show=False):
  142. """
  143. Read config from file
  144. """
  145. assert os.path.exists(fname), ('config file({}) is not exist'.format(fname))
  146. config = parse_config(fname)
  147. override_config(config, overrides)
  148. if show:
  149. print_config(config)
  150. # check_config(config)
  151. return config
  152. def parse_args():
  153. """ parse args """
  154. parser = argparse.ArgumentParser("PaddleX script")
  155. parser.add_argument(
  156. '-c',
  157. '--config',
  158. type=str,
  159. default='configs/config.yaml',
  160. help='config file path')
  161. parser.add_argument(
  162. '-o',
  163. '--override',
  164. action='append',
  165. default=[],
  166. help='config options to be overridden')
  167. parser.add_argument(
  168. '-p',
  169. '--profiler_options',
  170. type=str,
  171. default=None,
  172. help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
  173. )
  174. args = parser.parse_args()
  175. return args