load.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. from __future__ import annotations
  2. import base64
  3. import hashlib
  4. import os
  5. def read_file(blobpath: str) -> bytes:
  6. if "://" not in blobpath:
  7. with open(blobpath, "rb", buffering=0) as f:
  8. return f.read()
  9. if blobpath.startswith(("http://", "https://")):
  10. # avoiding blobfile for public files helps avoid auth issues, like MFA prompts.
  11. import requests
  12. resp = requests.get(blobpath)
  13. resp.raise_for_status()
  14. return resp.content
  15. try:
  16. import blobfile
  17. except ImportError as e:
  18. raise ImportError(
  19. "blobfile is not installed. Please install it by running `pip install blobfile`."
  20. ) from e
  21. with blobfile.BlobFile(blobpath, "rb") as f:
  22. return f.read()
  23. def check_hash(data: bytes, expected_hash: str) -> bool:
  24. actual_hash = hashlib.sha256(data).hexdigest()
  25. return actual_hash == expected_hash
  26. def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes:
  27. user_specified_cache = True
  28. if "TIKTOKEN_CACHE_DIR" in os.environ:
  29. cache_dir = os.environ["TIKTOKEN_CACHE_DIR"]
  30. elif "DATA_GYM_CACHE_DIR" in os.environ:
  31. cache_dir = os.environ["DATA_GYM_CACHE_DIR"]
  32. else:
  33. import tempfile
  34. cache_dir = os.path.join(tempfile.gettempdir(), "data-gym-cache")
  35. user_specified_cache = False
  36. if cache_dir == "":
  37. # disable caching
  38. return read_file(blobpath)
  39. cache_key = hashlib.sha1(blobpath.encode()).hexdigest()
  40. cache_path = os.path.join(cache_dir, cache_key)
  41. if os.path.exists(cache_path):
  42. with open(cache_path, "rb", buffering=0) as f:
  43. data = f.read()
  44. if expected_hash is None or check_hash(data, expected_hash):
  45. return data
  46. # the cached file does not match the hash, remove it and re-fetch
  47. try:
  48. os.remove(cache_path)
  49. except OSError:
  50. pass
  51. contents = read_file(blobpath)
  52. if expected_hash and not check_hash(contents, expected_hash):
  53. raise ValueError(
  54. f"Hash mismatch for data downloaded from {blobpath} (expected {expected_hash}). "
  55. f"This may indicate a corrupted download. Please try again."
  56. )
  57. import uuid
  58. try:
  59. os.makedirs(cache_dir, exist_ok=True)
  60. tmp_filename = cache_path + "." + str(uuid.uuid4()) + ".tmp"
  61. with open(tmp_filename, "wb") as f:
  62. f.write(contents)
  63. os.rename(tmp_filename, cache_path)
  64. except OSError:
  65. # don't raise if we can't write to the default cache, e.g. issue #75
  66. if user_specified_cache:
  67. raise
  68. return contents
  69. def data_gym_to_mergeable_bpe_ranks(
  70. vocab_bpe_file: str,
  71. encoder_json_file: str,
  72. vocab_bpe_hash: str | None = None,
  73. encoder_json_hash: str | None = None,
  74. clobber_one_byte_tokens: bool = False,
  75. ) -> dict[bytes, int]:
  76. # NB: do not add caching to this function
  77. rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]
  78. data_gym_byte_to_byte = {chr(b): b for b in rank_to_intbyte}
  79. n = 0
  80. for b in range(2**8):
  81. if b not in rank_to_intbyte:
  82. rank_to_intbyte.append(b)
  83. data_gym_byte_to_byte[chr(2**8 + n)] = b
  84. n += 1
  85. assert len(rank_to_intbyte) == 2**8
  86. # vocab_bpe contains the merges along with associated ranks
  87. vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode()
  88. bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]]
  89. def decode_data_gym(value: str) -> bytes:
  90. return bytes(data_gym_byte_to_byte[b] for b in value)
  91. # add the single byte tokens
  92. # if clobber_one_byte_tokens is True, we'll replace these with ones from the encoder json
  93. bpe_ranks = {bytes([b]): i for i, b in enumerate(rank_to_intbyte)}
  94. del rank_to_intbyte
  95. # add the merged tokens
  96. n = len(bpe_ranks)
  97. for first, second in bpe_merges:
  98. bpe_ranks[decode_data_gym(first) + decode_data_gym(second)] = n
  99. n += 1
  100. import json
  101. # check that the encoder file matches the merges file
  102. # this sanity check is important since tiktoken assumes that ranks are ordered the same
  103. # as merge priority
  104. encoder_json = json.loads(read_file_cached(encoder_json_file, encoder_json_hash))
  105. encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()}
  106. # drop these two special tokens if present, since they're not mergeable bpe tokens
  107. encoder_json_loaded.pop(b"<|endoftext|>", None)
  108. encoder_json_loaded.pop(b"<|startoftext|>", None)
  109. if clobber_one_byte_tokens:
  110. for k in encoder_json_loaded:
  111. if len(k) == 1:
  112. bpe_ranks[k] = encoder_json_loaded[k]
  113. assert bpe_ranks == encoder_json_loaded
  114. return bpe_ranks
  115. def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> None:
  116. try:
  117. import blobfile
  118. except ImportError as e:
  119. raise ImportError(
  120. "blobfile is not installed. Please install it by running `pip install blobfile`."
  121. ) from e
  122. with blobfile.BlobFile(tiktoken_bpe_file, "wb") as f:
  123. for token, rank in sorted(bpe_ranks.items(), key=lambda x: x[1]):
  124. f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")
  125. def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: str | None = None) -> dict[bytes, int]:
  126. # NB: do not add caching to this function
  127. contents = read_file_cached(tiktoken_bpe_file, expected_hash)
  128. ret = {}
  129. for line in contents.splitlines():
  130. if not line:
  131. continue
  132. try:
  133. token, rank = line.split()
  134. ret[base64.b64decode(token)] = int(rank)
  135. except Exception as e:
  136. raise ValueError(f"Error parsing line {line!r} in {tiktoken_bpe_file}") from e
  137. return ret