cache.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import atexit
  15. import functools
  16. import hashlib
  17. import inspect
  18. import os
  19. import os.path as osp
  20. import pickle
  21. import tempfile
  22. from pathlib import Path
  23. import filelock
  24. DEFAULT_CACHE_DIR = osp.abspath(osp.join(os.path.expanduser("~"), ".paddlex"))
  25. CACHE_DIR = os.environ.get("PADDLE_PDX_CACHE_HOME", DEFAULT_CACHE_DIR)
  26. FUNC_CACHE_DIR = osp.join(CACHE_DIR, "func_ret")
  27. FILE_LOCK_DIR = osp.join(CACHE_DIR, "locks")
  28. TEMP_DIR = osp.join(CACHE_DIR, "temp")
  29. def create_cache_dir(*args, **kwargs):
  30. """create cache dir"""
  31. # `args` and `kwargs` reserved for extension
  32. os.makedirs(CACHE_DIR, exist_ok=True)
  33. os.makedirs(FUNC_CACHE_DIR, exist_ok=True)
  34. os.makedirs(FILE_LOCK_DIR, exist_ok=True)
  35. # TODO: Ensure permission
  36. def get_cache_dir(*args, **kwargs):
  37. """get cache dir"""
  38. # `args` and `kwargs` reserved for extension
  39. return CACHE_DIR
  40. def persist(cond=None):
  41. """persist"""
  42. # FIXME: Current implementation creates files in cache dir and we do
  43. # not set a limit on number of files
  44. # TODO: Faster implementation and support more arg types
  45. FILENAME_PATTERN = "persist_{key}.pkl"
  46. SUPPORTED_ARG_TYPES = (str, int, float)
  47. if cond is None:
  48. cond = lambda ret: ret is not None
  49. def _to_bytes(obj):
  50. return str(obj).encode("utf-8")
  51. def _make_key(func, bnd_args):
  52. # Use MD5 algorithm to make deterministic hashing
  53. # Note that the object-to-bytes conversion should be deterministic,
  54. # we ensure this by restricting types of arguments.
  55. m = hashlib.md5()
  56. m.update(_to_bytes(osp.realpath(inspect.getsourcefile(func))))
  57. m.update(_to_bytes(func.__name__))
  58. for k, v in bnd_args.arguments.items():
  59. if not isinstance(v, SUPPORTED_ARG_TYPES):
  60. raise TypeError(
  61. f"{repr(k)}: {v}, {type(v)} is unhashable or not a supported type."
  62. )
  63. m.update(_to_bytes(k))
  64. m.update(_to_bytes(v))
  65. hash_ = m.hexdigest()
  66. return hash_
  67. def _deco(func):
  68. @functools.wraps(func)
  69. def _wrapper(*args, **kwargs):
  70. sig = inspect.signature(func)
  71. bnd_args = sig.bind(*args, **kwargs)
  72. bnd_args.apply_defaults()
  73. key = _make_key(func, bnd_args)
  74. cache_file_path = osp.join(
  75. FUNC_CACHE_DIR, FILENAME_PATTERN.format(key=str(key))
  76. )
  77. lock = filelock.FileLock(osp.join(FILE_LOCK_DIR, f"{key}.lock"))
  78. with lock:
  79. if osp.exists(cache_file_path):
  80. with open(cache_file_path, "rb") as f:
  81. ret = pickle.load(f)
  82. else:
  83. ret = func(*args, **kwargs)
  84. if cond(ret):
  85. with open(cache_file_path, "wb") as f:
  86. pickle.dump(ret, f)
  87. return ret
  88. return _wrapper
  89. return _deco
  90. class TempFileManager:
  91. def __init__(self):
  92. self.temp_files = []
  93. Path(TEMP_DIR).mkdir(parents=True, exist_ok=True)
  94. atexit.register(self.cleanup)
  95. def create_temp_file(self, **kwargs):
  96. temp_file = tempfile.NamedTemporaryFile(delete=False, dir=TEMP_DIR, **kwargs)
  97. self.temp_files.append(temp_file)
  98. return temp_file
  99. def cleanup(self):
  100. for temp_file in self.temp_files:
  101. try:
  102. temp_file.close()
  103. os.remove(temp_file.name)
  104. except FileNotFoundError:
  105. pass
  106. self.temp_files = []
  107. class TempFileContextManager:
  108. def __init__(self, manager, **kwargs):
  109. self.manager = manager
  110. self.kwargs = kwargs
  111. self.temp_file = None
  112. def __enter__(self):
  113. self.temp_file = self.manager.create_temp_file(**self.kwargs)
  114. return self.temp_file
  115. def __exit__(self, exc_type, exc_value, traceback):
  116. if self.temp_file:
  117. self.temp_file.close()
  118. def temp_file_context(self, **kwargs):
  119. return self.TempFileContextManager(self, **kwargs)
  120. temp_file_manager = TempFileManager()