processors.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List
  15. import numpy as np
  16. from ....utils.deps import class_requires_deps
  17. from ...utils.benchmark import benchmark
  18. @benchmark.timeit
  19. @class_requires_deps("opencv-contrib-python")
  20. class ResizeVideo:
  21. """Resizes frames of a video to a specified target size.
  22. This class provides functionality to resize each frame of a video to
  23. a specified square dimension (height and width are equal).
  24. Attributes:
  25. target_size (int): The desired size (in pixels) for both the height
  26. and width of each frame in the video.
  27. """
  28. def __init__(self, target_size: int = 224) -> None:
  29. """Initializes the ResizeVideo with a target size.
  30. Args:
  31. target_size (int): The desired size in pixels for the output
  32. frames. Defaults to 224.
  33. """
  34. super().__init__()
  35. self.target_size = target_size
  36. def resize(self, video: List) -> List:
  37. """Resizes all frames of a single video.
  38. Args:
  39. video (list): A list of segments, where each segment is a list
  40. of frames represented as numpy arrays.
  41. Returns:
  42. list: The input video with each frame resized to the target size.
  43. Raises:
  44. NotImplementedError: If a frame is not an instance of numpy.ndarray.
  45. """
  46. import cv2
  47. num_seg = len(video)
  48. seg_len = len(video[0])
  49. for i in range(num_seg):
  50. for j in range(seg_len):
  51. img = video[i][j]
  52. if isinstance(img, np.ndarray):
  53. h, w, _ = img.shape
  54. else:
  55. raise NotImplementedError(
  56. "Currently, only numpy.ndarray frames are supported."
  57. )
  58. video[i][j] = cv2.resize(
  59. img,
  60. (self.target_size, self.target_size),
  61. interpolation=cv2.INTER_LINEAR,
  62. )
  63. return video
  64. def __call__(self, videos: List) -> List:
  65. """Resizes frames of multiple videos.
  66. Args:
  67. videos (list): A list containing multiple videos, where each video
  68. is a list of segments, and each segment is a list of frames.
  69. Returns:
  70. list: A list of videos with each frame resized to the target size.
  71. """
  72. return [self.resize(video) for video in videos]
  73. @benchmark.timeit
  74. class Image2Array:
  75. """Convert a sequence of images to a numpy array with optional transposition."""
  76. def __init__(self, data_format: str = "tchw") -> None:
  77. """
  78. Initializes the Image2Array class.
  79. Args:
  80. data_format (str): The format to transpose to, either 'tchw' or 'cthw'.
  81. Raises:
  82. AssertionError: If data_format is not one of the allowed values.
  83. """
  84. super().__init__()
  85. assert data_format in [
  86. "tchw",
  87. "cthw",
  88. ], f"Target format must be in ['tchw', 'cthw'], but got {data_format}"
  89. self.data_format = data_format
  90. def img2array(self, video: List) -> List:
  91. """
  92. Converts a list of video frames to a numpy array, with frames transposed.
  93. Args:
  94. video (List): A list of frames represented as numpy arrays.
  95. Returns:
  96. List: A numpy array with the video frames transposed and concatenated.
  97. """
  98. # Transpose each image from HWC to CHW format
  99. num_seg = len(video)
  100. for i in range(num_seg):
  101. video_one = video[i]
  102. video_one = [img.transpose([2, 0, 1]) for img in video_one]
  103. video_one = np.concatenate(
  104. [np.expand_dims(img, axis=1) for img in video_one], axis=1
  105. )
  106. video[i] = video_one
  107. return video
  108. def __call__(self, videos: List[List[np.ndarray]]) -> List[np.ndarray]:
  109. """
  110. Process videos by converting each video to a transposed numpy array.
  111. Args:
  112. videos (List[List[np.ndarray]]): A list of videos, where each video is a list
  113. of frames represented as numpy arrays.
  114. Returns:
  115. List[np.ndarray]: A list of processed videos with transposed frames.
  116. """
  117. return [self.img2array(video) for video in videos]
  118. @benchmark.timeit
  119. class NormalizeVideo:
  120. """
  121. A class to normalize video frames by scaling the pixel values.
  122. """
  123. def __init__(self, scale: float = 255.0) -> None:
  124. """
  125. Initializes the NormalizeVideo class.
  126. Args:
  127. scale (float): The scale factor to normalize the frames, usually the max pixel value.
  128. """
  129. super().__init__()
  130. self.scale = scale
  131. def normalize_video(self, video: List[np.ndarray]) -> List[np.ndarray]:
  132. """
  133. Normalizes a sequence of images by scaling the pixel values.
  134. Args:
  135. video (List[np.ndarray]): A list of frames, where each frame is a numpy array to be normalized.
  136. Returns:
  137. List[np.ndarray]: The normalized video frames as a list of numpy arrays.
  138. """
  139. num_seg = len(video) # Number of frames in the video
  140. for i in range(num_seg):
  141. # Convert frame to float32 and scale pixel values
  142. video[i] = video[i].astype(np.float32) / self.scale
  143. # Expand dimensions if needed
  144. video[i] = np.expand_dims(video[i], axis=0)
  145. return video
  146. def __call__(self, videos: List[List[np.ndarray]]) -> List[List[np.ndarray]]:
  147. """
  148. Apply normalization to a list of videos.
  149. Args:
  150. videos (List[List[np.ndarray]]): A list of videos, where each video is a list of frames
  151. represented as numpy arrays.
  152. Returns:
  153. List[List[np.ndarray]]: A list of normalized videos, each represented as a list of normalized frames.
  154. """
  155. return [self.normalize_video(video) for video in videos]
  156. def convert2cpu(gpu_matrix):
  157. float_32_g = gpu_matrix.astype("float32")
  158. return float_32_g.cpu()
  159. def convert2cpu_long(gpu_matrix):
  160. int_64_g = gpu_matrix.astype("int64")
  161. return int_64_g.cpu()
  162. def get_region_boxes(
  163. output,
  164. conf_thresh=0.005,
  165. num_classes=24,
  166. anchors=[
  167. 0.70458,
  168. 1.18803,
  169. 1.26654,
  170. 2.55121,
  171. 1.59382,
  172. 4.08321,
  173. 2.30548,
  174. 4.94180,
  175. 3.52332,
  176. 5.91979,
  177. ],
  178. num_anchors=5,
  179. only_objectness=1,
  180. ):
  181. """
  182. Processes the output of a neural network to extract bounding box predictions.
  183. Args:
  184. output (Tensor): The output tensor from the neural network.
  185. conf_thresh (float): The confidence threshold for filtering predictions. Default is 0.005.
  186. num_classes (int): The number of classes for classification. Default is 24.
  187. anchors (List[float]): A list of anchor box dimensions used in the model. Default is a list
  188. of 10 predefined anchor values.
  189. num_anchors (int): The number of anchor boxes used in the model. Default is 5.
  190. only_objectness (int): If set to 1, only objectness scores are considered for filtering. Default is 1.
  191. Returns:
  192. all_box(List[List[float]]): A list of predicted bounding boxes for each image in the batch.
  193. """
  194. import paddle
  195. anchor_step = len(anchors) // num_anchors
  196. if output.dim() == 3:
  197. output = output.unsqueeze(0)
  198. batch = output.shape[0]
  199. assert output.shape[1] == (5 + num_classes) * num_anchors
  200. h = output.shape[2]
  201. w = output.shape[3]
  202. all_boxes = []
  203. output = paddle.reshape(output, [batch * num_anchors, 5 + num_classes, h * w])
  204. output = paddle.transpose(output, (1, 0, 2))
  205. output = paddle.reshape(output, [5 + num_classes, batch * num_anchors * h * w])
  206. grid_x = paddle.linspace(0, w - 1, w)
  207. grid_x = paddle.tile(grid_x, [h, 1])
  208. grid_x = paddle.tile(grid_x, [batch * num_anchors, 1, 1])
  209. grid_x = paddle.reshape(grid_x, [batch * num_anchors * h * w])
  210. grid_y = paddle.linspace(0, h - 1, h)
  211. grid_y = paddle.tile(grid_y, [w, 1]).t()
  212. grid_y = paddle.tile(grid_y, [batch * num_anchors, 1, 1])
  213. grid_y = paddle.reshape(grid_y, [batch * num_anchors * h * w])
  214. sigmoid = paddle.nn.Sigmoid()
  215. xs = sigmoid(output[0]) + grid_x
  216. ys = sigmoid(output[1]) + grid_y
  217. anchor_w = paddle.to_tensor(anchors)
  218. anchor_w = paddle.reshape(anchor_w, [num_anchors, anchor_step])
  219. anchor_w = paddle.index_select(
  220. anchor_w, index=paddle.to_tensor(np.array([0]).astype("int32")), axis=1
  221. )
  222. anchor_h = paddle.to_tensor(anchors)
  223. anchor_h = paddle.reshape(anchor_h, [num_anchors, anchor_step])
  224. anchor_h = paddle.index_select(
  225. anchor_h, index=paddle.to_tensor(np.array([1]).astype("int32")), axis=1
  226. )
  227. anchor_w = paddle.tile(anchor_w, [batch, 1])
  228. anchor_w = paddle.tile(anchor_w, [1, 1, h * w])
  229. anchor_w = paddle.reshape(anchor_w, [batch * num_anchors * h * w])
  230. anchor_h = paddle.tile(anchor_h, [batch, 1])
  231. anchor_h = paddle.tile(anchor_h, [1, 1, h * w])
  232. anchor_h = paddle.reshape(anchor_h, [batch * num_anchors * h * w])
  233. ws = paddle.exp(output[2]) * anchor_w
  234. hs = paddle.exp(output[3]) * anchor_h
  235. det_confs = sigmoid(output[4])
  236. cls_confs = paddle.to_tensor(output[5 : 5 + num_classes], stop_gradient=True)
  237. cls_confs = paddle.transpose(cls_confs, [1, 0])
  238. s = paddle.nn.Softmax()
  239. cls_confs = paddle.to_tensor(s(cls_confs))
  240. cls_max_confs = paddle.max(cls_confs, axis=1)
  241. cls_max_ids = paddle.argmax(cls_confs, axis=1)
  242. cls_max_confs = paddle.reshape(cls_max_confs, [-1])
  243. cls_max_ids = paddle.reshape(cls_max_ids, [-1])
  244. sz_hw = h * w
  245. sz_hwa = sz_hw * num_anchors
  246. det_confs = convert2cpu(det_confs)
  247. cls_max_confs = convert2cpu(cls_max_confs)
  248. cls_max_ids = convert2cpu_long(cls_max_ids)
  249. xs = convert2cpu(xs)
  250. ys = convert2cpu(ys)
  251. ws = convert2cpu(ws)
  252. hs = convert2cpu(hs)
  253. for b in range(batch):
  254. boxes = []
  255. for cy in range(h):
  256. for cx in range(w):
  257. for i in range(num_anchors):
  258. ind = b * sz_hwa + i * sz_hw + cy * w + cx
  259. det_conf = det_confs[ind]
  260. if only_objectness:
  261. conf = det_confs[ind]
  262. else:
  263. conf = det_confs[ind] * cls_max_confs[ind]
  264. if conf > conf_thresh:
  265. bcx = xs[ind]
  266. bcy = ys[ind]
  267. bw = ws[ind]
  268. bh = hs[ind]
  269. cls_max_conf = cls_max_confs[ind]
  270. cls_max_id = cls_max_ids[ind]
  271. box = [
  272. bcx / w,
  273. bcy / h,
  274. bw / w,
  275. bh / h,
  276. det_conf,
  277. cls_max_conf,
  278. cls_max_id,
  279. ]
  280. boxes.append(box)
  281. all_boxes.append(boxes)
  282. return all_boxes
  283. def nms(boxes, nms_thresh):
  284. """
  285. Performs non-maximum suppression on the input boxes based on their IoUs.
  286. """
  287. import paddle
  288. if len(boxes) == 0:
  289. return boxes
  290. det_confs = paddle.zeros([len(boxes)])
  291. for i in range(len(boxes)):
  292. det_confs[i] = 1 - boxes[i][4]
  293. sortIds = paddle.argsort(det_confs)
  294. out_boxes = []
  295. for i in range(len(boxes)):
  296. box_i = boxes[sortIds[i]]
  297. if box_i[4] > 0:
  298. out_boxes.append(box_i)
  299. for j in range(i + 1, len(boxes)):
  300. box_j = boxes[sortIds[j]]
  301. if bbox_iou(box_i, box_j, x1y1x2y2=False) > nms_thresh:
  302. box_j[4] = 0
  303. return out_boxes
  304. def bbox_iou(box1, box2, x1y1x2y2=True):
  305. """
  306. Returns the Intersection over Union (IoU) of two bounding boxes.
  307. """
  308. import paddle
  309. if x1y1x2y2:
  310. mx = min(box1[0], box2[0])
  311. Mx = max(box1[2], box2[2])
  312. my = min(box1[1], box2[1])
  313. My = max(box1[3], box2[3])
  314. w1 = box1[2] - box1[0]
  315. h1 = box1[3] - box1[1]
  316. w2 = box2[2] - box2[0]
  317. h2 = box2[3] - box2[1]
  318. else:
  319. mx = min(float(box1[0] - box1[2] / 2.0), float(box2[0] - box2[2] / 2.0))
  320. Mx = max(float(box1[0] + box1[2] / 2.0), float(box2[0] + box2[2] / 2.0))
  321. my = min(float(box1[1] - box1[3] / 2.0), float(box2[1] - box2[3] / 2.0))
  322. My = max(float(box1[1] + box1[3] / 2.0), float(box2[1] + box2[3] / 2.0))
  323. w1 = box1[2]
  324. h1 = box1[3]
  325. w2 = box2[2]
  326. h2 = box2[3]
  327. uw = Mx - mx
  328. uh = My - my
  329. cw = w1 + w2 - uw
  330. ch = h1 + h2 - uh
  331. carea = 0
  332. if cw <= 0 or ch <= 0:
  333. return paddle.to_tensor(0.0)
  334. area1 = w1 * h1
  335. area2 = w2 * h2
  336. carea = cw * ch
  337. uarea = area1 + area2 - carea
  338. return carea / uarea
  339. @benchmark.timeit
  340. class DetVideoPostProcess:
  341. """
  342. A class used to perform post-processing on detection results in videos.
  343. """
  344. def __init__(
  345. self,
  346. label_list: List[str] = [],
  347. ) -> None:
  348. """
  349. Args:
  350. labels : List[str]
  351. A list of labels or class names associated with the detection results.
  352. """
  353. super().__init__()
  354. self.labels = label_list
  355. def postprocess(self, pred: List, nms_thresh: float, score_thresh: float) -> List:
  356. import paddle
  357. num_seg = len(pred)
  358. pred_all = []
  359. for i in range(num_seg):
  360. outputs = pred[i]
  361. for out in outputs:
  362. preds = []
  363. out = paddle.to_tensor(out)
  364. all_boxes = get_region_boxes(out, num_classes=len(self.labels))
  365. for i in range(out.shape[0]):
  366. boxes = all_boxes[i]
  367. boxes = nms(boxes, nms_thresh)
  368. for box in boxes:
  369. x1 = round(float(box[0] - box[2] / 2.0) * 320.0)
  370. y1 = round(float(box[1] - box[3] / 2.0) * 240.0)
  371. x2 = round(float(box[0] + box[2] / 2.0) * 320.0)
  372. y2 = round(float(box[1] + box[3] / 2.0) * 240.0)
  373. det_conf = float(box[4])
  374. for j in range((len(box) - 5) // 2):
  375. cls_conf = float(box[5 + 2 * j].item())
  376. prob = det_conf * cls_conf
  377. if prob > score_thresh:
  378. preds.append(
  379. [[x1, y1, x2, y2], prob, self.labels[int(box[6])]]
  380. )
  381. pred_all.append(preds)
  382. return pred_all
  383. def __call__(self, preds: List, nms_thresh, score_thresh) -> List:
  384. return [self.postprocess(pred, nms_thresh, score_thresh) for pred in preds]