cache.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. from pathlib import Path
  17. import inspect
  18. import functools
  19. import pickle
  20. import hashlib
  21. import tempfile
  22. import atexit
  23. import filelock
  24. from . import logging
  25. DEFAULT_CACHE_DIR = osp.abspath(osp.join(os.path.expanduser("~"), ".paddlex"))
  26. CACHE_DIR = os.environ.get("PADDLE_PDX_CACHE_HOME", DEFAULT_CACHE_DIR)
  27. FUNC_CACHE_DIR = osp.join(CACHE_DIR, "func_ret")
  28. FILE_LOCK_DIR = osp.join(CACHE_DIR, "locks")
  29. TEMP_DIR = osp.join(CACHE_DIR, "temp")
  30. def create_cache_dir(*args, **kwargs):
  31. """create cache dir"""
  32. # `args` and `kwargs` reserved for extension
  33. os.makedirs(CACHE_DIR, exist_ok=True)
  34. os.makedirs(FUNC_CACHE_DIR, exist_ok=True)
  35. os.makedirs(FILE_LOCK_DIR, exist_ok=True)
  36. # TODO: Ensure permission
  37. def get_cache_dir(*args, **kwargs):
  38. """get cache dir"""
  39. # `args` and `kwargs` reserved for extension
  40. return CACHE_DIR
  41. def persist(cond=None):
  42. """persist"""
  43. # FIXME: Current implementation creates files in cache dir and we do
  44. # not set a limit on number of files
  45. # TODO: Faster implementation and support more arg types
  46. FILENAME_PATTERN = "persist_{key}.pkl"
  47. SUPPORTED_ARG_TYPES = (str, int, float)
  48. if cond is None:
  49. cond = lambda ret: ret is not None
  50. def _to_bytes(obj):
  51. return str(obj).encode("utf-8")
  52. def _make_key(func, bnd_args):
  53. # Use MD5 algorithm to make deterministic hashing
  54. # Note that the object-to-bytes conversion should be deterministic,
  55. # we ensure this by restricting types of arguments.
  56. m = hashlib.md5()
  57. m.update(_to_bytes(osp.realpath(inspect.getsourcefile(func))))
  58. m.update(_to_bytes(func.__name__))
  59. for k, v in bnd_args.arguments.items():
  60. if not isinstance(v, SUPPORTED_ARG_TYPES):
  61. raise TypeError(
  62. f"{repr(k)}: {v}, {type(v)} is unhashable or not a supported type."
  63. )
  64. m.update(_to_bytes(k))
  65. m.update(_to_bytes(v))
  66. hash_ = m.hexdigest()
  67. return hash_
  68. def _deco(func):
  69. @functools.wraps(func)
  70. def _wrapper(*args, **kwargs):
  71. sig = inspect.signature(func)
  72. bnd_args = sig.bind(*args, **kwargs)
  73. bnd_args.apply_defaults()
  74. key = _make_key(func, bnd_args)
  75. cache_file_path = osp.join(
  76. FUNC_CACHE_DIR, FILENAME_PATTERN.format(key=str(key))
  77. )
  78. lock = filelock.FileLock(osp.join(FILE_LOCK_DIR, f"{key}.lock"))
  79. with lock:
  80. if osp.exists(cache_file_path):
  81. with open(cache_file_path, "rb") as f:
  82. ret = pickle.load(f)
  83. else:
  84. ret = func(*args, **kwargs)
  85. if cond(ret):
  86. with open(cache_file_path, "wb") as f:
  87. pickle.dump(ret, f)
  88. return ret
  89. return _wrapper
  90. return _deco
  91. class TempFileManager:
  92. def __init__(self):
  93. self.temp_files = []
  94. Path(TEMP_DIR).mkdir(parents=True, exist_ok=True)
  95. atexit.register(self.cleanup)
  96. def create_temp_file(self, **kwargs):
  97. temp_file = tempfile.NamedTemporaryFile(delete=False, dir=TEMP_DIR, **kwargs)
  98. self.temp_files.append(temp_file)
  99. return temp_file
  100. def cleanup(self):
  101. for temp_file in self.temp_files:
  102. try:
  103. temp_file.close()
  104. os.remove(temp_file.name)
  105. except FileNotFoundError as e:
  106. pass
  107. self.temp_files = []
  108. class TempFileContextManager:
  109. def __init__(self, manager, **kwargs):
  110. self.manager = manager
  111. self.kwargs = kwargs
  112. self.temp_file = None
  113. def __enter__(self):
  114. self.temp_file = self.manager.create_temp_file(**self.kwargs)
  115. return self.temp_file
  116. def __exit__(self, exc_type, exc_value, traceback):
  117. if self.temp_file:
  118. self.temp_file.close()
  119. def temp_file_context(self, **kwargs):
  120. return self.TempFileContextManager(self, **kwargs)
  121. temp_file_manager = TempFileManager()