processors.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. from typing import List, Sequence, Union, Optional, Tuple
  17. import re
  18. import numpy as np
  19. import cv2
  20. import math
  21. import json
  22. import tempfile
  23. import lazy_paddle
  24. from ...utils.benchmark import benchmark
  25. @benchmark.timeit
  26. class Scale:
  27. """Scale images."""
  28. def __init__(
  29. self,
  30. short_size: int,
  31. fixed_ratio: bool = True,
  32. keep_ratio: Union[bool, None] = None,
  33. do_round: bool = False,
  34. ) -> None:
  35. """
  36. Initializes the Scale class.
  37. Args:
  38. short_size (int): The target size for the shorter side of the image.
  39. fixed_ratio (bool): Whether to maintain a fixed aspect ratio of 4:3.
  40. keep_ratio (Union[bool, None]): Whether to keep the aspect ratio. Cannot be True if fixed_ratio is True.
  41. do_round (bool): Whether to round the scaling factor.
  42. """
  43. super().__init__()
  44. self.short_size = short_size
  45. assert (fixed_ratio and not keep_ratio) or (
  46. not fixed_ratio
  47. ), f"fixed_ratio and keep_ratio cannot be true at the same time"
  48. self.fixed_ratio = fixed_ratio
  49. self.keep_ratio = keep_ratio
  50. self.do_round = do_round
  51. def scale(self, video: List[np.ndarray]) -> List[np.ndarray]:
  52. """
  53. Performs resize operations on a sequence of images.
  54. Args:
  55. video (List[np.ndarray]): List where each item is an image, as a numpy array.
  56. For example, [np.ndarray0, np.ndarray1, np.ndarray2, ...]
  57. Returns:
  58. List[np.ndarray]: List where each item is a np.ndarray after scaling.
  59. """
  60. imgs = video
  61. resized_imgs = []
  62. for i in range(len(imgs)):
  63. img = imgs[i]
  64. if isinstance(img, np.ndarray):
  65. h, w, _ = img.shape
  66. else:
  67. raise NotImplementedError
  68. if (w <= h and w == self.short_size) or (h <= w and h == self.short_size):
  69. resized_imgs.append(img)
  70. continue
  71. if w <= h:
  72. ow = self.short_size
  73. if self.fixed_ratio:
  74. oh = int(self.short_size * 4.0 / 3.0)
  75. elif self.keep_ratio is False:
  76. oh = self.short_size
  77. else:
  78. scale_factor = self.short_size / w
  79. oh = (
  80. int(h * float(scale_factor) + 0.5)
  81. if self.do_round
  82. else int(h * self.short_size / w)
  83. )
  84. ow = (
  85. int(w * float(scale_factor) + 0.5)
  86. if self.do_round
  87. else self.short_size
  88. )
  89. else:
  90. oh = self.short_size
  91. if self.fixed_ratio:
  92. ow = int(self.short_size * 4.0 / 3.0)
  93. elif self.keep_ratio is False:
  94. ow = self.short_size
  95. else:
  96. scale_factor = self.short_size / h
  97. oh = (
  98. int(h * float(scale_factor) + 0.5)
  99. if self.do_round
  100. else self.short_size
  101. )
  102. ow = (
  103. int(w * float(scale_factor) + 0.5)
  104. if self.do_round
  105. else int(w * self.short_size / h)
  106. )
  107. resized_imgs.append(
  108. cv2.resize(img, (ow, oh), interpolation=cv2.INTER_LINEAR)
  109. )
  110. imgs = resized_imgs
  111. return imgs
  112. def __call__(self, videos: List[np.ndarray]) -> List[np.ndarray]:
  113. """
  114. Apply the scaling operation to a list of videos.
  115. Args:
  116. videos (List[np.ndarray]): A list of videos, where each video is a sequence
  117. of images.
  118. Returns:
  119. List[np.ndarray]: A list of videos after scaling, where each video is a list of images.
  120. """
  121. return [self.scale(video) for video in videos]
  122. @benchmark.timeit
  123. class CenterCrop:
  124. """Center crop images."""
  125. def __init__(self, target_size: int, do_round: bool = True) -> None:
  126. """
  127. Initializes the CenterCrop class.
  128. Args:
  129. target_size (int): The size of the cropped area.
  130. do_round (bool): Whether to round the crop coordinates.
  131. """
  132. super().__init__()
  133. self.target_size = target_size
  134. self.do_round = do_round
  135. def center_crop(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
  136. """
  137. Performs center crop operations on images.
  138. Args:
  139. imgs (List[np.ndarray]): A sequence of images (a numpy array).
  140. Returns:
  141. List[np.ndarray]: A list of images after center cropping or a cropped numpy array.
  142. """
  143. crop_imgs = []
  144. th, tw = self.target_size, self.target_size
  145. if isinstance(imgs, lazy_paddle.Tensor):
  146. h, w = imgs.shape[-2:]
  147. x1 = int(round((w - tw) / 2.0)) if self.do_round else (w - tw) // 2
  148. y1 = int(round((h - th) / 2.0)) if self.do_round else (h - th) // 2
  149. crop_imgs = imgs[:, :, y1 : y1 + th, x1 : x1 + tw]
  150. else:
  151. for img in imgs:
  152. h, w, _ = img.shape
  153. assert (w >= self.target_size) and (
  154. h >= self.target_size
  155. ), "image width({}) and height({}) should be larger than crop size".format(
  156. w, h, self.target_size
  157. )
  158. x1 = int(round((w - tw) / 2.0)) if self.do_round else (w - tw) // 2
  159. y1 = int(round((h - th) / 2.0)) if self.do_round else (h - th) // 2
  160. crop_imgs.append(img[y1 : y1 + th, x1 : x1 + tw])
  161. return crop_imgs
  162. def __call__(self, videos: List[np.ndarray]) -> List[np.ndarray]:
  163. """
  164. Apply the center crop operation to a list of videos.
  165. Args:
  166. videos (List[np.ndarray]): A list of videos, where each video is a sequence of images.
  167. Returns:
  168. List[np.ndarray]: A list of videos after center cropping.
  169. """
  170. return [self.center_crop(video) for video in videos]
  171. @benchmark.timeit
  172. class Image2Array:
  173. """Convert a sequence of images to a numpy array with optional transposition."""
  174. def __init__(self, transpose: bool = True, data_format: str = "tchw") -> None:
  175. """
  176. Initializes the Image2Array class.
  177. Args:
  178. transpose (bool): Whether to transpose the resulting numpy array.
  179. data_format (str): The format to transpose to, either 'tchw' or 'cthw'.
  180. Raises:
  181. AssertionError: If data_format is not one of the allowed values.
  182. """
  183. super().__init__()
  184. assert data_format in [
  185. "tchw",
  186. "cthw",
  187. ], f"Target format must in ['tchw', 'cthw'], but got {data_format}"
  188. self.transpose = transpose
  189. self.data_format = data_format
  190. def img2array(self, imgs: List[np.ndarray]) -> np.ndarray:
  191. """
  192. Converts a sequence of images to a numpy array and optionally transposes it.
  193. Args:
  194. imgs (List[np.ndarray]): A list of images to be converted to a numpy array.
  195. Returns:
  196. np.ndarray: A numpy array representation of the images.
  197. """
  198. t_imgs = np.stack(imgs).astype("float32")
  199. if self.transpose:
  200. if self.data_format == "tchw":
  201. t_imgs = t_imgs.transpose([0, 3, 1, 2]) # tchw
  202. else:
  203. t_imgs = t_imgs.transpose([3, 0, 1, 2]) # cthw
  204. return t_imgs
  205. def __call__(self, videos: List[np.ndarray]) -> List[np.ndarray]:
  206. """
  207. Apply the image to array conversion to a list of videos.
  208. Args:
  209. videos (List[Sequence[np.ndarray]]): A list of videos, where each video is a sequence of images.
  210. Returns:
  211. List[np.ndarray]: A list of numpy arrays, one for each video.
  212. """
  213. return [self.img2array(video) for video in videos]
  214. @benchmark.timeit
  215. class NormalizeVideo:
  216. """
  217. Normalize video frames by subtracting the mean and dividing by the standard deviation.
  218. """
  219. def __init__(
  220. self,
  221. mean: Sequence[float],
  222. std: Sequence[float],
  223. tensor_shape: Sequence[int] = [3, 1, 1],
  224. inplace: bool = False,
  225. ) -> None:
  226. """
  227. Initializes the NormalizeVideo class.
  228. Args:
  229. mean (Sequence[float]): The mean values for each channel.
  230. std (Sequence[float]): The standard deviation values for each channel.
  231. tensor_shape (Sequence[int]): The shape of the mean and std tensors.
  232. inplace (bool): Whether to perform normalization in place.
  233. """
  234. super().__init__()
  235. self.inplace = inplace
  236. if not inplace:
  237. self.mean = np.array(mean).reshape(tensor_shape).astype(np.float32)
  238. self.std = np.array(std).reshape(tensor_shape).astype(np.float32)
  239. else:
  240. self.mean = np.array(mean, dtype=np.float32)
  241. self.std = np.array(std, dtype=np.float32)
  242. def normalize_video(self, imgs: np.ndarray) -> np.ndarray:
  243. """
  244. Normalizes a sequence of images.
  245. Args:
  246. imgs (np.ndarray): A numpy array of images to be normalized.
  247. Returns:
  248. np.ndarray: The normalized images as a numpy array.
  249. """
  250. if self.inplace:
  251. n = len(imgs)
  252. h, w, c = imgs[0].shape
  253. norm_imgs = np.empty((n, h, w, c), dtype=np.float32)
  254. for i, img in enumerate(imgs):
  255. norm_imgs[i] = img
  256. for img in norm_imgs: # [n,h,w,c]
  257. mean = np.float64(self.mean.reshape(1, -1)) # [1, 3]
  258. stdinv = 1 / np.float64(self.std.reshape(1, -1)) # [1, 3]
  259. cv2.subtract(img, mean, img)
  260. cv2.multiply(img, stdinv, img)
  261. else:
  262. imgs = imgs
  263. norm_imgs = imgs / 255.0
  264. norm_imgs -= self.mean
  265. norm_imgs /= self.std
  266. imgs = norm_imgs
  267. imgs = np.expand_dims(imgs, axis=0).copy()
  268. return imgs
  269. def __call__(self, videos: List[np.ndarray]) -> List[np.ndarray]:
  270. """
  271. Apply normalization to a list of videos.
  272. Args:
  273. videos (List[np.ndarray]): A list of videos, where each video is a numpy array of images.
  274. Returns:
  275. List[np.ndarray]: A list of normalized videos as numpy arrays.
  276. """
  277. return [self.normalize_video(video) for video in videos]
  278. @benchmark.timeit
  279. class VideoClasTopk:
  280. """Applies a top-k transformation on video classification predictions."""
  281. def __init__(self, class_ids: Optional[Sequence[Union[str, int]]] = None) -> None:
  282. """
  283. Initializes the VideoClasTopk class.
  284. Args:
  285. class_ids (Optional[Sequence[Union[str, int]]]): A list of class labels corresponding to class indices.
  286. """
  287. super().__init__()
  288. self.class_id_map = self._parse_class_id_map(class_ids)
  289. def softmax(self, data: np.ndarray) -> np.ndarray:
  290. """
  291. Applies the softmax function to an array of data.
  292. Args:
  293. data (np.ndarray): An array of data for which to compute softmax.
  294. Returns:
  295. np.ndarray: The softmax-transformed data.
  296. """
  297. x_max = np.max(data, axis=-1, keepdims=True)
  298. e_x = np.exp(data - x_max)
  299. return e_x / np.sum(e_x, axis=-1, keepdims=True)
  300. def _parse_class_id_map(
  301. self, class_ids: Optional[Sequence[Union[str, int]]]
  302. ) -> Optional[dict]:
  303. """
  304. Parses a list of class IDs into a mapping from class index to class label.
  305. Args:
  306. class_ids (Optional[Sequence[Union[str, int]]]): A list of class labels.
  307. Returns:
  308. Optional[dict]: A dictionary mapping class indices to labels, or None if no class_ids are provided.
  309. """
  310. if class_ids is None:
  311. return None
  312. class_id_map = {id: str(lb) for id, lb in enumerate(class_ids)}
  313. return class_id_map
  314. def __call__(
  315. self, preds: np.ndarray, topk: int = 5
  316. ) -> Tuple[np.ndarray, List[np.ndarray], List[List[str]]]:
  317. """
  318. Selects the top-k predictions from the classification output.
  319. Args:
  320. preds (np.ndarray): A 2D array of prediction scores.
  321. topk (int): The number of top predictions to return.
  322. Returns:
  323. Tuple[np.ndarray, List[np.ndarray], List[List[str]]]: A tuple containing:
  324. - An array of indices of the top-k predictions.
  325. - A list of arrays of scores for the top-k predictions.
  326. - A list of lists of label names for the top-k predictions.
  327. """
  328. preds[0] = self.softmax(preds[0])
  329. indexes = preds[0].argsort(axis=1)[:, -topk:][:, ::-1].astype("int32")
  330. scores = [
  331. list(np.around(pred[index], decimals=5))
  332. for pred, index in zip(preds[0], indexes)
  333. ]
  334. label_names = [[self.class_id_map[i] for i in index] for index in indexes]
  335. return indexes, scores, label_names
  336. @benchmark.timeit
  337. class ToBatch:
  338. """A class for batching videos."""
  339. def __call__(self, videos: List[np.ndarray]) -> List[np.ndarray]:
  340. """Call method to stack videos into a batch.
  341. Args:
  342. videos (list of np.ndarrays): List of videos to process.
  343. Returns:
  344. list of np.ndarrays: List containing a stacked tensor of the videos.
  345. """
  346. return [np.concatenate(videos, axis=0).astype(dtype=np.float32, copy=False)]