misc.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import threading
  16. from abc import ABCMeta
  17. import numpy as np
  18. from .errors import (
  19. DuplicateRegistrationError,
  20. raise_class_not_found_error,
  21. raise_no_entity_registered_error,
  22. )
  23. from .logging import *
  24. def convert_and_remove_types(data):
  25. if isinstance(data, dict):
  26. return {
  27. k: convert_and_remove_types(v)
  28. for k, v in data.items()
  29. if not isinstance(v, type)
  30. }
  31. elif isinstance(data, list):
  32. return [convert_and_remove_types(v) for v in data]
  33. elif isinstance(data, np.ndarray):
  34. return data.tolist()
  35. elif isinstance(data, (np.float32, np.float64)):
  36. return float(data)
  37. elif isinstance(data, (np.int32, np.int64)):
  38. return int(data)
  39. elif isinstance(data, np.bool_):
  40. return bool(data)
  41. elif isinstance(data, (np.str_, np.unicode_)):
  42. return str(data)
  43. return data
  44. def abspath(path: str):
  45. """get absolute path
  46. Args:
  47. path (str): the relative path
  48. Returns:
  49. str: the absolute path
  50. """
  51. return os.path.abspath(path)
  52. class CachedProperty(object):
  53. """
  54. A property that is only computed once per instance and then replaces itself
  55. with an ordinary attribute.
  56. The implementation refers to
  57. https://github.com/pydanny/cached-property/blob/master/cached_property.py .
  58. Note that this implementation does NOT work in multi-thread or coroutine
  59. senarios.
  60. """
  61. def __init__(self, func):
  62. super().__init__()
  63. self.func = func
  64. self.__doc__ = getattr(func, "__doc__", "")
  65. def __get__(self, obj, cls):
  66. if obj is None:
  67. return self
  68. val = self.func(obj)
  69. # Hack __dict__ of obj to inject the value
  70. # Note that this is only executed once
  71. obj.__dict__[self.func.__name__] = val
  72. return val
  73. class Constant(object):
  74. """Constant"""
  75. def __init__(self, val):
  76. super().__init__()
  77. self.val = val
  78. def __get__(self, obj, type_=None):
  79. return self.val
  80. def __set__(self, obj, val):
  81. raise AttributeError("The value of a constant cannot be modified!")
  82. class Singleton(type):
  83. """singleton meta class
  84. Args:
  85. type (class): type
  86. Returns:
  87. class: meta class
  88. """
  89. _insts = {}
  90. _lock = threading.Lock()
  91. def __call__(cls, *args, **kwargs):
  92. if cls not in cls._insts:
  93. with cls._lock:
  94. if cls not in cls._insts:
  95. cls._insts[cls] = super().__call__(*args, **kwargs)
  96. return cls._insts[cls]
  97. # TODO(gaotingquan): has been mv to subclass_register.py
  98. class AutoRegisterMetaClass(type):
  99. """meta class that automatically registry subclass to its baseclass
  100. Args:
  101. type (class): type
  102. Returns:
  103. class: meta class
  104. """
  105. __model_type_attr_name = "entities"
  106. __base_class_flag = "__is_base"
  107. __registered_map_name = "__registered_map"
  108. def __new__(mcs, name, bases, attrs):
  109. cls = super().__new__(mcs, name, bases, attrs)
  110. mcs.__register_model_entity(bases, cls, attrs)
  111. return cls
  112. @classmethod
  113. def __register_model_entity(mcs, bases, cls, attrs):
  114. if bases:
  115. for base in bases:
  116. base_cls = mcs.__find_base_class(base)
  117. if base_cls:
  118. mcs.__register_to_base_class(base_cls, cls)
  119. @classmethod
  120. def __find_base_class(mcs, cls):
  121. is_base_flag = mcs.__base_class_flag
  122. if is_base_flag.startswith("__"):
  123. is_base_flag = f"_{cls.__name__}" + is_base_flag
  124. if getattr(cls, is_base_flag, False):
  125. return cls
  126. for base in cls.__bases__:
  127. base_cls = mcs.__find_base_class(base)
  128. if base_cls:
  129. return base_cls
  130. return None
  131. @classmethod
  132. def __register_to_base_class(mcs, base, cls):
  133. cls_entity_name = getattr(cls, mcs.__model_type_attr_name, cls.__name__)
  134. if isinstance(cls_entity_name, str):
  135. cls_entity_name = [cls_entity_name]
  136. records = getattr(base, mcs.__registered_map_name, {})
  137. for name in cls_entity_name:
  138. if name in records and records[name] is not cls:
  139. raise DuplicateRegistrationError(
  140. f"The name(`{name}`) duplicated registration! The class entities are: `{cls.__name__}` and \
  141. `{records[name].__name__}`."
  142. )
  143. records[name] = cls
  144. debug(
  145. f"The class entity({cls.__name__}) has been register as name(`{name}`)."
  146. )
  147. setattr(base, mcs.__registered_map_name, records)
  148. def all(cls):
  149. """get all subclass"""
  150. if not hasattr(cls, type(cls).__registered_map_name):
  151. raise_no_entity_registered_error(cls)
  152. return getattr(cls, type(cls).__registered_map_name)
  153. def get(cls, name: str):
  154. """get the registried class by name"""
  155. all_entities = cls.all()
  156. if name not in all_entities:
  157. raise_class_not_found_error(name, cls, all_entities)
  158. return all_entities[name]
  159. class AutoRegisterABCMetaClass(ABCMeta, AutoRegisterMetaClass):
  160. """AutoRegisterABCMetaClass"""