node.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import abc
  15. import inspect
  16. import functools
  17. from .....utils.misc import AutoRegisterABCMetaClass
  18. class _KeyMissingError(Exception):
  19. """ _KeyMissingError """
  20. pass
  21. class _NodeMeta(AutoRegisterABCMetaClass):
  22. """ _Node Meta Class """
  23. def __new__(cls, name, bases, attrs):
  24. def _deco(init_func):
  25. @functools.wraps(init_func)
  26. def _wrapper(self, *args, **kwargs):
  27. if not hasattr(self, '_raw_args'):
  28. sig = inspect.signature(init_func)
  29. bnd_args = sig.bind(self, *args, **kwargs)
  30. raw_args = bnd_args.arguments
  31. self_key = next(iter(raw_args.keys()))
  32. raw_args.pop(self_key)
  33. setattr(self, '_raw_args', raw_args)
  34. ret = init_func(self, *args, **kwargs)
  35. return ret
  36. return _wrapper
  37. if '__init__' in attrs:
  38. old_init_func = attrs['__init__']
  39. attrs['__init__'] = _deco(old_init_func)
  40. return super().__new__(cls, name, bases, attrs)
  41. class Node(metaclass=_NodeMeta):
  42. """ Node Class """
  43. @classmethod
  44. @abc.abstractmethod
  45. def get_input_keys(cls):
  46. """ get input keys """
  47. raise NotImplementedError
  48. @classmethod
  49. @abc.abstractmethod
  50. def get_output_keys(cls):
  51. """ get output keys """
  52. raise NotImplementedError
  53. @classmethod
  54. def check_input_keys(cls, data):
  55. """ check input keys """
  56. required_keys = cls.get_input_keys()
  57. cls._check_keys(data, required_keys, 'input')
  58. @classmethod
  59. def check_output_keys(cls, data):
  60. """ check output keys """
  61. required_keys = cls.get_output_keys()
  62. cls._check_keys(data, required_keys, 'output')
  63. @classmethod
  64. def _check_keys(cls, data, required_keys, tag):
  65. """ check keys """
  66. if len(required_keys) == 0:
  67. return
  68. if isinstance(required_keys[0], list):
  69. if not all(isinstance(ele, list) for ele in required_keys):
  70. raise TypeError
  71. for group in required_keys:
  72. try:
  73. cls._check_keys(data, group, tag)
  74. except _KeyMissingError:
  75. pass
  76. else:
  77. break
  78. else:
  79. raise _KeyMissingError(
  80. f"The {tag} does not contain the keys required by `{cls.__name__}` object."
  81. )
  82. else:
  83. for key in required_keys:
  84. if key not in data:
  85. raise _KeyMissingError(
  86. f"{repr(key)} is a required key in {tag} for `{cls.__name__}` object, but not found."
  87. )
  88. def __repr__(self):
  89. # TODO: Use fully qualified name which is globally unique
  90. def _format_args(args_dict):
  91. """ format arguments
  92. Refer to https://github.com/albumentations-team/albumentations/blob/\
  93. e3b47b3a127f92541cfeb16abbb44a6f8bf79cc8/albumentations/core/utils.py#L30
  94. """
  95. formatted_args = []
  96. for k, v in args_dict.items():
  97. if isinstance(v, str):
  98. v = f"'{v}'"
  99. formatted_args.append(f"{k}={v}")
  100. return ', '.join(formatted_args)
  101. return '{}({})'.format(self.__class__.__name__,
  102. _format_args(getattr(self, '_raw_args', {})))