tokenizer_utils.py 84 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import bisect
  15. import io
  16. import itertools
  17. import json
  18. import os
  19. import re
  20. import six
  21. import inspect
  22. import unicodedata
  23. import functools
  24. from collections import OrderedDict
  25. from dataclasses import asdict, dataclass
  26. from typing import Any, Dict, List, Optional, Tuple, Union
  27. import numpy
  28. import numpy as np
  29. import lazy_paddle as paddle
  30. from jinja2 import Template
  31. from jinja2.exceptions import TemplateError, TemplateSyntaxError
  32. from jinja2.sandbox import ImmutableSandboxedEnvironment
  33. from .tokenizer_utils_base import CHAT_TEMPLATE_CONFIG_NAME
  34. from .....utils import logging
  35. from functools import lru_cache
  36. from .vocab import Vocab
  37. from .tokenizer_utils_base import (
  38. AddedToken,
  39. BatchEncoding,
  40. EncodedInput,
  41. EncodedInputPair,
  42. PaddingStrategy,
  43. PreTokenizedInput,
  44. PreTokenizedInputPair,
  45. PretrainedTokenizerBase,
  46. TensorType,
  47. TextInput,
  48. TextInputPair,
  49. TruncationStrategy,
  50. )
  51. from .utils import convert_to_dict_message, fn_args_to_dict
  52. __all__ = [
  53. "ChatTemplate",
  54. "Trie",
  55. "ChatTemplateMixin",
  56. "PretrainedTokenizer",
  57. "InitTrackerMeta",
  58. ]
  59. @dataclass
  60. class ChatTemplate:
  61. conversation: Union[List[str], None] = None
  62. system: Union[str, None] = None
  63. query: str = None
  64. @staticmethod
  65. @lru_cache()
  66. def _compile_jinja_template(chat_template) -> Template:
  67. def raise_exception(message):
  68. raise TemplateError(message)
  69. jinja_env = ImmutableSandboxedEnvironment(
  70. trim_blocks=True, lstrip_blocks=True, keep_trailing_newline=True
  71. )
  72. jinja_env.globals["raise_exception"] = raise_exception
  73. return jinja_env.from_string(chat_template)
  74. def render_conversation(
  75. self,
  76. conversation_data: Union[List[str], Dict[str, str]],
  77. index: int = 0,
  78. context_data: Dict[str, Any] = {},
  79. ) -> List[str]:
  80. """
  81. Args:
  82. conversation_data (list[str]): the conversation data which must be two parts
  83. index (int): the index of current conversation
  84. Returns:
  85. list[str]: the rendered conversation data
  86. """
  87. if self.conversation is None:
  88. raise ValueError(
  89. "The template for multi-turns is invalid, please check `conversation` filed in your chat-template."
  90. )
  91. if isinstance(conversation_data, (list, tuple)):
  92. assert (
  93. len(conversation_data) == 2
  94. ), "Each round/turn of conversation must be two participants, eg: [user-query, bot-query]"
  95. conversation_data = {
  96. "user": conversation_data[0],
  97. "bot": conversation_data[1],
  98. "index": index,
  99. }
  100. conversation_data.update(context_data)
  101. one_turn_conversation = []
  102. for conversation in self.conversation:
  103. template = self._compile_jinja_template(conversation)
  104. result = template.render(conversation_data)
  105. one_turn_conversation.append(result)
  106. return one_turn_conversation
  107. def render_query(
  108. self, query: str, index: int = 0, context_data: Dict[str, Union[int, str]] = {}
  109. ):
  110. if self.query is None:
  111. return query
  112. template = self._compile_jinja_template(self.query)
  113. return template.render(query=query, index=index, **context_data)
  114. def _init_context_data(
  115. self, context_data: Dict[str, Union[int, str]] = {}
  116. ) -> Dict[str, Union[int, str]]:
  117. """init the context data for chat-template"""
  118. context_data["is_training"] = context_data.get("is_training", False)
  119. return context_data
  120. def render_system(self, context_data: Dict[str, Union[int, str]] = {}) -> str:
  121. if self.system is None:
  122. return ""
  123. template = self._compile_jinja_template(self.system)
  124. return template.render(**context_data)
  125. def __call__(
  126. self,
  127. conversations: Union[List[List[str]], str],
  128. context_data: Dict[str, Union[int, str]] = {},
  129. ) -> str:
  130. """render the conversations by chat-template
  131. Args:
  132. conversations (list[list[str]]): the conversations of use and bot
  133. Returns:
  134. str: the result of conversation
  135. """
  136. if isinstance(conversations, str):
  137. conversations = [[conversations]]
  138. # [1 ... n-1] conversation
  139. final_query = self.render_system(context_data=context_data)
  140. context_data["length"] = len(conversations)
  141. for index, conversation in enumerate(conversations[:-1]):
  142. context_data["is_first"] = index == 0
  143. context_data["is_last"] = False
  144. final_query += "".join(
  145. self.render_conversation(
  146. conversation, index=index, context_data=context_data
  147. )
  148. )
  149. if not isinstance(conversations[-1], list) and not len(conversations[-1]) != 1:
  150. raise ValueError(
  151. "The length of last conversation must be one, eg: [[user-query, bot-answer], [user-query, bot-answer], ..., [user-query]]"
  152. )
  153. if len(conversations[-1]) > 1:
  154. logging.warning(
  155. f"The last conversation is not a single-round, chat-template will skip the conversation: {conversations[-1][1:]}"
  156. )
  157. final_query += self.render_query(
  158. conversations[-1][0],
  159. index=len(conversations) - 1,
  160. context_data=context_data,
  161. )
  162. return final_query
  163. @classmethod
  164. def from_dict(cls, config: Dict):
  165. return cls(**config)
  166. @classmethod
  167. def from_file(cls, file: str):
  168. with open(file, "r", encoding="utf-8") as f:
  169. config = json.load(f)
  170. return cls.from_dict(config)
  171. def adapt_stale_fwd_patch(self, name, value):
  172. """
  173. Since there are some monkey patches for forward of PretrainedModel, such as
  174. model compression, we make these patches compatible with the latest forward
  175. method.
  176. """
  177. if name == "forward":
  178. # NOTE(guosheng): In dygraph to static, `layer.forward` would be patched
  179. # by an instance of `StaticFunction`. And use string compare to avoid to
  180. # import fluid.
  181. if type(value).__name__.endswith(
  182. "StaticFunction"
  183. ) or self.forward.__class__.__name__.endswith("StaticFunction"):
  184. return value
  185. (
  186. patch_spec_args,
  187. patch_spec_varargs,
  188. patch_spec_varkw,
  189. patch_spec_defaults,
  190. _,
  191. _,
  192. _,
  193. ) = inspect.getfullargspec(value)
  194. (spec_args, spec_varargs, spec_varkw, spec_defaults, _, _, _) = (
  195. inspect.getfullargspec(self.forward)
  196. )
  197. new_args = [
  198. arg
  199. for arg in ("output_hidden_states", "output_attentions", "return_dict")
  200. if arg not in patch_spec_args and arg in spec_args
  201. ]
  202. if new_args:
  203. if self.__module__.startswith("paddlenlp"):
  204. logging.warning(
  205. f"The `forward` method of {self.__class__ if isinstance(self, paddle.nn.Layer) else self} is patched and the patch "
  206. "might be based on an old oversion which missing some "
  207. f"arguments compared with the latest, such as {new_args}. "
  208. "We automatically add compatibility on the patch for "
  209. "these arguemnts, and maybe the patch should be updated."
  210. )
  211. else:
  212. logging.warning(
  213. f"The `forward` method of {self.__class__ if isinstance(self, paddle.nn.Layer) else self} "
  214. "is patched and the patch might be conflict with patches made "
  215. f"by paddlenlp which seems have more arguments such as {new_args}. "
  216. "We automatically add compatibility on the patch for "
  217. "these arguemnts, and maybe the patch should be updated."
  218. )
  219. if isinstance(self, paddle.nn.Layer) and inspect.isfunction(value):
  220. @functools.wraps(value)
  221. def wrap_fwd(*args, **kwargs):
  222. for arg in new_args:
  223. kwargs.pop(arg, None)
  224. return value(self, *args, **kwargs)
  225. else:
  226. @functools.wraps(value)
  227. def wrap_fwd(*args, **kwargs):
  228. for arg in new_args:
  229. kwargs.pop(arg, None)
  230. return value(*args, **kwargs)
  231. return wrap_fwd
  232. return value
  233. # NOTE:
  234. # Modification:
  235. # class InitTrackerMeta(type(paddle.nn.Layer)) -> class InitTrackerMeta(type)
  236. # Context:
  237. # 1. In paddle 3.0rc, type(paddle.nn.Layer) == type
  238. # 2. Solve the conflict between ultra-infer and paddle
  239. class InitTrackerMeta(type):
  240. """
  241. This metaclass wraps the `__init__` method of a class to add `init_config`
  242. attribute for instances of that class, and `init_config` use a dict to track
  243. the initial configuration. If the class has `_pre_init` or `_post_init`
  244. method, it would be hooked before or after `__init__` and called as
  245. `_pre_init(self, init_fn, init_args)` or `_post_init(self, init_fn, init_args)`.
  246. Since InitTrackerMeta would be used as metaclass for pretrained model classes,
  247. which always are Layer and `type(Layer)` is not `type`, thus use `type(Layer)`
  248. rather than `type` as base class for it to avoid inheritance metaclass
  249. conflicts.
  250. """
  251. def __init__(cls, name, bases, attrs):
  252. init_func = cls.__init__
  253. # If attrs has `__init__`, wrap it using accessable `_pre_init, _post_init`.
  254. # Otherwise, no need to wrap again since the super cls has been wraped.
  255. # TODO: remove reduplicated tracker if using super cls `__init__`
  256. pre_init_func = getattr(cls, "_pre_init", None) if "__init__" in attrs else None
  257. post_init_func = (
  258. getattr(cls, "_post_init", None) if "__init__" in attrs else None
  259. )
  260. cls.__init__ = InitTrackerMeta.init_and_track_conf(
  261. init_func, pre_init_func, post_init_func
  262. )
  263. super(InitTrackerMeta, cls).__init__(name, bases, attrs)
  264. @staticmethod
  265. def init_and_track_conf(init_func, pre_init_func=None, post_init_func=None):
  266. """
  267. wraps `init_func` which is `__init__` method of a class to add `init_config`
  268. attribute for instances of that class.
  269. Args:
  270. init_func (callable): It should be the `__init__` method of a class.
  271. warning: `self` always is the class type of down-stream model, eg: BertForTokenClassification
  272. pre_init_func (callable, optional): If provided, it would be hooked after
  273. `init_func` and called as `pre_init_func(self, init_func, *init_args, **init_args)`.
  274. Default None.
  275. post_init_func (callable, optional): If provided, it would be hooked after
  276. `init_func` and called as `post_init_func(self, init_func, *init_args, **init_args)`.
  277. Default None.
  278. Returns:
  279. function: the wrapped function
  280. """
  281. @functools.wraps(init_func)
  282. def __impl__(self, *args, **kwargs):
  283. # registed helper by `pre_init_func`
  284. if pre_init_func:
  285. pre_init_func(self, init_func, *args, **kwargs)
  286. # keep full configuration
  287. init_func(self, *args, **kwargs)
  288. # registed helper by `post_init_func`
  289. if post_init_func:
  290. post_init_func(self, init_func, *args, **kwargs)
  291. self.init_config = kwargs
  292. if args:
  293. kwargs["init_args"] = args
  294. kwargs["init_class"] = self.__class__.__name__
  295. return __impl__
  296. def __setattr__(self, name, value):
  297. value = adapt_stale_fwd_patch(self, name, value)
  298. return super(InitTrackerMeta, self).__setattr__(name, value)
  299. class Trie:
  300. """
  301. Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass
  302. Loose reference https://en.wikipedia.org/wiki/Trie
  303. """
  304. def __init__(self):
  305. self.data = {}
  306. def add(self, word: str):
  307. """
  308. Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
  309. The special key `""` is used to represent termination.
  310. This function is idempotent, adding twice the same word will leave the trie unchanged
  311. Example:
  312. ```python
  313. >>> trie = Trie()
  314. >>> trie.add("Hello 友達")
  315. >>> trie.data
  316. {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
  317. >>> trie.add("Hello")
  318. >>> trie.data
  319. {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
  320. ```
  321. """
  322. if not word:
  323. # Prevent empty string
  324. return
  325. ref = self.data
  326. for char in word:
  327. ref[char] = char in ref and ref[char] or {}
  328. ref = ref[char]
  329. ref[""] = 1
  330. def split(self, text: str) -> List[str]:
  331. """
  332. Will look for the words added to the trie within `text`. Output is the original string splitted along the
  333. boundaries of the words found.
  334. This trie will match the longest possible word first !
  335. Example:
  336. ```python
  337. >>> trie = Trie()
  338. >>> trie.split("[CLS] This is a extra_id_100")
  339. ["[CLS] This is a extra_id_100"]
  340. >>> trie.add("[CLS]")
  341. >>> trie.add("extra_id_1")
  342. >>> trie.add("extra_id_100")
  343. >>> trie.split("[CLS] This is a extra_id_100")
  344. ["[CLS]", " This is a ", "extra_id_100"]
  345. ```
  346. """
  347. # indexes are counted left of the chars index.
  348. # "hello", index 0, is left of h, index 1 is between h and e.
  349. # index 5 is right of the "o".
  350. # States are going to capture every possible start (indexes as above)
  351. # as keys, and have as values, a pointer to the position in the trie
  352. # where we're at. This is a partial match for now.
  353. # This enables to keep track of multiple matches while we're iterating
  354. # the string
  355. # If the trie contains, "blowing", and "lower" and we encounter the
  356. # string "blower", we need to split into ["b", "lower"].
  357. # This is where we need to keep track of multiple possible starts.
  358. states = OrderedDict()
  359. # This will contain every indices where we need
  360. # to cut.
  361. # We force to cut at offset 0 and len(text) (added later)
  362. offsets = [0]
  363. # This is used by the lookahead which needs to skip over
  364. # some text where the full match exceeded the place in the initial
  365. # for loop
  366. skip = 0
  367. # Main loop, Giving this algorithm O(n) complexity
  368. for current, current_char in enumerate(text):
  369. if skip and current < skip:
  370. # Prevents the lookahead for matching twice
  371. # like extra_id_100 and id_100
  372. continue
  373. # This will track every state
  374. # that stop matching, we need to stop tracking them.
  375. # If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then
  376. # fail on "b", we need to remove 0 from the valid states.
  377. to_remove = set()
  378. # Whenever we found a match, we need to drop everything
  379. # this is a greedy algorithm, it will match on the first found token
  380. reset = False
  381. # In this case, we already have partial matches (But unfinished)
  382. for start, trie_pointer in states.items():
  383. if "" in trie_pointer:
  384. # This is a final match, we need to reset and
  385. # store the results in `offsets`.
  386. # Lookahead to match longest first
  387. # Important in case of extra_id_1 vs extra_id_100
  388. # Here we are also actively looking for other earlier partial
  389. # matches
  390. # "[CLS]", "L", we need to match CLS even if L is special
  391. for lookstart, looktrie_pointer in states.items():
  392. if lookstart > start:
  393. # This partial match is later, we can stop looking
  394. break
  395. elif lookstart < start:
  396. # This partial match is earlier, the trie pointer
  397. # was already updated, so index is + 1
  398. lookahead_index = current + 1
  399. end = current + 1
  400. else:
  401. # Here lookstart == start and
  402. # looktrie_pointer == trie_pointer
  403. # It wasn't updated yet so indices are current ones
  404. lookahead_index = current
  405. end = current
  406. next_char = (
  407. text[lookahead_index]
  408. if lookahead_index < len(text)
  409. else None
  410. )
  411. if "" in looktrie_pointer:
  412. start = lookstart
  413. end = lookahead_index
  414. skip = lookahead_index
  415. while next_char in looktrie_pointer:
  416. looktrie_pointer = looktrie_pointer[next_char]
  417. lookahead_index += 1
  418. if "" in looktrie_pointer:
  419. start = lookstart
  420. end = lookahead_index
  421. skip = lookahead_index
  422. if lookahead_index == len(text):
  423. # End of string
  424. break
  425. next_char = text[lookahead_index]
  426. # End lookahead
  427. # Storing and resetting
  428. offsets.append(start)
  429. offsets.append(end)
  430. reset = True
  431. break
  432. elif current_char in trie_pointer:
  433. # The current character being looked at has a match within the trie
  434. # update the pointer (it will be stored back into states later).
  435. trie_pointer = trie_pointer[current_char]
  436. # Storing back the new pointer into the states.
  437. # Partial matches got longer by one.
  438. states[start] = trie_pointer
  439. else:
  440. # The new character has not match in the trie, we need
  441. # to stop keeping track of this partial match.
  442. # We can't do it directly within the loop because of how
  443. # python iteration works
  444. to_remove.add(start)
  445. # Either clearing the full start (we found a real match)
  446. # Or clearing only the partial matches that didn't work.
  447. if reset:
  448. states = {}
  449. else:
  450. for start in to_remove:
  451. del states[start]
  452. # If this character is a starting character within the trie
  453. # start keeping track of this partial match.
  454. if current >= skip and current_char in self.data:
  455. states[current] = self.data[current_char]
  456. # We have a cut at the end with states.
  457. for start, trie_pointer in states.items():
  458. if "" in trie_pointer:
  459. # This is a final match, we need to reset and
  460. # store the results in `offsets`.
  461. end = len(text)
  462. offsets.append(start)
  463. offsets.append(end)
  464. # Longest cut is always the one with lower start so the first
  465. # item so we need to break.
  466. break
  467. return self.cut_text(text, offsets)
  468. def cut_text(self, text, offsets):
  469. # We have all the offsets now, we just need to do the actual splitting.
  470. # We need to eventually add the first part of the string and the eventual
  471. # last part.
  472. offsets.append(len(text))
  473. tokens = []
  474. start = 0
  475. for end in offsets:
  476. if start > end:
  477. logging.error(
  478. "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway."
  479. )
  480. continue
  481. elif start == end:
  482. # This might happen if there's a match at index 0
  483. # we're also preventing zero-width cuts in case of two
  484. # consecutive matches
  485. continue
  486. tokens.append(text[start:end])
  487. start = end
  488. return tokens
  489. def _insert_one_token_to_ordered_list(token_list: List[str], new_token: str):
  490. """
  491. Inserts one token to an ordered list if it does not already exist. Note: token_list must be sorted.
  492. """
  493. insertion_idx = bisect.bisect_left(token_list, new_token)
  494. # Checks if new_token is already in the ordered token_list
  495. if insertion_idx < len(token_list) and token_list[insertion_idx] == new_token:
  496. # new_token is in token_list, don't add
  497. return
  498. else:
  499. token_list.insert(insertion_idx, new_token)
  500. def _is_control(char):
  501. """Checks whether `chars` is a control character."""
  502. # These are technically control characters but we count them as whitespace
  503. # characters.
  504. if char == "\t" or char == "\n" or char == "\r":
  505. return False
  506. cat = unicodedata.category(char)
  507. if cat.startswith("C"):
  508. return True
  509. return False
  510. def _is_nonnormalized_char(char):
  511. """Check whther `chars` is a non-normalized character."""
  512. cp = ord(char)
  513. if (
  514. (0xFF00 <= cp <= 0xFFEF)
  515. or (0xFE50 <= cp <= 0xFE6B) # Halfwidth and Fullwidth Forms
  516. or (0x3358 <= cp <= 0x33FF) # Small Form Variants
  517. or (0x249C <= cp <= 0x24E9) # CJK Compatibility
  518. or (0x3200 <= cp <= 0x32FF) # Enclosed Alphanumerics: Ⓛ ⒰
  519. ): # Enclosed CJK Letters and Months
  520. return True
  521. return False
  522. def _is_nonnormalized_numeric(char):
  523. """Check whether `chars` is a non-normalized numeric character."""
  524. cp = ord(char)
  525. if (
  526. (0x2460 <= cp <= 0x249B)
  527. or (0x24EA <= cp <= 0x24FF) #
  528. or (0x2776 <= cp <= 0x2793) #
  529. or (0x2160 <= cp <= 0x217F) # Enclosed Alphanumerics
  530. ): # Number Forms
  531. return True
  532. return False
  533. def normalize_chars(text):
  534. """
  535. Normalize the text for multiligual and chinese models. Unicode range:
  536. https://www.ling.upenn.edu/courses/Spring_2003/ling538/UnicodeRanges.html
  537. """
  538. output = []
  539. for char in text:
  540. if _is_nonnormalized_char(char):
  541. for c in unicodedata.normalize("NFKC", char):
  542. output.append(c)
  543. elif _is_nonnormalized_numeric(char):
  544. output.append(" ")
  545. for c in str(int(unicodedata.numeric(char))):
  546. output.append(c)
  547. output.append(" ")
  548. elif ord(char) == 0xF979: # https://www.zhihu.com/question/20697984
  549. output.append("凉")
  550. else:
  551. output.append(char)
  552. return "".join(output)
  553. class ChatTemplateMixin:
  554. chat_template: Optional[ChatTemplate] = None
  555. def apply_chat_template(
  556. self,
  557. conversation: Union[Dict[str, str], str],
  558. tokenize: bool = True,
  559. context_data: Dict[str, Any] = {},
  560. **tokenizer_kwargs,
  561. ) -> Union[str, Dict[str, Union["numpy.ndarray", "paddle.Tensor"]]]:
  562. """apply chat_template rules to conversation which should not be batched data
  563. Args:
  564. conversation (List[List[str, str]] | str): the conversation messages between user and bot
  565. context_data (Dict[str, Any]): the context data for chat_template.json
  566. tokenize (bool, optional): whether do tokenization. Defaults to True.
  567. Returns:
  568. str | dict[str, Union["numpy.ndarray", "paddle.Tensor"]]: return the result of applied data
  569. """
  570. if not self.chat_template:
  571. raise ValueError(
  572. "chat_template is not set, please set chat_template first."
  573. )
  574. elif isinstance(self.chat_template, Template):
  575. add_generation_prompt = tokenizer_kwargs.pop("add_generation_prompt", True)
  576. query = self._apply_chat_template(
  577. conversation, add_generation_prompt=add_generation_prompt
  578. )
  579. elif isinstance(self.chat_template, ChatTemplate):
  580. query = self._apply_chat_template_paddle(conversation, context_data)
  581. if not tokenize:
  582. return query
  583. # chat_template should not add special tokens
  584. tokenizer_kwargs["add_special_tokens"] = False
  585. return self(query, **tokenizer_kwargs)
  586. def _apply_chat_template_paddle(
  587. self,
  588. conversation: Union[List[Dict[str, str]], str],
  589. context_data: Dict[str, Any] = {},
  590. ) -> Union[str, Dict[str, Union["numpy.ndarray", "paddle.Tensor"]]]:
  591. context_data = self.chat_template._init_context_data(context_data)
  592. if isinstance(conversation, str):
  593. conversation = [[conversation]]
  594. elif isinstance(conversation, list) and isinstance(conversation[0], str):
  595. raise ValueError(
  596. "apply_chat_template do not support appling batch conversations, "
  597. "so you should apply the conversation one by one."
  598. )
  599. query = self.chat_template(conversation, context_data=context_data)
  600. return query
  601. def _apply_chat_template(
  602. self,
  603. conversation: Union[Dict[str, str], str],
  604. add_generation_prompt=True,
  605. ) -> Union[str, Dict[str, Union["numpy.ndarray", "paddle.Tensor"]]]:
  606. if isinstance(conversation, str):
  607. conversations = [{"role": "user", "content": conversation}]
  608. elif isinstance(conversation, list):
  609. assert len(conversation) > 0, "empty conversation is not allowed"
  610. if isinstance(conversation[0], list):
  611. conversations = convert_to_dict_message(conversation)
  612. elif isinstance(conversation[0], dict):
  613. conversations = conversation
  614. else:
  615. raise ValueError(
  616. "apply_chat_template do not support appling batch conversations, "
  617. "so you should apply the conversation one by one."
  618. )
  619. query = self.chat_template.render(
  620. messages=conversations,
  621. **self.special_tokens_map,
  622. add_generation_prompt=add_generation_prompt,
  623. )
  624. return query
  625. def encode_chat_inputs(
  626. self,
  627. conversations: List[Dict[str, str]],
  628. context_data: Dict[str, Any] = {},
  629. **kwargs,
  630. ):
  631. """Encodes conversation to pairs of token ids.
  632. Turn 0: bos + system + sep + user bot + eos
  633. Turn t: sep + bot + query bot + eos
  634. Args:
  635. conversation (List[Dict[str, str]]): the conversation of data
  636. context_data (Dict[str, Any]): the context data of conversation
  637. Returns:
  638. List[list[int], list[int]]: the pair of input_ids and target_ids
  639. """
  640. if not self.chat_template:
  641. raise ValueError(
  642. "chat_template is not set, please set chat_template first."
  643. )
  644. elif isinstance(self.chat_template, Template):
  645. add_generation_prompt = kwargs.pop("add_generation_prompt", True)
  646. query = self._encode_chat_inputs(
  647. conversations, context_data, add_generation_prompt=add_generation_prompt
  648. )
  649. elif isinstance(self.chat_template, ChatTemplate):
  650. query = self._encode_chat_inputs_paddle(conversations, context_data)
  651. return query
  652. def _encode_chat_inputs_paddle(
  653. self, conversations: List[Dict[str, str]], context_data: Dict[str, Any] = {}
  654. ):
  655. context_data = self.chat_template._init_context_data(context_data)
  656. # encode system
  657. result = {}
  658. if self.chat_template.system:
  659. system = self.chat_template.render_system(context_data)
  660. result["system"] = self.encode(system, add_special_tokens=False)[
  661. "input_ids"
  662. ]
  663. # encode conversation
  664. conversation_ids = []
  665. for index, conversation in enumerate(conversations):
  666. # give more control to chat_template
  667. context_data["is_first"] = index == 0
  668. context_data["is_last"] = index == len(conversations) - 1
  669. user_input, bot_output = self.chat_template.render_conversation(
  670. conversation, index=index, context_data=context_data
  671. )
  672. user_ids = self.encode(user_input, add_special_tokens=False)["input_ids"]
  673. bot_ids = self.encode(bot_output, add_special_tokens=False)["input_ids"]
  674. conversation_ids.append([user_ids, bot_ids])
  675. result["conversations"] = conversation_ids
  676. return result
  677. def _encode_chat_inputs(
  678. self,
  679. conversations: List[Dict[str, str]],
  680. context_data: Dict[str, Any] = {},
  681. system: str = None,
  682. add_generation_prompt=True,
  683. ):
  684. result = {}
  685. # Some template do not support system msg, so we need to check it first.
  686. if system:
  687. try:
  688. self.chat_template.render(
  689. messages={"role": "system", "content": system}
  690. )
  691. except Exception as e:
  692. raise ValueError("System is not supported in this tokenizer.", e)
  693. # convert list msg to role dict msg
  694. conversation_dict = []
  695. origin_msg = []
  696. for round in conversations:
  697. round_role = [
  698. {"role": "user", "content": round[0]},
  699. {"role": "assistant", "content": round[1]},
  700. ]
  701. origin_msg.extend(round_role)
  702. conversation_dict.append(round_role)
  703. ans = []
  704. # get answer in single round, then compile the chat entirely and split by single round ans
  705. # attention: answer should include end token!
  706. for conv in conversation_dict:
  707. roundi = [system] + conv if system else conv
  708. roundi_str = self.chat_template.render(
  709. messages=roundi, add_generation_prompt=False, **self.special_tokens_map
  710. )
  711. roundi_no_ans = [system] + [conv[0]] if system else [conv[0]]
  712. roundi_no_ans_str = self.chat_template.render(
  713. messages=roundi_no_ans,
  714. add_generation_prompt=add_generation_prompt,
  715. **self.special_tokens_map,
  716. )
  717. ans_roundi = roundi_str[len(roundi_no_ans_str) :]
  718. ans.append(ans_roundi)
  719. non_learnable_parts = self._extract_non_learnable_parts(origin_msg, ans)
  720. assert len(non_learnable_parts) == len(ans)
  721. conversation_ids = []
  722. for i in range(len(non_learnable_parts)):
  723. conversation_ids.append(
  724. self.batch_encode(
  725. [non_learnable_parts[i], ans[i]],
  726. add_special_tokens=False,
  727. padding=False,
  728. )["input_ids"]
  729. )
  730. result["conversations"] = conversation_ids
  731. return result
  732. def _extract_non_learnable_parts(
  733. self, origin_msg: List[Dict[str, str]], split_s: List[str]
  734. ):
  735. """Split the entire chat by specified words. Extract the non-learnable parts."""
  736. # distingish and replace the special words in original string to an uncompiled form: Like | -> \|
  737. regex_pattern = "|".join(map(re.escape, split_s))
  738. # splited by replaced specified words
  739. non_learnable_parts = re.split(
  740. r"(?:%s)" % regex_pattern,
  741. self.chat_template.render(
  742. messages=origin_msg,
  743. add_generation_prompt=False,
  744. **self.special_tokens_map,
  745. ),
  746. )
  747. if non_learnable_parts[-1] == "":
  748. non_learnable_parts.pop()
  749. return non_learnable_parts
  750. @classmethod
  751. def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
  752. cache_dir = kwargs.pop("cache_dir", None)
  753. from_hf_hub = kwargs.pop("from_hf_hub", False)
  754. from_aistudio = kwargs.pop("from_aistudio", False)
  755. subfolder = kwargs.pop("subfolder", "")
  756. if subfolder is None:
  757. subfolder = ""
  758. kwargs["subfolder"] = subfolder
  759. kwargs["cache_dir"] = cache_dir
  760. kwargs["from_hf_hub"] = from_hf_hub
  761. kwargs["from_aistudio"] = from_aistudio
  762. kwargs["return_tokenizer_file_dir"] = True
  763. tokenizer, tokenizer_config_file_dir = super().from_pretrained(
  764. pretrained_model_name_or_path, *args, **kwargs
  765. )
  766. # load chat-template
  767. chat_template_file = os.path.join(
  768. tokenizer_config_file_dir, CHAT_TEMPLATE_CONFIG_NAME
  769. )
  770. if not os.path.exists(chat_template_file):
  771. return tokenizer
  772. if tokenizer.chat_template is not None:
  773. logging.warning(
  774. "Chat-template already exists in config file, it will be overwritten by chat_template.json file."
  775. )
  776. logging.warning(
  777. "`chat_template.json` will be deprecated in the future! Please set it in `tokenizer_config.json`."
  778. )
  779. tokenizer.init_chat_template(chat_template_file)
  780. return tokenizer
  781. def init_chat_template(self, chat_template: Union[str, Dict]):
  782. """init chat_tempalte by file_path or template dict data
  783. Args:
  784. chat_template (str | dict): file_path or template dict data
  785. """
  786. if isinstance(chat_template, str):
  787. if not os.path.exists(chat_template):
  788. try:
  789. self.chat_template: Template = ChatTemplate._compile_jinja_template(
  790. chat_template
  791. )
  792. except TemplateSyntaxError:
  793. # It is neither jinjia string nor path string
  794. raise TemplateSyntaxError(
  795. "The chat-template in json is not valid jinja string: {}".format(
  796. chat_template
  797. ),
  798. lineno=0, # fake lineno, useless required msg
  799. )
  800. else:
  801. self.chat_template = ChatTemplate.from_file(chat_template)
  802. elif isinstance(chat_template, dict):
  803. self.chat_template = ChatTemplate.from_dict(chat_template)
  804. elif isinstance(chat_template, ChatTemplate):
  805. self.chat_template = chat_template
  806. else:
  807. raise ValueError("Receive error chat_template data: ", chat_template)
  808. def save_resources(self, save_directory):
  809. super().save_resources(save_directory)
  810. if isinstance(
  811. self.chat_template, ChatTemplate
  812. ): # Future remove if ChatTemplate is deprecated
  813. chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_CONFIG_NAME)
  814. with open(chat_template_file, "w", encoding="utf-8") as f:
  815. json.dump(asdict(self.chat_template), f, ensure_ascii=False, indent=4)
  816. logging.info("Chat-template config file saved in " + chat_template_file)
  817. @six.add_metaclass(InitTrackerMeta)
  818. class PretrainedTokenizer(ChatTemplateMixin, PretrainedTokenizerBase):
  819. """
  820. Base class for all tokenizers.
  821. Inherits from [`~tokenizer_utils_base.PretrainedTokenizerBase`].
  822. Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading
  823. pretrained tokenizers as well as adding tokens to the vocabulary.
  824. This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the
  825. specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
  826. - **resource_files_names** (`Dict[str, str]`) -- A dictionary with, as keys, the `__init__` keyword name of each
  827. vocabulary file required by the model, and as associated values, the filename for saving the associated file
  828. (string).
  829. - **pretrained_resource_files_map** (`Dict[str, Dict[str, str]]`) -- A dictionary of dictionaries, with the
  830. high-level keys being the `__init__` keyword name of each vocabulary file required by the model, the
  831. low-level being the `short-cut-names` of the pretrained models with, as associated values, the `url` to the
  832. associated pretrained vocabulary file.
  833. - **max_model_input_sizes** (`Dict[str, Optional[int]]`) -- A dictionary with, as keys, the `short-cut-names`
  834. of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model,
  835. or `None` if the model has no maximum input size.
  836. - **pretrained_init_configuration** (`Dict[str, Dict[str, Any]]`) -- A dictionary with, as keys, the
  837. `short-cut-names` of the pretrained models, and as associated values, a dictionary of specific arguments to
  838. pass to the `__init__` method of the tokenizer class for this pretrained model when loading the tokenizer
  839. with the [`~tokenizer_utils_base.PretrainedTokenizerBase.from_pretrained`] method.
  840. - **model_input_names** (`List[str]`) -- A list of inputs expected in the forward pass of the model.
  841. - **padding_side** (`str`) -- The default value for the side on which the model should have padding applied.
  842. Should be `'right'` or `'left'`.
  843. - **truncation_side** (`str`) -- The default value for the side on which the model should have truncation
  844. applied. Should be `'right'` or `'left'`.
  845. Moreover, methods common to tokenizers for tokenization, token/id conversion
  846. and encoding as model inputs are also provided here.
  847. Besides, metaclass `InitTrackerMeta` is used to create `PretrainedTokenizer`,
  848. by which subclasses can track arguments for initialization automatically
  849. and expose special tokens initialization used as attributes.
  850. """
  851. added_tokens_encoder: Dict[str, int] = {}
  852. added_tokens_decoder: Dict[int, str] = {}
  853. unique_no_split_tokens: List[str] = []
  854. tokens_trie = Trie()
  855. _decode_use_source_tokenizer = False
  856. def _pre_init(self, original_init, *args, **kwargs):
  857. """
  858. It would be hooked before `__init__` to add specials tokens (arguments of
  859. `__init__` whose name ends with `_token`) as attributes of the tokenizer
  860. instance.
  861. """
  862. init_dict = fn_args_to_dict(original_init, *((self,) + args), **kwargs)
  863. init_dict.pop("self", None)
  864. super(PretrainedTokenizer, self).__init__(**init_dict)
  865. self.added_tokens_encoder: Dict[str, int] = {}
  866. self.added_tokens_decoder: Dict[int, str] = {}
  867. self.unique_no_split_tokens: List[str] = []
  868. self.tokens_trie = Trie()
  869. self._decode_use_source_tokenizer = False
  870. def _build_special_tokens_map_extended(self, **kwargs):
  871. for key, value in kwargs.items():
  872. if value is None:
  873. continue
  874. if key in self.SPECIAL_TOKENS_ATTRIBUTES:
  875. if key == "additional_special_tokens":
  876. assert isinstance(
  877. value, (list, tuple)
  878. ), f"Value {value} is not a list or tuple"
  879. assert all(
  880. isinstance(t, (str, AddedToken)) for t in value
  881. ), "One of the tokens is not a string or an AddedToken"
  882. setattr(self, key, value)
  883. elif isinstance(value, (str, AddedToken)):
  884. setattr(self, key, value)
  885. else:
  886. raise TypeError(
  887. f"special token {key} has to be either str or AddedToken but got: {type(value)}"
  888. )
  889. @property
  890. def vocab_size(self) -> int:
  891. """
  892. `int`: Size of the base vocabulary (without the added tokens).
  893. """
  894. raise NotImplementedError
  895. @property
  896. def is_fast(self) -> bool:
  897. return False
  898. def get_added_vocab(self) -> Dict[str, int]:
  899. """
  900. Returns the added tokens in the vocabulary as a dictionary of token to index.
  901. Returns:
  902. `Dict[str, int]`: The added tokens.
  903. """
  904. return self.added_tokens_encoder
  905. def __len__(self):
  906. """
  907. Size of the full vocabulary with the added tokens.
  908. """
  909. return self.vocab_size + len(self.added_tokens_encoder)
  910. def _add_tokens(
  911. self,
  912. new_tokens: Union[List[str], List[AddedToken]],
  913. special_tokens: bool = False,
  914. ) -> int:
  915. """
  916. Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
  917. it with indices starting from length of the current vocabulary.
  918. Args:
  919. new_tokens (`List[str]`or `List[AddedToken]`):
  920. Token(s) to add in vocabulary. A token is only added if it's not already in the vocabulary (tested by
  921. checking if the tokenizer assign the index of the `unk_token` to them).
  922. special_tokens (`bool`, *optional*, defaults to `False`):
  923. Whether or not the tokens should be added as special tokens.
  924. Returns:
  925. `int`: The number of tokens actually added to the vocabulary.
  926. Examples:
  927. ```python
  928. # Let's see how to increase the vocabulary of Bert model and tokenizer
  929. tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
  930. model = BertModel.from_pretrained("bert-base-uncased")
  931. num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"])
  932. print("We have added", num_added_toks, "tokens")
  933. ```"""
  934. new_tokens = [str(tok) for tok in new_tokens]
  935. tokens_to_add = []
  936. for token in new_tokens:
  937. if not isinstance(token, str):
  938. raise TypeError(f"Token {token} is not a string but a {type(token)}.")
  939. if (
  940. not special_tokens
  941. and hasattr(self, "do_lower_case")
  942. and self.do_lower_case
  943. ):
  944. token = token.lower()
  945. if (
  946. token != self.unk_token
  947. and self.convert_tokens_to_ids(token)
  948. == self.convert_tokens_to_ids(self.unk_token)
  949. and token not in tokens_to_add
  950. ):
  951. tokens_to_add.append(token)
  952. if self.verbose:
  953. logging.info(f"Adding {token} to the vocabulary")
  954. added_tok_encoder = dict(
  955. (tok, len(self) + i) for i, tok in enumerate(tokens_to_add)
  956. )
  957. added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
  958. self.added_tokens_encoder.update(added_tok_encoder)
  959. self.added_tokens_decoder.update(added_tok_decoder)
  960. # Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)
  961. if special_tokens:
  962. if len(new_tokens) == 1:
  963. _insert_one_token_to_ordered_list(
  964. self.unique_no_split_tokens, new_tokens[0]
  965. )
  966. else:
  967. self.unique_no_split_tokens = sorted(
  968. set(self.unique_no_split_tokens).union(set(new_tokens))
  969. )
  970. else:
  971. # Or on the newly added tokens
  972. if len(tokens_to_add) == 1:
  973. _insert_one_token_to_ordered_list(
  974. self.unique_no_split_tokens, tokens_to_add[0]
  975. )
  976. else:
  977. self.unique_no_split_tokens = sorted(
  978. set(self.unique_no_split_tokens).union(set(tokens_to_add))
  979. )
  980. self._create_trie(self.unique_no_split_tokens)
  981. return len(tokens_to_add)
  982. def _create_trie(self, unique_no_split_tokens):
  983. trie = Trie()
  984. for token in unique_no_split_tokens:
  985. if (
  986. hasattr(self, "do_lower_case")
  987. and self.do_lower_case
  988. and token not in self.all_special_tokens
  989. ):
  990. trie.add(token.lower())
  991. else:
  992. trie.add(token)
  993. self.tokens_trie = trie
  994. def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
  995. """
  996. Performs any necessary transformations before tokenization.
  997. This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the
  998. `kwargs` at the end of the encoding process to be sure all the arguments have been used.
  999. Args:
  1000. text (`str`):
  1001. The text to prepare.
  1002. is_split_into_words (`bool`, *optional*, defaults to `False`):
  1003. Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
  1004. tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
  1005. which it will tokenize. This is useful for NER or token classification.
  1006. kwargs:
  1007. Keyword arguments to use for the tokenization.
  1008. Returns:
  1009. `Tuple[str, Dict[str, Any]]`: The prepared text and the unused kwargs.
  1010. """
  1011. return (text, kwargs)
  1012. def tokenize(self, text: TextInput, **kwargs) -> List[str]:
  1013. """
  1014. Converts a string in a sequence of tokens, using the tokenizer.
  1015. Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies
  1016. (BPE/SentencePieces/WordPieces). Takes care of added tokens.
  1017. Args:
  1018. text (`str`):
  1019. The sequence to be encoded.
  1020. **kwargs (additional keyword arguments):
  1021. Passed along to the model-specific `prepare_for_tokenization` preprocessing method.
  1022. Returns:
  1023. `List[str]`: The list of tokens.
  1024. """
  1025. # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
  1026. all_special_tokens_extended = dict(
  1027. (str(t), t)
  1028. for t in self.all_special_tokens_extended
  1029. if isinstance(t, AddedToken)
  1030. )
  1031. text, kwargs = self.prepare_for_tokenization(text, **kwargs)
  1032. # TODO: should this be in the base class?
  1033. if hasattr(self, "do_lower_case") and self.do_lower_case:
  1034. # convert non-special tokens to lowercase
  1035. escaped_special_toks = [
  1036. re.escape(s_tok)
  1037. for s_tok in (self.unique_no_split_tokens + self.all_special_tokens)
  1038. ]
  1039. pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
  1040. text = re.sub(
  1041. pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text
  1042. )
  1043. no_split_token = set(self.unique_no_split_tokens)
  1044. tokens = self.tokens_trie.split(text)
  1045. # ["This is something", "<special_token_1>", " else"]
  1046. for i, token in enumerate(tokens):
  1047. if token in no_split_token:
  1048. tok_extended = all_special_tokens_extended.get(token, None)
  1049. left = tokens[i - 1] if i > 0 else None
  1050. right = tokens[i + 1] if i < len(tokens) - 1 else None
  1051. if isinstance(tok_extended, AddedToken):
  1052. if tok_extended.rstrip and right:
  1053. # A bit counter-intuitive but we strip the left of the string
  1054. # since tok_extended.rstrip means the special token is eating all white spaces on its right
  1055. tokens[i + 1] = right.lstrip()
  1056. # Strip white spaces on the left
  1057. if tok_extended.lstrip and left:
  1058. tokens[i - 1] = left.rstrip() # Opposite here
  1059. else:
  1060. # We strip left and right by default
  1061. if right:
  1062. tokens[i + 1] = right.lstrip()
  1063. if left:
  1064. tokens[i - 1] = left.rstrip()
  1065. # ["This is something", "<special_token_1>", "else"]
  1066. tokenized_text = []
  1067. for token in tokens:
  1068. # Need to skip eventual empty (fully stripped) tokens
  1069. if not token:
  1070. continue
  1071. if token in no_split_token:
  1072. tokenized_text.append(token)
  1073. else:
  1074. tokenized_text.extend(self._tokenize(token))
  1075. # ["This", " is", " something", "<special_token_1>", "else"]
  1076. return tokenized_text
  1077. def _tokenize(self, text, **kwargs):
  1078. """
  1079. Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
  1080. vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
  1081. Do NOT take care of added tokens.
  1082. """
  1083. raise NotImplementedError
  1084. def convert_tokens_to_ids(self, tokens):
  1085. if tokens is None:
  1086. return None
  1087. if isinstance(tokens, str):
  1088. return self._convert_token_to_id_with_added_voc(tokens)
  1089. ids = []
  1090. for token in tokens:
  1091. ids.append(self._convert_token_to_id_with_added_voc(token))
  1092. return ids
  1093. def _convert_token_to_id_with_added_voc(self, token):
  1094. if token is None:
  1095. return None
  1096. if token in self.added_tokens_encoder:
  1097. return self.added_tokens_encoder[token]
  1098. return self._convert_token_to_id(token)
  1099. def _convert_token_to_id(self, token):
  1100. return self.vocab.to_indices(token)
  1101. def convert_tokens_to_string(self, tokens):
  1102. """
  1103. Converts a sequence of tokens (list of string) to a single string by
  1104. using ``' '.join(tokens)`` .
  1105. Args:
  1106. tokens (list[str]): A sequence of tokens.
  1107. Returns:
  1108. str: Converted string.
  1109. """
  1110. return " ".join(tokens)
  1111. def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
  1112. if isinstance(ids, int):
  1113. if ids in self.added_tokens_decoder:
  1114. return self.added_tokens_decoder[ids]
  1115. else:
  1116. return self._convert_id_to_token(ids)
  1117. tokens = []
  1118. for index in ids:
  1119. index = int(index)
  1120. if skip_special_tokens and index in self.all_special_ids:
  1121. continue
  1122. if index in self.added_tokens_decoder:
  1123. tokens.append(self.added_tokens_decoder[index])
  1124. else:
  1125. tokens.append(self._convert_id_to_token(index))
  1126. return tokens
  1127. def _convert_id_to_token(self, index):
  1128. return self.vocab.to_tokens(index)
  1129. @staticmethod
  1130. def load_vocabulary(
  1131. filepath,
  1132. unk_token=None,
  1133. pad_token=None,
  1134. bos_token=None,
  1135. eos_token=None,
  1136. **kwargs,
  1137. ):
  1138. """
  1139. Instantiate an instance of `Vocab` from a file reserving all tokens
  1140. by using `Vocab.from_dict`. The file contains a token per line, and the
  1141. line number would be the index of corresponding token.
  1142. Args:
  1143. filepath (str): path of file to construct vocabulary.
  1144. unk_token (str): special token for unknown token. If no need, it also
  1145. could be `None`. Defaults to `None`.
  1146. pad_token (str): special token for padding token. If no need, it also
  1147. could be `None`. Defaults to `None`.
  1148. bos_token (str): special token for bos token. If no need, it also
  1149. could be `None`. Defaults to `None`.
  1150. eos_token (str): special token for eos token. If no need, it also
  1151. could be `None`. Defaults to `None`.
  1152. **kwargs (dict): keyword arguments for `Vocab.from_dict`.
  1153. Returns:
  1154. Vocab: An instance of `Vocab`.
  1155. """
  1156. token_to_idx = {}
  1157. with io.open(filepath, "r", encoding="utf-8") as f:
  1158. for index, line in enumerate(f):
  1159. token = line.rstrip("\n")
  1160. token_to_idx[token] = int(index)
  1161. vocab = Vocab.from_dict(
  1162. token_to_idx,
  1163. unk_token=unk_token,
  1164. pad_token=pad_token,
  1165. bos_token=bos_token,
  1166. eos_token=eos_token,
  1167. **kwargs,
  1168. )
  1169. return vocab
  1170. @staticmethod
  1171. def save_vocabulary(filepath, vocab):
  1172. """
  1173. Save all tokens to a vocabulary file. The file contains a token per line,
  1174. and the line number would be the index of corresponding token.
  1175. Args:
  1176. filepath (str): File path to be saved to.
  1177. vocab (Vocab|dict): The `Vocab` or `dict` instance to be saved.
  1178. """
  1179. if isinstance(vocab, Vocab):
  1180. tokens = vocab.idx_to_token
  1181. else:
  1182. tokens = sorted(vocab.keys(), key=lambda token: vocab[token])
  1183. with io.open(filepath, "w", encoding="utf-8") as f:
  1184. for token in tokens:
  1185. f.write(token + "\n")
  1186. def get_special_tokens_mask(
  1187. self, token_ids_0, token_ids_1=None, already_has_special_tokens=False
  1188. ):
  1189. """
  1190. Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
  1191. special tokens using the tokenizer ``encode`` methods.
  1192. Args:
  1193. token_ids_0 (List[int]): List of ids of the first sequence.
  1194. token_ids_1 (List[int], optional): List of ids of the second sequence.
  1195. already_has_special_tokens (bool, optional): Whether or not the token list is already
  1196. formatted with special tokens for the model. Defaults to None.
  1197. Returns:
  1198. results (List[int]): The list of integers in the range [0, 1]:
  1199. 1 for a special token, 0 for a sequence token.
  1200. """
  1201. if already_has_special_tokens:
  1202. if token_ids_1 is not None:
  1203. raise ValueError(
  1204. "You should not supply a second sequence if the provided sequence of "
  1205. "ids is already formatted with special tokens for the model."
  1206. )
  1207. return super().get_special_tokens_mask(
  1208. token_ids_0=token_ids_0,
  1209. token_ids_1=token_ids_1,
  1210. already_has_special_tokens=True,
  1211. )
  1212. return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
  1213. def num_special_tokens_to_add(self, pair):
  1214. """
  1215. Returns the number of added tokens when encoding a sequence with special tokens.
  1216. Args:
  1217. pair (bool, optional):
  1218. Whether the number of added tokens should be computed in the case of a sequence pair or a single
  1219. sequence. Defaults to `False`.
  1220. Returns:
  1221. int: Number of special tokens added to sequences.
  1222. """
  1223. token_ids_0 = []
  1224. token_ids_1 = []
  1225. return len(
  1226. self.build_inputs_with_special_tokens(
  1227. token_ids_0, token_ids_1 if pair else None
  1228. )
  1229. )
  1230. def _encode_plus(
  1231. self,
  1232. text: Union[TextInput, PreTokenizedInput, EncodedInput],
  1233. text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
  1234. add_special_tokens: bool = True,
  1235. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  1236. truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
  1237. max_length: Optional[int] = None,
  1238. stride: int = 0,
  1239. is_split_into_words: bool = False,
  1240. pad_to_multiple_of: Optional[int] = None,
  1241. return_tensors: Optional[Union[str, TensorType]] = None,
  1242. return_position_ids: Optional[bool] = None,
  1243. return_token_type_ids: Optional[bool] = None,
  1244. return_attention_mask: Optional[bool] = None,
  1245. return_overflowing_tokens: bool = False,
  1246. return_special_tokens_mask: bool = False,
  1247. return_offsets_mapping: bool = False,
  1248. return_length: bool = False,
  1249. verbose: bool = True,
  1250. **kwargs,
  1251. ) -> BatchEncoding:
  1252. def get_input_ids(text):
  1253. if isinstance(text, str):
  1254. tokens = self.tokenize(text, **kwargs)
  1255. return self.convert_tokens_to_ids(tokens)
  1256. elif (
  1257. isinstance(text, (list, tuple))
  1258. and len(text) > 0
  1259. and isinstance(text[0], str)
  1260. ):
  1261. if is_split_into_words:
  1262. tokens = list(
  1263. itertools.chain(
  1264. *(
  1265. self.tokenize(t, is_split_into_words=True, **kwargs)
  1266. for t in text
  1267. )
  1268. )
  1269. )
  1270. return self.convert_tokens_to_ids(tokens)
  1271. else:
  1272. return self.convert_tokens_to_ids(text)
  1273. elif (
  1274. isinstance(text, (list, tuple))
  1275. and len(text) > 0
  1276. and isinstance(text[0], int)
  1277. ):
  1278. return text
  1279. else:
  1280. if is_split_into_words:
  1281. raise ValueError(
  1282. f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_split_into_words=True`."
  1283. )
  1284. else:
  1285. raise ValueError(
  1286. f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
  1287. )
  1288. first_ids = get_input_ids(text)
  1289. second_ids = get_input_ids(text_pair) if text_pair is not None else None
  1290. if return_offsets_mapping:
  1291. kwargs["text"] = text
  1292. kwargs["text_pair"] = text_pair
  1293. return self.prepare_for_model(
  1294. first_ids,
  1295. pair_ids=second_ids,
  1296. add_special_tokens=add_special_tokens,
  1297. padding=padding_strategy.value,
  1298. truncation=truncation_strategy.value,
  1299. max_length=max_length,
  1300. stride=stride,
  1301. pad_to_multiple_of=pad_to_multiple_of,
  1302. return_tensors=return_tensors,
  1303. prepend_batch_axis=True,
  1304. return_position_ids=return_position_ids,
  1305. return_attention_mask=return_attention_mask,
  1306. return_token_type_ids=return_token_type_ids,
  1307. return_overflowing_tokens=return_overflowing_tokens,
  1308. return_special_tokens_mask=return_special_tokens_mask,
  1309. return_offsets_mapping=return_offsets_mapping,
  1310. return_length=return_length,
  1311. verbose=verbose,
  1312. **kwargs,
  1313. )
  1314. def _batch_encode_plus(
  1315. self,
  1316. batch_text_or_text_pairs: Union[
  1317. List[TextInput],
  1318. List[TextInputPair],
  1319. List[PreTokenizedInput],
  1320. List[PreTokenizedInputPair],
  1321. List[EncodedInput],
  1322. List[EncodedInputPair],
  1323. ],
  1324. add_special_tokens: bool = True,
  1325. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  1326. truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
  1327. max_length: Optional[int] = None,
  1328. stride: int = 0,
  1329. is_split_into_words: bool = False,
  1330. pad_to_multiple_of: Optional[int] = None,
  1331. return_position_ids: Optional[bool] = None,
  1332. return_tensors: Optional[Union[str, TensorType]] = None,
  1333. return_token_type_ids: Optional[bool] = None,
  1334. return_attention_mask: Optional[bool] = None,
  1335. return_overflowing_tokens: bool = False,
  1336. return_special_tokens_mask: bool = False,
  1337. return_dict: bool = True,
  1338. return_offsets_mapping: bool = False,
  1339. return_length: bool = False,
  1340. verbose: bool = True,
  1341. **kwargs,
  1342. ) -> BatchEncoding:
  1343. def get_input_ids(text):
  1344. if isinstance(text, str):
  1345. tokens = self.tokenize(text, **kwargs)
  1346. return self.convert_tokens_to_ids(tokens)
  1347. elif (
  1348. isinstance(text, (list, tuple))
  1349. and len(text) > 0
  1350. and isinstance(text[0], str)
  1351. ):
  1352. if is_split_into_words:
  1353. tokens = list(
  1354. itertools.chain(
  1355. *(
  1356. self.tokenize(t, is_split_into_words=True, **kwargs)
  1357. for t in text
  1358. )
  1359. )
  1360. )
  1361. return self.convert_tokens_to_ids(tokens)
  1362. else:
  1363. return self.convert_tokens_to_ids(text)
  1364. elif (
  1365. isinstance(text, (list, tuple))
  1366. and len(text) > 0
  1367. and isinstance(text[0], int)
  1368. ):
  1369. return text
  1370. else:
  1371. raise ValueError(
  1372. "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
  1373. )
  1374. input_ids = []
  1375. for ids_or_pair_ids in batch_text_or_text_pairs:
  1376. if not isinstance(ids_or_pair_ids, (list, tuple)):
  1377. ids, pair_ids = ids_or_pair_ids, None
  1378. elif is_split_into_words and not isinstance(
  1379. ids_or_pair_ids[0], (list, tuple)
  1380. ):
  1381. ids, pair_ids = ids_or_pair_ids, None
  1382. else:
  1383. ids, pair_ids = ids_or_pair_ids
  1384. first_ids = get_input_ids(ids)
  1385. second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
  1386. input_ids.append((first_ids, second_ids))
  1387. if stride > 0 and second_ids is not None:
  1388. kwargs["batch_text_or_text_pairs"] = batch_text_or_text_pairs
  1389. else:
  1390. if return_offsets_mapping:
  1391. has_pair = False
  1392. if len(batch_text_or_text_pairs) > 0:
  1393. if isinstance(batch_text_or_text_pairs[0], (list, tuple)):
  1394. has_pair = True
  1395. kwargs["texts"] = None
  1396. kwargs["text_pairs"] = None
  1397. if has_pair:
  1398. kwargs["texts"] = [text[0] for text in batch_text_or_text_pairs]
  1399. kwargs["text_pairs"] = [
  1400. text[1] for text in batch_text_or_text_pairs
  1401. ]
  1402. else:
  1403. kwargs["texts"] = [text for text in batch_text_or_text_pairs]
  1404. batch_outputs = self._batch_prepare_for_model(
  1405. input_ids,
  1406. add_special_tokens=add_special_tokens,
  1407. padding_strategy=padding_strategy,
  1408. truncation_strategy=truncation_strategy,
  1409. max_length=max_length,
  1410. stride=stride,
  1411. pad_to_multiple_of=pad_to_multiple_of,
  1412. return_position_ids=return_position_ids,
  1413. return_attention_mask=return_attention_mask,
  1414. return_token_type_ids=return_token_type_ids,
  1415. return_overflowing_tokens=return_overflowing_tokens,
  1416. return_special_tokens_mask=return_special_tokens_mask,
  1417. return_dict=return_dict,
  1418. return_offsets_mapping=return_offsets_mapping,
  1419. return_length=return_length,
  1420. return_tensors=return_tensors,
  1421. verbose=verbose,
  1422. **kwargs,
  1423. )
  1424. return batch_outputs
  1425. def _batch_prepare_for_model(
  1426. self,
  1427. batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]],
  1428. add_special_tokens: bool = True,
  1429. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  1430. truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
  1431. max_length: Optional[int] = None,
  1432. stride: int = 0,
  1433. pad_to_multiple_of: Optional[int] = None,
  1434. return_position_ids: Optional[bool] = None,
  1435. return_tensors: Optional[str] = None,
  1436. return_token_type_ids: Optional[bool] = None,
  1437. return_attention_mask: Optional[bool] = None,
  1438. return_overflowing_tokens: bool = False,
  1439. return_special_tokens_mask: bool = False,
  1440. return_dict: bool = True,
  1441. return_offsets_mapping: bool = False,
  1442. return_length: bool = False,
  1443. verbose: bool = True,
  1444. **kwargs,
  1445. ) -> BatchEncoding:
  1446. """
  1447. Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
  1448. adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
  1449. manages a moving window (with user defined stride) for overflowing tokens
  1450. Args:
  1451. batch_ids_pairs: list of tokenized input ids or input ids pairs
  1452. """
  1453. if return_token_type_ids and not add_special_tokens:
  1454. raise ValueError(
  1455. "Asking to return token_type_ids while setting add_special_tokens to False "
  1456. "results in an undefined behavior. Please set add_special_tokens to True or "
  1457. "set return_token_type_ids to None."
  1458. )
  1459. batch_outputs = {}
  1460. batch_outputs_list = []
  1461. for example_id, (first_ids, second_ids) in enumerate(batch_ids_pairs):
  1462. if stride > 0 and second_ids is not None:
  1463. if return_token_type_ids is None:
  1464. return_token_type_ids = "token_type_ids" in self.model_input_names
  1465. if return_attention_mask is None:
  1466. return_attention_mask = "attention_mask" in self.model_input_names
  1467. max_len_for_pair = (
  1468. max_length
  1469. - len(first_ids)
  1470. - (
  1471. self.num_special_tokens_to_add(pair=True)
  1472. if add_special_tokens
  1473. else 0
  1474. )
  1475. )
  1476. text, text_pair = kwargs["batch_text_or_text_pairs"][example_id]
  1477. token_offset_mapping = self.get_offset_mapping(text)
  1478. token_pair_offset_mapping = self.get_offset_mapping(text_pair)
  1479. offset = 0
  1480. while offset < len(second_ids):
  1481. encoded_inputs = {}
  1482. length = len(second_ids) - offset
  1483. if length > max_len_for_pair:
  1484. length = max_len_for_pair
  1485. ids = first_ids
  1486. pair_ids = second_ids[offset : offset + length]
  1487. pair = bool(pair_ids is not None)
  1488. mapping = token_offset_mapping
  1489. pair_mapping = token_pair_offset_mapping[offset : offset + length]
  1490. if add_special_tokens:
  1491. offset_mapping = self.build_offset_mapping_with_special_tokens(
  1492. mapping, pair_mapping
  1493. )
  1494. sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
  1495. token_type_ids = self.create_token_type_ids_from_sequences(
  1496. ids, pair_ids
  1497. )
  1498. else:
  1499. offset_mapping = mapping + pair_mapping
  1500. sequence = ids + pair_ids if pair else ids
  1501. token_type_ids = [0] * len(ids) + (
  1502. [0] * len(pair_ids) if pair else []
  1503. )
  1504. encoded_inputs["offset_mapping"] = offset_mapping
  1505. # Build output dictionnary
  1506. encoded_inputs["input_ids"] = sequence
  1507. if return_token_type_ids:
  1508. encoded_inputs["token_type_ids"] = token_type_ids
  1509. if return_special_tokens_mask:
  1510. if add_special_tokens:
  1511. encoded_inputs["special_tokens_mask"] = (
  1512. self.get_special_tokens_mask(ids, pair_ids)
  1513. )
  1514. else:
  1515. encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
  1516. # Check lengths
  1517. self._eventual_warn_about_too_long_sequence(
  1518. encoded_inputs["input_ids"], max_length, verbose
  1519. )
  1520. if return_position_ids:
  1521. encoded_inputs["position_ids"] = list(
  1522. range(len(encoded_inputs["input_ids"]))
  1523. )
  1524. if return_length:
  1525. encoded_inputs["length"] = len(encoded_inputs["input_ids"])
  1526. encoded_inputs["seq_len"] = encoded_inputs["length"]
  1527. encoded_inputs["overflow_to_sample"] = example_id
  1528. for key, value in encoded_inputs.items():
  1529. if key not in batch_outputs:
  1530. batch_outputs[key] = []
  1531. batch_outputs[key].append(value)
  1532. if offset + length == len(second_ids):
  1533. break
  1534. offset += min(length, stride)
  1535. else:
  1536. if return_offsets_mapping:
  1537. kwargs["text"] = kwargs["texts"][example_id]
  1538. kwargs["text_pair"] = None
  1539. if kwargs["text_pairs"] is not None:
  1540. kwargs["text_pair"] = kwargs["text_pairs"][example_id]
  1541. encoded_inputs = self.prepare_for_model(
  1542. first_ids,
  1543. second_ids,
  1544. add_special_tokens=add_special_tokens,
  1545. padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
  1546. truncation=truncation_strategy.value,
  1547. max_length=max_length,
  1548. stride=stride,
  1549. pad_to_multiple_of=None, # we pad in batch afterward
  1550. return_position_ids=return_position_ids, # we pad in batch afterward
  1551. return_attention_mask=False, # we pad in batch afterward
  1552. return_token_type_ids=return_token_type_ids,
  1553. return_overflowing_tokens=return_overflowing_tokens,
  1554. return_special_tokens_mask=return_special_tokens_mask,
  1555. return_offsets_mapping=return_offsets_mapping,
  1556. return_length=return_length,
  1557. return_tensors=None, # We convert the whole batch to tensors at the end
  1558. prepend_batch_axis=False,
  1559. verbose=verbose,
  1560. **kwargs,
  1561. )
  1562. for key, value in encoded_inputs.items():
  1563. if key not in batch_outputs:
  1564. batch_outputs[key] = []
  1565. batch_outputs[key].append(value)
  1566. batch_outputs = self.pad(
  1567. batch_outputs,
  1568. padding=padding_strategy.value,
  1569. max_length=max_length,
  1570. pad_to_multiple_of=pad_to_multiple_of,
  1571. return_attention_mask=return_attention_mask,
  1572. )
  1573. if return_dict:
  1574. batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
  1575. return batch_outputs
  1576. else:
  1577. for k, v in batch_outputs.items():
  1578. for i in range(len(v)):
  1579. if i >= len(batch_outputs_list):
  1580. batch_outputs_list.append({k: v[i]})
  1581. else:
  1582. batch_outputs_list[i][k] = v[i]
  1583. return batch_outputs_list
  1584. def _get_bert_like_offset_mapping(self, text: str):
  1585. """
  1586. Returns the map of tokens and the start and end index of their start and end character.
  1587. Modified from https://github.com/bojone/bert4keras/blob/master/bert4keras/tokenizers.py#L372
  1588. Args:
  1589. text (str):
  1590. Input text.
  1591. Returns:
  1592. list: The offset map of input text.
  1593. """
  1594. if text is None:
  1595. return None
  1596. split_tokens = self.tokenize(text)
  1597. normalized_text, char_mapping = "", []
  1598. for i, ch in enumerate(text):
  1599. if hasattr(self, "do_lower_case") and self.do_lower_case:
  1600. ch = ch.lower()
  1601. if self.basic_tokenizer.strip_accents is not False:
  1602. ch = unicodedata.normalize("NFD", ch)
  1603. ch = "".join([c for c in ch if unicodedata.category(c) != "Mn"])
  1604. elif self.basic_tokenizer.strip_accents:
  1605. ch = unicodedata.normalize("NFD", ch)
  1606. ch = "".join([c for c in ch if unicodedata.category(c) != "Mn"])
  1607. ch = "".join(
  1608. [
  1609. c
  1610. for c in ch
  1611. if not (ord(c) == 0 or ord(c) == 0xFFFD or _is_control(c))
  1612. ]
  1613. )
  1614. normalized_text += ch
  1615. char_mapping.extend([i] * len(ch))
  1616. text, token_mapping, offset = normalized_text, [], 0
  1617. char_mapping_indexes = []
  1618. for index, token in enumerate(split_tokens):
  1619. if token[:2] == "##":
  1620. token = token[2:]
  1621. if token in self.all_special_tokens:
  1622. token = (
  1623. token.lower()
  1624. if hasattr(self, "do_lower_case") and self.do_lower_case
  1625. else token
  1626. )
  1627. # The greek letter "sigma" has 2 forms of lowercase, σ and ς respectively.
  1628. # When used as a final letter of a word, the final form (ς) is used. Otherwise, the form (σ) is used.
  1629. # https://latin.stackexchange.com/questions/6168/how-and-when-did-we-get-two-forms-of-sigma
  1630. if "σ" in token or "ς" in token:
  1631. start = (
  1632. text[offset:].replace("ς", "σ").index(token.replace("ς", "σ"))
  1633. + offset
  1634. )
  1635. else:
  1636. # try to fix: https://github.com/PaddlePaddle/PaddleNLP/issues/3985
  1637. if token not in text[offset:]:
  1638. # check whether there are consecutive UNK tokens, eg: ['好', '[UNK]', '[UNK]', 'good']
  1639. if (
  1640. index < len(split_tokens) - 1
  1641. and split_tokens[index + 1] in self.all_special_tokens
  1642. ):
  1643. start = offset
  1644. token = " " # only contains one char
  1645. else:
  1646. start = -1
  1647. else:
  1648. start = text[offset:].index(token) + offset
  1649. end = start + len(token)
  1650. char_mapping_indexes.append([start, end])
  1651. if start != -1:
  1652. offset = end
  1653. token_mapping = []
  1654. for index, (start, end) in enumerate(char_mapping_indexes):
  1655. if start == -1:
  1656. # init start
  1657. if index == 0:
  1658. start = 0
  1659. else:
  1660. start = char_mapping_indexes[index - 1][1]
  1661. # init end
  1662. if index == len(char_mapping_indexes) - 1:
  1663. end = len(char_mapping)
  1664. else:
  1665. # next start
  1666. end = char_mapping_indexes[index + 1][0]
  1667. token_mapping.append((char_mapping[start], char_mapping[end - 1] + 1))
  1668. return token_mapping
  1669. def get_offset_mapping(self, text: str, split_tokens: Optional[List[str]] = None):
  1670. """
  1671. Returns the map of tokens and the start and end index of their start and end character.
  1672. Modified from https://github.com/bojone/bert4keras/blob/master/bert4keras/tokenizers.py#L372
  1673. Args:
  1674. text (str):
  1675. Input text.
  1676. split_tokens (Optional[List[str]]):
  1677. the tokens which has been split which can accelerate the operation.
  1678. Returns:
  1679. list: The offset map of input text.
  1680. """
  1681. if text is None:
  1682. return None
  1683. split_tokens = self.tokenize(text)
  1684. # bert-like tokenizer use the old-school code block
  1685. if hasattr(self, "basic_tokenizer") or hasattr(self, "wordpiece_tokenizer"):
  1686. return self._get_bert_like_offset_mapping(text)
  1687. if not split_tokens:
  1688. split_tokens = self.tokenize(text)
  1689. normalized_text, char_mapping = "", []
  1690. for i, ch in enumerate(text):
  1691. normalized_text += normalize_chars(ch)
  1692. char_mapping.extend([i] * len(ch))
  1693. text, token_mapping, offset = normalized_text, [], 0
  1694. do_lower_case = getattr(self, "do_lower_case", False)
  1695. # lower the text if the token is lower-cased
  1696. # keep align with token
  1697. if do_lower_case:
  1698. text = text.lower()
  1699. char_mapping_indexes = []
  1700. for token in split_tokens:
  1701. # convert tokens into original string
  1702. token: str = self.convert_tokens_to_string(token).strip()
  1703. if token in self.all_special_tokens:
  1704. if do_lower_case:
  1705. token = token.lower()
  1706. # The greek letter "sigma" has 2 forms of lowercase, σ and ς respectively.
  1707. # When used as a final letter of a word, the final form (ς) is used. Otherwise, the form (σ) is used.
  1708. # https://latin.stackexchange.com/questions/6168/how-and-when-did-we-get-two-forms-of-sigma
  1709. if "σ" in token or "ς" in token:
  1710. start = (
  1711. text[offset:].replace("ς", "σ").index(token.replace("ς", "σ"))
  1712. + offset
  1713. )
  1714. else:
  1715. # try to fix: https://github.com/PaddlePaddle/PaddleNLP/issues/3985
  1716. if token not in text[offset:]:
  1717. start = -1
  1718. else:
  1719. start = text[offset:].index(token) + offset
  1720. end = start + len(token)
  1721. char_mapping_indexes.append([start, end])
  1722. if start != -1:
  1723. offset = end
  1724. token_mapping = []
  1725. for index, (start, end) in enumerate(char_mapping_indexes):
  1726. if start == -1:
  1727. # init start
  1728. if index == 0:
  1729. start = 0
  1730. else:
  1731. start = char_mapping_indexes[index - 1][1]
  1732. # init end
  1733. if index == len(char_mapping_indexes) - 1:
  1734. end = len(char_mapping)
  1735. else:
  1736. # next start
  1737. end = char_mapping_indexes[index + 1][0]
  1738. token_mapping.append((char_mapping[start], char_mapping[end - 1] + 1))
  1739. return token_mapping
  1740. def _decode(
  1741. self,
  1742. token_ids: List[int],
  1743. skip_special_tokens: bool = False,
  1744. clean_up_tokenization_spaces: bool = True,
  1745. spaces_between_special_tokens: bool = True,
  1746. **kwargs,
  1747. ) -> str:
  1748. if isinstance(token_ids, np.ndarray):
  1749. token_ids = token_ids.tolist()
  1750. self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
  1751. filtered_tokens = self.convert_ids_to_tokens(
  1752. token_ids, skip_special_tokens=skip_special_tokens
  1753. )
  1754. # To avoid mixing byte-level and unicode for byte-level BPT
  1755. # we need to build string separately for added tokens and byte-level tokens
  1756. # cf. https://github.com/huggingface/transformers/issues/1133
  1757. sub_texts = []
  1758. current_sub_text = []
  1759. for token in filtered_tokens:
  1760. if skip_special_tokens and token in self.all_special_ids:
  1761. continue
  1762. if token in self.added_tokens_encoder:
  1763. if current_sub_text:
  1764. sub_texts.append(self.convert_tokens_to_string(current_sub_text))
  1765. current_sub_text = []
  1766. sub_texts.append(token)
  1767. else:
  1768. current_sub_text.append(token)
  1769. if current_sub_text:
  1770. sub_texts.append(self.convert_tokens_to_string(current_sub_text))
  1771. if spaces_between_special_tokens:
  1772. text = " ".join(sub_texts)
  1773. else:
  1774. text = "".join(sub_texts)
  1775. if clean_up_tokenization_spaces:
  1776. clean_text = self.clean_up_tokenization(text)
  1777. return clean_text
  1778. else:
  1779. return text
  1780. def decode_token(
  1781. self,
  1782. all_input_ids: List[int],
  1783. prefix_offset: int = 0,
  1784. read_offset: int = 0,
  1785. ) -> Tuple[str, int, int]:
  1786. """tokenizer decoding for the streaming generation use case. This method can be overrided for tokenizer that doesn't follow this API"""
  1787. # The prefix text is necessary only to defeat cleanup algorithms in the decode
  1788. # which decide to add a space or not depending on the surrounding ids.
  1789. prefix_text = self.decode(
  1790. all_input_ids[prefix_offset:read_offset], skip_special_tokens=False
  1791. )
  1792. new_text = self.decode(all_input_ids[prefix_offset:], skip_special_tokens=False)
  1793. if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
  1794. # utf-8 char at the end means it's a potential unfinished byte sequence
  1795. # from byte fallback tokenization.
  1796. # If it's in the middle, it's probably a real invalid id generated
  1797. # by the model
  1798. prefix_index = new_text.index(prefix_text)
  1799. new_text = new_text[prefix_index + len(prefix_text) :]
  1800. return new_text, read_offset, len(all_input_ids)
  1801. else:
  1802. return "", prefix_offset, read_offset
  1803. def _is_control(char):
  1804. """Checks whether `chars` is a control character."""
  1805. # These are technically control characters but we count them as whitespace
  1806. # characters.
  1807. if char == "\t" or char == "\n" or char == "\r":
  1808. return False
  1809. cat = unicodedata.category(char)
  1810. if cat.startswith("C"):
  1811. return True
  1812. return False
  1813. def _is_punctuation(char):
  1814. """Checks whether `chars` is a punctuation character."""
  1815. cp = ord(char)
  1816. # We treat all non-letter/number ASCII as punctuation.
  1817. # Characters such as "^", "$", and "`" are not in the Unicode
  1818. # Punctuation class but we treat them as punctuation anyways, for
  1819. # consistency.
  1820. if (
  1821. (cp >= 33 and cp <= 47)
  1822. or (cp >= 58 and cp <= 64)
  1823. or (cp >= 91 and cp <= 96)
  1824. or (cp >= 123 and cp <= 126)
  1825. ):
  1826. return True
  1827. cat = unicodedata.category(char)
  1828. if cat.startswith("P"):
  1829. return True
  1830. return False
  1831. def _is_symbol(char):
  1832. """Check whether CP is the codepoint of a Symbol character."""
  1833. cp = ord(char)
  1834. if unicodedata.category(char).startswith("S") or (
  1835. cp in [0x00AD, 0x00B2, 0x00BA, 0x3007, 0x00B5, 0x00D8, 0x014B, 0x01B1]
  1836. ):
  1837. return True
  1838. return False
  1839. def _is_whitespace(char):
  1840. """
  1841. Checks whether `chars` is a whitespace character.
  1842. """
  1843. # \t, \n, and \r are technically contorl characters but we treat them
  1844. # as whitespace since they are generally considered as such.
  1845. if char == " " or char == "\t" or char == "\n" or char == "\r":
  1846. return True
  1847. cat = unicodedata.category(char)
  1848. if cat == "Zs":
  1849. return True
  1850. return False
  1851. def convert_to_unicode(text):
  1852. """
  1853. Converts `text` to Unicode (if it's not already), assuming utf-8 input.
  1854. Args:
  1855. text (str|bytes): Text to be converted to unicode.
  1856. Returns:
  1857. str: converted text.
  1858. """
  1859. if isinstance(text, str):
  1860. return text
  1861. elif isinstance(text, bytes):
  1862. return text.decode("utf-8", "ignore")
  1863. else:
  1864. raise ValueError("Unsupported string type: %s" % (type(text)))
  1865. def whitespace_tokenize(text):
  1866. """
  1867. Runs basic whitespace cleaning and splitting on a peice of text.
  1868. Args:
  1869. text (str): Text to be tokenized.
  1870. Returns:
  1871. list(str): Token list.
  1872. """
  1873. text = text.strip()
  1874. if not text:
  1875. return []
  1876. tokens = text.split()
  1877. return tokens