subclass_register.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from abc import ABCMeta
  15. from . import logging
  16. from .errors import (
  17. raise_class_not_found_error,
  18. raise_no_entity_registered_error,
  19. DuplicateRegistrationError,
  20. )
  21. class AutoRegisterMetaClass(type):
  22. """meta class that automatically registry subclass to its baseclass
  23. Args:
  24. type (class): type
  25. Returns:
  26. class: meta class
  27. """
  28. __model_type_attr_name = "entities"
  29. __base_class_flag = "__is_base"
  30. __registered_map_name = "__registered_map"
  31. def __new__(mcs, name, bases, attrs):
  32. cls = super().__new__(mcs, name, bases, attrs)
  33. mcs.__register_model_entity(bases, cls, attrs)
  34. return cls
  35. @classmethod
  36. def __register_model_entity(mcs, bases, cls, attrs):
  37. if bases:
  38. for base in bases:
  39. base_cls = mcs.__find_base_class(base)
  40. if base_cls:
  41. mcs.__register_to_base_class(base_cls, cls)
  42. @classmethod
  43. def __find_base_class(mcs, cls):
  44. is_base_flag = mcs.__base_class_flag
  45. if is_base_flag.startswith("__"):
  46. is_base_flag = f"_{cls.__name__}" + is_base_flag
  47. if getattr(cls, is_base_flag, False):
  48. return cls
  49. for base in cls.__bases__:
  50. base_cls = mcs.__find_base_class(base)
  51. if base_cls:
  52. return base_cls
  53. return None
  54. @classmethod
  55. def __register_to_base_class(mcs, base, cls):
  56. cls_entity_name = getattr(cls, mcs.__model_type_attr_name, cls.__name__)
  57. if isinstance(cls_entity_name, str):
  58. cls_entity_name = [cls_entity_name]
  59. records = getattr(base, mcs.__registered_map_name, {})
  60. for name in cls_entity_name:
  61. if name in records and records[name] is not cls:
  62. raise DuplicateRegistrationError(
  63. f"The name(`{name}`) duplicated registration! The class entities are: `{cls.__name__}` and \
  64. `{records[name].__name__}`."
  65. )
  66. records[name] = cls
  67. logging.debug(
  68. f"The class entity({cls.__name__}) has been register as name(`{name}`)."
  69. )
  70. setattr(base, mcs.__registered_map_name, records)
  71. def all(cls):
  72. """get all subclass"""
  73. if not hasattr(cls, type(cls).__registered_map_name):
  74. raise_no_entity_registered_error(cls)
  75. return getattr(cls, type(cls).__registered_map_name)
  76. def get(cls, name: str):
  77. """get the registried class by name"""
  78. all_entities = cls.all()
  79. if name not in all_entities:
  80. raise_class_not_found_error(name, cls, all_entities)
  81. return all_entities[name]
  82. class AutoRegisterABCMetaClass(ABCMeta, AutoRegisterMetaClass):
  83. """AutoRegisterABCMetaClass"""
  84. pass