cache.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import inspect
  17. import functools
  18. import pickle
  19. import hashlib
  20. import filelock
  21. DEFAULT_CACHE_DIR = osp.abspath(osp.join(os.path.expanduser('~'), '.paddlex'))
  22. CACHE_DIR = os.environ.get('PADDLE_PDX_CACHE_HOME', DEFAULT_CACHE_DIR)
  23. FUNC_CACHE_DIR = osp.join(CACHE_DIR, 'func_ret')
  24. FILE_LOCK_DIR = osp.join(CACHE_DIR, 'locks')
  25. def create_cache_dir(*args, **kwargs):
  26. """ create cache dir """
  27. # `args` and `kwargs` reserved for extension
  28. os.makedirs(CACHE_DIR, exist_ok=True)
  29. os.makedirs(FUNC_CACHE_DIR, exist_ok=True)
  30. os.makedirs(FILE_LOCK_DIR, exist_ok=True)
  31. # TODO: Ensure permission
  32. def get_cache_dir(*args, **kwargs):
  33. """ get cache dir """
  34. # `args` and `kwargs` reserved for extension
  35. return CACHE_DIR
  36. def persist(cond=None):
  37. """ persist """
  38. # FIXME: Current implementation creates files in cache dir and we do
  39. # not set a limit on number of files
  40. # TODO: Faster implementation and support more arg types
  41. FILENAME_PATTERN = 'persist_{key}.pkl'
  42. SUPPORTED_ARG_TYPES = (str, int, float)
  43. if cond is None:
  44. cond = lambda ret: ret is not None
  45. def _to_bytes(obj):
  46. return str(obj).encode('utf-8')
  47. def _make_key(func, bnd_args):
  48. # Use MD5 algorithm to make deterministic hashing
  49. # Note that the object-to-bytes conversion should be deterministic,
  50. # we ensure this by restricting types of arguments.
  51. m = hashlib.md5()
  52. m.update(_to_bytes(osp.realpath(inspect.getsourcefile(func))))
  53. m.update(_to_bytes(func.__name__))
  54. for k, v in bnd_args.arguments.items():
  55. if not isinstance(v, SUPPORTED_ARG_TYPES):
  56. raise TypeError(
  57. f"{repr(k)}: {v}, {type(v)} is unhashable or not a supported type."
  58. )
  59. m.update(_to_bytes(k))
  60. m.update(_to_bytes(v))
  61. hash_ = m.hexdigest()
  62. return hash_
  63. def _deco(func):
  64. @functools.wraps(func)
  65. def _wrapper(*args, **kwargs):
  66. sig = inspect.signature(func)
  67. bnd_args = sig.bind(*args, **kwargs)
  68. bnd_args.apply_defaults()
  69. key = _make_key(func, bnd_args)
  70. cache_file_path = osp.join(
  71. FUNC_CACHE_DIR, FILENAME_PATTERN.format(key=str(key)))
  72. lock = filelock.FileLock(osp.join(FILE_LOCK_DIR, f"{key}.lock"))
  73. with lock:
  74. if osp.exists(cache_file_path):
  75. with open(cache_file_path, 'rb') as f:
  76. ret = pickle.load(f)
  77. else:
  78. ret = func(*args, **kwargs)
  79. if cond(ret):
  80. with open(cache_file_path, 'wb') as f:
  81. pickle.dump(ret, f)
  82. return ret
  83. return _wrapper
  84. return _deco