| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Dict, List, Optional, Union
- import numpy as np
- from .....utils import logging
- from ....utils.benchmark import benchmark
- from ...common.tokenizer.tokenizer_utils_base import (
- PreTokenizedInput,
- TensorType,
- TextInput,
- TruncationStrategy,
- )
- from ...common.vision.funcs import resize
- from .common import (
- BatchFeature,
- ChannelDimension,
- ImageInput,
- PaddingStrategy,
- PILImageResampling,
- convert_to_rgb,
- fetch_image,
- get_image_size,
- infer_channel_dimension_format,
- make_batched_images,
- make_list_of_images,
- smart_resize,
- to_channel_dimension_format,
- to_numpy_array,
- valid_images,
- )
- OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
- OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
- IMAGE_FACTOR = 28
- MIN_PIXELS = 4 * 28 * 28
- MAX_PIXELS = 16384 * 28 * 28
- MAX_RATIO = 200
- __all__ = [
- "Qwen2_5_VLProcessor",
- "Qwen2_5_VLImageProcessor",
- "PPDocBee2Processor",
- ]
- def is_scaled_image(image: np.ndarray) -> bool:
- """
- Checks to see whether the pixel values have already been rescaled to [0, 1].
- """
- if image.dtype == np.uint8:
- return False
- return np.min(image) >= 0 and np.max(image) <= 1
- class Qwen2_5_VLProcessor(object):
- """
- Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
- [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2_5_VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
- [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
- Args:
- image_processor ([`Qwen2_5_VLImageProcessor`], *optional*):
- The image processor is a required input.
- tokenizer ([`Qwen2TokenizerFast`], *optional*):
- The tokenizer is a required input.
- chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
- in a chat into a tokenizable string.
- """
- def __init__(self, image_processor, tokenizer, **kwargs):
- self.image_processor = image_processor
- self.tokenizer = tokenizer
- self.image_processor.min_pixels = kwargs.get("min_pixels", 3136)
- self.image_processor.max_pixels = kwargs.get("max_pixels", 12845056)
- def preprocess(
- self,
- images: ImageInput = None,
- text: Union[
- TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
- ] = None,
- padding: Union[bool, str, PaddingStrategy] = False,
- truncation: Union[bool, str, TruncationStrategy] = None,
- max_length: int = None,
- return_tensors: Optional[Union[str, TensorType]] = TensorType.PADDLE,
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
- and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
- the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
- Qwen2_5_VLImageProcessor's [`~Qwen2_5_VLImageProcessor.__call__`] if `vision_infos` is not `None`.
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. Both channels-first and channels-last formats are supported.
- text (`str`, `List[str]`, `List[List[str]]`):
- The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
- (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
- `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors of a particular framework. Acceptable values are:
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
- - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
- `None`).
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
- """
- if images is not None:
- image_inputs = self.image_processor(
- images=images, return_tensors=return_tensors
- )
- image_grid_thw = image_inputs["image_grid_thw"]
- else:
- image_inputs = {}
- image_grid_thw = None
- if not isinstance(text, list):
- text = [text]
- if image_grid_thw is not None:
- merge_length = self.image_processor.merge_size**2
- index = 0
- for i in range(len(text)):
- while "<|image_pad|>" in text[i]:
- text[i] = text[i].replace(
- "<|image_pad|>",
- "<|placeholder|>"
- * int(image_grid_thw[index].prod() // merge_length),
- 1,
- )
- index += 1
- text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
- text_inputs = self.tokenizer(
- text,
- return_tensors=return_tensors,
- padding=padding,
- truncation=truncation,
- max_length=max_length,
- )
- return BatchFeature(data={**text_inputs, **image_inputs})
- def batch_decode(self, *args, **kwargs):
- """
- This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
- refer to the docstring of this method for more information.
- """
- return self.tokenizer.batch_decode(*args, **kwargs)
- def decode(self, *args, **kwargs):
- """
- This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
- the docstring of this method for more information.
- """
- return self.tokenizer.decode(*args, **kwargs)
- class Qwen2_5_VLImageProcessor(object):
- """
- Constructs a Qwen2.5-VL image processor that dynamically resizes images based on the original images.
- Args:
- do_resize (`bool`, *optional*, defaults to `True`):
- Whether to resize the image's (height, width) dimensions.
- resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
- Resampling filter to use when resizing the image.
- do_rescale (`bool`, *optional*, defaults to `True`):
- Whether to rescale the image by the specified scale `rescale_factor`.
- rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
- Scale factor to use if rescaling the image.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether to normalize the image.
- image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
- Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
- image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
- Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
- do_convert_rgb (`bool`, *optional*, defaults to `True`):
- Whether to convert the image to RGB.
- min_pixels (`int`, *optional*, defaults to `56 * 56`):
- The min pixels of the image to resize the image.
- max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
- The max pixels of the image to resize the image.
- patch_size (`int`, *optional*, defaults to 14):
- The spatial patch size of the vision encoder.
- temporal_patch_size (`int`, *optional*, defaults to 2):
- The temporal patch size of the vision encoder.
- merge_size (`int`, *optional*, defaults to 2):
- The merge size of the vision encoder to llm encoder.
- """
- model_input_names = ["pixel_values", "image_grid_thw", "second_per_grid_ts"]
- def __init__(
- self,
- do_resize: bool = True,
- resample: PILImageResampling = PILImageResampling.BICUBIC,
- do_rescale: bool = True,
- rescale_factor: Union[int, float] = 1 / 255,
- do_normalize: bool = True,
- image_mean: Optional[Union[float, List[float]]] = None,
- image_std: Optional[Union[float, List[float]]] = None,
- do_convert_rgb: bool = True,
- min_pixels: int = 56 * 56,
- max_pixels: int = 28 * 28 * 1280,
- patch_size: int = 14,
- temporal_patch_size: int = 2,
- merge_size: int = 2,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.do_resize = do_resize
- self.resample = resample
- self.do_rescale = do_rescale
- self.rescale_factor = rescale_factor
- self.do_normalize = do_normalize
- image_mean_ = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
- image_std_ = image_std if image_std is not None else OPENAI_CLIP_STD
- self.min_pixels = min_pixels
- self.max_pixels = max_pixels
- self.patch_size = patch_size
- self.temporal_patch_size = temporal_patch_size
- self.merge_size = merge_size
- self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
- self.do_convert_rgb = do_convert_rgb
- self.image_mean = np.array(image_mean_)[None, None, ...]
- self.image_std = np.array(image_std_)[None, None, ...]
- def _preprocess(
- self,
- images: Union[ImageInput],
- do_resize: bool = None,
- resample: PILImageResampling = None,
- do_rescale: bool = None,
- rescale_factor: float = None,
- do_normalize: bool = None,
- image_mean: Optional[Union[float, List[float]]] = None,
- image_std: Optional[Union[float, List[float]]] = None,
- do_convert_rgb: bool = None,
- data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
- input_data_format: Optional[Union[str, ChannelDimension]] = None,
- ):
- """
- Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
- Args:
- images (`ImageInput`):
- Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
- vision_info (`List[Dict]`, *optional*):
- Optional list of dictionaries containing additional information about vision inputs.
- do_resize (`bool`, *optional*, defaults to `self.do_resize`):
- Whether to resize the image.
- resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
- Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
- do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
- Whether to rescale the image.
- rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
- Scale factor to use if rescaling the image.
- do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
- Whether to normalize the image.
- image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
- Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
- image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
- Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
- do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
- Whether to convert the image to RGB.
- data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
- The channel dimension format for the output image. Can be one of:
- - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- - Unset: Use the channel dimension format of the input image.
- input_data_format (`ChannelDimension` or `str`, *optional*):
- The channel dimension format for the input image. Can be one of:
- - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
- """
- images = make_list_of_images(images)
- if do_convert_rgb:
- images = [convert_to_rgb(image) for image in images]
- # All transformations expect numpy arrays.
- images = [to_numpy_array(image) for image in images]
- if is_scaled_image(images[0]) and do_rescale:
- logging.warning_once(
- "It looks like you are trying to rescale already rescaled images. If the input"
- " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
- )
- if input_data_format is None:
- # We assume that all images have the same channel dimension format.
- input_data_format = infer_channel_dimension_format(images[0])
- height, width = get_image_size(images[0], channel_dim=input_data_format)
- resized_height, resized_width = height, width
- processed_images = []
- for image in images:
- if do_resize:
- resized_height, resized_width = smart_resize(
- height,
- width,
- factor=self.patch_size * self.merge_size,
- min_pixels=self.min_pixels,
- max_pixels=self.max_pixels,
- max_ratio=MAX_RATIO,
- )
- image = image.astype("uint8")
- image = resize(
- image,
- (resized_width, resized_height),
- interp=None,
- backend="cv2",
- )
- if do_rescale:
- image = image.astype("float32")
- image *= rescale_factor
- if do_normalize:
- assert input_data_format == ChannelDimension.LAST
- image = (image - self.image_mean) / self.image_std
- image = to_channel_dimension_format(
- image, data_format, input_channel_dim=input_data_format
- )
- processed_images.append(image)
- patches = np.array(processed_images)
- if data_format == ChannelDimension.LAST:
- patches = patches.transpose([0, 3, 1, 2])
- if patches.shape[0] == 1:
- patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
- channel = patches.shape[1]
- grid_t = patches.shape[0] // self.temporal_patch_size
- grid_h, grid_w = (
- resized_height // self.patch_size,
- resized_width // self.patch_size,
- )
- patches = patches.reshape(
- [
- grid_t,
- self.temporal_patch_size,
- channel,
- grid_h // self.merge_size,
- self.merge_size,
- self.patch_size,
- grid_w // self.merge_size,
- self.merge_size,
- self.patch_size,
- ]
- )
- patches = patches.transpose([0, 3, 6, 4, 7, 2, 1, 5, 8])
- flatten_patches = patches.reshape(
- [
- grid_t * grid_h * grid_w,
- channel * self.temporal_patch_size * self.patch_size * self.patch_size,
- ]
- )
- return flatten_patches, (grid_t, grid_h, grid_w)
- def __call__(
- self,
- images: ImageInput,
- do_resize: bool = None,
- size: Dict[str, int] = None,
- resample: PILImageResampling = None,
- do_rescale: bool = None,
- rescale_factor: float = None,
- do_normalize: bool = None,
- image_mean: Optional[Union[float, List[float]]] = None,
- image_std: Optional[Union[float, List[float]]] = None,
- do_convert_rgb: bool = None,
- return_tensors: Optional[Union[str, TensorType]] = None,
- data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
- input_data_format: Optional[Union[str, ChannelDimension]] = None,
- ):
- """
- Args:
- images (`ImageInput`):
- Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
- passing in images with pixel values between 0 and 1, set `do_rescale=False`.
- do_resize (`bool`, *optional*, defaults to `self.do_resize`):
- Whether to resize the image.
- size (`Dict[str, int]`, *optional*, defaults to `self.size`):
- Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
- the longest edge resized to keep the input aspect ratio.
- resample (`int`, *optional*, defaults to `self.resample`):
- Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
- has an effect if `do_resize` is set to `True`.
- do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
- Whether to rescale the image.
- rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
- Rescale factor to rescale the image by if `do_rescale` is set to `True`.
- do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
- Whether to normalize the image.
- image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
- Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
- image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
- Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
- `True`.
- do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
- Whether to convert the image to RGB.
- return_tensors (`str` or `TensorType`, *optional*):
- The type of tensors to return. Can be one of:
- - Unset: Return a list of `np.ndarray`.
- - `TensorType.PADDLE` or `'pt'`: Return a batch of type `torch.Tensor`.
- - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
- The channel dimension format for the output image. Can be one of:
- - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- - Unset: Use the channel dimension format of the input image.
- input_data_format (`ChannelDimension` or `str`, *optional*):
- The channel dimension format for the input image. If unset, the channel dimension format is inferred
- from the input image. Can be one of:
- - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
- """
- do_resize = do_resize if do_resize is not None else self.do_resize
- size = size if size is not None else self.size
- resample = resample if resample is not None else self.resample
- do_rescale = do_rescale if do_rescale is not None else self.do_rescale
- rescale_factor = (
- rescale_factor if rescale_factor is not None else self.rescale_factor
- )
- do_normalize = do_normalize if do_normalize is not None else self.do_normalize
- image_mean = image_mean if image_mean is not None else self.image_mean
- image_std = image_std if image_std is not None else self.image_std
- do_convert_rgb = (
- do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
- )
- if images is not None:
- images = make_batched_images(images)
- if images is not None and not valid_images(images):
- raise ValueError(
- "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
- "paddle.Tensor."
- )
- if images is not None:
- pixel_values, vision_grid_thws = [], []
- for image in images:
- patches, image_grid_thw = self._preprocess(
- image,
- do_resize=do_resize,
- resample=resample,
- do_rescale=do_rescale,
- rescale_factor=rescale_factor,
- do_normalize=do_normalize,
- image_mean=image_mean,
- image_std=image_std,
- data_format=data_format,
- do_convert_rgb=do_convert_rgb,
- input_data_format=input_data_format,
- )
- pixel_values.extend(patches)
- vision_grid_thws.append(image_grid_thw)
- pixel_values = np.array(pixel_values)
- vision_grid_thws = np.array(vision_grid_thws)
- data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
- return BatchFeature(data=data, tensor_type=return_tensors)
- class PPDocBee2Processor(Qwen2_5_VLProcessor):
- """
- PP-DocBee processor, based on Qwen2VLProcessor
- """
- @benchmark.timeit
- def preprocess(self, input_dicts: List[Dict]):
- """
- PreProcess for PP-DocBee2 Series
- """
- assert (isinstance(input_dict, dict) for input_dict in input_dicts)
- prompt = (
- "<|im_start|>system\n"
- "You are a helpful assistant.<|im_end|>\n"
- "<|im_start|>user\n"
- "<|vision_start|><|image_pad|><|vision_end|>{query}<|im_end|>\n"
- "<|im_start|>assistant\n"
- )
- query_inputs = [
- prompt.format(query=input_dict["query"]) for input_dict in input_dicts
- ]
- image_inputs = [
- fetch_image(
- input_dict["image"],
- size_factor=IMAGE_FACTOR,
- min_pixels=MIN_PIXELS,
- max_pixels=MAX_PIXELS,
- max_ratio=MAX_RATIO,
- )
- for input_dict in input_dicts
- ]
- rst_inputs = super().preprocess(
- text=query_inputs,
- images=image_inputs,
- padding=True,
- return_tensors="pd",
- )
- return rst_inputs
- @benchmark.timeit
- def postprocess(self, model_pred, **kwargs) -> List[str]:
- """
- Post process adapt for PaddleX
- """
- return self.tokenizer.batch_decode(
- model_pred[0],
- skip_special_tokens=kwargs.get("skip_special_tokens", True),
- clean_up_tokenization_spaces=False,
- )
|