misc.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import threading
  16. from abc import ABCMeta
  17. from .errors import raise_class_not_found_error, raise_no_entity_registered_error, DuplicateRegistrationError
  18. from .logging import *
  19. def abspath(path: str):
  20. """get absolute path
  21. Args:
  22. path (str): the relative path
  23. Returns:
  24. str: the absolute path
  25. """
  26. return os.path.abspath(path)
  27. class CachedProperty(object):
  28. """
  29. A property that is only computed once per instance and then replaces itself
  30. with an ordinary attribute.
  31. The implementation refers to
  32. https://github.com/pydanny/cached-property/blob/master/cached_property.py .
  33. Note that this implementation does NOT work in multi-thread or coroutine
  34. senarios.
  35. """
  36. def __init__(self, func):
  37. super().__init__()
  38. self.func = func
  39. self.__doc__ = getattr(func, '__doc__', '')
  40. def __get__(self, obj, cls):
  41. if obj is None:
  42. return self
  43. val = self.func(obj)
  44. # Hack __dict__ of obj to inject the value
  45. # Note that this is only executed once
  46. obj.__dict__[self.func.__name__] = val
  47. return val
  48. class Constant(object):
  49. """ Constant """
  50. def __init__(self, val):
  51. super().__init__()
  52. self.val = val
  53. def __get__(self, obj, type_=None):
  54. return self.val
  55. def __set__(self, obj, val):
  56. raise AttributeError("The value of a constant cannot be modified!")
  57. class Singleton(type):
  58. """singleton meta class
  59. Args:
  60. type (class): type
  61. Returns:
  62. class: meta class
  63. """
  64. _insts = {}
  65. _lock = threading.Lock()
  66. def __call__(cls, *args, **kwargs):
  67. if cls not in cls._insts:
  68. with cls._lock:
  69. if cls not in cls._insts:
  70. cls._insts[cls] = super().__call__(*args, **kwargs)
  71. return cls._insts[cls]
  72. class AutoRegisterMetaClass(type):
  73. """meta class that automatically registry subclass to its baseclass
  74. Args:
  75. type (class): type
  76. Returns:
  77. class: meta class
  78. """
  79. __model_type_attr_name = 'support_models'
  80. __base_class_flag = '__is_base'
  81. __registered_map_name = '__registered_map'
  82. def __new__(mcs, name, bases, attrs):
  83. cls = super().__new__(mcs, name, bases, attrs)
  84. mcs.__register_model_entity(bases, cls, attrs)
  85. return cls
  86. @classmethod
  87. def __register_model_entity(mcs, bases, cls, attrs):
  88. if bases:
  89. for base in bases:
  90. base_cls = mcs.__find_base_class(base)
  91. if base_cls:
  92. mcs.__register_to_base_class(base_cls, cls)
  93. @classmethod
  94. def __find_base_class(mcs, cls):
  95. is_base_flag = mcs.__base_class_flag
  96. if is_base_flag.startswith("__"):
  97. is_base_flag = f"_{cls.__name__}" + is_base_flag
  98. if getattr(cls, is_base_flag, False):
  99. return cls
  100. for base in cls.__bases__:
  101. base_cls = mcs.__find_base_class(base)
  102. if base_cls:
  103. return base_cls
  104. return None
  105. @classmethod
  106. def __register_to_base_class(mcs, base, cls):
  107. cls_entity_name = getattr(cls, mcs.__model_type_attr_name, cls.__name__)
  108. if isinstance(cls_entity_name, str):
  109. cls_entity_name = [cls_entity_name]
  110. records = getattr(base, mcs.__registered_map_name, {})
  111. for name in cls_entity_name:
  112. if name in records and records[name] is not cls:
  113. raise DuplicateRegistrationError(
  114. f"The name(`{name}`) duplicated registration! The class entities are: `{cls.__name__}` and \
  115. `{records[name].__name__}`.")
  116. records[name] = cls
  117. debug(
  118. f"The class entity({cls.__name__}) has been register as name(`{name}`)."
  119. )
  120. setattr(base, mcs.__registered_map_name, records)
  121. def all(cls):
  122. """ get all subclass """
  123. if not hasattr(cls, type(cls).__registered_map_name):
  124. raise_no_entity_registered_error(cls)
  125. return getattr(cls, type(cls).__registered_map_name)
  126. def get(cls, name: str):
  127. """ get the registried class by name """
  128. all_entities = cls.all()
  129. if name not in all_entities:
  130. raise_class_not_found_error(name, cls, all_entities)
  131. return all_entities[name]
  132. class AutoRegisterABCMetaClass(ABCMeta, AutoRegisterMetaClass):
  133. """ AutoRegisterABCMetaClass """
  134. pass