registry.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from __future__ import annotations
  2. import functools
  3. import importlib
  4. import pkgutil
  5. import threading
  6. from typing import Any, Callable, Sequence
  7. import tiktoken_ext
  8. import tiktoken
  9. from tiktoken.core import Encoding
  10. _lock = threading.RLock()
  11. ENCODINGS: dict[str, Encoding] = {}
  12. ENCODING_CONSTRUCTORS: dict[str, Callable[[], dict[str, Any]]] | None = None
  13. @functools.lru_cache
  14. def _available_plugin_modules() -> Sequence[str]:
  15. # tiktoken_ext is a namespace package
  16. # submodules inside tiktoken_ext will be inspected for ENCODING_CONSTRUCTORS attributes
  17. # - we use namespace package pattern so `pkgutil.iter_modules` is fast
  18. # - it's a separate top-level package because namespace subpackages of non-namespace
  19. # packages don't quite do what you want with editable installs
  20. mods = []
  21. plugin_mods = pkgutil.iter_modules(tiktoken_ext.__path__, tiktoken_ext.__name__ + ".")
  22. for _, mod_name, _ in plugin_mods:
  23. mods.append(mod_name)
  24. return mods
  25. def _find_constructors() -> None:
  26. global ENCODING_CONSTRUCTORS
  27. with _lock:
  28. if ENCODING_CONSTRUCTORS is not None:
  29. return
  30. ENCODING_CONSTRUCTORS = {}
  31. try:
  32. for mod_name in _available_plugin_modules():
  33. mod = importlib.import_module(mod_name)
  34. try:
  35. constructors = mod.ENCODING_CONSTRUCTORS
  36. except AttributeError as e:
  37. raise ValueError(
  38. f"tiktoken plugin {mod_name} does not define ENCODING_CONSTRUCTORS"
  39. ) from e
  40. for enc_name, constructor in constructors.items():
  41. if enc_name in ENCODING_CONSTRUCTORS:
  42. raise ValueError(
  43. f"Duplicate encoding name {enc_name} in tiktoken plugin {mod_name}"
  44. )
  45. ENCODING_CONSTRUCTORS[enc_name] = constructor
  46. except Exception:
  47. # Ensure we idempotently raise errors
  48. ENCODING_CONSTRUCTORS = None
  49. raise
  50. def get_encoding(encoding_name: str) -> Encoding:
  51. if not isinstance(encoding_name, str):
  52. raise ValueError(f"Expected a string in get_encoding, got {type(encoding_name)}")
  53. if encoding_name in ENCODINGS:
  54. return ENCODINGS[encoding_name]
  55. with _lock:
  56. if encoding_name in ENCODINGS:
  57. return ENCODINGS[encoding_name]
  58. if ENCODING_CONSTRUCTORS is None:
  59. _find_constructors()
  60. assert ENCODING_CONSTRUCTORS is not None
  61. if encoding_name not in ENCODING_CONSTRUCTORS:
  62. raise ValueError(
  63. f"Unknown encoding {encoding_name}.\n"
  64. f"Plugins found: {_available_plugin_modules()}\n"
  65. f"tiktoken version: {tiktoken.__version__} (are you on latest?)"
  66. )
  67. constructor = ENCODING_CONSTRUCTORS[encoding_name]
  68. enc = Encoding(**constructor())
  69. ENCODINGS[encoding_name] = enc
  70. return enc
  71. def list_encoding_names() -> list[str]:
  72. with _lock:
  73. if ENCODING_CONSTRUCTORS is None:
  74. _find_constructors()
  75. assert ENCODING_CONSTRUCTORS is not None
  76. return list(ENCODING_CONSTRUCTORS)