common.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import base64
  15. import math
  16. from collections import UserDict
  17. from io import BytesIO
  18. from typing import Any, Dict, List, Optional, Tuple, Union
  19. import numpy as np
  20. import paddle
  21. import PIL.Image
  22. import requests
  23. from packaging import version
  24. from PIL import Image
  25. from ...common.tokenizer.tokenizer_utils_base import ExplicitEnum
  26. def is_paddle_tensor(tensor):
  27. return paddle.is_tensor(tensor)
  28. def to_numpy(obj):
  29. """
  30. Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a Numpy array.
  31. """
  32. if isinstance(obj, (dict, UserDict)):
  33. return {k: to_numpy(v) for k, v in obj.items()}
  34. elif isinstance(obj, (list, tuple)):
  35. return np.array(obj)
  36. elif is_paddle_tensor(obj):
  37. return obj.detach().cpu().numpy()
  38. else:
  39. return obj
  40. if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
  41. PILImageResampling = PIL.Image.Resampling
  42. else:
  43. PILImageResampling = PIL.Image
  44. ImageInput = Union[
  45. "PIL.Image.Image",
  46. np.ndarray,
  47. "paddle.Tensor",
  48. List["PIL.Image.Image"],
  49. List[np.ndarray],
  50. List["paddle.Tensor"],
  51. ] # noqa
  52. TextInput = str
  53. class ChannelDimension(ExplicitEnum):
  54. FIRST = "channels_first"
  55. LAST = "channels_last"
  56. class TensorType(ExplicitEnum):
  57. """
  58. Possible values for the `return_tensors` argument in [`PretrainedTokenizerBase.__call__`]. Useful for
  59. tab-completion in an IDE.
  60. """
  61. PADDLE = "pd"
  62. NUMPY = "np"
  63. def is_valid_image(img):
  64. return (
  65. isinstance(img, PIL.Image.Image)
  66. or isinstance(img, np.ndarray)
  67. or is_paddle_tensor(img)
  68. )
  69. def valid_images(imgs):
  70. # If we have an list of images, make sure every image is valid
  71. if isinstance(imgs, (list, tuple)):
  72. for img in imgs:
  73. if not valid_images(img):
  74. return False
  75. # If not a list of tuple, we have been given a single image or batched tensor of images
  76. elif not is_valid_image(imgs):
  77. return False
  78. return True
  79. def is_batched(img):
  80. if isinstance(img, (list, tuple)):
  81. return is_valid_image(img[0])
  82. return False
  83. def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
  84. """
  85. Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1.
  86. If the input is a batch of images, it is converted to a list of images.
  87. Args:
  88. images (`ImageInput`):
  89. Image of images to turn into a list of images.
  90. expected_ndims (`int`, *optional*, defaults to 3):
  91. Expected number of dimensions for a single input image. If the input image has a different number of
  92. dimensions, an error is raised.
  93. """
  94. if is_batched(images):
  95. return images
  96. # Either the input is a single image, in which case we create a list of length 1
  97. if isinstance(images, PIL.Image.Image):
  98. # PIL images are never batched
  99. return [images]
  100. if is_valid_image(images):
  101. if images.ndim == expected_ndims + 1:
  102. # Batch of images
  103. images = list(images)
  104. elif images.ndim == expected_ndims:
  105. # Single image
  106. images = [images]
  107. else:
  108. raise ValueError(
  109. f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
  110. f" {images.ndim} dimensions."
  111. )
  112. return images
  113. raise ValueError(
  114. "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, paddle.Tensor "
  115. f"but got {type(images)}."
  116. )
  117. def to_numpy_array(img) -> np.ndarray:
  118. if not is_valid_image(img):
  119. raise ValueError(f"Invalid image type: {type(img)}")
  120. if isinstance(img, PIL.Image.Image):
  121. return np.array(img)
  122. return to_numpy(img)
  123. def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension:
  124. """
  125. Infers the channel dimension format of `image`.
  126. Args:
  127. image (`np.ndarray`):
  128. The image to infer the channel dimension of.
  129. Returns:
  130. The channel dimension of the image.
  131. """
  132. if image.ndim == 3:
  133. first_dim, last_dim = 0, 2
  134. elif image.ndim == 4:
  135. first_dim, last_dim = 1, 3
  136. else:
  137. raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
  138. if image.shape[first_dim] in (1, 3):
  139. return ChannelDimension.FIRST
  140. elif image.shape[last_dim] in (1, 3):
  141. return ChannelDimension.LAST
  142. raise ValueError("Unable to infer channel dimension format")
  143. def get_channel_dimension_axis(image: np.ndarray) -> int:
  144. """
  145. Returns the channel dimension axis of the image.
  146. Args:
  147. image (`np.ndarray`):
  148. The image to get the channel dimension axis of.
  149. Returns:
  150. The channel dimension axis of the image.
  151. """
  152. channel_dim = infer_channel_dimension_format(image)
  153. if channel_dim == ChannelDimension.FIRST:
  154. return image.ndim - 3
  155. elif channel_dim == ChannelDimension.LAST:
  156. return image.ndim - 1
  157. raise ValueError(f"Unsupported data format: {channel_dim}")
  158. def get_image_size(
  159. image: np.ndarray, channel_dim: ChannelDimension = None
  160. ) -> Tuple[int, int]:
  161. """
  162. Returns the (height, width) dimensions of the image.
  163. Args:
  164. image (`np.ndarray`):
  165. The image to get the dimensions of.
  166. channel_dim (`ChannelDimension`, *optional*):
  167. Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
  168. Returns:
  169. A tuple of the image's height and width.
  170. """
  171. if channel_dim is None:
  172. channel_dim = infer_channel_dimension_format(image)
  173. if channel_dim == ChannelDimension.FIRST:
  174. return image.shape[-2], image.shape[-1]
  175. elif channel_dim == ChannelDimension.LAST:
  176. return image.shape[-3], image.shape[-2]
  177. else:
  178. raise ValueError(f"Unsupported data format: {channel_dim}")
  179. def convert_to_rgb(image: ImageInput) -> ImageInput:
  180. """
  181. Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
  182. as is.
  183. Args:
  184. image (Image):
  185. The image to convert.
  186. """
  187. if not isinstance(image, PIL.Image.Image):
  188. return image
  189. image = image.convert("RGB")
  190. return image
  191. def to_channel_dimension_format(
  192. image: np.ndarray,
  193. channel_dim: Union[ChannelDimension, str],
  194. input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
  195. ) -> np.ndarray:
  196. """
  197. Converts `image` to the channel dimension format specified by `channel_dim`.
  198. Args:
  199. image (`numpy.ndarray`):
  200. The image to have its channel dimension set.
  201. channel_dim (`ChannelDimension`):
  202. The channel dimension format to use.
  203. Returns:
  204. `np.ndarray`: The image with the channel dimension set to `channel_dim`.
  205. """
  206. if not isinstance(image, np.ndarray):
  207. raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
  208. if input_channel_dim is None:
  209. input_channel_dim = infer_channel_dimension_format(image)
  210. target_channel_dim = ChannelDimension(channel_dim)
  211. if input_channel_dim == target_channel_dim:
  212. return image
  213. if target_channel_dim == ChannelDimension.FIRST:
  214. image = image.transpose((2, 0, 1))
  215. elif target_channel_dim == ChannelDimension.LAST:
  216. image = image.transpose((1, 2, 0))
  217. else:
  218. raise ValueError("Unsupported channel dimension format: {}".format(channel_dim))
  219. return image
  220. class BatchFeature(UserDict):
  221. r"""
  222. Holds the feature extractor specific `__call__` methods.
  223. This class is derived from a python dictionary and can be used as a dictionary.
  224. Args:
  225. data (`dict`):
  226. Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask',
  227. etc.).
  228. tensor_type (`Union[None, str, TensorType]`, *optional*):
  229. You can give a tensor_type here to convert the lists of integers in Paddle/Numpy Tensors at
  230. initialization.
  231. """
  232. def __init__(
  233. self,
  234. data: Optional[Dict[str, Any]] = None,
  235. tensor_type: Union[None, str, TensorType] = None,
  236. ):
  237. super().__init__(data)
  238. self.convert_to_tensors(tensor_type=tensor_type)
  239. def __getitem__(self, item: str):
  240. """
  241. If the key is a string, returns the value of the dict associated to `key` ('input_values', 'attention_mask',
  242. etc.).
  243. """
  244. if isinstance(item, str):
  245. return self.data[item]
  246. else:
  247. raise KeyError(
  248. "Indexing with integers is not available when using Python based feature extractors"
  249. )
  250. def __getattr__(self, item: str):
  251. try:
  252. return self.data[item]
  253. except KeyError:
  254. raise AttributeError
  255. def __getstate__(self):
  256. return {"data": self.data}
  257. def __setstate__(self, state):
  258. if "data" in state:
  259. self.data = state["data"]
  260. def keys(self):
  261. return self.data.keys()
  262. def values(self):
  263. return self.data.values()
  264. def items(self):
  265. return self.data.items()
  266. def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
  267. """
  268. Convert the inner content to tensors.
  269. Args:
  270. tensor_type (`str` or [`TensorType`], *optional*):
  271. The type of tensors to use. If `str`, should be one of the values of the enum [`TensorType`]. If
  272. `None`, no modification is done.
  273. """
  274. if tensor_type is None:
  275. return self
  276. # Convert to TensorType
  277. if not isinstance(tensor_type, TensorType):
  278. tensor_type = TensorType(tensor_type)
  279. # Get a function reference for the correct framework
  280. if tensor_type == TensorType.PADDLE:
  281. as_tensor = paddle.to_tensor
  282. is_tensor = paddle.is_tensor
  283. else:
  284. as_tensor = np.asarray
  285. def is_tensor(x):
  286. return isinstance(x, np.ndarray)
  287. # Do the tensor conversion in batch
  288. for key, value in self.items():
  289. try:
  290. if not is_tensor(value):
  291. tensor = as_tensor(value)
  292. self[key] = tensor
  293. except: # noqa E722
  294. if key == "overflowing_tokens":
  295. raise ValueError(
  296. "Unable to create tensor returning overflowing tokens of different lengths. "
  297. "Please see if a fast version of this tokenizer is available to have this feature available."
  298. )
  299. raise ValueError(
  300. "Unable to create tensor, you should probably activate truncation and/or padding "
  301. "with 'padding=True' 'truncation=True' to have batched tensors with the same length."
  302. )
  303. return self
  304. class PaddingStrategy(ExplicitEnum):
  305. """
  306. Possible values for the `padding` argument in [`PretrainedTokenizerBase.__call__`]. Useful for tab-completion in an
  307. IDE.
  308. """
  309. LONGEST = "longest"
  310. MAX_LENGTH = "max_length"
  311. DO_NOT_PAD = "do_not_pad"
  312. def extract_vision_info(
  313. conversations: Union[List[dict], List[List[dict]]]
  314. ) -> List[dict]:
  315. vision_infos = []
  316. if isinstance(conversations[0], dict):
  317. conversations = [conversations]
  318. for conversation in conversations:
  319. for message in conversation:
  320. if isinstance(message["content"], list):
  321. for ele in message["content"]:
  322. if (
  323. "image" in ele
  324. or "image_url" in ele
  325. or ele["type"] in ("image", "image_url")
  326. ):
  327. vision_infos.append(ele)
  328. return vision_infos
  329. def process_vision_info(
  330. conversations: Union[List[dict], List[List[dict]]],
  331. ) -> Tuple[
  332. Union[List[Image.Image], None, List[Union[paddle.Tensor, List[Image.Image]]], None]
  333. ]:
  334. vision_infos = extract_vision_info(conversations)
  335. image_inputs = []
  336. for vision_info in vision_infos:
  337. if "image" in vision_info or "image_url" in vision_info:
  338. image_inputs.append(fetch_image(vision_info))
  339. else:
  340. raise ValueError("image, image_url should in content.")
  341. if len(image_inputs) == 0:
  342. image_inputs = None
  343. return image_inputs
  344. def fetch_image(
  345. ele: Dict[str, Union[str, Image.Image]],
  346. size_factor: int,
  347. min_pixels: int,
  348. max_pixels: int,
  349. max_ratio: float,
  350. ) -> Image.Image:
  351. if not isinstance(ele, dict):
  352. ele = {"image": ele}
  353. if "image" in ele:
  354. image = ele["image"]
  355. else:
  356. image = ele["image_url"]
  357. image_obj = None
  358. if isinstance(image, Image.Image):
  359. image_obj = image
  360. elif isinstance(image, np.ndarray):
  361. image_obj = Image.fromarray(image)
  362. elif image.startswith("http://") or image.startswith("https://"):
  363. image_obj = Image.open(requests.get(image, stream=True).raw)
  364. elif image.startswith("file://"):
  365. image_obj = Image.open(image[7:])
  366. elif image.startswith("data:image"):
  367. data = image.split(";", 1)[1]
  368. if data.startswith("base64,"):
  369. data = base64.b64decode(data[7:])
  370. image_obj = Image.open(BytesIO(data))
  371. else:
  372. image_obj = Image.open(image)
  373. if image_obj is None:
  374. raise ValueError(
  375. f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
  376. )
  377. image = image_obj.convert("RGB")
  378. # resize
  379. if "resized_height" in ele and "resized_width" in ele:
  380. resized_height, resized_width = smart_resize(
  381. ele["resized_height"],
  382. ele["resized_width"],
  383. factor=size_factor,
  384. min_pixels=min_pixels,
  385. max_pixels=max_pixels,
  386. max_ratio=max_ratio,
  387. )
  388. else:
  389. width, height = image.size # Image, not tensor
  390. min_pixels = ele.get("min_pixels", min_pixels)
  391. max_pixels = ele.get("max_pixels", max_pixels)
  392. resized_height, resized_width = smart_resize(
  393. height,
  394. width,
  395. factor=size_factor,
  396. min_pixels=min_pixels,
  397. max_pixels=max_pixels,
  398. max_ratio=max_ratio,
  399. )
  400. image = image.resize((resized_width, resized_height))
  401. return image
  402. def round_by_factor(number: int, factor: int) -> int:
  403. """Returns the closest integer to 'number' that is divisible by 'factor'."""
  404. return round(number / factor) * factor
  405. def ceil_by_factor(number: int, factor: int) -> int:
  406. """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
  407. return math.ceil(number / factor) * factor
  408. def floor_by_factor(number: int, factor: int) -> int:
  409. """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
  410. return math.floor(number / factor) * factor
  411. def smart_resize(
  412. height: int,
  413. width: int,
  414. factor: int,
  415. min_pixels: int,
  416. max_pixels: int,
  417. max_ratio: float,
  418. ) -> Tuple[int, int]:
  419. """
  420. Rescales the image so that the following conditions are met:
  421. 1. Both dimensions (height and width) are divisible by 'factor'.
  422. 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
  423. 3. The aspect ratio of the image is maintained as closely as possible.
  424. """
  425. if max(height, width) / min(height, width) > max_ratio:
  426. raise ValueError(
  427. f"absolute aspect ratio must be smaller than {max_ratio}, got {max(height, width) / min(height, width)}"
  428. )
  429. h_bar = max(factor, round_by_factor(height, factor))
  430. w_bar = max(factor, round_by_factor(width, factor))
  431. if h_bar * w_bar > max_pixels:
  432. beta = math.sqrt((height * width) / max_pixels)
  433. h_bar = floor_by_factor(height / beta, factor)
  434. w_bar = floor_by_factor(width / beta, factor)
  435. elif h_bar * w_bar < min_pixels:
  436. beta = math.sqrt(min_pixels / (height * width))
  437. h_bar = ceil_by_factor(height * beta, factor)
  438. w_bar = ceil_by_factor(width * beta, factor)
  439. return h_bar, w_bar
  440. def make_batched_images(images) -> List[List[ImageInput]]:
  441. """
  442. Accepts images in list or nested list format, and makes a list of images for preprocessing.
  443. Args:
  444. images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
  445. The input image.
  446. Returns:
  447. list: A list of images.
  448. """
  449. if (
  450. isinstance(images, (list, tuple))
  451. and isinstance(images[0], (list, tuple))
  452. and is_valid_image(images[0][0])
  453. ):
  454. return [img for img_list in images for img in img_list]
  455. elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
  456. return images
  457. elif is_valid_image(images):
  458. return [images]
  459. raise ValueError(f"Could not make batched images from {images}")