| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import abc
- import inspect
- import functools
- from .....utils.misc import AutoRegisterABCMetaClass
- class _KeyMissingError(Exception):
- """ _KeyMissingError """
- pass
- class _NodeMeta(AutoRegisterABCMetaClass):
- """ _Node Meta Class """
- def __new__(cls, name, bases, attrs):
- def _deco(init_func):
- @functools.wraps(init_func)
- def _wrapper(self, *args, **kwargs):
- if not hasattr(self, '_raw_args'):
- sig = inspect.signature(init_func)
- bnd_args = sig.bind(self, *args, **kwargs)
- raw_args = bnd_args.arguments
- self_key = next(iter(raw_args.keys()))
- raw_args.pop(self_key)
- setattr(self, '_raw_args', raw_args)
- ret = init_func(self, *args, **kwargs)
- return ret
- return _wrapper
- if '__init__' in attrs:
- old_init_func = attrs['__init__']
- attrs['__init__'] = _deco(old_init_func)
- return super().__new__(cls, name, bases, attrs)
- class Node(metaclass=_NodeMeta):
- """ Node Class """
- @classmethod
- @abc.abstractmethod
- def get_input_keys(cls):
- """ get input keys """
- raise NotImplementedError
- @classmethod
- @abc.abstractmethod
- def get_output_keys(cls):
- """ get output keys """
- raise NotImplementedError
- @classmethod
- def check_input_keys(cls, data):
- """ check input keys """
- required_keys = cls.get_input_keys()
- cls._check_keys(data, required_keys, 'input')
- @classmethod
- def check_output_keys(cls, data):
- """ check output keys """
- required_keys = cls.get_output_keys()
- cls._check_keys(data, required_keys, 'output')
- @classmethod
- def _check_keys(cls, data, required_keys, tag):
- """ check keys """
- if len(required_keys) == 0:
- return
- if isinstance(required_keys[0], list):
- if not all(isinstance(ele, list) for ele in required_keys):
- raise TypeError
- for group in required_keys:
- try:
- cls._check_keys(data, group, tag)
- except _KeyMissingError:
- pass
- else:
- break
- else:
- raise _KeyMissingError(
- f"The {tag} does not contain the keys required by `{cls.__name__}` object."
- )
- else:
- for key in required_keys:
- if key not in data:
- raise _KeyMissingError(
- f"{repr(key)} is a required key in {tag} for `{cls.__name__}` object, but not found."
- )
- def __repr__(self):
- # TODO: Use fully qualified name which is globally unique
- def _format_args(args_dict):
- """ format arguments
- Refer to https://github.com/albumentations-team/albumentations/blob/\
- e3b47b3a127f92541cfeb16abbb44a6f8bf79cc8/albumentations/core/utils.py#L30
- """
- formatted_args = []
- for k, v in args_dict.items():
- if isinstance(v, str):
- v = f"'{v}'"
- formatted_args.append(f"{k}={v}")
- return ', '.join(formatted_args)
- return '{}({})'.format(self.__class__.__name__,
- _format_args(getattr(self, '_raw_args', {})))
|