common.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections import UserDict
  15. from typing import Any, Dict, List, Optional, Tuple, Union
  16. import numpy as np
  17. import paddle
  18. import PIL.Image
  19. from packaging import version
  20. from ...common.tokenizer.tokenizer_utils_base import ExplicitEnum
  21. def is_paddle_tensor(tensor):
  22. return paddle.is_tensor(tensor)
  23. def to_numpy(obj):
  24. """
  25. Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a Numpy array.
  26. """
  27. if isinstance(obj, (dict, UserDict)):
  28. return {k: to_numpy(v) for k, v in obj.items()}
  29. elif isinstance(obj, (list, tuple)):
  30. return np.array(obj)
  31. elif is_paddle_tensor(obj):
  32. return obj.detach().cpu().numpy()
  33. else:
  34. return obj
  35. if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
  36. PILImageResampling = PIL.Image.Resampling
  37. else:
  38. PILImageResampling = PIL.Image
  39. ImageInput = Union[
  40. "PIL.Image.Image",
  41. np.ndarray,
  42. "paddle.Tensor",
  43. List["PIL.Image.Image"],
  44. List[np.ndarray],
  45. List["paddle.Tensor"],
  46. ] # noqa
  47. TextInput = str
  48. class ChannelDimension(ExplicitEnum):
  49. FIRST = "channels_first"
  50. LAST = "channels_last"
  51. class TensorType(ExplicitEnum):
  52. """
  53. Possible values for the `return_tensors` argument in [`PretrainedTokenizerBase.__call__`]. Useful for
  54. tab-completion in an IDE.
  55. """
  56. PADDLE = "pd"
  57. NUMPY = "np"
  58. def is_valid_image(img):
  59. return (
  60. isinstance(img, PIL.Image.Image)
  61. or isinstance(img, np.ndarray)
  62. or is_paddle_tensor(img)
  63. )
  64. def valid_images(imgs):
  65. # If we have an list of images, make sure every image is valid
  66. if isinstance(imgs, (list, tuple)):
  67. for img in imgs:
  68. if not valid_images(img):
  69. return False
  70. # If not a list of tuple, we have been given a single image or batched tensor of images
  71. elif not is_valid_image(imgs):
  72. return False
  73. return True
  74. def is_batched(img):
  75. if isinstance(img, (list, tuple)):
  76. return is_valid_image(img[0])
  77. return False
  78. def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
  79. """
  80. Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1.
  81. If the input is a batch of images, it is converted to a list of images.
  82. Args:
  83. images (`ImageInput`):
  84. Image of images to turn into a list of images.
  85. expected_ndims (`int`, *optional*, defaults to 3):
  86. Expected number of dimensions for a single input image. If the input image has a different number of
  87. dimensions, an error is raised.
  88. """
  89. if is_batched(images):
  90. return images
  91. # Either the input is a single image, in which case we create a list of length 1
  92. if isinstance(images, PIL.Image.Image):
  93. # PIL images are never batched
  94. return [images]
  95. if is_valid_image(images):
  96. if images.ndim == expected_ndims + 1:
  97. # Batch of images
  98. images = list(images)
  99. elif images.ndim == expected_ndims:
  100. # Single image
  101. images = [images]
  102. else:
  103. raise ValueError(
  104. f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
  105. f" {images.ndim} dimensions."
  106. )
  107. return images
  108. raise ValueError(
  109. "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, paddle.Tensor "
  110. f"but got {type(images)}."
  111. )
  112. def to_numpy_array(img) -> np.ndarray:
  113. if not is_valid_image(img):
  114. raise ValueError(f"Invalid image type: {type(img)}")
  115. if isinstance(img, PIL.Image.Image):
  116. return np.array(img)
  117. return to_numpy(img)
  118. def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension:
  119. """
  120. Infers the channel dimension format of `image`.
  121. Args:
  122. image (`np.ndarray`):
  123. The image to infer the channel dimension of.
  124. Returns:
  125. The channel dimension of the image.
  126. """
  127. if image.ndim == 3:
  128. first_dim, last_dim = 0, 2
  129. elif image.ndim == 4:
  130. first_dim, last_dim = 1, 3
  131. else:
  132. raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
  133. if image.shape[first_dim] in (1, 3):
  134. return ChannelDimension.FIRST
  135. elif image.shape[last_dim] in (1, 3):
  136. return ChannelDimension.LAST
  137. raise ValueError("Unable to infer channel dimension format")
  138. def get_channel_dimension_axis(image: np.ndarray) -> int:
  139. """
  140. Returns the channel dimension axis of the image.
  141. Args:
  142. image (`np.ndarray`):
  143. The image to get the channel dimension axis of.
  144. Returns:
  145. The channel dimension axis of the image.
  146. """
  147. channel_dim = infer_channel_dimension_format(image)
  148. if channel_dim == ChannelDimension.FIRST:
  149. return image.ndim - 3
  150. elif channel_dim == ChannelDimension.LAST:
  151. return image.ndim - 1
  152. raise ValueError(f"Unsupported data format: {channel_dim}")
  153. def get_image_size(
  154. image: np.ndarray, channel_dim: ChannelDimension = None
  155. ) -> Tuple[int, int]:
  156. """
  157. Returns the (height, width) dimensions of the image.
  158. Args:
  159. image (`np.ndarray`):
  160. The image to get the dimensions of.
  161. channel_dim (`ChannelDimension`, *optional*):
  162. Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
  163. Returns:
  164. A tuple of the image's height and width.
  165. """
  166. if channel_dim is None:
  167. channel_dim = infer_channel_dimension_format(image)
  168. if channel_dim == ChannelDimension.FIRST:
  169. return image.shape[-2], image.shape[-1]
  170. elif channel_dim == ChannelDimension.LAST:
  171. return image.shape[-3], image.shape[-2]
  172. else:
  173. raise ValueError(f"Unsupported data format: {channel_dim}")
  174. def convert_to_rgb(image: ImageInput) -> ImageInput:
  175. """
  176. Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
  177. as is.
  178. Args:
  179. image (Image):
  180. The image to convert.
  181. """
  182. if not isinstance(image, PIL.Image.Image):
  183. return image
  184. image = image.convert("RGB")
  185. return image
  186. def to_channel_dimension_format(
  187. image: np.ndarray,
  188. channel_dim: Union[ChannelDimension, str],
  189. input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
  190. ) -> np.ndarray:
  191. """
  192. Converts `image` to the channel dimension format specified by `channel_dim`.
  193. Args:
  194. image (`numpy.ndarray`):
  195. The image to have its channel dimension set.
  196. channel_dim (`ChannelDimension`):
  197. The channel dimension format to use.
  198. Returns:
  199. `np.ndarray`: The image with the channel dimension set to `channel_dim`.
  200. """
  201. if not isinstance(image, np.ndarray):
  202. raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
  203. if input_channel_dim is None:
  204. input_channel_dim = infer_channel_dimension_format(image)
  205. target_channel_dim = ChannelDimension(channel_dim)
  206. if input_channel_dim == target_channel_dim:
  207. return image
  208. if target_channel_dim == ChannelDimension.FIRST:
  209. image = image.transpose((2, 0, 1))
  210. elif target_channel_dim == ChannelDimension.LAST:
  211. image = image.transpose((1, 2, 0))
  212. else:
  213. raise ValueError("Unsupported channel dimension format: {}".format(channel_dim))
  214. return image
  215. class BatchFeature(UserDict):
  216. r"""
  217. Holds the feature extractor specific `__call__` methods.
  218. This class is derived from a python dictionary and can be used as a dictionary.
  219. Args:
  220. data (`dict`):
  221. Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask',
  222. etc.).
  223. tensor_type (`Union[None, str, TensorType]`, *optional*):
  224. You can give a tensor_type here to convert the lists of integers in Paddle/Numpy Tensors at
  225. initialization.
  226. """
  227. def __init__(
  228. self,
  229. data: Optional[Dict[str, Any]] = None,
  230. tensor_type: Union[None, str, TensorType] = None,
  231. ):
  232. super().__init__(data)
  233. self.convert_to_tensors(tensor_type=tensor_type)
  234. def __getitem__(self, item: str):
  235. """
  236. If the key is a string, returns the value of the dict associated to `key` ('input_values', 'attention_mask',
  237. etc.).
  238. """
  239. if isinstance(item, str):
  240. return self.data[item]
  241. else:
  242. raise KeyError(
  243. "Indexing with integers is not available when using Python based feature extractors"
  244. )
  245. def __getattr__(self, item: str):
  246. try:
  247. return self.data[item]
  248. except KeyError:
  249. raise AttributeError
  250. def __getstate__(self):
  251. return {"data": self.data}
  252. def __setstate__(self, state):
  253. if "data" in state:
  254. self.data = state["data"]
  255. def keys(self):
  256. return self.data.keys()
  257. def values(self):
  258. return self.data.values()
  259. def items(self):
  260. return self.data.items()
  261. def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
  262. """
  263. Convert the inner content to tensors.
  264. Args:
  265. tensor_type (`str` or [`TensorType`], *optional*):
  266. The type of tensors to use. If `str`, should be one of the values of the enum [`TensorType`]. If
  267. `None`, no modification is done.
  268. """
  269. if tensor_type is None:
  270. return self
  271. # Convert to TensorType
  272. if not isinstance(tensor_type, TensorType):
  273. tensor_type = TensorType(tensor_type)
  274. # Get a function reference for the correct framework
  275. if tensor_type == TensorType.PADDLE:
  276. as_tensor = paddle.to_tensor
  277. is_tensor = paddle.is_tensor
  278. else:
  279. as_tensor = np.asarray
  280. def is_tensor(x):
  281. return isinstance(x, np.ndarray)
  282. # Do the tensor conversion in batch
  283. for key, value in self.items():
  284. try:
  285. if not is_tensor(value):
  286. tensor = as_tensor(value)
  287. self[key] = tensor
  288. except: # noqa E722
  289. if key == "overflowing_tokens":
  290. raise ValueError(
  291. "Unable to create tensor returning overflowing tokens of different lengths. "
  292. "Please see if a fast version of this tokenizer is available to have this feature available."
  293. )
  294. raise ValueError(
  295. "Unable to create tensor, you should probably activate truncation and/or padding "
  296. "with 'padding=True' 'truncation=True' to have batched tensors with the same length."
  297. )
  298. return self