qwen2_5_vl.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, List, Optional, Union
  15. import numpy as np
  16. from .....utils import logging
  17. from ....utils.benchmark import benchmark
  18. from ...common.tokenizer.tokenizer_utils_base import (
  19. PreTokenizedInput,
  20. TensorType,
  21. TextInput,
  22. TruncationStrategy,
  23. )
  24. from ...common.vision.funcs import resize
  25. from .common import (
  26. BatchFeature,
  27. ChannelDimension,
  28. ImageInput,
  29. PaddingStrategy,
  30. PILImageResampling,
  31. convert_to_rgb,
  32. fetch_image,
  33. get_image_size,
  34. infer_channel_dimension_format,
  35. make_batched_images,
  36. make_list_of_images,
  37. smart_resize,
  38. to_channel_dimension_format,
  39. to_numpy_array,
  40. valid_images,
  41. )
  42. OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
  43. OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
  44. IMAGE_FACTOR = 28
  45. MIN_PIXELS = 4 * 28 * 28
  46. MAX_PIXELS = 16384 * 28 * 28
  47. MAX_RATIO = 200
  48. __all__ = [
  49. "Qwen2_5_VLProcessor",
  50. "Qwen2_5_VLImageProcessor",
  51. "PPDocBee2Processor",
  52. ]
  53. def is_scaled_image(image: np.ndarray) -> bool:
  54. """
  55. Checks to see whether the pixel values have already been rescaled to [0, 1].
  56. """
  57. if image.dtype == np.uint8:
  58. return False
  59. return np.min(image) >= 0 and np.max(image) <= 1
  60. class Qwen2_5_VLProcessor(object):
  61. """
  62. Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
  63. [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2_5_VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
  64. [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
  65. Args:
  66. image_processor ([`Qwen2_5_VLImageProcessor`], *optional*):
  67. The image processor is a required input.
  68. tokenizer ([`Qwen2TokenizerFast`], *optional*):
  69. The tokenizer is a required input.
  70. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
  71. in a chat into a tokenizable string.
  72. """
  73. def __init__(self, image_processor, tokenizer, **kwargs):
  74. self.image_processor = image_processor
  75. self.tokenizer = tokenizer
  76. self.image_processor.min_pixels = kwargs.get("min_pixels", 3136)
  77. self.image_processor.max_pixels = kwargs.get("max_pixels", 12845056)
  78. def preprocess(
  79. self,
  80. images: ImageInput = None,
  81. text: Union[
  82. TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
  83. ] = None,
  84. padding: Union[bool, str, PaddingStrategy] = False,
  85. truncation: Union[bool, str, TruncationStrategy] = None,
  86. max_length: int = None,
  87. return_tensors: Optional[Union[str, TensorType]] = TensorType.PADDLE,
  88. ) -> BatchFeature:
  89. """
  90. Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
  91. and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
  92. the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
  93. Qwen2_5_VLImageProcessor's [`~Qwen2_5_VLImageProcessor.__call__`] if `vision_infos` is not `None`.
  94. Args:
  95. images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
  96. The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
  97. tensor. Both channels-first and channels-last formats are supported.
  98. text (`str`, `List[str]`, `List[List[str]]`):
  99. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  100. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  101. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  102. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  103. If set, will return tensors of a particular framework. Acceptable values are:
  104. - `'tf'`: Return TensorFlow `tf.constant` objects.
  105. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  106. - `'np'`: Return NumPy `np.ndarray` objects.
  107. - `'jax'`: Return JAX `jnp.ndarray` objects.
  108. Returns:
  109. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  110. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  111. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  112. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  113. `None`).
  114. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  115. - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
  116. """
  117. if images is not None:
  118. image_inputs = self.image_processor(
  119. images=images, return_tensors=return_tensors
  120. )
  121. image_grid_thw = image_inputs["image_grid_thw"]
  122. else:
  123. image_inputs = {}
  124. image_grid_thw = None
  125. if not isinstance(text, list):
  126. text = [text]
  127. if image_grid_thw is not None:
  128. merge_length = self.image_processor.merge_size**2
  129. index = 0
  130. for i in range(len(text)):
  131. while "<|image_pad|>" in text[i]:
  132. text[i] = text[i].replace(
  133. "<|image_pad|>",
  134. "<|placeholder|>"
  135. * int(image_grid_thw[index].prod() // merge_length),
  136. 1,
  137. )
  138. index += 1
  139. text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
  140. text_inputs = self.tokenizer(
  141. text,
  142. return_tensors=return_tensors,
  143. padding=padding,
  144. truncation=truncation,
  145. max_length=max_length,
  146. )
  147. return BatchFeature(data={**text_inputs, **image_inputs})
  148. def batch_decode(self, *args, **kwargs):
  149. """
  150. This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
  151. refer to the docstring of this method for more information.
  152. """
  153. return self.tokenizer.batch_decode(*args, **kwargs)
  154. def decode(self, *args, **kwargs):
  155. """
  156. This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
  157. the docstring of this method for more information.
  158. """
  159. return self.tokenizer.decode(*args, **kwargs)
  160. class Qwen2_5_VLImageProcessor(object):
  161. """
  162. Constructs a Qwen2.5-VL image processor that dynamically resizes images based on the original images.
  163. Args:
  164. do_resize (`bool`, *optional*, defaults to `True`):
  165. Whether to resize the image's (height, width) dimensions.
  166. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
  167. Resampling filter to use when resizing the image.
  168. do_rescale (`bool`, *optional*, defaults to `True`):
  169. Whether to rescale the image by the specified scale `rescale_factor`.
  170. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
  171. Scale factor to use if rescaling the image.
  172. do_normalize (`bool`, *optional*, defaults to `True`):
  173. Whether to normalize the image.
  174. image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
  175. Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
  176. image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
  177. Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
  178. do_convert_rgb (`bool`, *optional*, defaults to `True`):
  179. Whether to convert the image to RGB.
  180. min_pixels (`int`, *optional*, defaults to `56 * 56`):
  181. The min pixels of the image to resize the image.
  182. max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
  183. The max pixels of the image to resize the image.
  184. patch_size (`int`, *optional*, defaults to 14):
  185. The spatial patch size of the vision encoder.
  186. temporal_patch_size (`int`, *optional*, defaults to 2):
  187. The temporal patch size of the vision encoder.
  188. merge_size (`int`, *optional*, defaults to 2):
  189. The merge size of the vision encoder to llm encoder.
  190. """
  191. model_input_names = ["pixel_values", "image_grid_thw", "second_per_grid_ts"]
  192. def __init__(
  193. self,
  194. do_resize: bool = True,
  195. resample: PILImageResampling = PILImageResampling.BICUBIC,
  196. do_rescale: bool = True,
  197. rescale_factor: Union[int, float] = 1 / 255,
  198. do_normalize: bool = True,
  199. image_mean: Optional[Union[float, List[float]]] = None,
  200. image_std: Optional[Union[float, List[float]]] = None,
  201. do_convert_rgb: bool = True,
  202. min_pixels: int = 56 * 56,
  203. max_pixels: int = 28 * 28 * 1280,
  204. patch_size: int = 14,
  205. temporal_patch_size: int = 2,
  206. merge_size: int = 2,
  207. **kwargs,
  208. ) -> None:
  209. super().__init__(**kwargs)
  210. self.do_resize = do_resize
  211. self.resample = resample
  212. self.do_rescale = do_rescale
  213. self.rescale_factor = rescale_factor
  214. self.do_normalize = do_normalize
  215. image_mean_ = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
  216. image_std_ = image_std if image_std is not None else OPENAI_CLIP_STD
  217. self.min_pixels = min_pixels
  218. self.max_pixels = max_pixels
  219. self.patch_size = patch_size
  220. self.temporal_patch_size = temporal_patch_size
  221. self.merge_size = merge_size
  222. self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
  223. self.do_convert_rgb = do_convert_rgb
  224. self.image_mean = np.array(image_mean_)[None, None, ...]
  225. self.image_std = np.array(image_std_)[None, None, ...]
  226. def _preprocess(
  227. self,
  228. images: Union[ImageInput],
  229. do_resize: bool = None,
  230. resample: PILImageResampling = None,
  231. do_rescale: bool = None,
  232. rescale_factor: float = None,
  233. do_normalize: bool = None,
  234. image_mean: Optional[Union[float, List[float]]] = None,
  235. image_std: Optional[Union[float, List[float]]] = None,
  236. do_convert_rgb: bool = None,
  237. data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
  238. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  239. ):
  240. """
  241. Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
  242. Args:
  243. images (`ImageInput`):
  244. Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
  245. vision_info (`List[Dict]`, *optional*):
  246. Optional list of dictionaries containing additional information about vision inputs.
  247. do_resize (`bool`, *optional*, defaults to `self.do_resize`):
  248. Whether to resize the image.
  249. resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
  250. Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
  251. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
  252. Whether to rescale the image.
  253. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
  254. Scale factor to use if rescaling the image.
  255. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
  256. Whether to normalize the image.
  257. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
  258. Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
  259. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
  260. Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
  261. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
  262. Whether to convert the image to RGB.
  263. data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
  264. The channel dimension format for the output image. Can be one of:
  265. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  266. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  267. - Unset: Use the channel dimension format of the input image.
  268. input_data_format (`ChannelDimension` or `str`, *optional*):
  269. The channel dimension format for the input image. Can be one of:
  270. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  271. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  272. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  273. """
  274. images = make_list_of_images(images)
  275. if do_convert_rgb:
  276. images = [convert_to_rgb(image) for image in images]
  277. # All transformations expect numpy arrays.
  278. images = [to_numpy_array(image) for image in images]
  279. if is_scaled_image(images[0]) and do_rescale:
  280. logging.warning_once(
  281. "It looks like you are trying to rescale already rescaled images. If the input"
  282. " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
  283. )
  284. if input_data_format is None:
  285. # We assume that all images have the same channel dimension format.
  286. input_data_format = infer_channel_dimension_format(images[0])
  287. height, width = get_image_size(images[0], channel_dim=input_data_format)
  288. resized_height, resized_width = height, width
  289. processed_images = []
  290. for image in images:
  291. if do_resize:
  292. resized_height, resized_width = smart_resize(
  293. height,
  294. width,
  295. factor=self.patch_size * self.merge_size,
  296. min_pixels=self.min_pixels,
  297. max_pixels=self.max_pixels,
  298. max_ratio=MAX_RATIO,
  299. )
  300. image = image.astype("uint8")
  301. image = resize(
  302. image,
  303. (resized_width, resized_height),
  304. interp=None,
  305. backend="cv2",
  306. )
  307. if do_rescale:
  308. image = image.astype("float32")
  309. image *= rescale_factor
  310. if do_normalize:
  311. assert input_data_format == ChannelDimension.LAST
  312. image = (image - self.image_mean) / self.image_std
  313. image = to_channel_dimension_format(
  314. image, data_format, input_channel_dim=input_data_format
  315. )
  316. processed_images.append(image)
  317. patches = np.array(processed_images)
  318. if data_format == ChannelDimension.LAST:
  319. patches = patches.transpose([0, 3, 1, 2])
  320. if patches.shape[0] == 1:
  321. patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
  322. channel = patches.shape[1]
  323. grid_t = patches.shape[0] // self.temporal_patch_size
  324. grid_h, grid_w = (
  325. resized_height // self.patch_size,
  326. resized_width // self.patch_size,
  327. )
  328. patches = patches.reshape(
  329. [
  330. grid_t,
  331. self.temporal_patch_size,
  332. channel,
  333. grid_h // self.merge_size,
  334. self.merge_size,
  335. self.patch_size,
  336. grid_w // self.merge_size,
  337. self.merge_size,
  338. self.patch_size,
  339. ]
  340. )
  341. patches = patches.transpose([0, 3, 6, 4, 7, 2, 1, 5, 8])
  342. flatten_patches = patches.reshape(
  343. [
  344. grid_t * grid_h * grid_w,
  345. channel * self.temporal_patch_size * self.patch_size * self.patch_size,
  346. ]
  347. )
  348. return flatten_patches, (grid_t, grid_h, grid_w)
  349. def __call__(
  350. self,
  351. images: ImageInput,
  352. do_resize: bool = None,
  353. size: Dict[str, int] = None,
  354. resample: PILImageResampling = None,
  355. do_rescale: bool = None,
  356. rescale_factor: float = None,
  357. do_normalize: bool = None,
  358. image_mean: Optional[Union[float, List[float]]] = None,
  359. image_std: Optional[Union[float, List[float]]] = None,
  360. do_convert_rgb: bool = None,
  361. return_tensors: Optional[Union[str, TensorType]] = None,
  362. data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
  363. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  364. ):
  365. """
  366. Args:
  367. images (`ImageInput`):
  368. Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
  369. passing in images with pixel values between 0 and 1, set `do_rescale=False`.
  370. do_resize (`bool`, *optional*, defaults to `self.do_resize`):
  371. Whether to resize the image.
  372. size (`Dict[str, int]`, *optional*, defaults to `self.size`):
  373. Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
  374. the longest edge resized to keep the input aspect ratio.
  375. resample (`int`, *optional*, defaults to `self.resample`):
  376. Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
  377. has an effect if `do_resize` is set to `True`.
  378. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
  379. Whether to rescale the image.
  380. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
  381. Rescale factor to rescale the image by if `do_rescale` is set to `True`.
  382. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
  383. Whether to normalize the image.
  384. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
  385. Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
  386. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
  387. Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
  388. `True`.
  389. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
  390. Whether to convert the image to RGB.
  391. return_tensors (`str` or `TensorType`, *optional*):
  392. The type of tensors to return. Can be one of:
  393. - Unset: Return a list of `np.ndarray`.
  394. - `TensorType.PADDLE` or `'pt'`: Return a batch of type `torch.Tensor`.
  395. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
  396. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  397. The channel dimension format for the output image. Can be one of:
  398. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  399. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  400. - Unset: Use the channel dimension format of the input image.
  401. input_data_format (`ChannelDimension` or `str`, *optional*):
  402. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  403. from the input image. Can be one of:
  404. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  405. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  406. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  407. """
  408. do_resize = do_resize if do_resize is not None else self.do_resize
  409. size = size if size is not None else self.size
  410. resample = resample if resample is not None else self.resample
  411. do_rescale = do_rescale if do_rescale is not None else self.do_rescale
  412. rescale_factor = (
  413. rescale_factor if rescale_factor is not None else self.rescale_factor
  414. )
  415. do_normalize = do_normalize if do_normalize is not None else self.do_normalize
  416. image_mean = image_mean if image_mean is not None else self.image_mean
  417. image_std = image_std if image_std is not None else self.image_std
  418. do_convert_rgb = (
  419. do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
  420. )
  421. if images is not None:
  422. images = make_batched_images(images)
  423. if images is not None and not valid_images(images):
  424. raise ValueError(
  425. "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
  426. "paddle.Tensor."
  427. )
  428. if images is not None:
  429. pixel_values, vision_grid_thws = [], []
  430. for image in images:
  431. patches, image_grid_thw = self._preprocess(
  432. image,
  433. do_resize=do_resize,
  434. resample=resample,
  435. do_rescale=do_rescale,
  436. rescale_factor=rescale_factor,
  437. do_normalize=do_normalize,
  438. image_mean=image_mean,
  439. image_std=image_std,
  440. data_format=data_format,
  441. do_convert_rgb=do_convert_rgb,
  442. input_data_format=input_data_format,
  443. )
  444. pixel_values.extend(patches)
  445. vision_grid_thws.append(image_grid_thw)
  446. pixel_values = np.array(pixel_values)
  447. vision_grid_thws = np.array(vision_grid_thws)
  448. data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
  449. return BatchFeature(data=data, tensor_type=return_tensors)
  450. class PPDocBee2Processor(Qwen2_5_VLProcessor):
  451. """
  452. PP-DocBee processor, based on Qwen2VLProcessor
  453. """
  454. @benchmark.timeit
  455. def preprocess(self, input_dicts: List[Dict]):
  456. """
  457. PreProcess for PP-DocBee2 Series
  458. """
  459. assert (isinstance(input_dict, dict) for input_dict in input_dicts)
  460. prompt = (
  461. "<|im_start|>system\n"
  462. "You are a helpful assistant.<|im_end|>\n"
  463. "<|im_start|>user\n"
  464. "<|vision_start|><|image_pad|><|vision_end|>{query}<|im_end|>\n"
  465. "<|im_start|>assistant\n"
  466. )
  467. query_inputs = [
  468. prompt.format(query=input_dict["query"]) for input_dict in input_dicts
  469. ]
  470. image_inputs = [
  471. fetch_image(
  472. input_dict["image"],
  473. size_factor=IMAGE_FACTOR,
  474. min_pixels=MIN_PIXELS,
  475. max_pixels=MAX_PIXELS,
  476. max_ratio=MAX_RATIO,
  477. )
  478. for input_dict in input_dicts
  479. ]
  480. rst_inputs = super().preprocess(
  481. text=query_inputs,
  482. images=image_inputs,
  483. padding=True,
  484. return_tensors="pd",
  485. )
  486. return rst_inputs
  487. @benchmark.timeit
  488. def postprocess(self, model_pred, **kwargs) -> List[str]:
  489. """
  490. Post process adapt for PaddleX
  491. """
  492. return self.tokenizer.batch_decode(
  493. model_pred[0],
  494. skip_special_tokens=kwargs.get("skip_special_tokens", True),
  495. clean_up_tokenization_spaces=False,
  496. )