cache.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. import os.path as osp
  13. import inspect
  14. import functools
  15. import pickle
  16. import hashlib
  17. import filelock
  18. DEFAULT_CACHE_DIR = osp.abspath(osp.join(os.path.expanduser('~'), '.paddlex'))
  19. CACHE_DIR = os.environ.get('PADDLE_PDX_CACHE_HOME', DEFAULT_CACHE_DIR)
  20. FUNC_CACHE_DIR = osp.join(CACHE_DIR, 'func_ret')
  21. FILE_LOCK_DIR = osp.join(CACHE_DIR, 'locks')
  22. def create_cache_dir(*args, **kwargs):
  23. """ create cache dir """
  24. # `args` and `kwargs` reserved for extension
  25. os.makedirs(CACHE_DIR, exist_ok=True)
  26. os.makedirs(FUNC_CACHE_DIR, exist_ok=True)
  27. os.makedirs(FILE_LOCK_DIR, exist_ok=True)
  28. # TODO: Ensure permission
  29. def get_cache_dir(*args, **kwargs):
  30. """ get cache dir """
  31. # `args` and `kwargs` reserved for extension
  32. return CACHE_DIR
  33. def persist(cond=None):
  34. """ persist """
  35. # FIXME: Current implementation creates files in cache dir and we do
  36. # not set a limit on number of files
  37. # TODO: Faster implementation and support more arg types
  38. FILENAME_PATTERN = 'persist_{key}.pkl'
  39. SUPPORTED_ARG_TYPES = (str, int, float)
  40. if cond is None:
  41. cond = lambda ret: ret is not None
  42. def _to_bytes(obj):
  43. return str(obj).encode('utf-8')
  44. def _make_key(func, bnd_args):
  45. # Use MD5 algorithm to make deterministic hashing
  46. # Note that the object-to-bytes conversion should be deterministic,
  47. # we ensure this by restricting types of arguments.
  48. m = hashlib.md5()
  49. m.update(_to_bytes(osp.realpath(inspect.getsourcefile(func))))
  50. m.update(_to_bytes(func.__name__))
  51. for k, v in bnd_args.arguments.items():
  52. if not isinstance(v, SUPPORTED_ARG_TYPES):
  53. raise TypeError(
  54. f"{repr(k)}: {v}, {type(v)} is unhashable or not a supported type."
  55. )
  56. m.update(_to_bytes(k))
  57. m.update(_to_bytes(v))
  58. hash_ = m.hexdigest()
  59. return hash_
  60. def _deco(func):
  61. @functools.wraps(func)
  62. def _wrapper(*args, **kwargs):
  63. sig = inspect.signature(func)
  64. bnd_args = sig.bind(*args, **kwargs)
  65. bnd_args.apply_defaults()
  66. key = _make_key(func, bnd_args)
  67. cache_file_path = osp.join(
  68. FUNC_CACHE_DIR, FILENAME_PATTERN.format(key=str(key)))
  69. lock = filelock.FileLock(osp.join(FILE_LOCK_DIR, f"{key}.lock"))
  70. with lock:
  71. if osp.exists(cache_file_path):
  72. with open(cache_file_path, 'rb') as f:
  73. ret = pickle.load(f)
  74. else:
  75. ret = func(*args, **kwargs)
  76. if cond(ret):
  77. with open(cache_file_path, 'wb') as f:
  78. pickle.dump(ret, f)
  79. return ret
  80. return _wrapper
  81. return _deco