| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986 |
- # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import os.path as osp
- import re
- import numpy as np
- from PIL import Image, ImageOps, ImageDraw
- import cv2
- import math
- import json
- import tempfile
- from tokenizers import Tokenizer as TokenizerFast
- from tokenizers import AddedToken
- from typing import List, Tuple, Optional, Any, Dict, Union
- from ....utils import logging
- class MinMaxResize:
- """Class for resizing images to be within specified minimum and maximum dimensions, with padding and normalization."""
- def __init__(
- self,
- min_dimensions: Optional[List[int]] = [32, 32],
- max_dimensions: Optional[List[int]] = [672, 192],
- **kwargs,
- ) -> None:
- """Initializes the MinMaxResize class with minimum and maximum dimensions.
- Args:
- min_dimensions (list of int, optional): Minimum dimensions (width, height). Defaults to [32, 32].
- max_dimensions (list of int, optional): Maximum dimensions (width, height). Defaults to [672, 192].
- **kwargs: Additional keyword arguments for future expansion.
- """
- self.min_dimensions = min_dimensions
- self.max_dimensions = max_dimensions
- def pad_(self, img: Image.Image, divable: int = 32) -> Image.Image:
- """Pads the image to ensure its dimensions are divisible by a specified value.
- Args:
- img (PIL.Image.Image): The input image.
- divable (int, optional): The value by which the dimensions should be divisible. Defaults to 32.
- Returns:
- PIL.Image.Image: The padded image.
- """
- threshold = 128
- data = np.array(img.convert("LA"))
- if data[..., -1].var() == 0:
- data = (data[..., 0]).astype(np.uint8)
- else:
- data = (255 - data[..., -1]).astype(np.uint8)
- data = (data - data.min()) / (data.max() - data.min()) * 255
- if data.mean() > threshold:
- # To invert the text to white
- gray = 255 * (data < threshold).astype(np.uint8)
- else:
- gray = 255 * (data > threshold).astype(np.uint8)
- data = 255 - data
- coords = cv2.findNonZero(gray) # Find all non-zero points (text)
- a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
- rect = data[b : b + h, a : a + w]
- im = Image.fromarray(rect).convert("L")
- dims = []
- for x in [w, h]:
- div, mod = divmod(x, divable)
- dims.append(divable * (div + (1 if mod > 0 else 0)))
- padded = Image.new("L", dims, 255)
- padded.paste(im, (0, 0, im.size[0], im.size[1]))
- return padded
- def minmax_size_(
- self,
- img: Image.Image,
- max_dimensions: Optional[List[int]],
- min_dimensions: Optional[List[int]],
- ) -> Image.Image:
- """Resizes the image to be within the specified minimum and maximum dimensions.
- Args:
- img (PIL.Image.Image): The input image.
- max_dimensions (list of int or None): Maximum dimensions (width, height).
- min_dimensions (list of int or None): Minimum dimensions (width, height).
- Returns:
- PIL.Image.Image: The resized image.
- """
- if max_dimensions is not None:
- ratios = [a / b for a, b in zip(img.size, max_dimensions)]
- if any([r > 1 for r in ratios]):
- size = np.array(img.size) // max(ratios)
- img = img.resize(tuple(size.astype(int)), Image.BILINEAR)
- if min_dimensions is not None:
- # hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions
- padded_size = [
- max(img_dim, min_dim)
- for img_dim, min_dim in zip(img.size, min_dimensions)
- ]
- if padded_size != list(img.size): # assert hypothesis
- padded_im = Image.new("L", padded_size, 255)
- padded_im.paste(img, img.getbbox())
- img = padded_im
- return img
- def resize(self, img: np.ndarray) -> np.ndarray:
- """Resizes the input image according to the specified minimum and maximum dimensions.
- Args:
- img (np.ndarray): The input image as a numpy array.
- Returns:
- np.ndarray: The resized image as a numpy array with three channels.
- """
- h, w = img.shape[:2]
- if (
- self.min_dimensions[0] <= w <= self.max_dimensions[0]
- and self.min_dimensions[1] <= h <= self.max_dimensions[1]
- ):
- return img
- else:
- img = Image.fromarray(np.uint8(img))
- img = self.minmax_size_(
- self.pad_(img), self.max_dimensions, self.min_dimensions
- )
- img = np.array(img)
- img = np.dstack((img, img, img))
- return img
- def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
- """Applies the resize method to a list of images.
- Args:
- imgs (list of np.ndarray): The list of input images as numpy arrays.
- Returns:
- list of np.ndarray: The list of resized images as numpy arrays with three channels.
- """
- return [self.resize(img) for img in imgs]
- class LatexTestTransform:
- """
- A transform class for processing images according to Latex test requirements.
- """
- def __init__(self, **kwargs) -> None:
- """
- Initialize the transform with default number of output channels.
- """
- super().__init__()
- self.num_output_channels = 3
- def transform(self, img: np.ndarray) -> np.ndarray:
- """
- Convert the input image to grayscale, squeeze it, and merge to create an output
- image with the specified number of output channels.
- Parameters:
- img (np.array): The input image.
- Returns:
- np.array: The transformed image.
- """
- grayscale_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- squeezed = np.squeeze(grayscale_image)
- return cv2.merge([squeezed] * self.num_output_channels)
- def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
- """
- Apply the transform to a list of images.
- Parameters:
- imgs (list of np.array): The list of input images.
- Returns:
- list of np.array: The list of transformed images.
- """
- return [self.transform(img) for img in imgs]
- class LatexImageFormat:
- """Class for formatting images to a specific format suitable for LaTeX."""
- def __init__(self, **kwargs) -> None:
- """Initializes the LatexImageFormat class with optional keyword arguments."""
- super().__init__()
- def format(self, img: np.ndarray) -> np.ndarray:
- """Formats a single image to the LaTeX-compatible format.
- Args:
- img (numpy.ndarray): The input image as a numpy array.
- Returns:
- numpy.ndarray: The formatted image as a numpy array with an added dimension for color.
- """
- im_h, im_w = img.shape[:2]
- divide_h = math.ceil(im_h / 16) * 16
- divide_w = math.ceil(im_w / 16) * 16
- img = img[:, :, 0]
- img = np.pad(
- img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
- )
- img_expanded = img[:, :, np.newaxis].transpose(2, 0, 1)
- return img_expanded[np.newaxis, :]
- def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
- """Applies the format method to a list of images.
- Args:
- imgs (list of numpy.ndarray): A list of input images as numpy arrays.
- Returns:
- list of numpy.ndarray: A list of formatted images as numpy arrays.
- """
- return [self.format(img) for img in imgs]
- class NormalizeImage(object):
- """Normalize an image by subtracting the mean and dividing by the standard deviation.
- Args:
- scale (float or str): The scale factor to apply to the image. If a string is provided, it will be evaluated as a Python expression.
- mean (list of float): The mean values to subtract from each channel. Defaults to [0.485, 0.456, 0.406].
- std (list of float): The standard deviation values to divide by for each channel. Defaults to [0.229, 0.224, 0.225].
- order (str): The order of dimensions for the mean and std. 'chw' for channels-height-width, 'hwc' for height-width-channels. Defaults to 'chw'.
- **kwargs: Additional keyword arguments that may be used by subclasses.
- Attributes:
- scale (float): The scale factor applied to the image.
- mean (numpy.ndarray): The mean values reshaped according to the specified order.
- std (numpy.ndarray): The standard deviation values reshaped according to the specified order.
- """
- def __init__(
- self,
- scale: Optional[Union[float, str]] = None,
- mean: Optional[List[float]] = None,
- std: Optional[List[float]] = None,
- order: str = "chw",
- **kwargs,
- ) -> None:
- if isinstance(scale, str):
- scale = eval(scale)
- self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
- mean = mean if mean is not None else [0.485, 0.456, 0.406]
- std = std if std is not None else [0.229, 0.224, 0.225]
- shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
- self.mean = np.array(mean).reshape(shape).astype("float32")
- self.std = np.array(std).reshape(shape).astype("float32")
- def normalize(self, img: Union[np.ndarray, Image.Image]) -> np.ndarray:
- from PIL import Image
- if isinstance(img, Image.Image):
- img = np.array(img)
- assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
- img = (img.astype("float32") * self.scale - self.mean) / self.std
- return img
- def __call__(self, imgs: List[Union[np.ndarray, Image.Image]]) -> List[np.ndarray]:
- """Apply normalization to a list of images."""
- return [self.normalize(img) for img in imgs]
- class ToBatch(object):
- """A class for batching images."""
- def __init__(self, **kwargs) -> None:
- """Initializes the ToBatch object."""
- super(ToBatch, self).__init__()
- def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
- """Concatenates a list of images into a single batch.
- Args:
- imgs (list): A list of image arrays to be concatenated.
- Returns:
- list: A list containing the concatenated batch of images wrapped in another list (to comply with common batch processing formats).
- """
- batch_imgs = np.concatenate(imgs)
- batch_imgs = batch_imgs.copy()
- x = [batch_imgs]
- return x
- class LaTeXOCRDecode(object):
- """Class for decoding LaTeX OCR tokens based on a provided character list."""
- def __init__(self, character_list: List[str], **kwargs) -> None:
- """Initializes the LaTeXOCRDecode object.
- Args:
- character_list (list): The list of characters to use for tokenization.
- **kwargs: Additional keyword arguments for initialization.
- """
- from tokenizers import Tokenizer as TokenizerFast
- super(LaTeXOCRDecode, self).__init__()
- temp_path = tempfile.gettempdir()
- rec_char_dict_path = os.path.join(temp_path, "latexocr_tokenizer.json")
- try:
- with open(rec_char_dict_path, "w") as f:
- json.dump(character_list, f)
- except Exception as e:
- print(f"创建 latexocr_tokenizer.json 文件失败, 原因{str(e)}")
- self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
- def post_process(self, s: str) -> str:
- """Post-processes the decoded LaTeX string.
- Args:
- s (str): The decoded LaTeX string to post-process.
- Returns:
- str: The post-processed LaTeX string.
- """
- text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
- letter = "[a-zA-Z]"
- noletter = "[\W_^\d]"
- names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
- s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
- news = s
- while True:
- s = news
- news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
- news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
- news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
- if news == s:
- break
- return s
- def decode(self, tokens: np.ndarray) -> List[str]:
- """Decodes the provided tokens into LaTeX strings.
- Args:
- tokens (np.array): The tokens to decode.
- Returns:
- list: The decoded LaTeX strings.
- """
- if len(tokens.shape) == 1:
- tokens = tokens[None, :]
- dec = [self.tokenizer.decode(tok) for tok in tokens]
- dec_str_list = [
- "".join(detok.split(" "))
- .replace("Ġ", " ")
- .replace("[EOS]", "")
- .replace("[BOS]", "")
- .replace("[PAD]", "")
- .strip()
- for detok in dec
- ]
- return [self.post_process(dec_str) for dec_str in dec_str_list]
- def __call__(
- self,
- preds: np.ndarray,
- label: Optional[np.ndarray] = None,
- mode: str = "eval",
- *args,
- **kwargs,
- ) -> Tuple[List[str], List[str]]:
- """Calls the object with the provided predictions and label.
- Args:
- preds (np.array): The predictions to decode.
- label (np.array, optional): The labels to decode. Defaults to None.
- mode (str): The mode to run in, either 'train' or 'eval'. Defaults to 'eval'.
- *args: Positional arguments to pass.
- **kwargs: Keyword arguments to pass.
- Returns:
- tuple or list: The decoded text and optionally the decoded label.
- """
- if mode == "train":
- preds_idx = np.array(preds.argmax(axis=2))
- text = self.decode(preds_idx)
- else:
- text = self.decode(np.array(preds))
- if label is None:
- return text
- label = self.decode(np.array(label))
- return text, label
- class UniMERNetImgDecode(object):
- """Class for decoding images for UniMERNet, including cropping margins, resizing, and padding."""
- def __init__(
- self, input_size: Tuple[int, int], random_padding: bool = False, **kwargs
- ) -> None:
- """Initializes the UniMERNetImgDecode class with input size and random padding options.
- Args:
- input_size (tuple): The desired input size for the images (height, width).
- random_padding (bool): Whether to use random padding for resizing.
- **kwargs: Additional keyword arguments."""
- self.input_size = input_size
- self.random_padding = random_padding
- def crop_margin(self, img: Image.Image) -> Image.Image:
- """Crops the margin of the image based on grayscale thresholding.
- Args:
- img (PIL.Image.Image): The input image.
- Returns:
- PIL.Image.Image: The cropped image."""
- data = np.array(img.convert("L"))
- data = data.astype(np.uint8)
- max_val = data.max()
- min_val = data.min()
- if max_val == min_val:
- return img
- data = (data - min_val) / (max_val - min_val) * 255
- gray = 255 * (data < 200).astype(np.uint8)
- coords = cv2.findNonZero(gray) # Find all non-zero points (text)
- a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
- return img.crop((a, b, w + a, h + b))
- def get_dimensions(self, img: Union[Image.Image, np.ndarray]) -> List[int]:
- """Gets the dimensions of the image.
- Args:
- img (PIL.Image.Image or numpy.ndarray): The input image.
- Returns:
- list: A list containing the number of channels, height, and width."""
- if hasattr(img, "getbands"):
- channels = len(img.getbands())
- else:
- channels = img.channels
- width, height = img.size
- return [channels, height, width]
- def _compute_resized_output_size(
- self,
- image_size: Tuple[int, int],
- size: Union[int, Tuple[int, int]],
- max_size: Optional[int] = None,
- ) -> List[int]:
- """Computes the resized output size of the image.
- Args:
- image_size (tuple): The original size of the image (height, width).
- size (int or tuple): The desired size for the smallest edge or both height and width.
- max_size (int, optional): The maximum allowed size for the longer edge.
- Returns:
- list: A list containing the new height and width."""
- if len(size) == 1: # specified size only for the smallest edge
- h, w = image_size
- short, long = (w, h) if w <= h else (h, w)
- requested_new_short = size if isinstance(size, int) else size[0]
- new_short, new_long = requested_new_short, int(
- requested_new_short * long / short
- )
- if max_size is not None:
- if max_size <= requested_new_short:
- raise ValueError(
- f"max_size = {max_size} must be strictly greater than the requested "
- f"size for the smaller edge size = {size}"
- )
- if new_long > max_size:
- new_short, new_long = int(max_size * new_short / new_long), max_size
- new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
- else: # specified both h and w
- new_w, new_h = size[1], size[0]
- return [new_h, new_w]
- def resize(
- self, img: Image.Image, size: Union[int, Tuple[int, int]]
- ) -> Image.Image:
- """Resizes the image to the specified size.
- Args:
- img (PIL.Image.Image): The input image.
- size (int or tuple): The desired size for the smallest edge or both height and width.
- Returns:
- PIL.Image.Image: The resized image."""
- _, image_height, image_width = self.get_dimensions(img)
- if isinstance(size, int):
- size = [size]
- max_size = None
- output_size = self._compute_resized_output_size(
- (image_height, image_width), size, max_size
- )
- img = img.resize(tuple(output_size[::-1]), resample=2)
- return img
- def img_decode(self, img: np.ndarray) -> Optional[np.ndarray]:
- """Decodes the image by cropping margins, resizing, and adding padding.
- Args:
- img (numpy.ndarray): The input image array.
- Returns:
- numpy.ndarray: The decoded image array."""
- try:
- img = self.crop_margin(Image.fromarray(img).convert("RGB"))
- except OSError:
- return
- if img.height == 0 or img.width == 0:
- return
- img = self.resize(img, min(self.input_size))
- img.thumbnail((self.input_size[1], self.input_size[0]))
- delta_width = self.input_size[1] - img.width
- delta_height = self.input_size[0] - img.height
- if self.random_padding:
- pad_width = np.random.randint(low=0, high=delta_width + 1)
- pad_height = np.random.randint(low=0, high=delta_height + 1)
- else:
- pad_width = delta_width // 2
- pad_height = delta_height // 2
- padding = (
- pad_width,
- pad_height,
- delta_width - pad_width,
- delta_height - pad_height,
- )
- return np.array(ImageOps.expand(img, padding))
- def __call__(self, imgs: List[np.ndarray]) -> List[Optional[np.ndarray]]:
- """Calls the img_decode method on a list of images.
- Args:
- imgs (list of numpy.ndarray): The list of input image arrays.
- Returns:
- list of numpy.ndarray: The list of decoded image arrays."""
- return [self.img_decode(img) for img in imgs]
- class UniMERNetDecode(object):
- """Class for decoding tokenized inputs using UniMERNet tokenizer.
- Attributes:
- SPECIAL_TOKENS_ATTRIBUTES (List[str]): List of special token attributes.
- model_input_names (List[str]): List of model input names.
- max_seq_len (int): Maximum sequence length.
- pad_token_id (int): ID for the padding token.
- bos_token_id (int): ID for the beginning-of-sequence token.
- eos_token_id (int): ID for the end-of-sequence token.
- padding_side (str): Padding side, either 'left' or 'right'.
- pad_token (str): Padding token.
- pad_token_type_id (int): Type ID for the padding token.
- pad_to_multiple_of (Optional[int]): If set, pad to a multiple of this value.
- tokenizer (TokenizerFast): Fast tokenizer instance.
- Args:
- character_list (Dict[str, Any]): Dictionary containing tokenizer configuration.
- **kwargs: Additional keyword arguments.
- """
- SPECIAL_TOKENS_ATTRIBUTES = [
- "bos_token",
- "eos_token",
- "unk_token",
- "sep_token",
- "pad_token",
- "cls_token",
- "mask_token",
- "additional_special_tokens",
- ]
- def __init__(
- self,
- character_list: Dict[str, Any],
- **kwargs,
- ) -> None:
- """Initializes the UniMERNetDecode class.
- Args:
- character_list (Dict[str, Any]): Dictionary containing tokenizer configuration.
- **kwargs: Additional keyword arguments.
- """
- self._unk_token = "<unk>"
- self._bos_token = "<s>"
- self._eos_token = "</s>"
- self._pad_token = "<pad>"
- self._sep_token = None
- self._cls_token = None
- self._mask_token = None
- self._additional_special_tokens = []
- self.model_input_names = ["input_ids", "token_type_ids", "attention_mask"]
- self.max_seq_len = 2048
- self.pad_token_id = 1
- self.bos_token_id = 0
- self.eos_token_id = 2
- self.padding_side = "right"
- self.pad_token_id = 1
- self.pad_token = "<pad>"
- self.pad_token_type_id = 0
- self.pad_to_multiple_of = None
- temp_path = tempfile.gettempdir()
- fast_tokenizer_file = os.path.join(temp_path, "tokenizer.json")
- tokenizer_config_file = os.path.join(temp_path, "tokenizer_config.json")
- try:
- with open(fast_tokenizer_file, "w") as f:
- json.dump(character_list["fast_tokenizer_file"], f)
- with open(tokenizer_config_file, "w") as f:
- json.dump(character_list["tokenizer_config_file"], f)
- except Exception as e:
- print(
- f"创建 tokenizer.json 和 tokenizer_config.json 文件失败, 原因{str(e)}"
- )
- self.tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
- added_tokens_decoder = {}
- added_tokens_map = {}
- if tokenizer_config_file is not None:
- with open(
- tokenizer_config_file, encoding="utf-8"
- ) as tokenizer_config_handle:
- init_kwargs = json.load(tokenizer_config_handle)
- if "added_tokens_decoder" in init_kwargs:
- for idx, token in init_kwargs["added_tokens_decoder"].items():
- if isinstance(token, dict):
- token = AddedToken(**token)
- if isinstance(token, AddedToken):
- added_tokens_decoder[int(idx)] = token
- added_tokens_map[str(token)] = token
- else:
- raise ValueError(
- f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance"
- )
- init_kwargs["added_tokens_decoder"] = added_tokens_decoder
- added_tokens_decoder = init_kwargs.pop("added_tokens_decoder", {})
- tokens_to_add = [
- token
- for index, token in sorted(
- added_tokens_decoder.items(), key=lambda x: x[0]
- )
- if token not in added_tokens_decoder
- ]
- added_tokens_encoder = self.added_tokens_encoder(added_tokens_decoder)
- encoder = list(added_tokens_encoder.keys()) + [
- str(token) for token in tokens_to_add
- ]
- tokens_to_add += [
- token
- for token in self.all_special_tokens_extended
- if token not in encoder and token not in tokens_to_add
- ]
- if len(tokens_to_add) > 0:
- is_last_special = None
- tokens = []
- special_tokens = self.all_special_tokens
- for token in tokens_to_add:
- is_special = (
- (token.special or str(token) in special_tokens)
- if isinstance(token, AddedToken)
- else str(token) in special_tokens
- )
- if is_last_special is None or is_last_special == is_special:
- tokens.append(token)
- else:
- self._add_tokens(tokens, special_tokens=is_last_special)
- tokens = [token]
- is_last_special = is_special
- if tokens:
- self._add_tokens(tokens, special_tokens=is_last_special)
- def _add_tokens(
- self, new_tokens: List[Union[AddedToken, str]], special_tokens: bool = False
- ) -> List[Union[AddedToken, str]]:
- """Adds new tokens to the tokenizer.
- Args:
- new_tokens (List[Union[AddedToken, str]]): Tokens to be added.
- special_tokens (bool): Indicates whether the tokens are special tokens.
- Returns:
- List[Union[AddedToken, str]]: added tokens.
- """
- if special_tokens:
- return self.tokenizer.add_special_tokens(new_tokens)
- return self.tokenizer.add_tokens(new_tokens)
- def added_tokens_encoder(
- self, added_tokens_decoder: Dict[int, AddedToken]
- ) -> Dict[str, int]:
- """Creates an encoder dictionary from added tokens.
- Args:
- added_tokens_decoder (Dict[int, AddedToken]): Dictionary mapping token IDs to tokens.
- Returns:
- Dict[str, int]: Dictionary mapping token strings to IDs.
- """
- return {
- k.content: v
- for v, k in sorted(added_tokens_decoder.items(), key=lambda item: item[0])
- }
- @property
- def all_special_tokens(self) -> List[str]:
- """Retrieves all special tokens.
- Returns:
- List[str]: List of all special tokens as strings.
- """
- all_toks = [str(s) for s in self.all_special_tokens_extended]
- return all_toks
- @property
- def all_special_tokens_extended(self) -> List[Union[str, AddedToken]]:
- """Retrieves all special tokens, including extended ones.
- Returns:
- List[Union[str, AddedToken]]: List of all special tokens.
- """
- all_tokens = []
- seen = set()
- for value in self.special_tokens_map_extended.values():
- if isinstance(value, (list, tuple)):
- tokens_to_add = [token for token in value if str(token) not in seen]
- else:
- tokens_to_add = [value] if str(value) not in seen else []
- seen.update(map(str, tokens_to_add))
- all_tokens.extend(tokens_to_add)
- return all_tokens
- @property
- def special_tokens_map_extended(self) -> Dict[str, Union[str, List[str]]]:
- """Retrieves the extended map of special tokens.
- Returns:
- Dict[str, Union[str, List[str]]]: Dictionary mapping special token attributes to their values.
- """
- set_attr = {}
- for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
- attr_value = getattr(self, "_" + attr)
- if attr_value:
- set_attr[attr] = attr_value
- return set_attr
- def convert_ids_to_tokens(
- self, ids: Union[int, List[int]], skip_special_tokens: bool = False
- ) -> Union[str, List[str]]:
- """Converts token IDs to token strings.
- Args:
- ids (Union[int, List[int]]): Token ID(s) to convert.
- skip_special_tokens (bool): Whether to skip special tokens during conversion.
- Returns:
- Union[str, List[str]]: Converted token string(s).
- """
- if isinstance(ids, int):
- return self.tokenizer.id_to_token(ids)
- tokens = []
- for index in ids:
- index = int(index)
- if skip_special_tokens and index in self.all_special_ids:
- continue
- tokens.append(self.tokenizer.id_to_token(index))
- return tokens
- def detokenize(self, tokens: List[List[int]]) -> List[List[str]]:
- """Detokenizes a list of token IDs back into strings.
- Args:
- tokens (List[List[int]]): List of token ID lists.
- Returns:
- List[List[str]]: List of detokenized strings.
- """
- self.tokenizer.bos_token = "<s>"
- self.tokenizer.eos_token = "</s>"
- self.tokenizer.pad_token = "<pad>"
- toks = [self.convert_ids_to_tokens(tok) for tok in tokens]
- for b in range(len(toks)):
- for i in reversed(range(len(toks[b]))):
- if toks[b][i] is None:
- toks[b][i] = ""
- toks[b][i] = toks[b][i].replace("Ġ", " ").strip()
- if toks[b][i] in (
- [
- self.tokenizer.bos_token,
- self.tokenizer.eos_token,
- self.tokenizer.pad_token,
- ]
- ):
- del toks[b][i]
- return toks
- def token2str(self, token_ids: List[List[int]]) -> List[str]:
- """Converts a list of token IDs to strings.
- Args:
- token_ids (List[List[int]]): List of token ID lists.
- Returns:
- List[str]: List of converted strings.
- """
- generated_text = []
- for tok_id in token_ids:
- end_idx = np.argwhere(tok_id == 2)
- if len(end_idx) > 0:
- end_idx = int(end_idx[0][0])
- tok_id = tok_id[: end_idx + 1]
- generated_text.append(
- self.tokenizer.decode(tok_id, skip_special_tokens=True)
- )
- generated_text = [self.post_process(text) for text in generated_text]
- return generated_text
- def normalize(self, s: str) -> str:
- """Normalizes a string by removing unnecessary spaces.
- Args:
- s (str): String to normalize.
- Returns:
- str: Normalized string.
- """
- text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
- letter = "[a-zA-Z]"
- noletter = "[\W_^\d]"
- names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
- s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
- news = s
- while True:
- s = news
- news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
- news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
- news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
- if news == s:
- break
- return s
- def post_process(self, text: str) -> str:
- """Post-processes a string by fixing text and normalizing it.
- Args:
- text (str): String to post-process.
- Returns:
- str: Post-processed string.
- """
- from ftfy import fix_text
- text = fix_text(text)
- text = self.normalize(text)
- return text
- def __call__(
- self,
- preds: np.ndarray,
- label: Optional[np.ndarray] = None,
- mode: str = "eval",
- *args,
- **kwargs,
- ) -> Union[List[str], tuple]:
- """Processes predictions and optionally labels, returning the decoded text.
- Args:
- preds (np.ndarray): Model predictions.
- label (Optional[np.ndarray]): True labels, if available.
- mode (str): Mode of operation, either 'train' or 'eval'.
- Returns:
- Union[List[str], tuple]: Decoded text, optionally with labels.
- """
- if mode == "train":
- preds_idx = np.array(preds.argmax(axis=2))
- text = self.token2str(preds_idx)
- else:
- text = self.token2str(np.array(preds))
- if label is None:
- return text
- label = self.token2str(np.array(label))
- return text, label
- class UniMERNetTestTransform:
- """
- A class for transforming images according to UniMERNet test specifications.
- """
- def __init__(self, **kwargs) -> None:
- """
- Initializes the UniMERNetTestTransform class.
- """
- super().__init__()
- self.num_output_channels = 3
- def transform(self, img: np.ndarray) -> np.ndarray:
- """
- Transforms a single image for UniMERNet testing.
- Args:
- img (numpy.ndarray): The input image.
- Returns:
- numpy.ndarray: The transformed image.
- """
- mean = [0.7931, 0.7931, 0.7931]
- std = [0.1738, 0.1738, 0.1738]
- scale = float(1 / 255.0)
- shape = (1, 1, 3)
- mean = np.array(mean).reshape(shape).astype("float32")
- std = np.array(std).reshape(shape).astype("float32")
- img = (img.astype("float32") * scale - mean) / std
- grayscale_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- squeezed = np.squeeze(grayscale_image)
- img = cv2.merge([squeezed] * self.num_output_channels)
- return img
- def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
- """
- Applies the transform to a list of images.
- Args:
- imgs (list of numpy.ndarray): The list of input images.
- Returns:
- list of numpy.ndarray: The list of transformed images.
- """
- return [self.transform(img) for img in imgs]
- class UniMERNetImageFormat:
- """Class for formatting images to UniMERNet's required format."""
- def __init__(self, **kwargs) -> None:
- """Initializes the UniMERNetImageFormat instance."""
- # your init code
- pass
- def format(self, img: np.ndarray) -> np.ndarray:
- """Formats a single image to UniMERNet's required format.
- Args:
- img (numpy.ndarray): The input image to be formatted.
- Returns:
- numpy.ndarray: The formatted image.
- """
- im_h, im_w = img.shape[:2]
- divide_h = math.ceil(im_h / 32) * 32
- divide_w = math.ceil(im_w / 32) * 32
- img = img[:, :, 0]
- img = np.pad(
- img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
- )
- img_expanded = img[:, :, np.newaxis].transpose(2, 0, 1)
- return img_expanded[np.newaxis, :]
- def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
- """Applies the format method to a list of images.
- Args:
- imgs (list of numpy.ndarray): The list of input images to be formatted.
- Returns:
- list of numpy.ndarray: The list of formatted images.
- """
- return [self.format(img) for img in imgs]
|