processors.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. from typing import List, Sequence, Union, Optional, Tuple
  17. import numpy as np
  18. import cv2
  19. import lazy_paddle as paddle
  20. from ...utils.benchmark import benchmark
  21. @benchmark.timeit
  22. class ResizeVideo:
  23. """Resizes frames of a video to a specified target size.
  24. This class provides functionality to resize each frame of a video to
  25. a specified square dimension (height and width are equal).
  26. Attributes:
  27. target_size (int): The desired size (in pixels) for both the height
  28. and width of each frame in the video.
  29. """
  30. def __init__(self, target_size: int = 224) -> None:
  31. """Initializes the ResizeVideo with a target size.
  32. Args:
  33. target_size (int): The desired size in pixels for the output
  34. frames. Defaults to 224.
  35. """
  36. super().__init__()
  37. self.target_size = target_size
  38. def resize(self, video: List) -> List:
  39. """Resizes all frames of a single video.
  40. Args:
  41. video (list): A list of segments, where each segment is a list
  42. of frames represented as numpy arrays.
  43. Returns:
  44. list: The input video with each frame resized to the target size.
  45. Raises:
  46. NotImplementedError: If a frame is not an instance of numpy.ndarray.
  47. """
  48. num_seg = len(video)
  49. seg_len = len(video[0])
  50. for i in range(num_seg):
  51. for j in range(seg_len):
  52. img = video[i][j]
  53. if isinstance(img, np.ndarray):
  54. h, w, _ = img.shape
  55. else:
  56. raise NotImplementedError(
  57. "Currently, only numpy.ndarray frames are supported."
  58. )
  59. video[i][j] = cv2.resize(
  60. img,
  61. (self.target_size, self.target_size),
  62. interpolation=cv2.INTER_LINEAR,
  63. )
  64. return video
  65. def __call__(self, videos: List) -> List:
  66. """Resizes frames of multiple videos.
  67. Args:
  68. videos (list): A list containing multiple videos, where each video
  69. is a list of segments, and each segment is a list of frames.
  70. Returns:
  71. list: A list of videos with each frame resized to the target size.
  72. """
  73. return [self.resize(video) for video in videos]
  74. @benchmark.timeit
  75. class Image2Array:
  76. """Convert a sequence of images to a numpy array with optional transposition."""
  77. def __init__(self, data_format: str = "tchw") -> None:
  78. """
  79. Initializes the Image2Array class.
  80. Args:
  81. data_format (str): The format to transpose to, either 'tchw' or 'cthw'.
  82. Raises:
  83. AssertionError: If data_format is not one of the allowed values.
  84. """
  85. super().__init__()
  86. assert data_format in [
  87. "tchw",
  88. "cthw",
  89. ], f"Target format must be in ['tchw', 'cthw'], but got {data_format}"
  90. self.data_format = data_format
  91. def img2array(self, video: List) -> List:
  92. """
  93. Converts a list of video frames to a numpy array, with frames transposed.
  94. Args:
  95. video (List): A list of frames represented as numpy arrays.
  96. Returns:
  97. List: A numpy array with the video frames transposed and concatenated.
  98. """
  99. # Transpose each image from HWC to CHW format
  100. num_seg = len(video)
  101. for i in range(num_seg):
  102. video_one = video[i]
  103. video_one = [img.transpose([2, 0, 1]) for img in video_one]
  104. video_one = np.concatenate(
  105. [np.expand_dims(img, axis=1) for img in video_one], axis=1
  106. )
  107. video[i] = video_one
  108. return video
  109. def __call__(self, videos: List[List[np.ndarray]]) -> List[np.ndarray]:
  110. """
  111. Process videos by converting each video to a transposed numpy array.
  112. Args:
  113. videos (List[List[np.ndarray]]): A list of videos, where each video is a list
  114. of frames represented as numpy arrays.
  115. Returns:
  116. List[np.ndarray]: A list of processed videos with transposed frames.
  117. """
  118. return [self.img2array(video) for video in videos]
  119. @benchmark.timeit
  120. class NormalizeVideo:
  121. """
  122. A class to normalize video frames by scaling the pixel values.
  123. """
  124. def __init__(self, scale: float = 255.0) -> None:
  125. """
  126. Initializes the NormalizeVideo class.
  127. Args:
  128. scale (float): The scale factor to normalize the frames, usually the max pixel value.
  129. """
  130. super().__init__()
  131. self.scale = scale
  132. def normalize_video(self, video: List[np.ndarray]) -> List[np.ndarray]:
  133. """
  134. Normalizes a sequence of images by scaling the pixel values.
  135. Args:
  136. video (List[np.ndarray]): A list of frames, where each frame is a numpy array to be normalized.
  137. Returns:
  138. List[np.ndarray]: The normalized video frames as a list of numpy arrays.
  139. """
  140. num_seg = len(video) # Number of frames in the video
  141. for i in range(num_seg):
  142. # Convert frame to float32 and scale pixel values
  143. video[i] = video[i].astype(np.float32) / self.scale
  144. # Expand dimensions if needed
  145. video[i] = np.expand_dims(video[i], axis=0)
  146. return video
  147. def __call__(self, videos: List[List[np.ndarray]]) -> List[List[np.ndarray]]:
  148. """
  149. Apply normalization to a list of videos.
  150. Args:
  151. videos (List[List[np.ndarray]]): A list of videos, where each video is a list of frames
  152. represented as numpy arrays.
  153. Returns:
  154. List[List[np.ndarray]]: A list of normalized videos, each represented as a list of normalized frames.
  155. """
  156. return [self.normalize_video(video) for video in videos]
  157. def convert2cpu(gpu_matrix):
  158. float_32_g = gpu_matrix.astype("float32")
  159. return float_32_g.cpu()
  160. def convert2cpu_long(gpu_matrix):
  161. int_64_g = gpu_matrix.astype("int64")
  162. return int_64_g.cpu()
  163. def get_region_boxes(
  164. output,
  165. conf_thresh=0.005,
  166. num_classes=24,
  167. anchors=[
  168. 0.70458,
  169. 1.18803,
  170. 1.26654,
  171. 2.55121,
  172. 1.59382,
  173. 4.08321,
  174. 2.30548,
  175. 4.94180,
  176. 3.52332,
  177. 5.91979,
  178. ],
  179. num_anchors=5,
  180. only_objectness=1,
  181. ):
  182. """
  183. Processes the output of a neural network to extract bounding box predictions.
  184. Args:
  185. output (Tensor): The output tensor from the neural network.
  186. conf_thresh (float): The confidence threshold for filtering predictions. Default is 0.005.
  187. num_classes (int): The number of classes for classification. Default is 24.
  188. anchors (List[float]): A list of anchor box dimensions used in the model. Default is a list
  189. of 10 predefined anchor values.
  190. num_anchors (int): The number of anchor boxes used in the model. Default is 5.
  191. only_objectness (int): If set to 1, only objectness scores are considered for filtering. Default is 1.
  192. Returns:
  193. all_box(List[List[float]]): A list of predicted bounding boxes for each image in the batch.
  194. """
  195. anchor_step = len(anchors) // num_anchors
  196. if output.dim() == 3:
  197. output = output.unsqueeze(0)
  198. batch = output.shape[0]
  199. assert output.shape[1] == (5 + num_classes) * num_anchors
  200. h = output.shape[2]
  201. w = output.shape[3]
  202. all_boxes = []
  203. output = paddle.reshape(output, [batch * num_anchors, 5 + num_classes, h * w])
  204. output = paddle.transpose(output, (1, 0, 2))
  205. output = paddle.reshape(output, [5 + num_classes, batch * num_anchors * h * w])
  206. grid_x = paddle.linspace(0, w - 1, w)
  207. grid_x = paddle.tile(grid_x, [h, 1])
  208. grid_x = paddle.tile(grid_x, [batch * num_anchors, 1, 1])
  209. grid_x = paddle.reshape(grid_x, [batch * num_anchors * h * w]).cuda()
  210. grid_y = paddle.linspace(0, h - 1, h)
  211. grid_y = paddle.tile(grid_y, [w, 1]).t()
  212. grid_y = paddle.tile(grid_y, [batch * num_anchors, 1, 1])
  213. grid_y = paddle.reshape(grid_y, [batch * num_anchors * h * w]).cuda()
  214. sigmoid = paddle.nn.Sigmoid()
  215. xs = sigmoid(output[0]) + grid_x
  216. ys = sigmoid(output[1]) + grid_y
  217. anchor_w = paddle.to_tensor(anchors)
  218. anchor_w = paddle.reshape(anchor_w, [num_anchors, anchor_step])
  219. anchor_w = paddle.index_select(
  220. anchor_w, index=paddle.to_tensor(np.array([0]).astype("int32")), axis=1
  221. )
  222. anchor_h = paddle.to_tensor(anchors)
  223. anchor_h = paddle.reshape(anchor_h, [num_anchors, anchor_step])
  224. anchor_h = paddle.index_select(
  225. anchor_h, index=paddle.to_tensor(np.array([1]).astype("int32")), axis=1
  226. )
  227. anchor_w = paddle.tile(anchor_w, [batch, 1])
  228. anchor_w = paddle.tile(anchor_w, [1, 1, h * w])
  229. anchor_w = paddle.reshape(anchor_w, [batch * num_anchors * h * w]).cuda()
  230. anchor_h = paddle.tile(anchor_h, [batch, 1])
  231. anchor_h = paddle.tile(anchor_h, [1, 1, h * w])
  232. anchor_h = paddle.reshape(anchor_h, [batch * num_anchors * h * w]).cuda()
  233. ws = paddle.exp(output[2]) * anchor_w
  234. hs = paddle.exp(output[3]) * anchor_h
  235. det_confs = sigmoid(output[4])
  236. cls_confs = paddle.to_tensor(output[5 : 5 + num_classes], stop_gradient=True)
  237. cls_confs = paddle.transpose(cls_confs, [1, 0])
  238. s = paddle.nn.Softmax()
  239. cls_confs = paddle.to_tensor(s(cls_confs))
  240. cls_max_confs = paddle.max(cls_confs, axis=1)
  241. cls_max_ids = paddle.argmax(cls_confs, axis=1)
  242. cls_max_confs = paddle.reshape(cls_max_confs, [-1])
  243. cls_max_ids = paddle.reshape(cls_max_ids, [-1])
  244. sz_hw = h * w
  245. sz_hwa = sz_hw * num_anchors
  246. det_confs = convert2cpu(det_confs)
  247. cls_max_confs = convert2cpu(cls_max_confs)
  248. cls_max_ids = convert2cpu_long(cls_max_ids)
  249. xs = convert2cpu(xs)
  250. ys = convert2cpu(ys)
  251. ws = convert2cpu(ws)
  252. hs = convert2cpu(hs)
  253. for b in range(batch):
  254. boxes = []
  255. for cy in range(h):
  256. for cx in range(w):
  257. for i in range(num_anchors):
  258. ind = b * sz_hwa + i * sz_hw + cy * w + cx
  259. det_conf = det_confs[ind]
  260. if only_objectness:
  261. conf = det_confs[ind]
  262. else:
  263. conf = det_confs[ind] * cls_max_confs[ind]
  264. if conf > conf_thresh:
  265. bcx = xs[ind]
  266. bcy = ys[ind]
  267. bw = ws[ind]
  268. bh = hs[ind]
  269. cls_max_conf = cls_max_confs[ind]
  270. cls_max_id = cls_max_ids[ind]
  271. box = [
  272. bcx / w,
  273. bcy / h,
  274. bw / w,
  275. bh / h,
  276. det_conf,
  277. cls_max_conf,
  278. cls_max_id,
  279. ]
  280. boxes.append(box)
  281. all_boxes.append(boxes)
  282. return all_boxes
  283. def nms(boxes, nms_thresh):
  284. """
  285. Performs non-maximum suppression on the input boxes based on their IoUs.
  286. """
  287. if len(boxes) == 0:
  288. return boxes
  289. det_confs = paddle.zeros([len(boxes)])
  290. for i in range(len(boxes)):
  291. det_confs[i] = 1 - boxes[i][4]
  292. sortIds = paddle.argsort(det_confs)
  293. out_boxes = []
  294. for i in range(len(boxes)):
  295. box_i = boxes[sortIds[i]]
  296. if box_i[4] > 0:
  297. out_boxes.append(box_i)
  298. for j in range(i + 1, len(boxes)):
  299. box_j = boxes[sortIds[j]]
  300. if bbox_iou(box_i, box_j, x1y1x2y2=False) > nms_thresh:
  301. box_j[4] = 0
  302. return out_boxes
  303. def bbox_iou(box1, box2, x1y1x2y2=True):
  304. """
  305. Returns the Intersection over Union (IoU) of two bounding boxes.
  306. """
  307. if x1y1x2y2:
  308. mx = min(box1[0], box2[0])
  309. Mx = max(box1[2], box2[2])
  310. my = min(box1[1], box2[1])
  311. My = max(box1[3], box2[3])
  312. w1 = box1[2] - box1[0]
  313. h1 = box1[3] - box1[1]
  314. w2 = box2[2] - box2[0]
  315. h2 = box2[3] - box2[1]
  316. else:
  317. mx = min(float(box1[0] - box1[2] / 2.0), float(box2[0] - box2[2] / 2.0))
  318. Mx = max(float(box1[0] + box1[2] / 2.0), float(box2[0] + box2[2] / 2.0))
  319. my = min(float(box1[1] - box1[3] / 2.0), float(box2[1] - box2[3] / 2.0))
  320. My = max(float(box1[1] + box1[3] / 2.0), float(box2[1] + box2[3] / 2.0))
  321. w1 = box1[2]
  322. h1 = box1[3]
  323. w2 = box2[2]
  324. h2 = box2[3]
  325. uw = Mx - mx
  326. uh = My - my
  327. cw = w1 + w2 - uw
  328. ch = h1 + h2 - uh
  329. carea = 0
  330. if cw <= 0 or ch <= 0:
  331. return paddle.to_tensor(0.0)
  332. area1 = w1 * h1
  333. area2 = w2 * h2
  334. carea = cw * ch
  335. uarea = area1 + area2 - carea
  336. return carea / uarea
  337. @benchmark.timeit
  338. class DetVideoPostProcess:
  339. """
  340. A class used to perform post-processing on detection results in videos.
  341. """
  342. def __init__(
  343. self,
  344. label_list: List[str] = [],
  345. ) -> None:
  346. """
  347. Args:
  348. labels : List[str]
  349. A list of labels or class names associated with the detection results.
  350. """
  351. super().__init__()
  352. self.labels = label_list
  353. def postprocess(self, pred: List, nms_thresh: float, score_thresh: float) -> List:
  354. font = cv2.FONT_HERSHEY_SIMPLEX
  355. num_seg = len(pred)
  356. pred_all = []
  357. for i in range(num_seg):
  358. outputs = pred[i]
  359. for out in outputs:
  360. preds = []
  361. out = paddle.to_tensor(out)
  362. all_boxes = get_region_boxes(out, num_classes=len(self.labels))
  363. for i in range(out.shape[0]):
  364. boxes = all_boxes[i]
  365. boxes = nms(boxes, nms_thresh)
  366. for box in boxes:
  367. x1 = round(float(box[0] - box[2] / 2.0) * 320.0)
  368. y1 = round(float(box[1] - box[3] / 2.0) * 240.0)
  369. x2 = round(float(box[0] + box[2] / 2.0) * 320.0)
  370. y2 = round(float(box[1] + box[3] / 2.0) * 240.0)
  371. det_conf = float(box[4])
  372. for j in range((len(box) - 5) // 2):
  373. cls_conf = float(box[5 + 2 * j].item())
  374. prob = det_conf * cls_conf
  375. if prob > score_thresh:
  376. preds.append(
  377. [[x1, y1, x2, y2], prob, self.labels[int(box[6])]]
  378. )
  379. pred_all.append(preds)
  380. return pred_all
  381. def __call__(self, preds: List, nms_thresh, score_thresh) -> List:
  382. return [self.postprocess(pred, nms_thresh, score_thresh) for pred in preds]