config.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import copy
  16. import argparse
  17. import yaml
  18. from . import logging
  19. from .errors import raise_key_not_found_error
  20. from .file_interface import custom_open
  21. __all__ = ["get_config"]
  22. class AttrDict(dict):
  23. """Attr Dict"""
  24. def __getattr__(self, key):
  25. if key in self:
  26. return self[key]
  27. else:
  28. raise raise_key_not_found_error(key, self)
  29. def __setattr__(self, key, value):
  30. if key in self.__dict__:
  31. self.__dict__[key] = value
  32. else:
  33. self[key] = value
  34. def __deepcopy__(self, content):
  35. return copy.deepcopy(dict(self))
  36. def create_attr_dict(yaml_config):
  37. """create attr dict"""
  38. from ast import literal_eval
  39. for key, value in yaml_config.items():
  40. if type(value) is dict:
  41. yaml_config[key] = value = AttrDict(value)
  42. if isinstance(value, str):
  43. try:
  44. value = literal_eval(value)
  45. except BaseException:
  46. pass
  47. if isinstance(value, AttrDict):
  48. create_attr_dict(yaml_config[key])
  49. else:
  50. yaml_config[key] = value
  51. def parse_config(cfg_file):
  52. """Load a config file into AttrDict"""
  53. with custom_open(cfg_file, "r") as fopen:
  54. yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader))
  55. create_attr_dict(yaml_config)
  56. return yaml_config
  57. def print_dict(d, delimiter=0):
  58. """
  59. Recursively visualize a dict and
  60. indenting according by the relationship of keys.
  61. """
  62. placeholder = "-" * 60
  63. for k, v in sorted(d.items()):
  64. if isinstance(v, dict):
  65. logging.info("{}{} : ".format(delimiter * " ", k))
  66. print_dict(v, delimiter + 4)
  67. elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
  68. logging.info("{}{} : ".format(delimiter * " ", k))
  69. for value in v:
  70. print_dict(value, delimiter + 4)
  71. else:
  72. logging.info("{}{} : {}".format(delimiter * " ", k, v))
  73. if k.isupper():
  74. logging.info(placeholder)
  75. def print_config(config):
  76. """
  77. visualize configs
  78. Arguments:
  79. config: configs
  80. """
  81. logging.advertise()
  82. print_dict(config)
  83. def override(dl, ks, v):
  84. """
  85. Recursively replace dict of list
  86. Args:
  87. dl(dict or list): dict or list to be replaced
  88. ks(list): list of keys
  89. v(str): value to be replaced
  90. """
  91. def parse_str(s):
  92. """convert str type value
  93. to None type if it is "None",
  94. to bool type if it means True or False,
  95. to int type if it can be eval().
  96. """
  97. if s in ("None"):
  98. return None
  99. elif s in ("TRUE", "True", "true", "T", "t"):
  100. return True
  101. elif s in ("FALSE", "False", "false", "F", "f"):
  102. return False
  103. try:
  104. return eval(v)
  105. except Exception:
  106. return s
  107. assert isinstance(dl, (list, dict)), "{} should be a list or a dict"
  108. assert len(ks) > 0, "length of keys should larger than 0"
  109. if isinstance(dl, list):
  110. k = parse_str(ks[0])
  111. if len(ks) == 1:
  112. assert k < len(dl), "index({}) out of range({})".format(k, dl)
  113. dl[k] = parse_str(v)
  114. else:
  115. override(dl[k], ks[1:], v)
  116. else:
  117. if len(ks) == 1:
  118. # assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
  119. if not ks[0] in dl:
  120. logging.warning(f"A new field ({ks[0]}) detected!")
  121. dl[ks[0]] = parse_str(v)
  122. else:
  123. if ks[0] not in dl.keys():
  124. dl[ks[0]] = {}
  125. logging.warning(f"A new Series field ({ks[0]}) detected!")
  126. override(dl[ks[0]], ks[1:], v)
  127. def override_config(config, options=None):
  128. """
  129. Recursively override the config
  130. Args:
  131. config(dict): dict to be replaced
  132. options(list): list of pairs(key0.key1.idx.key2=value)
  133. such as: [
  134. 'topk=2',
  135. 'VALID.transforms.1.ResizeImage.resize_short=300'
  136. ]
  137. Returns:
  138. config(dict): replaced config
  139. """
  140. if options is not None:
  141. for opt in options:
  142. assert isinstance(opt, str), "option({}) should be a str".format(opt)
  143. assert "=" in opt, (
  144. "option({}) should contain a ="
  145. "to distinguish between key and value".format(opt)
  146. )
  147. pair = opt.split("=")
  148. assert len(pair) == 2, "there can be only a = in the option"
  149. key, value = pair
  150. keys = key.split(".")
  151. override(config, keys, value)
  152. return config
  153. def get_config(fname, overrides=None, show=False):
  154. """
  155. Read config from file
  156. """
  157. assert os.path.exists(fname), "config file({}) is not exist".format(fname)
  158. config = parse_config(fname)
  159. override_config(config, overrides)
  160. if show:
  161. print_config(config)
  162. # check_config(config)
  163. return config
  164. def parse_args():
  165. """parse args"""
  166. parser = argparse.ArgumentParser("PaddleX script")
  167. parser.add_argument(
  168. "-c",
  169. "--config",
  170. type=str,
  171. default="configs/config.yaml",
  172. help="config file path",
  173. )
  174. parser.add_argument(
  175. "-o",
  176. "--override",
  177. action="append",
  178. default=[],
  179. help="config options to be overridden",
  180. )
  181. parser.add_argument(
  182. "-p",
  183. "--profiler_options",
  184. type=str,
  185. default=None,
  186. help='The option of profiler, which should be in format "key1=value1;key2=value2;key3=value3".',
  187. )
  188. args = parser.parse_args()
  189. return args