qwen2_vl.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, List, Optional, Union
  15. import numpy as np
  16. from .....utils import logging
  17. from ....utils.benchmark import benchmark
  18. from ...common.vision.funcs import resize
  19. from .common import (
  20. BatchFeature,
  21. ChannelDimension,
  22. ImageInput,
  23. PILImageResampling,
  24. TensorType,
  25. TextInput,
  26. convert_to_rgb,
  27. fetch_image,
  28. get_image_size,
  29. infer_channel_dimension_format,
  30. make_batched_images,
  31. make_list_of_images,
  32. smart_resize,
  33. to_channel_dimension_format,
  34. to_numpy_array,
  35. valid_images,
  36. )
  37. OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
  38. OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
  39. IMAGE_FACTOR = 28
  40. MIN_PIXELS = 4 * 28 * 28
  41. MAX_PIXELS = 16384 * 28 * 28
  42. MAX_RATIO = 200
  43. def is_scaled_image(image: np.ndarray) -> bool:
  44. """
  45. Checks to see whether the pixel values have already been rescaled to [0, 1].
  46. """
  47. if image.dtype == np.uint8:
  48. return False
  49. # It's possible the image has pixel values in [0, 255] but is of floating type
  50. return np.min(image) >= 0 and np.max(image) <= 1
  51. class Qwen2VLProcessor(object):
  52. r"""
  53. Constructs a Qwen2-VL processor which wraps a Qwen2-VL image processor and a Qwen2 tokenizer into a single processor.
  54. [`Qwen2VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
  55. [`~Qwen2VLProcessor.__call__`] and [`~Qwen2VLProcessor.decode`] for more information.
  56. Args:
  57. image_processor ([`Qwen2VLImageProcessor`], *optional*):
  58. The image processor is a required input.
  59. tokenizer ([`MIXQwen2Tokenizer`], *optional*):
  60. The tokenizer is a required input.
  61. """
  62. def __init__(self, image_processor, tokenizer, **kwargs):
  63. self.image_processor = image_processor
  64. self.tokenizer = tokenizer
  65. self.image_processor.min_pixels = kwargs.get("min_pixels", 3136)
  66. self.image_processor.max_pixels = kwargs.get("max_pixels", 12845056)
  67. def preprocess(
  68. self,
  69. images: ImageInput = None,
  70. text: Union[TextInput, List[TextInput]] = None,
  71. padding: bool = False,
  72. truncation: Union[bool, str] = None,
  73. max_length: int = None,
  74. return_tensors: Optional[Union[str, TensorType]] = TensorType.PADDLE,
  75. ):
  76. """
  77. Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
  78. and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
  79. the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
  80. Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
  81. Args:
  82. images (`PIL.Image.Image`, `np.ndarray`, `paddle.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[paddle.Tensor]`):
  83. The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or Paddle
  84. tensor. Both channels-first and channels-last formats are supported.
  85. text (`str`, `List[str]`, `List[List[str]]`):
  86. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  87. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  88. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  89. padding (`bool`, *optional*, defaults to `False`):
  90. Select a strategy to pad the returned sequences (according to the model's padding side and padding
  91. index) among:
  92. - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
  93. sequence if provided).
  94. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  95. acceptable input length for the model if that argument is not provided.
  96. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
  97. lengths).
  98. max_length (`int`, *optional*):
  99. Maximum length of the returned list and optionally padding length (see above).
  100. truncation (`bool`, *optional*):
  101. Activates truncation to cut input sequences longer than `max_length` to `max_length`.
  102. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  103. If set, will return tensors of a particular framework. Acceptable values are:
  104. - `'pd'`: Return Paddle `paddle.Tensor` objects.
  105. - `'np'`: Return NumPy `np.ndarray` objects.
  106. Returns:
  107. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  108. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  109. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  110. `None`).
  111. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  112. - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
  113. """
  114. if images is not None:
  115. image_inputs = self.image_processor(
  116. images=images, return_tensors=return_tensors
  117. )
  118. image_grid_thw = image_inputs["image_grid_thw"]
  119. else:
  120. image_inputs = {}
  121. image_grid_thw = None
  122. if not isinstance(text, list):
  123. text = [text]
  124. if image_grid_thw is not None:
  125. merge_length = self.image_processor.merge_size**2
  126. index = 0
  127. for i in range(len(text)):
  128. while "<|image_pad|>" in text[i]:
  129. text[i] = text[i].replace(
  130. "<|image_pad|>",
  131. "<|placeholder|>"
  132. * int(image_grid_thw[index].prod() // merge_length),
  133. 1, # 单个<|image_pad|>替换成对应的视觉token数量的<|placeholder|>
  134. )
  135. index += 1
  136. text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
  137. text_inputs = self.tokenizer(
  138. text,
  139. return_tensors=return_tensors,
  140. padding=padding,
  141. truncation=truncation,
  142. max_length=max_length,
  143. )
  144. return BatchFeature(data={**text_inputs, **image_inputs}).data
  145. def batch_decode(self, *args, **kwargs):
  146. """
  147. This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
  148. refer to the docstring of this method for more information.
  149. """
  150. return self.tokenizer.batch_decode(*args, **kwargs)
  151. def decode(self, *args, **kwargs):
  152. """
  153. This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
  154. the docstring of this method for more information.
  155. """
  156. return self.tokenizer.decode(*args, **kwargs)
  157. class Qwen2VLImageProcessor(object):
  158. r"""
  159. Constructs a Qwen2-VL image processor that dynamically resizes images based on the original images.
  160. Args:
  161. do_resize (`bool`, *optional*, defaults to `True`):
  162. Whether to resize the image's (height, width) dimensions.
  163. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
  164. Resampling filter to use when resizing the image.
  165. do_rescale (`bool`, *optional*, defaults to `True`):
  166. Whether to rescale the image by the specified scale `rescale_factor`.
  167. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
  168. Scale factor to use if rescaling the image.
  169. do_normalize (`bool`, *optional*, defaults to `True`):
  170. Whether to normalize the image.
  171. image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
  172. Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
  173. image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
  174. Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
  175. do_convert_rgb (`bool`, *optional*, defaults to `True`):
  176. Whether to convert the image to RGB.
  177. min_pixels (`int`, *optional*, defaults to `56 * 56`):
  178. The min pixels of the image to resize the image.
  179. max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
  180. The max pixels of the image to resize the image.
  181. patch_size (`int`, *optional*, defaults to 14):
  182. The spatial patch size of the vision encoder.
  183. temporal_patch_size (`int`, *optional*, defaults to 2):
  184. The temporal patch size of the vision encoder.
  185. merge_size (`int`, *optional*, defaults to 2):
  186. The merge size of the vision encoder to llm encoder.
  187. """
  188. def __init__(
  189. self,
  190. do_resize: bool = True,
  191. resample=None,
  192. do_rescale: bool = True,
  193. rescale_factor: float = 1 / 255.0,
  194. do_normalize: bool = True,
  195. image_mean: Optional[Union[float, List[float]]] = None,
  196. image_std: Optional[Union[float, List[float]]] = None,
  197. do_convert_rgb: bool = True,
  198. min_pixels: int = 56 * 56,
  199. max_pixels: int = 28 * 28 * 1280,
  200. patch_size: int = 14,
  201. temporal_patch_size: int = 2,
  202. merge_size: int = 2,
  203. **kwargs,
  204. ) -> None:
  205. super().__init__(**kwargs)
  206. import cv2
  207. resample = cv2.INTER_CUBIC if resample is None else resample
  208. self.do_resize = do_resize
  209. self.resample = resample
  210. self.do_rescale = do_rescale
  211. self.rescale_factor = rescale_factor
  212. self.do_normalize = do_normalize
  213. image_mean_ = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
  214. image_std_ = image_std if image_std is not None else OPENAI_CLIP_STD
  215. self.min_pixels = min_pixels
  216. self.max_pixels = max_pixels
  217. self.patch_size = patch_size
  218. self.temporal_patch_size = temporal_patch_size
  219. self.merge_size = merge_size
  220. self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
  221. self.do_convert_rgb = do_convert_rgb
  222. self.image_mean = np.array(image_mean_)[None, None, ...]
  223. self.image_std = np.array(image_std_)[None, None, ...]
  224. def _preprocess(
  225. self,
  226. images,
  227. do_resize: bool = None,
  228. resample: PILImageResampling = None,
  229. do_rescale: bool = None,
  230. rescale_factor: float = None,
  231. do_normalize: bool = None,
  232. image_mean: Optional[Union[float, List[float]]] = None,
  233. image_std: Optional[Union[float, List[float]]] = None,
  234. do_convert_rgb: bool = None,
  235. data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
  236. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  237. ):
  238. """
  239. Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
  240. Args:
  241. images (`ImageInput`):
  242. Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
  243. vision_info (`List[Dict]`, *optional*):
  244. Optional list of dictionaries containing additional information about vision inputs.
  245. do_resize (`bool`, *optional*, defaults to `self.do_resize`):
  246. Whether to resize the image.
  247. resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
  248. Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
  249. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
  250. Whether to rescale the image.
  251. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
  252. Scale factor to use if rescaling the image.
  253. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
  254. Whether to normalize the image.
  255. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
  256. Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
  257. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
  258. Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
  259. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
  260. Whether to convert the image to RGB.
  261. data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
  262. The channel dimension format for the output image. Can be one of:
  263. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  264. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  265. - Unset: Use the channel dimension format of the input image.
  266. input_data_format (`ChannelDimension` or `str`, *optional*):
  267. The channel dimension format for the input image. Can be one of:
  268. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  269. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  270. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  271. """
  272. images = make_list_of_images(images)
  273. if do_convert_rgb:
  274. images = [convert_to_rgb(image) for image in images]
  275. # All transformations expect numpy arrays.
  276. images = [to_numpy_array(image) for image in images]
  277. if is_scaled_image(images[0]) and do_rescale:
  278. logging.warning(
  279. "It looks like you are trying to rescale already rescaled images. If the input"
  280. " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
  281. )
  282. if input_data_format is None:
  283. # We assume that all images have the same channel dimension format.
  284. input_data_format = infer_channel_dimension_format(images[0])
  285. height, width = get_image_size(images[0], channel_dim=input_data_format)
  286. resized_height, resized_width = height, width
  287. processed_images = []
  288. for image in images:
  289. if do_resize:
  290. resized_height, resized_width = smart_resize(
  291. height,
  292. width,
  293. factor=self.patch_size * self.merge_size,
  294. min_pixels=self.min_pixels,
  295. max_pixels=self.max_pixels,
  296. max_ratio=MAX_RATIO,
  297. )
  298. image = image.astype("uint8")
  299. image = resize(
  300. image,
  301. (resized_width, resized_height),
  302. interp=None,
  303. backend="cv2",
  304. )
  305. if do_rescale:
  306. image = image.astype("float32")
  307. image *= rescale_factor
  308. if do_normalize:
  309. assert input_data_format == ChannelDimension.LAST
  310. image = (image - self.image_mean) / self.image_std
  311. image = to_channel_dimension_format(
  312. image, data_format, input_channel_dim=input_data_format
  313. )
  314. processed_images.append(image)
  315. patches = np.array(processed_images)
  316. if data_format == ChannelDimension.LAST:
  317. patches = patches.transpose([0, 3, 1, 2])
  318. if patches.shape[0] == 1:
  319. patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
  320. channel = patches.shape[1]
  321. grid_t = patches.shape[0] // self.temporal_patch_size
  322. grid_h, grid_w = (
  323. resized_height // self.patch_size,
  324. resized_width // self.patch_size,
  325. )
  326. patches = patches.reshape(
  327. [
  328. grid_t,
  329. self.temporal_patch_size,
  330. channel,
  331. grid_h // self.merge_size,
  332. self.merge_size,
  333. self.patch_size,
  334. grid_w // self.merge_size,
  335. self.merge_size,
  336. self.patch_size,
  337. ]
  338. )
  339. patches = patches.transpose([0, 3, 6, 4, 7, 2, 1, 5, 8])
  340. flatten_patches = patches.reshape(
  341. [
  342. grid_t * grid_h * grid_w,
  343. channel * self.temporal_patch_size * self.patch_size * self.patch_size,
  344. ]
  345. )
  346. return flatten_patches, (grid_t, grid_h, grid_w)
  347. def preprocess(
  348. self,
  349. images: ImageInput,
  350. do_resize: bool = None,
  351. size: Dict[str, int] = None,
  352. resample: PILImageResampling = None,
  353. do_rescale: bool = None,
  354. rescale_factor: float = None,
  355. do_normalize: bool = None,
  356. image_mean: Optional[Union[float, List[float]]] = None,
  357. image_std: Optional[Union[float, List[float]]] = None,
  358. do_convert_rgb: bool = None,
  359. return_tensors: Optional[Union[str, TensorType]] = None,
  360. data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
  361. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  362. ):
  363. """
  364. Args:
  365. images (`ImageInput`):
  366. Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
  367. passing in images with pixel values between 0 and 1, set `do_rescale=False`.
  368. do_resize (`bool`, *optional*, defaults to `self.do_resize`):
  369. Whether to resize the image.
  370. size (`Dict[str, int]`, *optional*, defaults to `self.size`):
  371. Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
  372. the longest edge resized to keep the input aspect ratio.
  373. resample (`int`, *optional*, defaults to `self.resample`):
  374. Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
  375. has an effect if `do_resize` is set to `True`.
  376. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
  377. Whether to rescale the image.
  378. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
  379. Rescale factor to rescale the image by if `do_rescale` is set to `True`.
  380. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
  381. Whether to normalize the image.
  382. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
  383. Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
  384. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
  385. Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
  386. `True`.
  387. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
  388. Whether to convert the image to RGB.
  389. return_tensors (`str` or `TensorType`, *optional*):
  390. The type of tensors to return. Can be one of:
  391. - Unset: Return a list of `np.ndarray`.
  392. - `TensorType.PADDLE` or `'pt'`: Return a batch of type `paddle.Tensor`.
  393. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
  394. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  395. The channel dimension format for the output image. Can be one of:
  396. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  397. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  398. - Unset: Use the channel dimension format of the input image.
  399. input_data_format (`ChannelDimension` or `str`, *optional*):
  400. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  401. from the input image. Can be one of:
  402. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  403. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  404. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  405. """
  406. do_resize = do_resize if do_resize is not None else self.do_resize
  407. size = size if size is not None else self.size
  408. resample = resample if resample is not None else self.resample
  409. do_rescale = do_rescale if do_rescale is not None else self.do_rescale
  410. rescale_factor = (
  411. rescale_factor if rescale_factor is not None else self.rescale_factor
  412. )
  413. do_normalize = do_normalize if do_normalize is not None else self.do_normalize
  414. image_mean = image_mean if image_mean is not None else self.image_mean
  415. image_std = image_std if image_std is not None else self.image_std
  416. do_convert_rgb = (
  417. do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
  418. )
  419. if images is not None:
  420. images = make_batched_images(images)
  421. if images is not None and not valid_images(images):
  422. raise ValueError(
  423. "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
  424. "paddle.Tensor."
  425. )
  426. if images is not None:
  427. pixel_values, vision_grid_thws = [], []
  428. for image in images:
  429. patches, image_grid_thw = self._preprocess(
  430. image,
  431. do_resize=do_resize,
  432. resample=resample,
  433. do_rescale=do_rescale,
  434. rescale_factor=rescale_factor,
  435. do_normalize=do_normalize,
  436. image_mean=image_mean,
  437. image_std=image_std,
  438. data_format=data_format,
  439. do_convert_rgb=do_convert_rgb,
  440. input_data_format=input_data_format,
  441. )
  442. pixel_values.extend(patches)
  443. vision_grid_thws.append(image_grid_thw)
  444. pixel_values = np.array(pixel_values)
  445. vision_grid_thws = np.array(vision_grid_thws)
  446. data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
  447. return BatchFeature(data=data, tensor_type=return_tensors)
  448. def __call__(self, images, **kwargs):
  449. return self.preprocess(images, **kwargs)
  450. class PPDocBeeProcessor(Qwen2VLProcessor):
  451. """
  452. PP-DocBee processor, based on Qwen2VLProcessor
  453. """
  454. @benchmark.timeit
  455. def preprocess(self, input_dicts):
  456. """
  457. PreProcess for PP-DocBee Series
  458. """
  459. assert (
  460. isinstance(input_dicts, list) and len(input_dicts) == 1
  461. ), f"PP-DocBee series only supports batchsize of one, but received {len(input_dicts)} samples."
  462. input_dict = input_dicts[0]
  463. image = input_dict["image"]
  464. query = input_dict["query"]
  465. image_inputs = fetch_image(
  466. image,
  467. size_factor=IMAGE_FACTOR,
  468. min_pixels=MIN_PIXELS,
  469. max_pixels=MAX_PIXELS,
  470. max_ratio=MAX_RATIO,
  471. )
  472. image_pad_token = "<|vision_start|><|image_pad|><|vision_end|>"
  473. text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{image_pad_token}{query}<|im_end|>\n<|im_start|>assistant\n"
  474. text = [text]
  475. rst_inputs = super().preprocess(
  476. text=text,
  477. images=[image_inputs],
  478. padding=False,
  479. return_tensors="pd",
  480. )
  481. return rst_inputs
  482. @benchmark.timeit
  483. def postprocess(self, model_pred, *args, **kwargs):
  484. """
  485. Post process adapt for PaddleX
  486. """
  487. return self.tokenizer.batch_decode(
  488. model_pred[0], skip_special_tokens=True, clean_up_tokenization_spaces=False
  489. )