node.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import abc
  12. import inspect
  13. import functools
  14. from .....utils.misc import AutoRegisterABCMetaClass
  15. class _KeyMissingError(Exception):
  16. """ _KeyMissingError """
  17. pass
  18. class _NodeMeta(AutoRegisterABCMetaClass):
  19. """ _Node Meta Class """
  20. def __new__(cls, name, bases, attrs):
  21. def _deco(init_func):
  22. @functools.wraps(init_func)
  23. def _wrapper(self, *args, **kwargs):
  24. if not hasattr(self, '_raw_args'):
  25. sig = inspect.signature(init_func)
  26. bnd_args = sig.bind(self, *args, **kwargs)
  27. raw_args = bnd_args.arguments
  28. self_key = next(iter(raw_args.keys()))
  29. raw_args.pop(self_key)
  30. setattr(self, '_raw_args', raw_args)
  31. ret = init_func(self, *args, **kwargs)
  32. return ret
  33. return _wrapper
  34. if '__init__' in attrs:
  35. old_init_func = attrs['__init__']
  36. attrs['__init__'] = _deco(old_init_func)
  37. return super().__new__(cls, name, bases, attrs)
  38. class Node(metaclass=_NodeMeta):
  39. """ Node Class """
  40. @classmethod
  41. @abc.abstractmethod
  42. def get_input_keys(cls):
  43. """ get input keys """
  44. raise NotImplementedError
  45. @classmethod
  46. @abc.abstractmethod
  47. def get_output_keys(cls):
  48. """ get output keys """
  49. raise NotImplementedError
  50. @classmethod
  51. def check_input_keys(cls, data):
  52. """ check input keys """
  53. required_keys = cls.get_input_keys()
  54. cls._check_keys(data, required_keys, 'input')
  55. @classmethod
  56. def check_output_keys(cls, data):
  57. """ check output keys """
  58. required_keys = cls.get_output_keys()
  59. cls._check_keys(data, required_keys, 'output')
  60. @classmethod
  61. def _check_keys(cls, data, required_keys, tag):
  62. """ check keys """
  63. if len(required_keys) == 0:
  64. return
  65. if isinstance(required_keys[0], list):
  66. if not all(isinstance(ele, list) for ele in required_keys):
  67. raise TypeError
  68. for group in required_keys:
  69. try:
  70. cls._check_keys(data, group, tag)
  71. except _KeyMissingError:
  72. pass
  73. else:
  74. break
  75. else:
  76. raise _KeyMissingError(
  77. f"The {tag} does not contain the keys required by `{cls.__name__}` object."
  78. )
  79. else:
  80. for key in required_keys:
  81. if key not in data:
  82. raise _KeyMissingError(
  83. f"{repr(key)} is a required key in {tag} for `{cls.__name__}` object, but not found."
  84. )
  85. def __repr__(self):
  86. # TODO: Use fully qualified name which is globally unique
  87. def _format_args(args_dict):
  88. """ format arguments
  89. Refer to https://github.com/albumentations-team/albumentations/blob/\
  90. e3b47b3a127f92541cfeb16abbb44a6f8bf79cc8/albumentations/core/utils.py#L30
  91. """
  92. formatted_args = []
  93. for k, v in args_dict.items():
  94. if isinstance(v, str):
  95. v = f"'{v}'"
  96. formatted_args.append(f"{k}={v}")
  97. return ', '.join(formatted_args)
  98. return '{}({})'.format(self.__class__.__name__,
  99. _format_args(getattr(self, '_raw_args', {})))