misc.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. import threading
  13. from abc import ABCMeta
  14. from .errors import raise_class_not_found_error, DuplicateRegistrationError
  15. from .logging import *
  16. def abspath(path: str):
  17. """get absolute path
  18. Args:
  19. path (str): the relative path
  20. Returns:
  21. str: the absolute path
  22. """
  23. return os.path.abspath(path)
  24. class CachedProperty(object):
  25. """
  26. A property that is only computed once per instance and then replaces itself
  27. with an ordinary attribute.
  28. The implementation refers to
  29. https://github.com/pydanny/cached-property/blob/master/cached_property.py .
  30. Note that this implementation does NOT work in multi-thread or coroutine
  31. senarios.
  32. """
  33. def __init__(self, func):
  34. super().__init__()
  35. self.func = func
  36. self.__doc__ = getattr(func, '__doc__', '')
  37. def __get__(self, obj, cls):
  38. if obj is None:
  39. return self
  40. val = self.func(obj)
  41. # Hack __dict__ of obj to inject the value
  42. # Note that this is only executed once
  43. obj.__dict__[self.func.__name__] = val
  44. return val
  45. class Constant(object):
  46. """ Constant """
  47. def __init__(self, val):
  48. super().__init__()
  49. self.val = val
  50. def __get__(self, obj, type_=None):
  51. return self.val
  52. def __set__(self, obj, val):
  53. raise AttributeError("The value of a constant cannot be modified!")
  54. class Singleton(type):
  55. """singleton meta class
  56. Args:
  57. type (class): type
  58. Returns:
  59. class: meta class
  60. """
  61. _insts = {}
  62. _lock = threading.Lock()
  63. def __call__(cls, *args, **kwargs):
  64. if cls not in cls._insts:
  65. with cls._lock:
  66. if cls not in cls._insts:
  67. cls._insts[cls] = super().__call__(*args, **kwargs)
  68. return cls._insts[cls]
  69. class AutoRegisterMetaClass(type):
  70. """meta class that automatically registry subclass to its baseclass
  71. Args:
  72. type (class): type
  73. Returns:
  74. class: meta class
  75. """
  76. __model_type_attr_name = 'support_models'
  77. __base_class_flag = '__is_base'
  78. __registered_map_name = '__registered_map'
  79. def __new__(mcs, name, bases, attrs):
  80. cls = super().__new__(mcs, name, bases, attrs)
  81. mcs.__register_model_entity(bases, cls, attrs)
  82. return cls
  83. @classmethod
  84. def __register_model_entity(mcs, bases, cls, attrs):
  85. if bases:
  86. for base in bases:
  87. base_cls = mcs.__find_base_class(base)
  88. if base_cls:
  89. mcs.__register_to_base_class(base_cls, cls)
  90. @classmethod
  91. def __find_base_class(mcs, cls):
  92. is_base_flag = mcs.__base_class_flag
  93. if is_base_flag.startswith("__"):
  94. is_base_flag = f"_{cls.__name__}" + is_base_flag
  95. if getattr(cls, is_base_flag, False):
  96. return cls
  97. for base in cls.__bases__:
  98. base_cls = mcs.__find_base_class(base)
  99. if base_cls:
  100. return base_cls
  101. return None
  102. @classmethod
  103. def __register_to_base_class(mcs, base, cls):
  104. cls_entity_name = getattr(cls, mcs.__model_type_attr_name, cls.__name__)
  105. if isinstance(cls_entity_name, str):
  106. cls_entity_name = [cls_entity_name]
  107. records = getattr(base, mcs.__registered_map_name, {})
  108. for name in cls_entity_name:
  109. if name in records and records[name] is not cls:
  110. raise DuplicateRegistrationError(
  111. f"The name(`{name}`) duplicated registration! The class entities are: `{cls.__name__}` and \
  112. `{records[name].__name__}`.")
  113. records[name] = cls
  114. debug(
  115. f"The class entity({cls.__name__}) has been register as name(`{name}`)."
  116. )
  117. setattr(base, mcs.__registered_map_name, records)
  118. def all(cls):
  119. """ get all subclass """
  120. return getattr(cls, type(cls).__registered_map_name)
  121. def get(cls, name: str):
  122. """ get the registried class by name """
  123. all_entities = cls.all()
  124. if name not in all_entities:
  125. raise_class_not_found_error(name, cls, all_entities)
  126. return all_entities[name]
  127. class AutoRegisterABCMetaClass(ABCMeta, AutoRegisterMetaClass):
  128. """ AutoRegisterABCMetaClass """
  129. pass