# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, List, Optional, Union import numpy as np from .....utils import logging from ....utils.benchmark import benchmark from ...common.vision.funcs import resize from .common import ( BatchFeature, ChannelDimension, ImageInput, PILImageResampling, TensorType, TextInput, convert_to_rgb, fetch_image, get_image_size, infer_channel_dimension_format, make_batched_images, make_list_of_images, smart_resize, to_channel_dimension_format, to_numpy_array, valid_images, ) OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711] IMAGE_FACTOR = 28 MIN_PIXELS = 4 * 28 * 28 MAX_PIXELS = 16384 * 28 * 28 MAX_RATIO = 200 def is_scaled_image(image: np.ndarray) -> bool: """ Checks to see whether the pixel values have already been rescaled to [0, 1]. """ if image.dtype == np.uint8: return False # It's possible the image has pixel values in [0, 255] but is of floating type return np.min(image) >= 0 and np.max(image) <= 1 class Qwen2VLProcessor(object): r""" Constructs a Qwen2-VL processor which wraps a Qwen2-VL image processor and a Qwen2 tokenizer into a single processor. [`Qwen2VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the [`~Qwen2VLProcessor.__call__`] and [`~Qwen2VLProcessor.decode`] for more information. Args: image_processor ([`Qwen2VLImageProcessor`], *optional*): The image processor is a required input. tokenizer ([`MIXQwen2Tokenizer`], *optional*): The tokenizer is a required input. """ def __init__(self, image_processor, tokenizer, **kwargs): self.image_processor = image_processor self.tokenizer = tokenizer self.image_processor.min_pixels = kwargs.get("min_pixels", 3136) self.image_processor.max_pixels = kwargs.get("max_pixels", 12845056) def preprocess( self, images: ImageInput = None, text: Union[TextInput, List[TextInput]] = None, padding: bool = False, truncation: Union[bool, str] = None, max_length: int = None, return_tensors: Optional[Union[str, TensorType]] = TensorType.PADDLE, ): """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. Args: images (`PIL.Image.Image`, `np.ndarray`, `paddle.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[paddle.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or Paddle tensor. Both channels-first and channels-last formats are supported. text (`str`, `List[str]`, `List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). padding (`bool`, *optional*, defaults to `False`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence if provided). - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum acceptable input length for the model if that argument is not provided. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different lengths). max_length (`int`, *optional*): Maximum length of the returned list and optionally padding length (see above). truncation (`bool`, *optional*): Activates truncation to cut input sequences longer than `max_length` to `max_length`. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - `'pd'`: Return Paddle `paddle.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. Returns: - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. """ if images is not None: image_inputs = self.image_processor( images=images, return_tensors=return_tensors ) image_grid_thw = image_inputs["image_grid_thw"] else: image_inputs = {} image_grid_thw = None if not isinstance(text, list): text = [text] if image_grid_thw is not None: merge_length = self.image_processor.merge_size**2 index = 0 for i in range(len(text)): while "<|image_pad|>" in text[i]: text[i] = text[i].replace( "<|image_pad|>", "<|placeholder|>" * int(image_grid_thw[index].prod() // merge_length), 1, # 单个<|image_pad|>替换成对应的视觉token数量的<|placeholder|> ) index += 1 text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>") text_inputs = self.tokenizer( text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length, ) return BatchFeature(data={**text_inputs, **image_inputs}).data def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): """ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) class Qwen2VLImageProcessor(object): r""" Constructs a Qwen2-VL image processor that dynamically resizes images based on the original images. Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to resize the image's (height, width) dimensions. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): Resampling filter to use when resizing the image. do_rescale (`bool`, *optional*, defaults to `True`): Whether to rescale the image by the specified scale `rescale_factor`. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): Scale factor to use if rescaling the image. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image. image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. do_convert_rgb (`bool`, *optional*, defaults to `True`): Whether to convert the image to RGB. min_pixels (`int`, *optional*, defaults to `56 * 56`): The min pixels of the image to resize the image. max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`): The max pixels of the image to resize the image. patch_size (`int`, *optional*, defaults to 14): The spatial patch size of the vision encoder. temporal_patch_size (`int`, *optional*, defaults to 2): The temporal patch size of the vision encoder. merge_size (`int`, *optional*, defaults to 2): The merge size of the vision encoder to llm encoder. """ def __init__( self, do_resize: bool = True, resample=None, do_rescale: bool = True, rescale_factor: float = 1 / 255.0, do_normalize: bool = True, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_convert_rgb: bool = True, min_pixels: int = 56 * 56, max_pixels: int = 28 * 28 * 1280, patch_size: int = 14, temporal_patch_size: int = 2, merge_size: int = 2, **kwargs, ) -> None: super().__init__(**kwargs) import cv2 resample = cv2.INTER_CUBIC if resample is None else resample self.do_resize = do_resize self.resample = resample self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_normalize = do_normalize image_mean_ = image_mean if image_mean is not None else OPENAI_CLIP_MEAN image_std_ = image_std if image_std is not None else OPENAI_CLIP_STD self.min_pixels = min_pixels self.max_pixels = max_pixels self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.merge_size = merge_size self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} self.do_convert_rgb = do_convert_rgb self.image_mean = np.array(image_mean_)[None, None, ...] self.image_std = np.array(image_std_)[None, None, ...] def _preprocess( self, images, do_resize: bool = None, resample: PILImageResampling = None, do_rescale: bool = None, rescale_factor: float = None, do_normalize: bool = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_convert_rgb: bool = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, ): """ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`. Args: images (`ImageInput`): Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`. vision_info (`List[Dict]`, *optional*): Optional list of dictionaries containing additional information about vision inputs. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. resample (`PILImageResampling`, *optional*, defaults to `self.resample`): Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): Whether to rescale the image. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): Scale factor to use if rescaling the image. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): Whether to normalize the image. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): Whether to convert the image to RGB. data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - Unset: Use the channel dimension format of the input image. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the input image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ images = make_list_of_images(images) if do_convert_rgb: images = [convert_to_rgb(image) for image in images] # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] if is_scaled_image(images[0]) and do_rescale: logging.warning( "It looks like you are trying to rescale already rescaled images. If the input" " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." ) if input_data_format is None: # We assume that all images have the same channel dimension format. input_data_format = infer_channel_dimension_format(images[0]) height, width = get_image_size(images[0], channel_dim=input_data_format) resized_height, resized_width = height, width processed_images = [] for image in images: if do_resize: resized_height, resized_width = smart_resize( height, width, factor=self.patch_size * self.merge_size, min_pixels=self.min_pixels, max_pixels=self.max_pixels, max_ratio=MAX_RATIO, ) image = image.astype("uint8") image = resize( image, (resized_width, resized_height), interp=None, backend="cv2", ) if do_rescale: image = image.astype("float32") image *= rescale_factor if do_normalize: assert input_data_format == ChannelDimension.LAST image = (image - self.image_mean) / self.image_std image = to_channel_dimension_format( image, data_format, input_channel_dim=input_data_format ) processed_images.append(image) patches = np.array(processed_images) if data_format == ChannelDimension.LAST: patches = patches.transpose([0, 3, 1, 2]) if patches.shape[0] == 1: patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1)) channel = patches.shape[1] grid_t = patches.shape[0] // self.temporal_patch_size grid_h, grid_w = ( resized_height // self.patch_size, resized_width // self.patch_size, ) patches = patches.reshape( [ grid_t, self.temporal_patch_size, channel, grid_h // self.merge_size, self.merge_size, self.patch_size, grid_w // self.merge_size, self.merge_size, self.patch_size, ] ) patches = patches.transpose([0, 3, 6, 4, 7, 2, 1, 5, 8]) flatten_patches = patches.reshape( [ grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size, ] ) return flatten_patches, (grid_t, grid_h, grid_w) def preprocess( self, images: ImageInput, do_resize: bool = None, size: Dict[str, int] = None, resample: PILImageResampling = None, do_rescale: bool = None, rescale_factor: float = None, do_normalize: bool = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_convert_rgb: bool = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, ): """ Args: images (`ImageInput`): Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with the longest edge resized to keep the input aspect ratio. resample (`int`, *optional*, defaults to `self.resample`): Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only has an effect if `do_resize` is set to `True`. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): Whether to rescale the image. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): Rescale factor to rescale the image by if `do_rescale` is set to `True`. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): Whether to normalize the image. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): Whether to convert the image to RGB. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - `TensorType.PADDLE` or `'pt'`: Return a batch of type `paddle.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - Unset: Use the channel dimension format of the input image. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size resample = resample if resample is not None else self.resample do_rescale = do_rescale if do_rescale is not None else self.do_rescale rescale_factor = ( rescale_factor if rescale_factor is not None else self.rescale_factor ) do_normalize = do_normalize if do_normalize is not None else self.do_normalize image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std do_convert_rgb = ( do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb ) if images is not None: images = make_batched_images(images) if images is not None and not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "paddle.Tensor." ) if images is not None: pixel_values, vision_grid_thws = [], [] for image in images: patches, image_grid_thw = self._preprocess( image, do_resize=do_resize, resample=resample, do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, data_format=data_format, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, ) pixel_values.extend(patches) vision_grid_thws.append(image_grid_thw) pixel_values = np.array(pixel_values) vision_grid_thws = np.array(vision_grid_thws) data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws} return BatchFeature(data=data, tensor_type=return_tensors) def __call__(self, images, **kwargs): return self.preprocess(images, **kwargs) class PPDocBeeProcessor(Qwen2VLProcessor): """ PP-DocBee processor, based on Qwen2VLProcessor """ @benchmark.timeit def preprocess(self, input_dicts): """ PreProcess for PP-DocBee Series """ assert ( isinstance(input_dicts, list) and len(input_dicts) == 1 ), f"PP-DocBee series only supports batchsize of one, but received {len(input_dicts)} samples." input_dict = input_dicts[0] image = input_dict["image"] query = input_dict["query"] image_inputs = fetch_image( image, size_factor=IMAGE_FACTOR, min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS, max_ratio=MAX_RATIO, ) image_pad_token = "<|vision_start|><|image_pad|><|vision_end|>" text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{image_pad_token}{query}<|im_end|>\n<|im_start|>assistant\n" text = [text] rst_inputs = super().preprocess( text=text, images=[image_inputs], padding=False, return_tensors="pd", ) return rst_inputs @benchmark.timeit def postprocess(self, model_pred, **kwargs): """ Post process adapt for PaddleX """ return self.tokenizer.batch_decode( model_pred[0], skip_special_tokens=kwargs.get("skip_special_tokens", True), clean_up_tokenization_spaces=False, )