_educational.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. """This is an educational implementation of the byte pair encoding algorithm."""
  2. from __future__ import annotations
  3. import collections
  4. import regex
  5. import tiktoken
  6. class SimpleBytePairEncoding:
  7. def __init__(self, *, pat_str: str, mergeable_ranks: dict[bytes, int]) -> None:
  8. """Creates an Encoding object."""
  9. # A regex pattern string that is used to split the input text
  10. self.pat_str = pat_str
  11. # A dictionary mapping token bytes to their ranks. The ranks correspond to merge priority
  12. self.mergeable_ranks = mergeable_ranks
  13. self._decoder = {token: token_bytes for token_bytes, token in mergeable_ranks.items()}
  14. self._pat = regex.compile(pat_str)
  15. def encode(self, text: str, visualise: str | None = "colour") -> list[int]:
  16. """Encodes a string into tokens.
  17. >>> enc.encode("hello world")
  18. [388, 372]
  19. """
  20. # Use the regex to split the text into (approximately) words
  21. words = self._pat.findall(text)
  22. tokens = []
  23. for word in words:
  24. # Turn each word into tokens, using the byte pair encoding algorithm
  25. word_bytes = word.encode("utf-8")
  26. word_tokens = bpe_encode(self.mergeable_ranks, word_bytes, visualise=visualise)
  27. tokens.extend(word_tokens)
  28. return tokens
  29. def decode_bytes(self, tokens: list[int]) -> bytes:
  30. """Decodes a list of tokens into bytes.
  31. >>> enc.decode_bytes([388, 372])
  32. b'hello world'
  33. """
  34. return b"".join(self._decoder[token] for token in tokens)
  35. def decode(self, tokens: list[int]) -> str:
  36. """Decodes a list of tokens into a string.
  37. Decoded bytes are not guaranteed to be valid UTF-8. In that case, we replace
  38. the invalid bytes with the replacement character "�".
  39. >>> enc.decode([388, 372])
  40. 'hello world'
  41. """
  42. return self.decode_bytes(tokens).decode("utf-8", errors="replace")
  43. def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]:
  44. """Decodes a list of tokens into a list of bytes.
  45. Useful for visualising how a string is tokenised.
  46. >>> enc.decode_tokens_bytes([388, 372])
  47. [b'hello', b' world']
  48. """
  49. return [self._decoder[token] for token in tokens]
  50. @staticmethod
  51. def train(training_data: str, vocab_size: int, pat_str: str):
  52. """Train a BPE tokeniser on some data!"""
  53. mergeable_ranks = bpe_train(data=training_data, vocab_size=vocab_size, pat_str=pat_str)
  54. return SimpleBytePairEncoding(pat_str=pat_str, mergeable_ranks=mergeable_ranks)
  55. @staticmethod
  56. def from_tiktoken(encoding):
  57. if isinstance(encoding, str):
  58. encoding = tiktoken.get_encoding(encoding)
  59. return SimpleBytePairEncoding(
  60. pat_str=encoding._pat_str, mergeable_ranks=encoding._mergeable_ranks
  61. )
  62. def bpe_encode(
  63. mergeable_ranks: dict[bytes, int], input: bytes, visualise: str | None = "colour"
  64. ) -> list[int]:
  65. parts = [bytes([b]) for b in input]
  66. while True:
  67. # See the intermediate merges play out!
  68. if visualise:
  69. if visualise in ["colour", "color"]:
  70. visualise_tokens(parts)
  71. elif visualise == "simple":
  72. print(parts)
  73. # Iterate over all pairs and find the pair we want to merge the most
  74. min_idx = None
  75. min_rank = None
  76. for i, pair in enumerate(zip(parts[:-1], parts[1:])):
  77. rank = mergeable_ranks.get(pair[0] + pair[1])
  78. if rank is not None and (min_rank is None or rank < min_rank):
  79. min_idx = i
  80. min_rank = rank
  81. # If there were no pairs we could merge, we're done!
  82. if min_rank is None:
  83. break
  84. assert min_idx is not None
  85. # Otherwise, merge that pair and leave the rest unchanged. Then repeat.
  86. parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :]
  87. if visualise:
  88. print()
  89. tokens = [mergeable_ranks[part] for part in parts]
  90. return tokens
  91. def bpe_train(
  92. data: str, vocab_size: int, pat_str: str, visualise: str | None = "colour"
  93. ) -> dict[bytes, int]:
  94. # First, add tokens for each individual byte value
  95. if vocab_size < 2**8:
  96. raise ValueError("vocab_size must be at least 256, so we can encode all bytes")
  97. ranks = {}
  98. for i in range(2**8):
  99. ranks[bytes([i])] = i
  100. # Splinter up our data into lists of bytes
  101. # data = "Hello world"
  102. # words = [
  103. # [b'H', b'e', b'l', b'l', b'o'],
  104. # [b' ', b'w', b'o', b'r', b'l', b'd']
  105. # ]
  106. words: list[list[bytes]] = [
  107. [bytes([b]) for b in word.encode("utf-8")] for word in regex.findall(pat_str, data)
  108. ]
  109. # Now, use our data to figure out which merges we should make
  110. while len(ranks) < vocab_size:
  111. # Find the most common pair. This will become our next token
  112. stats = collections.Counter()
  113. for piece in words:
  114. for pair in zip(piece[:-1], piece[1:]):
  115. stats[pair] += 1
  116. most_common_pair = max(stats, key=lambda x: stats[x])
  117. token_bytes = most_common_pair[0] + most_common_pair[1]
  118. token = len(ranks)
  119. # Add the new token!
  120. ranks[token_bytes] = token
  121. # Now merge that most common pair in all the words. That is, update our training data
  122. # to reflect our decision to make that pair into a new token.
  123. new_words = []
  124. for word in words:
  125. new_word = []
  126. i = 0
  127. while i < len(word) - 1:
  128. if (word[i], word[i + 1]) == most_common_pair:
  129. # We found our pair! Merge it
  130. new_word.append(token_bytes)
  131. i += 2
  132. else:
  133. new_word.append(word[i])
  134. i += 1
  135. if i == len(word) - 1:
  136. new_word.append(word[i])
  137. new_words.append(new_word)
  138. words = new_words
  139. # See the intermediate merges play out!
  140. if visualise:
  141. print(f"The current most common pair is {most_common_pair[0]} + {most_common_pair[1]}")
  142. print(f"So we made {token_bytes} our {len(ranks)}th token")
  143. if visualise in ["colour", "color"]:
  144. print("Now the first fifty words in our training data look like:")
  145. visualise_tokens([token for word in words[:50] for token in word])
  146. elif visualise == "simple":
  147. print("Now the first twenty words in our training data look like:")
  148. for word in words[:20]:
  149. print(word)
  150. print("\n")
  151. return ranks
  152. def visualise_tokens(token_values: list[bytes]) -> None:
  153. background = [f"\u001b[48;5;{i}m" for i in [167, 179, 185, 77, 80, 68, 134]]
  154. # If token boundaries do not occur at unicode character boundaries, it's unclear how best to
  155. # visualise the token. Here, we'll just use the unicode replacement character to represent some
  156. # fraction of a character.
  157. unicode_token_values = [x.decode("utf-8", errors="replace") for x in token_values]
  158. running_length = 0
  159. last_color = None
  160. for token in unicode_token_values:
  161. color = background[running_length % len(background)]
  162. if color == last_color:
  163. color = background[(running_length + 1) % len(background)]
  164. assert color != last_color
  165. last_color = color
  166. running_length += len(token)
  167. print(color + token, end="")
  168. print("\u001b[0m")
  169. def train_simple_encoding():
  170. gpt2_pattern = (
  171. r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
  172. )
  173. with open(__file__) as f:
  174. data = f.read()
  175. enc = SimpleBytePairEncoding.train(data, vocab_size=600, pat_str=gpt2_pattern)
  176. print("This is the sequence of merges performed in order to encode 'hello world':")
  177. tokens = enc.encode("hello world")
  178. assert enc.decode(tokens) == "hello world"
  179. assert enc.decode_bytes(tokens) == b"hello world"
  180. assert enc.decode_tokens_bytes(tokens) == [b"hello", b" world"]
  181. return enc