processors.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. import json
  2. import numpy as np
  3. import cv2
  4. import math
  5. import re
  6. from PIL import Image, ImageOps
  7. from typing import List, Optional, Tuple, Union, Dict, Any
  8. from loguru import logger
  9. from tokenizers import AddedToken
  10. from tokenizers import Tokenizer as TokenizerFast
  11. from mineru.model.mfr.utils import fix_latex_left_right, fix_latex_environments, remove_up_commands, \
  12. remove_unsupported_commands
  13. class UniMERNetImgDecode(object):
  14. """Class for decoding images for UniMERNet, including cropping margins, resizing, and padding."""
  15. def __init__(
  16. self, input_size: Tuple[int, int], random_padding: bool = False, **kwargs
  17. ) -> None:
  18. """Initializes the UniMERNetImgDecode class with input size and random padding options.
  19. Args:
  20. input_size (tuple): The desired input size for the images (height, width).
  21. random_padding (bool): Whether to use random padding for resizing.
  22. **kwargs: Additional keyword arguments."""
  23. self.input_size = input_size
  24. self.random_padding = random_padding
  25. def crop_margin(self, img: Image.Image) -> Image.Image:
  26. """Crops the margin of the image based on grayscale thresholding.
  27. Args:
  28. img (PIL.Image.Image): The input image.
  29. Returns:
  30. PIL.Image.Image: The cropped image."""
  31. data = np.array(img.convert("L"))
  32. data = data.astype(np.uint8)
  33. max_val = data.max()
  34. min_val = data.min()
  35. if max_val == min_val:
  36. return img
  37. data = (data - min_val) / (max_val - min_val) * 255
  38. gray = 255 * (data < 200).astype(np.uint8)
  39. coords = cv2.findNonZero(gray) # Find all non-zero points (text)
  40. a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
  41. return img.crop((a, b, w + a, h + b))
  42. def get_dimensions(self, img: Union[Image.Image, np.ndarray]) -> List[int]:
  43. """Gets the dimensions of the image.
  44. Args:
  45. img (PIL.Image.Image or numpy.ndarray): The input image.
  46. Returns:
  47. list: A list containing the number of channels, height, and width."""
  48. if hasattr(img, "getbands"):
  49. channels = len(img.getbands())
  50. else:
  51. channels = img.channels
  52. width, height = img.size
  53. return [channels, height, width]
  54. def _compute_resized_output_size(
  55. self,
  56. image_size: Tuple[int, int],
  57. size: Union[int, Tuple[int, int]],
  58. max_size: Optional[int] = None,
  59. ) -> List[int]:
  60. """Computes the resized output size of the image.
  61. Args:
  62. image_size (tuple): The original size of the image (height, width).
  63. size (int or tuple): The desired size for the smallest edge or both height and width.
  64. max_size (int, optional): The maximum allowed size for the longer edge.
  65. Returns:
  66. list: A list containing the new height and width."""
  67. if len(size) == 1: # specified size only for the smallest edge
  68. h, w = image_size
  69. short, long = (w, h) if w <= h else (h, w)
  70. requested_new_short = size if isinstance(size, int) else size[0]
  71. new_short, new_long = requested_new_short, int(
  72. requested_new_short * long / short
  73. )
  74. if max_size is not None:
  75. if max_size <= requested_new_short:
  76. raise ValueError(
  77. f"max_size = {max_size} must be strictly greater than the requested "
  78. f"size for the smaller edge size = {size}"
  79. )
  80. if new_long > max_size:
  81. new_short, new_long = int(max_size * new_short / new_long), max_size
  82. new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
  83. else: # specified both h and w
  84. new_w, new_h = size[1], size[0]
  85. return [new_h, new_w]
  86. def resize(
  87. self, img: Image.Image, size: Union[int, Tuple[int, int]]
  88. ) -> Image.Image:
  89. """Resizes the image to the specified size.
  90. Args:
  91. img (PIL.Image.Image): The input image.
  92. size (int or tuple): The desired size for the smallest edge or both height and width.
  93. Returns:
  94. PIL.Image.Image: The resized image."""
  95. _, image_height, image_width = self.get_dimensions(img)
  96. if isinstance(size, int):
  97. size = [size]
  98. max_size = None
  99. output_size = self._compute_resized_output_size(
  100. (image_height, image_width), size, max_size
  101. )
  102. img = img.resize(tuple(output_size[::-1]), resample=2)
  103. return img
  104. def img_decode(self, img: np.ndarray) -> Optional[np.ndarray]:
  105. """Decodes the image by cropping margins, resizing, and adding padding.
  106. Args:
  107. img (numpy.ndarray): The input image array.
  108. Returns:
  109. numpy.ndarray: The decoded image array."""
  110. try:
  111. img = self.crop_margin(Image.fromarray(img).convert("RGB"))
  112. except OSError:
  113. return
  114. if img.height == 0 or img.width == 0:
  115. return
  116. img = self.resize(img, min(self.input_size))
  117. img.thumbnail((self.input_size[1], self.input_size[0]))
  118. delta_width = self.input_size[1] - img.width
  119. delta_height = self.input_size[0] - img.height
  120. if self.random_padding:
  121. pad_width = np.random.randint(low=0, high=delta_width + 1)
  122. pad_height = np.random.randint(low=0, high=delta_height + 1)
  123. else:
  124. pad_width = delta_width // 2
  125. pad_height = delta_height // 2
  126. padding = (
  127. pad_width,
  128. pad_height,
  129. delta_width - pad_width,
  130. delta_height - pad_height,
  131. )
  132. return np.array(ImageOps.expand(img, padding))
  133. def __call__(self, imgs: List[np.ndarray]) -> List[Optional[np.ndarray]]:
  134. """Calls the img_decode method on a list of images.
  135. Args:
  136. imgs (list of numpy.ndarray): The list of input image arrays.
  137. Returns:
  138. list of numpy.ndarray: The list of decoded image arrays."""
  139. return [self.img_decode(img) for img in imgs]
  140. class UniMERNetTestTransform:
  141. """
  142. A class for transforming images according to UniMERNet test specifications.
  143. """
  144. def __init__(self, **kwargs) -> None:
  145. """
  146. Initializes the UniMERNetTestTransform class.
  147. """
  148. super().__init__()
  149. self.num_output_channels = 3
  150. def transform(self, img: np.ndarray) -> np.ndarray:
  151. """
  152. Transforms a single image for UniMERNet testing.
  153. Args:
  154. img (numpy.ndarray): The input image.
  155. Returns:
  156. numpy.ndarray: The transformed image.
  157. """
  158. mean = [0.7931, 0.7931, 0.7931]
  159. std = [0.1738, 0.1738, 0.1738]
  160. scale = float(1 / 255.0)
  161. shape = (1, 1, 3)
  162. mean = np.array(mean).reshape(shape).astype("float32")
  163. std = np.array(std).reshape(shape).astype("float32")
  164. img = (img.astype("float32") * scale - mean) / std
  165. grayscale_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  166. squeezed = np.squeeze(grayscale_image)
  167. img = cv2.merge([squeezed] * self.num_output_channels)
  168. return img
  169. def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
  170. """
  171. Applies the transform to a list of images.
  172. Args:
  173. imgs (list of numpy.ndarray): The list of input images.
  174. Returns:
  175. list of numpy.ndarray: The list of transformed images.
  176. """
  177. return [self.transform(img) for img in imgs]
  178. class LatexImageFormat:
  179. """Class for formatting images to a specific format suitable for LaTeX."""
  180. def __init__(self, **kwargs) -> None:
  181. """Initializes the LatexImageFormat class with optional keyword arguments."""
  182. super().__init__()
  183. def format(self, img: np.ndarray) -> np.ndarray:
  184. """Formats a single image to the LaTeX-compatible format.
  185. Args:
  186. img (numpy.ndarray): The input image as a numpy array.
  187. Returns:
  188. numpy.ndarray: The formatted image as a numpy array with an added dimension for color.
  189. """
  190. im_h, im_w = img.shape[:2]
  191. divide_h = math.ceil(im_h / 16) * 16
  192. divide_w = math.ceil(im_w / 16) * 16
  193. img = img[:, :, 0]
  194. img = np.pad(
  195. img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
  196. )
  197. img_expanded = img[:, :, np.newaxis].transpose(2, 0, 1)
  198. return img_expanded[np.newaxis, :]
  199. def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
  200. """Applies the format method to a list of images.
  201. Args:
  202. imgs (list of numpy.ndarray): A list of input images as numpy arrays.
  203. Returns:
  204. list of numpy.ndarray: A list of formatted images as numpy arrays.
  205. """
  206. return [self.format(img) for img in imgs]
  207. class ToBatch(object):
  208. """A class for batching images."""
  209. def __init__(self, **kwargs) -> None:
  210. """Initializes the ToBatch object."""
  211. super(ToBatch, self).__init__()
  212. def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
  213. """Concatenates a list of images into a single batch.
  214. Args:
  215. imgs (list): A list of image arrays to be concatenated.
  216. Returns:
  217. list: A list containing the concatenated batch of images wrapped in another list (to comply with common batch processing formats).
  218. """
  219. batch_imgs = np.concatenate(imgs)
  220. batch_imgs = batch_imgs.copy()
  221. x = [batch_imgs]
  222. return x
  223. class UniMERNetDecode(object):
  224. """Class for decoding tokenized inputs using UniMERNet tokenizer.
  225. Attributes:
  226. SPECIAL_TOKENS_ATTRIBUTES (List[str]): List of special token attributes.
  227. model_input_names (List[str]): List of model input names.
  228. max_seq_len (int): Maximum sequence length.
  229. pad_token_id (int): ID for the padding token.
  230. bos_token_id (int): ID for the beginning-of-sequence token.
  231. eos_token_id (int): ID for the end-of-sequence token.
  232. padding_side (str): Padding side, either 'left' or 'right'.
  233. pad_token (str): Padding token.
  234. pad_token_type_id (int): Type ID for the padding token.
  235. pad_to_multiple_of (Optional[int]): If set, pad to a multiple of this value.
  236. tokenizer (TokenizerFast): Fast tokenizer instance.
  237. Args:
  238. character_list (Dict[str, Any]): Dictionary containing tokenizer configuration.
  239. **kwargs: Additional keyword arguments.
  240. """
  241. SPECIAL_TOKENS_ATTRIBUTES = [
  242. "bos_token",
  243. "eos_token",
  244. "unk_token",
  245. "sep_token",
  246. "pad_token",
  247. "cls_token",
  248. "mask_token",
  249. "additional_special_tokens",
  250. ]
  251. def __init__(
  252. self,
  253. character_list: Dict[str, Any],
  254. **kwargs,
  255. ) -> None:
  256. """Initializes the UniMERNetDecode class.
  257. Args:
  258. character_list (Dict[str, Any]): Dictionary containing tokenizer configuration.
  259. **kwargs: Additional keyword arguments.
  260. """
  261. self._unk_token = "<unk>"
  262. self._bos_token = "<s>"
  263. self._eos_token = "</s>"
  264. self._pad_token = "<pad>"
  265. self._sep_token = None
  266. self._cls_token = None
  267. self._mask_token = None
  268. self._additional_special_tokens = []
  269. self.model_input_names = ["input_ids", "token_type_ids", "attention_mask"]
  270. self.max_seq_len = 2048
  271. self.pad_token_id = 1
  272. self.bos_token_id = 0
  273. self.eos_token_id = 2
  274. self.padding_side = "right"
  275. self.pad_token_id = 1
  276. self.pad_token = "<pad>"
  277. self.pad_token_type_id = 0
  278. self.pad_to_multiple_of = None
  279. fast_tokenizer_str = json.dumps(character_list["fast_tokenizer_file"])
  280. fast_tokenizer_buffer = fast_tokenizer_str.encode("utf-8")
  281. self.tokenizer = TokenizerFast.from_buffer(fast_tokenizer_buffer)
  282. tokenizer_config = (
  283. character_list["tokenizer_config_file"]
  284. if "tokenizer_config_file" in character_list
  285. else None
  286. )
  287. added_tokens_decoder = {}
  288. added_tokens_map = {}
  289. if tokenizer_config is not None:
  290. init_kwargs = tokenizer_config
  291. if "added_tokens_decoder" in init_kwargs:
  292. for idx, token in init_kwargs["added_tokens_decoder"].items():
  293. if isinstance(token, dict):
  294. token = AddedToken(**token)
  295. if isinstance(token, AddedToken):
  296. added_tokens_decoder[int(idx)] = token
  297. added_tokens_map[str(token)] = token
  298. else:
  299. raise ValueError(
  300. f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance"
  301. )
  302. init_kwargs["added_tokens_decoder"] = added_tokens_decoder
  303. added_tokens_decoder = init_kwargs.pop("added_tokens_decoder", {})
  304. tokens_to_add = [
  305. token
  306. for index, token in sorted(
  307. added_tokens_decoder.items(), key=lambda x: x[0]
  308. )
  309. if token not in added_tokens_decoder
  310. ]
  311. added_tokens_encoder = self.added_tokens_encoder(added_tokens_decoder)
  312. encoder = list(added_tokens_encoder.keys()) + [
  313. str(token) for token in tokens_to_add
  314. ]
  315. tokens_to_add += [
  316. token
  317. for token in self.all_special_tokens_extended
  318. if token not in encoder and token not in tokens_to_add
  319. ]
  320. if len(tokens_to_add) > 0:
  321. is_last_special = None
  322. tokens = []
  323. special_tokens = self.all_special_tokens
  324. for token in tokens_to_add:
  325. is_special = (
  326. (token.special or str(token) in special_tokens)
  327. if isinstance(token, AddedToken)
  328. else str(token) in special_tokens
  329. )
  330. if is_last_special is None or is_last_special == is_special:
  331. tokens.append(token)
  332. else:
  333. self._add_tokens(tokens, special_tokens=is_last_special)
  334. tokens = [token]
  335. is_last_special = is_special
  336. if tokens:
  337. self._add_tokens(tokens, special_tokens=is_last_special)
  338. def _add_tokens(
  339. self, new_tokens: "List[Union[AddedToken, str]]", special_tokens: bool = False
  340. ) -> "List[Union[AddedToken, str]]":
  341. """Adds new tokens to the tokenizer.
  342. Args:
  343. new_tokens (List[Union[AddedToken, str]]): Tokens to be added.
  344. special_tokens (bool): Indicates whether the tokens are special tokens.
  345. Returns:
  346. List[Union[AddedToken, str]]: added tokens.
  347. """
  348. if special_tokens:
  349. return self.tokenizer.add_special_tokens(new_tokens)
  350. return self.tokenizer.add_tokens(new_tokens)
  351. def added_tokens_encoder(
  352. self, added_tokens_decoder: "Dict[int, AddedToken]"
  353. ) -> Dict[str, int]:
  354. """Creates an encoder dictionary from added tokens.
  355. Args:
  356. added_tokens_decoder (Dict[int, AddedToken]): Dictionary mapping token IDs to tokens.
  357. Returns:
  358. Dict[str, int]: Dictionary mapping token strings to IDs.
  359. """
  360. return {
  361. k.content: v
  362. for v, k in sorted(added_tokens_decoder.items(), key=lambda item: item[0])
  363. }
  364. @property
  365. def all_special_tokens(self) -> List[str]:
  366. """Retrieves all special tokens.
  367. Returns:
  368. List[str]: List of all special tokens as strings.
  369. """
  370. all_toks = [str(s) for s in self.all_special_tokens_extended]
  371. return all_toks
  372. @property
  373. def all_special_tokens_extended(self) -> "List[Union[str, AddedToken]]":
  374. """Retrieves all special tokens, including extended ones.
  375. Returns:
  376. List[Union[str, AddedToken]]: List of all special tokens.
  377. """
  378. all_tokens = []
  379. seen = set()
  380. for value in self.special_tokens_map_extended.values():
  381. if isinstance(value, (list, tuple)):
  382. tokens_to_add = [token for token in value if str(token) not in seen]
  383. else:
  384. tokens_to_add = [value] if str(value) not in seen else []
  385. seen.update(map(str, tokens_to_add))
  386. all_tokens.extend(tokens_to_add)
  387. return all_tokens
  388. @property
  389. def special_tokens_map_extended(self) -> Dict[str, Union[str, List[str]]]:
  390. """Retrieves the extended map of special tokens.
  391. Returns:
  392. Dict[str, Union[str, List[str]]]: Dictionary mapping special token attributes to their values.
  393. """
  394. set_attr = {}
  395. for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
  396. attr_value = getattr(self, "_" + attr)
  397. if attr_value:
  398. set_attr[attr] = attr_value
  399. return set_attr
  400. def convert_ids_to_tokens(
  401. self, ids: Union[int, List[int]], skip_special_tokens: bool = False
  402. ) -> Union[str, List[str]]:
  403. """Converts token IDs to token strings.
  404. Args:
  405. ids (Union[int, List[int]]): Token ID(s) to convert.
  406. skip_special_tokens (bool): Whether to skip special tokens during conversion.
  407. Returns:
  408. Union[str, List[str]]: Converted token string(s).
  409. """
  410. if isinstance(ids, int):
  411. return self.tokenizer.id_to_token(ids)
  412. tokens = []
  413. for index in ids:
  414. index = int(index)
  415. if skip_special_tokens and index in self.all_special_ids:
  416. continue
  417. tokens.append(self.tokenizer.id_to_token(index))
  418. return tokens
  419. def detokenize(self, tokens: List[List[int]]) -> List[List[str]]:
  420. """Detokenizes a list of token IDs back into strings.
  421. Args:
  422. tokens (List[List[int]]): List of token ID lists.
  423. Returns:
  424. List[List[str]]: List of detokenized strings.
  425. """
  426. self.tokenizer.bos_token = "<s>"
  427. self.tokenizer.eos_token = "</s>"
  428. self.tokenizer.pad_token = "<pad>"
  429. toks = [self.convert_ids_to_tokens(tok) for tok in tokens]
  430. for b in range(len(toks)):
  431. for i in reversed(range(len(toks[b]))):
  432. if toks[b][i] is None:
  433. toks[b][i] = ""
  434. toks[b][i] = toks[b][i].replace("Ġ", " ").strip()
  435. if toks[b][i] in (
  436. [
  437. self.tokenizer.bos_token,
  438. self.tokenizer.eos_token,
  439. self.tokenizer.pad_token,
  440. ]
  441. ):
  442. del toks[b][i]
  443. return toks
  444. def token2str(self, token_ids: List[List[int]]) -> List[str]:
  445. """Converts a list of token IDs to strings.
  446. Args:
  447. token_ids (List[List[int]]): List of token ID lists.
  448. Returns:
  449. List[str]: List of converted strings.
  450. """
  451. generated_text = []
  452. for tok_id in token_ids:
  453. end_idx = np.argwhere(tok_id == 2)
  454. if len(end_idx) > 0:
  455. end_idx = int(end_idx[0][0])
  456. tok_id = tok_id[: end_idx + 1]
  457. generated_text.append(
  458. self.tokenizer.decode(tok_id, skip_special_tokens=True)
  459. )
  460. generated_text = [self.post_process(text) for text in generated_text]
  461. return generated_text
  462. def normalize(self, s: str) -> str:
  463. """Normalizes a string by removing unnecessary spaces.
  464. Args:
  465. s (str): String to normalize.
  466. Returns:
  467. str: Normalized string.
  468. """
  469. text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
  470. letter = "[a-zA-Z]"
  471. noletter = r"[\W_^\d]"
  472. names = []
  473. for x in re.findall(text_reg, s):
  474. pattern = r"(\\[a-zA-Z]+)\s(?=\w)|\\[a-zA-Z]+\s(?=})"
  475. matches = re.findall(pattern, x[0])
  476. for m in matches:
  477. if (
  478. m
  479. not in [
  480. "\\operatorname",
  481. "\\mathrm",
  482. "\\text",
  483. "\\mathbf",
  484. ]
  485. and m.strip() != ""
  486. ):
  487. s = s.replace(m, m + "XXXXXXX")
  488. s = s.replace(" ", "")
  489. names.append(s)
  490. if len(names) > 0:
  491. s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
  492. news = s
  493. while True:
  494. s = news
  495. news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
  496. news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
  497. news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
  498. if news == s:
  499. break
  500. return s.replace("XXXXXXX", " ")
  501. def remove_chinese_text_wrapping(self, formula):
  502. pattern = re.compile(r"\\text\s*{\s*([^}]*?[\u4e00-\u9fff]+[^}]*?)\s*}")
  503. def replacer(match):
  504. return match.group(1)
  505. replaced_formula = pattern.sub(replacer, formula)
  506. return replaced_formula.replace('"', "")
  507. def post_process(self, text: str) -> str:
  508. """Post-processes a string by fixing text and normalizing it.
  509. Args:
  510. text (str): String to post-process.
  511. Returns:
  512. str: Post-processed string.
  513. """
  514. from ftfy import fix_text
  515. text = self.remove_chinese_text_wrapping(text)
  516. text = fix_text(text)
  517. # logger.debug(f"Text after ftfy fix: {text}")
  518. text = self.fix_latex(text)
  519. # logger.debug(f"Text after LaTeX fix: {text}")
  520. return text
  521. def fix_latex(self, text: str) -> str:
  522. """Fixes LaTeX formatting in a string.
  523. Args:
  524. text (str): String to fix.
  525. Returns:
  526. str: Fixed string.
  527. """
  528. text = fix_latex_left_right(text, fix_delimiter=False)
  529. text = fix_latex_environments(text)
  530. text = remove_up_commands(text)
  531. text = remove_unsupported_commands(text)
  532. # text = self.normalize(text)
  533. return text
  534. def __call__(
  535. self,
  536. preds: np.ndarray,
  537. label: Optional[np.ndarray] = None,
  538. mode: str = "eval",
  539. *args,
  540. **kwargs,
  541. ) -> Union[List[str], tuple]:
  542. """Processes predictions and optionally labels, returning the decoded text.
  543. Args:
  544. preds (np.ndarray): Model predictions.
  545. label (Optional[np.ndarray]): True labels, if available.
  546. mode (str): Mode of operation, either 'train' or 'eval'.
  547. Returns:
  548. Union[List[str], tuple]: Decoded text, optionally with labels.
  549. """
  550. if mode == "train":
  551. preds_idx = np.array(preds.argmax(axis=2))
  552. text = self.token2str(preds_idx)
  553. else:
  554. text = self.token2str(np.array(preds))
  555. if label is None:
  556. return text
  557. label = self.token2str(np.array(label))
  558. return text, label