gpt_tokenizer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import os
  16. import shutil
  17. from functools import lru_cache
  18. from typing import Dict, Optional, Union
  19. import jieba
  20. import numpy as np
  21. import sentencepiece as spm
  22. import lazy_paddle as paddle
  23. import regex as re
  24. from .tokenizer_utils import PretrainedTokenizer
  25. from .tokenizer_utils_base import (
  26. AddedToken,
  27. BatchEncoding,
  28. EncodedInput,
  29. PaddingStrategy,
  30. )
  31. __all__ = [
  32. "GPTTokenizer",
  33. ]
  34. @lru_cache()
  35. def bytes_to_unicode():
  36. """
  37. Returns list of utf-8 byte and a corresponding list of unicode strings.
  38. The reversible bpe codes work on unicode strings.
  39. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
  40. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
  41. This is a signficant percentage of your normal, say, 32K bpe vocab.
  42. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
  43. And avoids mapping to whitespace/control characters the bpe code barfs on.
  44. """
  45. _chr = chr
  46. bs = (
  47. list(range(ord("!"), ord("~") + 1))
  48. + list(range(ord("¡"), ord("¬") + 1))
  49. + list(range(ord("®"), ord("ÿ") + 1))
  50. )
  51. cs = bs[:]
  52. n = 0
  53. for b in range(2**8):
  54. if b not in bs:
  55. bs.append(b)
  56. cs.append(2**8 + n)
  57. n += 1
  58. cs = [_chr(n) for n in cs]
  59. return dict(zip(bs, cs))
  60. def get_pairs(word):
  61. """Return set of symbol pairs in a word.
  62. Word is represented as tuple of symbols (symbols being variable-length strings).
  63. """
  64. pairs = set()
  65. prev_char = word[0]
  66. for char in word[1:]:
  67. pairs.add((prev_char, char))
  68. prev_char = char
  69. return pairs
  70. class GPTTokenizer(PretrainedTokenizer):
  71. """
  72. Constructs a GPT tokenizer based on byte-level Byte-Pair-Encoding.
  73. This tokenizer inherits from :class:`~paddlenlp.transformers.tokenizer_utils.PretrainedTokenizer`
  74. which contains most of the main methods. For more information regarding those methods,
  75. please refer to this superclass.
  76. Args:
  77. vocab_file (str):
  78. Path to the vocab file.
  79. The vocab file contains a mapping from vocabulary strings to indices.
  80. merges_file (str):
  81. Path to the merge file.
  82. The merge file is used to split the input sentence into "subword" units.
  83. The vocab file is then used to encode those units as intices.
  84. errors (str):
  85. Paradigm to follow when decoding bytes to UTF-8.
  86. Defaults to `'replace'`.
  87. max_len (int, optional):
  88. The maximum value of the input sequence length.
  89. Defaults to `None`.
  90. Examples:
  91. .. code-block::
  92. from paddlenlp.transformers import GPTTokenizer
  93. tokenizer = GPTTokenizer.from_pretrained('gpt2-medium-en')
  94. print(tokenizer('Welcome to use PaddlePaddle and PaddleNLP'))
  95. '''
  96. {'input_ids': [14618, 284, 779, 350, 37382, 47, 37382, 290, 350, 37382, 45, 19930],
  97. 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
  98. '''
  99. """
  100. resource_files_names = {
  101. "vocab_file": "vocab.json",
  102. "merges_file": "merges.txt",
  103. } # for save_pretrained
  104. gpt_vocab_link = (
  105. "http://bj.bcebos.com/paddlenlp/models/transformers/gpt/gpt-en-vocab.json"
  106. )
  107. gpt_merges_link = (
  108. "http://bj.bcebos.com/paddlenlp/models/transformers/gpt/gpt-en-merges.txt"
  109. )
  110. pretrained_resource_files_map = {
  111. "vocab_file": {
  112. "gpt3-175B-en": gpt_vocab_link,
  113. "gpt3-89B-en": gpt_vocab_link,
  114. "gpt3-13B-en": gpt_vocab_link,
  115. "gpt3-6.7B-en": gpt_vocab_link,
  116. "gpt3-1.3B-en": gpt_vocab_link,
  117. "gpt2-xl-en": gpt_vocab_link,
  118. "gpt2-large-en": gpt_vocab_link,
  119. "gpt2-medium-en": gpt_vocab_link,
  120. "gpt2-en": gpt_vocab_link,
  121. "gpt2-small-en": gpt_vocab_link,
  122. },
  123. "merges_file": {
  124. "gpt3-175B-en": gpt_merges_link,
  125. "gpt3-89B-en": gpt_merges_link,
  126. "gpt3-13B-en": gpt_merges_link,
  127. "gpt3-6.7B-en": gpt_merges_link,
  128. "gpt3-1.3B-en": gpt_merges_link,
  129. "gpt2-xl-en": gpt_merges_link,
  130. "gpt2-large-en": gpt_merges_link,
  131. "gpt2-medium-en": gpt_merges_link,
  132. "gpt2-en": gpt_merges_link,
  133. "gpt2-small-en": gpt_merges_link,
  134. },
  135. }
  136. pretrained_init_configuration = {
  137. "gpt3-175B-en": {},
  138. "gpt3-89B-en": {},
  139. "gpt3-13B-en": {},
  140. "gpt3-6.7B-en": {},
  141. "gpt3-1.3B-en": {},
  142. "gpt2-xl-en": {},
  143. "gpt2-large-en": {},
  144. "gpt2-medium-en": {},
  145. "gpt2-en": {},
  146. "gpt2-small-en": {},
  147. }
  148. def __init__(
  149. self,
  150. vocab_file,
  151. merges_file,
  152. errors="replace",
  153. max_len=None,
  154. pad_token="<|endoftext|>",
  155. eos_token="<|endoftext|>",
  156. unk_token="<|endoftext|>",
  157. eol_token="\u010a",
  158. add_prefix_space=False,
  159. add_bos_token=False,
  160. **kwargs # The token of newline.
  161. ):
  162. pad_token = (
  163. AddedToken(pad_token, lstrip=False, rstrip=False)
  164. if isinstance(pad_token, str)
  165. else pad_token
  166. )
  167. eos_token = (
  168. AddedToken(eos_token, lstrip=False, rstrip=False)
  169. if isinstance(eos_token, str)
  170. else eos_token
  171. )
  172. unk_token = (
  173. AddedToken(unk_token, lstrip=False, rstrip=False)
  174. if isinstance(unk_token, str)
  175. else unk_token
  176. )
  177. self.eol_token = eol_token
  178. self._build_special_tokens_map_extended(
  179. bos_token=(
  180. pad_token
  181. if getattr(self, "bos_token", None) is None
  182. else self.bos_token
  183. ),
  184. eos_token=eos_token,
  185. unk_token=unk_token,
  186. )
  187. self._vocab_file = vocab_file
  188. self._merges_file = merges_file
  189. self.max_len = max_len if max_len is not None else int(1e12)
  190. self.num_command_tokens = 2
  191. self.num_type_tokens = 2
  192. with open(vocab_file, "r", encoding="utf-8") as f:
  193. self.encoder = json.load(f)
  194. self.decoder = {v: k for k, v in self.encoder.items()}
  195. self.num_tokens = len(self.encoder)
  196. self.num_text_tokens = self.num_tokens - 1
  197. self.errors = errors # how to handle errors in decoding
  198. self.byte_encoder = bytes_to_unicode()
  199. self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
  200. with open(merges_file, encoding="utf-8") as f:
  201. bpe_data = f.read().split("\n")[1:-1]
  202. bpe_merges = [tuple(merge.split()) for merge in bpe_data]
  203. self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
  204. self.cache = {}
  205. self.add_prefix_space = add_prefix_space
  206. self.add_bos_token = add_bos_token
  207. self.pat = re.compile(
  208. r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
  209. )
  210. @property
  211. def vocab_size(self):
  212. """
  213. Returns the size of vocabulary.
  214. Returns:
  215. int: The sum of size of vocabulary and the size of speical tokens.
  216. """
  217. return len(self.encoder)
  218. @property
  219. def eol_token_id(self):
  220. if self.eol_token is None:
  221. return None
  222. return self.convert_tokens_to_ids(self.eol_token)
  223. def bpe(self, token):
  224. if token in self.cache:
  225. return self.cache[token]
  226. word = tuple(token)
  227. pairs = get_pairs(word)
  228. if not pairs:
  229. return token
  230. while True:
  231. bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
  232. if bigram not in self.bpe_ranks:
  233. break
  234. first, second = bigram
  235. new_word = []
  236. i = 0
  237. while i < len(word):
  238. try:
  239. j = word.index(first, i)
  240. new_word.extend(word[i:j])
  241. i = j
  242. except:
  243. new_word.extend(word[i:])
  244. break
  245. if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
  246. new_word.append(first + second)
  247. i += 2
  248. else:
  249. new_word.append(word[i])
  250. i += 1
  251. new_word = tuple(new_word)
  252. word = new_word
  253. if len(word) == 1:
  254. break
  255. else:
  256. pairs = get_pairs(word)
  257. word = " ".join(word)
  258. self.cache[token] = word
  259. return word
  260. def _tokenize(self, text):
  261. """Tokenize a string."""
  262. bpe_tokens = []
  263. for token in re.findall(self.pat, text):
  264. token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
  265. bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
  266. return bpe_tokens
  267. def _convert_token_to_id(self, token):
  268. return self.encoder.get(token, self.encoder.get(self.unk_token))
  269. def _convert_id_to_token(self, index):
  270. return self.decoder[index]
  271. def convert_ids_to_string(self, ids):
  272. """
  273. Converts a single index or a sequence of indices to texts.
  274. Args:
  275. ids (int|List[int]):
  276. The token id (or token ids) to be converted to text.
  277. Returns:
  278. str: The decoded text.
  279. Example:
  280. .. code-block::
  281. from paddlenlp.transformers import GPTTokenizer
  282. tokenizer = GPTTokenizer.from_pretrained('gpt2-medium-en')
  283. print(tokenizer.convert_ids_to_string(tokenizer.convert_ids_to_string([14618, 284, 779, 350, 37382, 47, 37382, 290, 350, 37382, 45, 19930]))
  284. # 'Welcome to use PaddlePaddle and PaddleNLP'
  285. """
  286. text = "".join([self.decoder[id] for id in ids])
  287. text = bytearray([self.byte_decoder[c] for c in text]).decode(
  288. "utf-8", errors=self.errors
  289. )
  290. return text
  291. def save_resources(self, save_directory):
  292. """
  293. Saves `SentencePiece <https://github.com/google/sentencepiece>`__ file
  294. (ends with '.spm') under `save_directory`.
  295. Args:
  296. save_directory (str): Directory to save files into.
  297. """
  298. for name, file_name in self.resource_files_names.items():
  299. source_path = getattr(self, "_%s" % name)
  300. save_path = os.path.join(save_directory, file_name)
  301. if os.path.abspath(source_path) != os.path.abspath(save_path):
  302. shutil.copyfile(source_path, save_path)
  303. def convert_tokens_to_string(self, tokens):
  304. """
  305. Converts a sequence of tokens (string) in a single string.
  306. """
  307. text = "".join(tokens)
  308. text = bytearray([self.byte_decoder[c] for c in text]).decode(
  309. "utf-8", errors=self.errors
  310. )
  311. return text
  312. def get_vocab(self):
  313. return dict(self.encoder, **self.added_tokens_encoder)
  314. def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
  315. add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
  316. if is_split_into_words or add_prefix_space:
  317. text = " " + text
  318. return (text, kwargs)
  319. def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
  320. if self.add_bos_token:
  321. bos_token_ids = [self.bos_token_id]
  322. else:
  323. bos_token_ids = []
  324. output = bos_token_ids + token_ids_0
  325. if token_ids_1 is None:
  326. return output
  327. return output + bos_token_ids + token_ids_1
  328. def _pad(
  329. self,
  330. encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
  331. max_length: Optional[int] = None,
  332. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  333. pad_to_multiple_of: Optional[int] = None,
  334. return_attention_mask: Optional[bool] = None,
  335. ) -> dict:
  336. """
  337. Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
  338. Args:
  339. encoded_inputs:
  340. Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
  341. max_length: maximum length of the returned list and optionally padding length (see below).
  342. Will truncate by taking into account the special tokens.
  343. padding_strategy: PaddingStrategy to use for padding.
  344. - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
  345. - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
  346. - PaddingStrategy.DO_NOT_PAD: Do not pad
  347. The tokenizer padding sides are defined in self.padding_side:
  348. - 'left': pads on the left of the sequences
  349. - 'right': pads on the right of the sequences
  350. pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
  351. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
  352. >= 7.5 (Volta).
  353. return_attention_mask:
  354. (optional) Set to False to avoid returning attention mask (default: set to model specifics)
  355. """
  356. # Load from model defaults
  357. # attention_mask shape [1,seq_len,seq_len]
  358. if (
  359. "attention_mask" in encoded_inputs
  360. and len(np.shape(encoded_inputs["attention_mask"])) > 2
  361. ):
  362. attention_mask = encoded_inputs["attention_mask"]
  363. encoded_inputs.pop("attention_mask")
  364. else:
  365. attention_mask = None
  366. required_input = encoded_inputs[self.model_input_names[0]]
  367. encoded_inputs = super()._pad(
  368. encoded_inputs,
  369. max_length,
  370. padding_strategy,
  371. pad_to_multiple_of,
  372. return_attention_mask,
  373. )
  374. if attention_mask is not None and len(np.shape(attention_mask)) > 2:
  375. encoded_inputs["attention_mask"] = attention_mask
  376. needs_to_be_padded = (
  377. padding_strategy != PaddingStrategy.DO_NOT_PAD
  378. and len(required_input) != max_length
  379. )
  380. if needs_to_be_padded:
  381. difference = max_length - len(required_input)
  382. if "attention_mask" in encoded_inputs:
  383. encoded_inputs["attention_mask"] = np.pad(
  384. encoded_inputs["attention_mask"],
  385. pad_width=[(0, 0), (difference, 0), (difference, 0)],
  386. mode="constant",
  387. constant_values=0,
  388. )
  389. return encoded_inputs