config.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import copy
  16. import os
  17. import yaml
  18. from . import logging
  19. from .file_interface import custom_open
  20. __all__ = ["get_config"]
  21. class AttrDict(dict):
  22. """Attr Dict"""
  23. def __getattr__(self, key):
  24. if key in self:
  25. return self[key]
  26. else:
  27. raise AttributeError(key)
  28. def __setattr__(self, key, value):
  29. if key in self.__dict__:
  30. self.__dict__[key] = value
  31. else:
  32. self[key] = value
  33. def __deepcopy__(self, content):
  34. return copy.deepcopy(dict(self))
  35. def create_attr_dict(yaml_config):
  36. """create attr dict"""
  37. from ast import literal_eval
  38. for key, value in yaml_config.items():
  39. if type(value) is dict:
  40. yaml_config[key] = value = AttrDict(value)
  41. if isinstance(value, str):
  42. try:
  43. value = literal_eval(value)
  44. except BaseException:
  45. pass
  46. if isinstance(value, AttrDict):
  47. create_attr_dict(yaml_config[key])
  48. else:
  49. yaml_config[key] = value
  50. def parse_config(cfg_file):
  51. """Load a config file into AttrDict"""
  52. with custom_open(cfg_file, "r") as fopen:
  53. yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader))
  54. create_attr_dict(yaml_config)
  55. return yaml_config
  56. def print_dict(d, delimiter=0):
  57. """
  58. Recursively visualize a dict and
  59. indenting according by the relationship of keys.
  60. """
  61. placeholder = "-" * 60
  62. for k, v in sorted(d.items()):
  63. if isinstance(v, dict):
  64. logging.info("{}{} : ".format(delimiter * " ", k))
  65. print_dict(v, delimiter + 4)
  66. elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
  67. logging.info("{}{} : ".format(delimiter * " ", k))
  68. for value in v:
  69. print_dict(value, delimiter + 4)
  70. else:
  71. logging.info("{}{} : {}".format(delimiter * " ", k, v))
  72. if k.isupper():
  73. logging.info(placeholder)
  74. def print_config(config):
  75. """
  76. visualize configs
  77. Arguments:
  78. config: configs
  79. """
  80. logging.advertise()
  81. print_dict(config)
  82. def override(dl, ks, v):
  83. """
  84. Recursively replace dict of list
  85. Args:
  86. dl(dict or list): dict or list to be replaced
  87. ks(list): list of keys
  88. v(str): value to be replaced
  89. """
  90. def parse_str(s):
  91. """convert str type value
  92. to None type if it is "None",
  93. to bool type if it means True or False,
  94. to int type if it can be eval().
  95. """
  96. if s in ("None"):
  97. return None
  98. elif s in ("TRUE", "True", "true", "T", "t"):
  99. return True
  100. elif s in ("FALSE", "False", "false", "F", "f"):
  101. return False
  102. try:
  103. return eval(v)
  104. except Exception:
  105. return s
  106. assert isinstance(dl, (list, dict)), "{} should be a list or a dict"
  107. assert len(ks) > 0, "length of keys should larger than 0"
  108. if isinstance(dl, list):
  109. k = parse_str(ks[0])
  110. if len(ks) == 1:
  111. assert k < len(dl), "index({}) out of range({})".format(k, dl)
  112. dl[k] = parse_str(v)
  113. else:
  114. override(dl[k], ks[1:], v)
  115. else:
  116. if len(ks) == 1:
  117. # assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
  118. if not ks[0] in dl:
  119. logging.warning(f"A new field ({ks[0]}) detected!")
  120. dl[ks[0]] = parse_str(v)
  121. else:
  122. if ks[0] not in dl.keys():
  123. dl[ks[0]] = {}
  124. logging.warning(f"A new Series field ({ks[0]}) detected!")
  125. override(dl[ks[0]], ks[1:], v)
  126. def override_config(config, options=None):
  127. """
  128. Recursively override the config
  129. Args:
  130. config(dict): dict to be replaced
  131. options(list): list of pairs(key0.key1.idx.key2=value)
  132. such as: [
  133. 'topk=2',
  134. 'VALID.transforms.1.ResizeImage.resize_short=300'
  135. ]
  136. Returns:
  137. config(dict): replaced config
  138. """
  139. if options is not None:
  140. for opt in options:
  141. assert isinstance(opt, str), "option({}) should be a str".format(opt)
  142. assert "=" in opt, (
  143. "option({}) should contain a ="
  144. "to distinguish between key and value".format(opt)
  145. )
  146. pair = opt.split("=")
  147. assert len(pair) == 2, "there can be only a = in the option"
  148. key, value = pair
  149. keys = key.split(".")
  150. override(config, keys, value)
  151. return config
  152. def get_config(fname, overrides=None, show=False):
  153. """
  154. Read config from file
  155. """
  156. assert os.path.exists(fname), "config file({}) is not exist".format(fname)
  157. config = parse_config(fname)
  158. override_config(config, overrides)
  159. if show:
  160. print_config(config)
  161. # check_config(config)
  162. return config
  163. def parse_args():
  164. """parse args"""
  165. parser = argparse.ArgumentParser("PaddleX script")
  166. parser.add_argument(
  167. "-c",
  168. "--config",
  169. type=str,
  170. default="configs/config.yaml",
  171. help="config file path",
  172. )
  173. parser.add_argument(
  174. "-o",
  175. "--override",
  176. action="append",
  177. default=[],
  178. help="config options to be overridden",
  179. )
  180. parser.add_argument(
  181. "-p",
  182. "--profiler_options",
  183. type=str,
  184. default=None,
  185. help='The option of profiler, which should be in format "key1=value1;key2=value2;key3=value3".',
  186. )
  187. args = parser.parse_args()
  188. return args