config_helper.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. import copy
  13. import collections.abc
  14. import yaml
  15. class PPDetConfigMixin(object):
  16. """ PPDetConfigMixin """
  17. def load_config_literally(self, config_path):
  18. """ load_config_literally """
  19. # Adapted from
  20. # https://github.com/PaddlePaddle/PaddleDetection/blob/e3f8dd16bffca04060ec1edc388c5a618e15bbf8/ppdet/core/workspace.py#L77
  21. # XXX: This function relies on implementation details of PaddleDetection.
  22. BASE_KEY = '_BASE_'
  23. with open(config_path, 'r', encoding='utf-8') as f:
  24. dic = yaml.load(f, Loader=_PPDetSerializableLoader)
  25. if not isinstance(dic, dict):
  26. raise TypeError
  27. if BASE_KEY in dic:
  28. all_base_cfg = dict()
  29. base_ymls = list(dic[BASE_KEY])
  30. for base_yml in base_ymls:
  31. if base_yml.startswith("~"):
  32. base_yml = os.path.expanduser(base_yml)
  33. if not base_yml.startswith('/'):
  34. base_yml = os.path.join(
  35. os.path.dirname(config_path), base_yml)
  36. with open(base_yml, 'r', encoding='utf-8') as f:
  37. base_cfg = self.load_config_literally(base_yml)
  38. all_base_cfg = merge_dicts(base_cfg, all_base_cfg)
  39. del dic[BASE_KEY]
  40. return merge_dicts(dic, all_base_cfg)
  41. return dic
  42. def dump_literal_config(self, config_path, dic):
  43. """ dump_literal_config """
  44. with open(config_path, 'w', encoding='utf-8') as f:
  45. # XXX: We make an extra copy here by calling `dict()`
  46. # to ensure that `dic` can be represented.
  47. yaml.dump(dict(dic), f, Dumper=_PPDetSerializableDumper)
  48. def update_from_dict(self, src_dic, dst_dic):
  49. """ update_from_dict """
  50. return merge_dicts(src_dic, dst_dic)
  51. class _PPDetSerializableHandler(collections.abc.MutableMapping):
  52. """ _PPDetSerializableHandler """
  53. TYPE_KEY = '_type_'
  54. EMPTY_TAG = object()
  55. def __init__(self, tag=None, dic=None):
  56. super().__init__()
  57. if tag is None:
  58. tag = self.EMPTY_TAG
  59. if dic is None:
  60. dic = dict()
  61. self.tag = tag
  62. self.dic = dic
  63. def __repr__(self):
  64. # TODO: Prettier format
  65. return repr({self.TYPE_KEY: self.tag, ** self.dic})
  66. def __getitem__(self, key):
  67. if key == self.TYPE_KEY:
  68. return self.tag
  69. else:
  70. return self.dic[key]
  71. def __setitem__(self, key, val):
  72. if key == self.TYPE_KEY:
  73. self.tag = val
  74. else:
  75. self.dic[key] = val
  76. def __delitem__(self, key):
  77. if key == self.TYPE_KEY:
  78. self.tag = self.EMPTY_TAG
  79. else:
  80. del self.dic[key]
  81. def __len__(self):
  82. return len(self.dic) + 1
  83. def __iter__(self):
  84. if self.has_nonempty_tag():
  85. yield self.TYPE_KEY
  86. yield from self.dic
  87. def has_nonempty_tag(self):
  88. """ has_nonempty_tag """
  89. return self.tag != self.EMPTY_TAG
  90. @classmethod
  91. def is_convertible(cls, obj):
  92. """ is_convertible """
  93. if isinstance(obj, cls):
  94. return False
  95. elif isinstance(obj, collections.abc.Mapping):
  96. return cls.TYPE_KEY in obj
  97. else:
  98. return False
  99. @classmethod
  100. def build_from_dict(cls, dic):
  101. """ build_from_dict """
  102. dic = copy.deepcopy(dic)
  103. tag = dic.pop(cls.TYPE_KEY)
  104. return cls(tag=tag, dic=dic)
  105. def merge_dicts(src_dic, dst_dic):
  106. """ merge_dicts """
  107. # Refer to
  108. # https://github.com/PaddlePaddle/PaddleDetection/blob/e3f8dd16bffca04060ec1edc388c5a618e15bbf8/ppdet/core/workspace.py#L121
  109. # Additionally, this function deals with the case when `src_dic`
  110. # or `dst_dic` contains `_PPDetSerializableHandler` objects.
  111. def _update_sohandler(src_handler, dst_handler):
  112. """ _update_sohandler """
  113. dst_handler.update(src_handler)
  114. def _convert_to_sohandler_if_possible(obj):
  115. """ _convert_to_sohandler_if_possible """
  116. if _PPDetSerializableHandler.is_convertible(obj):
  117. return _PPDetSerializableHandler.build_from_dict(obj)
  118. else:
  119. return obj
  120. def _convert_dict_to_sohandler_with_tag(dic, tag):
  121. """ _convert_dict_to_sohandler_with_tag """
  122. return _PPDetSerializableHandler(tag, dic)
  123. for k, v in src_dic.items():
  124. v = _convert_to_sohandler_if_possible(v)
  125. if k not in dst_dic:
  126. dst_dic[k] = v
  127. else:
  128. dst_dic[k] = _convert_to_sohandler_if_possible(dst_dic[k])
  129. if isinstance(dst_dic[k], _PPDetSerializableHandler):
  130. if isinstance(v, _PPDetSerializableHandler):
  131. _update_sohandler(v, dst_dic[k])
  132. elif isinstance(v, collections.abc.Mapping):
  133. v = _convert_dict_to_sohandler_with_tag(v, dst_dic[k].tag)
  134. _update_sohandler(v, dst_dic[k])
  135. else:
  136. dst_dic[k] = v
  137. elif isinstance(dst_dic[k], collections.abc.Mapping):
  138. if isinstance(v, _PPDetSerializableHandler):
  139. dst_dic[k] = _convert_dict_to_sohandler_with_tag(dst_dic[k],
  140. v.tag)
  141. _update_sohandler(v, dst_dic[k])
  142. elif isinstance(v, collections.abc.Mapping):
  143. merge_dicts(v, dst_dic[k])
  144. else:
  145. dst_dic[k] = v
  146. else:
  147. dst_dic[k] = v
  148. return dst_dic
  149. class _PPDetSerializableConstructor(yaml.constructor.SafeConstructor):
  150. """ _PPDetSerializableConstructor """
  151. def construct_sohandler(self, tag_suffix, node):
  152. """ construct_sohandler """
  153. if not isinstance(node, yaml.nodes.MappingNode):
  154. raise TypeError("Currently, we can only handle a MappingNode.")
  155. mapping = self.construct_mapping(node)
  156. return _PPDetSerializableHandler(tag_suffix, mapping)
  157. class _PPDetSerializableLoader(_PPDetSerializableConstructor,
  158. yaml.loader.SafeLoader):
  159. """ _PPDetSerializableLoader """
  160. def __init__(self, stream):
  161. _PPDetSerializableConstructor.__init__(self)
  162. yaml.loader.SafeLoader.__init__(self, stream)
  163. class _PPDetSerializableRepresenter(yaml.representer.SafeRepresenter):
  164. """ _PPDetSerializableRepresenter """
  165. def represent_sohandler(self, data):
  166. """ represent_sohandler """
  167. # If `data` has empty tag, we represent `data.dic` as a dict
  168. if not data.has_nonempty_tag:
  169. return self.represent_dict(data.dic)
  170. else:
  171. # XXX: Manually represent a serializable object according to the rules defined in
  172. # https://github.com/PaddlePaddle/PaddleDetection/blob/e3f8dd16bffca04060ec1edc388c5a618e15bbf8/ppdet/core/config/yaml_helpers.py#L80
  173. # We prepend a '!' to reconstruct the complete tag
  174. tag = u'!' + data.tag
  175. return self.represent_mapping(tag, data.dic)
  176. class _PPDetSerializableDumper(_PPDetSerializableRepresenter,
  177. yaml.dumper.SafeDumper):
  178. """ _PPDetSerializableDumper """
  179. def __init__(self,
  180. stream,
  181. default_style=None,
  182. default_flow_style=False,
  183. canonical=None,
  184. indent=None,
  185. width=None,
  186. allow_unicode=None,
  187. line_break=None,
  188. encoding=None,
  189. explicit_start=None,
  190. explicit_end=None,
  191. version=None,
  192. tags=None,
  193. sort_keys=True):
  194. _PPDetSerializableRepresenter.__init__(
  195. self,
  196. default_style=default_style,
  197. default_flow_style=default_flow_style,
  198. sort_keys=sort_keys)
  199. yaml.dumper.SafeDumper.__init__(
  200. self,
  201. stream,
  202. default_style=default_style,
  203. default_flow_style=default_flow_style,
  204. canonical=canonical,
  205. indent=indent,
  206. width=width,
  207. allow_unicode=allow_unicode,
  208. line_break=line_break,
  209. encoding=encoding,
  210. explicit_start=explicit_start,
  211. explicit_end=explicit_end,
  212. version=version,
  213. tags=tags,
  214. sort_keys=sort_keys)
  215. def ignore_aliases(self, data):
  216. """ ignore_aliases """
  217. return True
  218. # We note that all custom tags defined in ppdet starts with a '!'.
  219. # We assume that all unknown tags in the config file corresponds to a serializable class defined in ppdet.
  220. _PPDetSerializableLoader.add_multi_constructor(
  221. u'!', _PPDetSerializableConstructor.construct_sohandler)
  222. _PPDetSerializableDumper.add_representer(
  223. _PPDetSerializableHandler,
  224. _PPDetSerializableRepresenter.represent_sohandler)