| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import json
- import os
- import unicodedata
- from functools import lru_cache
- from typing import List, Optional, Tuple
- from .....utils import logging
- from .....utils.deps import is_dep_available
- from .tokenizer_utils import PretrainedTokenizer
- from .tokenizer_utils_base import AddedToken, TextInput
- if is_dep_available("regex"):
- import regex as re
- VOCAB_FILES_NAMES = {
- "vocab_file": "vocab.json",
- "merges_file": "merges.txt",
- }
- __all__ = ["Qwen2Tokenizer", "MIXQwen2Tokenizer"]
- MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
- PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
- @lru_cache()
- def bytes_to_unicode():
- """
- Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
- characters the bpe code barfs on.
- The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
- if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
- decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
- tables between utf-8 bytes and unicode strings.
- """
- bs = (
- list(range(ord("!"), ord("~") + 1))
- + list(range(ord("¡"), ord("¬") + 1))
- + list(range(ord("®"), ord("ÿ") + 1))
- )
- cs = bs[:]
- n = 0
- for b in range(2**8):
- if b not in bs:
- bs.append(b)
- cs.append(2**8 + n)
- n += 1
- cs = [chr(n) for n in cs]
- return dict(zip(bs, cs))
- def get_pairs(word):
- """
- Return set of symbol pairs in a word.
- Word is represented as tuple of symbols (symbols being variable-length strings).
- """
- pairs = set()
- prev_char = word[0]
- for char in word[1:]:
- pairs.add((prev_char, char))
- prev_char = char
- return pairs
- class Qwen2Tokenizer(PretrainedTokenizer):
- """
- Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
- Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
- be encoded differently whether it is at the beginning of the sentence (without space) or not:
- ```python
- >>> from transformers import Qwen2Tokenizer
- >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
- >>> tokenizer("Hello world")["input_ids"]
- [9707, 1879]
- >>> tokenizer(" Hello world")["input_ids"]
- [21927, 1879]
- ```
- This is expected.
- You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
- This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
- this superclass for more information regarding those methods.
- Args:
- vocab_file (`str`):
- Path to the vocabulary file.
- merges_file (`str`):
- Path to the merges file.
- errors (`str`, *optional*, defaults to `"replace"`):
- Paradigm to follow when decoding bytes to UTF-8. See
- [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
- unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
- token instead.
- bos_token (`str`, *optional*):
- The beginning of sequence token. Not applicable for this tokenizer.
- eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
- The end of sequence token.
- pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
- The token used for padding, for example when batching sequences of different lengths.
- clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
- Whether or not the model should cleanup the spaces that were added when splitting the input text during the
- tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
- split_special_tokens (`bool`, *optional*, defaults to `False`):
- Whether or not the special tokens should be split during the tokenization process. The default behavior is
- to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
- ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
- '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
- """
- resource_files_names = VOCAB_FILES_NAMES
- model_input_names = ["input_ids", "attention_mask"]
- max_model_input_sizes = MAX_MODEL_INPUT_SIZES
- def __init__(
- self,
- vocab_file,
- merges_file,
- errors="replace",
- unk_token="<|endoftext|>",
- bos_token=None,
- eos_token="<|endoftext|>",
- pad_token="<|endoftext|>",
- clean_up_tokenization_spaces=False,
- split_special_tokens=False,
- **kwargs,
- ):
- if unk_token is None:
- logging.info(
- "The `unk_token` parameter needs to be defined: we use `eos_token` by default."
- )
- unk_token = eos_token
- # Qwen vocab does not contain control tokens; added tokens need to be special
- bos_token = (
- AddedToken(
- bos_token, lstrip=False, rstrip=False, special=True, normalized=False
- )
- if isinstance(bos_token, str)
- else bos_token
- )
- eos_token = (
- AddedToken(
- eos_token, lstrip=False, rstrip=False, special=True, normalized=False
- )
- if isinstance(eos_token, str)
- else eos_token
- )
- unk_token = (
- AddedToken(
- unk_token, lstrip=False, rstrip=False, special=True, normalized=False
- )
- if isinstance(unk_token, str)
- else unk_token
- )
- pad_token = (
- AddedToken(
- pad_token, lstrip=False, rstrip=False, special=True, normalized=False
- )
- if isinstance(pad_token, str)
- else pad_token
- )
- with open(vocab_file, encoding="utf-8") as vocab_handle:
- self.encoder = json.load(vocab_handle)
- self.decoder = {v: k for k, v in self.encoder.items()}
- self.errors = errors # how to handle errors in decoding
- self.byte_encoder = bytes_to_unicode()
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
- bpe_merges = []
- with open(merges_file, encoding="utf-8") as merges_handle:
- for i, line in enumerate(merges_handle):
- line = line.strip()
- if (i == 0 and line.startswith("#version:")) or not line:
- continue
- bpe_merges.append(tuple(line.split()))
- self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
- # NOTE: the cache can grow without bound and will get really large for long running processes
- # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
- # not a memory leak but appears as one.
- # GPT2Tokenizer has the same problem, so let's be consistent.
- self.cache = {}
- self.pat = re.compile(PRETOKENIZE_REGEX)
- self.bos_token_id = kwargs["bos_token_id"] if "bos_token_id" in kwargs else None
- self.eos_token_id = kwargs["eos_token_id"] if "eos_token_id" in kwargs else None
- self.unk_token_id = kwargs["unk_token_id"] if "unk_token_id" in kwargs else None
- self.pad_token_id = kwargs["pad_token_id"] if "pad_token_id" in kwargs else None
- super().__init__(
- errors=errors,
- bos_token=bos_token,
- eos_token=eos_token,
- pad_token=pad_token,
- unk_token=unk_token,
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
- split_special_tokens=split_special_tokens,
- **kwargs,
- )
- @property
- def vocab_size(self) -> int:
- return len(self.encoder)
- def get_vocab(self):
- return dict(self.encoder, **self.added_tokens_encoder)
- def bpe(self, token):
- if token in self.cache:
- return self.cache[token]
- word = tuple(token)
- pairs = get_pairs(word)
- if not pairs:
- return token
- while True:
- bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
- if bigram not in self.bpe_ranks:
- break
- first, second = bigram
- new_word = []
- i = 0
- while i < len(word):
- try:
- j = word.index(first, i)
- except ValueError:
- new_word.extend(word[i:])
- break
- else:
- new_word.extend(word[i:j])
- i = j
- if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
- new_word.append(first + second)
- i += 2
- else:
- new_word.append(word[i])
- i += 1
- new_word = tuple(new_word)
- word = new_word
- if len(word) == 1:
- break
- else:
- pairs = get_pairs(word)
- word = " ".join(word)
- self.cache[token] = word
- return word
- def _tokenize(self, text):
- """Tokenize a string."""
- bpe_tokens = []
- for token in re.findall(self.pat, text):
- token = "".join(
- self.byte_encoder[b] for b in token.encode("utf-8")
- ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
- bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
- return bpe_tokens
- def _convert_token_to_id(self, token):
- """Converts a token (str) in an id using the vocab."""
- return self.encoder.get(
- token, self.added_tokens_encoder.get(token, len(self.encoder))
- )
- def _convert_id_to_token(self, index):
- """Converts an index (integer) in a token (str) using the vocab."""
- return self.decoder.get(
- index, self.added_tokens_decoder.get(index, self.unk_token)
- )
- def convert_tokens_to_string(self, tokens):
- """Converts a sequence of tokens (string) in a single string."""
- text = "".join(tokens)
- text = bytearray([self.byte_decoder[c] for c in text]).decode(
- "utf-8", errors=self.errors
- )
- return text
- def _decode(
- self,
- token_ids,
- skip_special_tokens: bool = False,
- clean_up_tokenization_spaces: Optional[bool] = False,
- spaces_between_special_tokens: bool = False,
- **kwargs,
- ) -> str:
- # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
- # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
- return super()._decode(
- token_ids,
- skip_special_tokens=skip_special_tokens,
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
- spaces_between_special_tokens=spaces_between_special_tokens,
- **kwargs,
- )
- def save_vocabulary(
- self, save_directory: str, filename_prefix: Optional[str] = None
- ) -> Tuple[str]:
- vocab_file = os.path.join(
- save_directory,
- (filename_prefix + "-" if filename_prefix else "")
- + VOCAB_FILES_NAMES["vocab_file"],
- )
- merge_file = os.path.join(
- save_directory,
- (filename_prefix + "-" if filename_prefix else "")
- + VOCAB_FILES_NAMES["merges_file"],
- )
- with open(vocab_file, "w", encoding="utf-8") as f:
- f.write(
- json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False)
- + "\n"
- )
- index = 0
- with open(merge_file, "w", encoding="utf-8") as writer:
- writer.write("#version: 0.2\n")
- for bpe_tokens, token_index in sorted(
- self.bpe_ranks.items(), key=lambda kv: kv[1]
- ):
- if index != token_index:
- index = token_index
- writer.write(" ".join(bpe_tokens) + "\n")
- index += 1
- return vocab_file, merge_file
- def prepare_for_tokenization(self, text, **kwargs):
- text = unicodedata.normalize("NFC", text)
- return (text, kwargs)
- class MIXQwen2Tokenizer(Qwen2Tokenizer):
- def __init__(self, *args, **kwargs):
- super(MIXQwen2Tokenizer, self).__init__(*args, **kwargs)
- def tokenize(self, text: TextInput, **kwargs) -> List[str]:
- """
- Converts a string in a sequence of tokens, using the tokenizer.
- Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies
- (BPE/SentencePieces/WordPieces). Takes care of added tokens.
- Args:
- text (`str`):
- The sequence to be encoded.
- **kwargs (additional keyword arguments):
- Passed along to the model-specific `prepare_for_tokenization` preprocessing method.
- Returns:
- `List[str]`: The list of tokens.
- """
- split_special_tokens = kwargs.pop(
- "split_special_tokens", self.split_special_tokens
- )
- # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
- all_special_tokens_extended = dict(
- (str(t), t)
- for t in self.all_special_tokens_extended
- if isinstance(t, AddedToken)
- )
- text, kwargs = self.prepare_for_tokenization(text, **kwargs)
- # TODO: should this be in the base class?
- if hasattr(self, "do_lower_case") and self.do_lower_case:
- # convert non-special tokens to lowercase
- escaped_special_toks = [
- re.escape(s_tok)
- for s_tok in (self.unique_no_split_tokens + self.all_special_tokens)
- ]
- pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
- text = re.sub(
- pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text
- )
- if split_special_tokens:
- no_split_token = []
- tokens = [text]
- else:
- no_split_token = set(
- self.unique_no_split_tokens
- ) # don't split on any of the added tokens
- # "This is something<special_token_1> else"
- tokens = self.tokens_trie.split(text)
- # ["This is something", "<special_token_1>", " else"]
- for i, token in enumerate(tokens):
- if token in no_split_token:
- tok_extended = all_special_tokens_extended.get(token, None)
- left = tokens[i - 1] if i > 0 else None
- right = tokens[i + 1] if i < len(tokens) - 1 else None
- if isinstance(tok_extended, AddedToken):
- if tok_extended.rstrip and right:
- # A bit counter-intuitive but we strip the left of the string
- # since tok_extended.rstrip means the special token is eating all white spaces on its right
- tokens[i + 1] = right.lstrip()
- if tok_extended.lstrip and left:
- tokens[i - 1] = left.rstrip()
- tokenized_text = []
- for token in tokens:
- # Need to skip eventual empty (fully stripped) tokens
- if not token:
- continue
- if token in no_split_token:
- tokenized_text.append(token)
- else:
- tokenized_text.extend(self._tokenize(token))
- return tokenized_text
|