model.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from __future__ import annotations
  2. from .core import Encoding
  3. from .registry import get_encoding
  4. # TODO: these will likely be replaced by an API endpoint
  5. MODEL_PREFIX_TO_ENCODING: dict[str, str] = {
  6. "o1-": "o200k_base",
  7. "o3-": "o200k_base",
  8. "o4-mini-": "o200k_base",
  9. # chat
  10. "gpt-5-": "o200k_base",
  11. "gpt-4.5-": "o200k_base",
  12. "gpt-4.1-": "o200k_base",
  13. "chatgpt-4o-": "o200k_base",
  14. "gpt-4o-": "o200k_base", # e.g., gpt-4o-2024-05-13
  15. "gpt-4-": "cl100k_base", # e.g., gpt-4-0314, etc., plus gpt-4-32k
  16. "gpt-3.5-turbo-": "cl100k_base", # e.g, gpt-3.5-turbo-0301, -0401, etc.
  17. "gpt-35-turbo-": "cl100k_base", # Azure deployment name
  18. "gpt-oss-": "o200k_harmony",
  19. # fine-tuned
  20. "ft:gpt-4o": "o200k_base",
  21. "ft:gpt-4": "cl100k_base",
  22. "ft:gpt-3.5-turbo": "cl100k_base",
  23. "ft:davinci-002": "cl100k_base",
  24. "ft:babbage-002": "cl100k_base",
  25. }
  26. MODEL_TO_ENCODING: dict[str, str] = {
  27. # reasoning
  28. "o1": "o200k_base",
  29. "o3": "o200k_base",
  30. "o4-mini": "o200k_base",
  31. # chat
  32. "gpt-5": "o200k_base",
  33. "gpt-4.1": "o200k_base",
  34. "gpt-4o": "o200k_base",
  35. "gpt-4": "cl100k_base",
  36. "gpt-3.5-turbo": "cl100k_base",
  37. "gpt-3.5": "cl100k_base", # Common shorthand
  38. "gpt-35-turbo": "cl100k_base", # Azure deployment name
  39. # base
  40. "davinci-002": "cl100k_base",
  41. "babbage-002": "cl100k_base",
  42. # embeddings
  43. "text-embedding-ada-002": "cl100k_base",
  44. "text-embedding-3-small": "cl100k_base",
  45. "text-embedding-3-large": "cl100k_base",
  46. # DEPRECATED MODELS
  47. # text (DEPRECATED)
  48. "text-davinci-003": "p50k_base",
  49. "text-davinci-002": "p50k_base",
  50. "text-davinci-001": "r50k_base",
  51. "text-curie-001": "r50k_base",
  52. "text-babbage-001": "r50k_base",
  53. "text-ada-001": "r50k_base",
  54. "davinci": "r50k_base",
  55. "curie": "r50k_base",
  56. "babbage": "r50k_base",
  57. "ada": "r50k_base",
  58. # code (DEPRECATED)
  59. "code-davinci-002": "p50k_base",
  60. "code-davinci-001": "p50k_base",
  61. "code-cushman-002": "p50k_base",
  62. "code-cushman-001": "p50k_base",
  63. "davinci-codex": "p50k_base",
  64. "cushman-codex": "p50k_base",
  65. # edit (DEPRECATED)
  66. "text-davinci-edit-001": "p50k_edit",
  67. "code-davinci-edit-001": "p50k_edit",
  68. # old embeddings (DEPRECATED)
  69. "text-similarity-davinci-001": "r50k_base",
  70. "text-similarity-curie-001": "r50k_base",
  71. "text-similarity-babbage-001": "r50k_base",
  72. "text-similarity-ada-001": "r50k_base",
  73. "text-search-davinci-doc-001": "r50k_base",
  74. "text-search-curie-doc-001": "r50k_base",
  75. "text-search-babbage-doc-001": "r50k_base",
  76. "text-search-ada-doc-001": "r50k_base",
  77. "code-search-babbage-code-001": "r50k_base",
  78. "code-search-ada-code-001": "r50k_base",
  79. # open source
  80. "gpt2": "gpt2",
  81. "gpt-2": "gpt2", # Maintains consistency with gpt-4
  82. }
  83. def encoding_name_for_model(model_name: str) -> str:
  84. """Returns the name of the encoding used by a model.
  85. Raises a KeyError if the model name is not recognised.
  86. """
  87. encoding_name = None
  88. if model_name in MODEL_TO_ENCODING:
  89. encoding_name = MODEL_TO_ENCODING[model_name]
  90. else:
  91. # Check if the model matches a known prefix
  92. # Prefix matching avoids needing library updates for every model version release
  93. # Note that this can match on non-existent models (e.g., gpt-3.5-turbo-FAKE)
  94. for model_prefix, model_encoding_name in MODEL_PREFIX_TO_ENCODING.items():
  95. if model_name.startswith(model_prefix):
  96. return model_encoding_name
  97. if encoding_name is None:
  98. raise KeyError(
  99. f"Could not automatically map {model_name} to a tokeniser. "
  100. "Please use `tiktoken.get_encoding` to explicitly get the tokeniser you expect."
  101. ) from None
  102. return encoding_name
  103. def encoding_for_model(model_name: str) -> Encoding:
  104. """Returns the encoding used by a model.
  105. Raises a KeyError if the model name is not recognised.
  106. """
  107. return get_encoding(encoding_name_for_model(model_name))