misc.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import threading
  16. from abc import ABCMeta
  17. from .errors import (
  18. raise_class_not_found_error,
  19. raise_no_entity_registered_error,
  20. DuplicateRegistrationError,
  21. )
  22. from .logging import *
  23. def abspath(path: str):
  24. """get absolute path
  25. Args:
  26. path (str): the relative path
  27. Returns:
  28. str: the absolute path
  29. """
  30. return os.path.abspath(path)
  31. class CachedProperty(object):
  32. """
  33. A property that is only computed once per instance and then replaces itself
  34. with an ordinary attribute.
  35. The implementation refers to
  36. https://github.com/pydanny/cached-property/blob/master/cached_property.py .
  37. Note that this implementation does NOT work in multi-thread or coroutine
  38. senarios.
  39. """
  40. def __init__(self, func):
  41. super().__init__()
  42. self.func = func
  43. self.__doc__ = getattr(func, "__doc__", "")
  44. def __get__(self, obj, cls):
  45. if obj is None:
  46. return self
  47. val = self.func(obj)
  48. # Hack __dict__ of obj to inject the value
  49. # Note that this is only executed once
  50. obj.__dict__[self.func.__name__] = val
  51. return val
  52. class Constant(object):
  53. """Constant"""
  54. def __init__(self, val):
  55. super().__init__()
  56. self.val = val
  57. def __get__(self, obj, type_=None):
  58. return self.val
  59. def __set__(self, obj, val):
  60. raise AttributeError("The value of a constant cannot be modified!")
  61. class Singleton(type):
  62. """singleton meta class
  63. Args:
  64. type (class): type
  65. Returns:
  66. class: meta class
  67. """
  68. _insts = {}
  69. _lock = threading.Lock()
  70. def __call__(cls, *args, **kwargs):
  71. if cls not in cls._insts:
  72. with cls._lock:
  73. if cls not in cls._insts:
  74. cls._insts[cls] = super().__call__(*args, **kwargs)
  75. return cls._insts[cls]
  76. # TODO(gaotingquan): has been mv to subclass_register.py
  77. class AutoRegisterMetaClass(type):
  78. """meta class that automatically registry subclass to its baseclass
  79. Args:
  80. type (class): type
  81. Returns:
  82. class: meta class
  83. """
  84. __model_type_attr_name = "entities"
  85. __base_class_flag = "__is_base"
  86. __registered_map_name = "__registered_map"
  87. def __new__(mcs, name, bases, attrs):
  88. cls = super().__new__(mcs, name, bases, attrs)
  89. mcs.__register_model_entity(bases, cls, attrs)
  90. return cls
  91. @classmethod
  92. def __register_model_entity(mcs, bases, cls, attrs):
  93. if bases:
  94. for base in bases:
  95. base_cls = mcs.__find_base_class(base)
  96. if base_cls:
  97. mcs.__register_to_base_class(base_cls, cls)
  98. @classmethod
  99. def __find_base_class(mcs, cls):
  100. is_base_flag = mcs.__base_class_flag
  101. if is_base_flag.startswith("__"):
  102. is_base_flag = f"_{cls.__name__}" + is_base_flag
  103. if getattr(cls, is_base_flag, False):
  104. return cls
  105. for base in cls.__bases__:
  106. base_cls = mcs.__find_base_class(base)
  107. if base_cls:
  108. return base_cls
  109. return None
  110. @classmethod
  111. def __register_to_base_class(mcs, base, cls):
  112. cls_entity_name = getattr(cls, mcs.__model_type_attr_name, cls.__name__)
  113. if isinstance(cls_entity_name, str):
  114. cls_entity_name = [cls_entity_name]
  115. records = getattr(base, mcs.__registered_map_name, {})
  116. for name in cls_entity_name:
  117. if name in records and records[name] is not cls:
  118. raise DuplicateRegistrationError(
  119. f"The name(`{name}`) duplicated registration! The class entities are: `{cls.__name__}` and \
  120. `{records[name].__name__}`."
  121. )
  122. records[name] = cls
  123. debug(
  124. f"The class entity({cls.__name__}) has been register as name(`{name}`)."
  125. )
  126. setattr(base, mcs.__registered_map_name, records)
  127. def all(cls):
  128. """get all subclass"""
  129. if not hasattr(cls, type(cls).__registered_map_name):
  130. raise_no_entity_registered_error(cls)
  131. return getattr(cls, type(cls).__registered_map_name)
  132. def get(cls, name: str):
  133. """get the registried class by name"""
  134. all_entities = cls.all()
  135. if name not in all_entities:
  136. raise_class_not_found_error(name, cls, all_entities)
  137. return all_entities[name]
  138. class AutoRegisterABCMetaClass(ABCMeta, AutoRegisterMetaClass):
  139. """AutoRegisterABCMetaClass"""
  140. pass