processors.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from typing import List, Optional, Sequence, Tuple, Union
  16. import numpy as np
  17. from numpy import ndarray
  18. from ....utils.deps import class_requires_deps, is_dep_available
  19. from ...utils.benchmark import benchmark
  20. from ..object_detection.processors import get_affine_transform
  21. if is_dep_available("opencv-contrib-python"):
  22. import cv2
  23. Number = Union[int, float]
  24. Kpts = List[dict]
  25. def get_warp_matrix(
  26. theta: float, size_input: ndarray, size_dst: ndarray, size_target: ndarray
  27. ) -> ndarray:
  28. """This code is based on
  29. https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
  30. Calculate the transformation matrix under the constraint of unbiased.
  31. Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
  32. Data Processing for Human Pose Estimation (CVPR 2020).
  33. Args:
  34. theta (float): Rotation angle in degrees.
  35. size_input (np.ndarray): Size of input image [w, h].
  36. size_dst (np.ndarray): Size of output image [w, h].
  37. size_target (np.ndarray): Size of ROI in input plane [w, h].
  38. Returns:
  39. matrix (np.ndarray): A matrix for transformation.
  40. """
  41. theta = np.deg2rad(theta)
  42. matrix = np.zeros((2, 3), dtype=np.float32)
  43. scale_x = size_dst[0] / size_target[0]
  44. scale_y = size_dst[1] / size_target[1]
  45. matrix[0, 0] = np.cos(theta) * scale_x
  46. matrix[0, 1] = -np.sin(theta) * scale_x
  47. matrix[0, 2] = scale_x * (
  48. -0.5 * size_input[0] * np.cos(theta)
  49. + 0.5 * size_input[1] * np.sin(theta)
  50. + 0.5 * size_target[0]
  51. )
  52. matrix[1, 0] = np.sin(theta) * scale_y
  53. matrix[1, 1] = np.cos(theta) * scale_y
  54. matrix[1, 2] = scale_y * (
  55. -0.5 * size_input[0] * np.sin(theta)
  56. - 0.5 * size_input[1] * np.cos(theta)
  57. + 0.5 * size_target[1]
  58. )
  59. return matrix
  60. @benchmark.timeit
  61. @class_requires_deps("opencv-contrib-python")
  62. class TopDownAffine:
  63. """refer to https://github.com/open-mmlab/mmpose/blob/71ec36ebd63c475ab589afc817868e749a61491f/mmpose/datasets/transforms/topdown_transforms.py#L13
  64. Get the bbox image as the model input by affine transform.
  65. Args:
  66. input_size (Tuple[int, int]): The input image size of the model in
  67. [w, h]. The bbox region will be cropped and resize to `input_size`
  68. use_udp (bool): Whether use unbiased data processing. See
  69. `UDP (CVPR 2020)`_ for details. Defaults to ``False``
  70. .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524
  71. """
  72. def __init__(self, input_size: Tuple[int, int], use_udp: bool = False):
  73. assert (
  74. all([isinstance(i, int) for i in input_size]) and len(input_size) == 2
  75. ), f"Invalid input_size {input_size}"
  76. self.input_size = input_size
  77. self.use_udp = use_udp
  78. def apply(
  79. self,
  80. img: ndarray,
  81. center: Optional[Union[Tuple[Number, Number], ndarray]] = None,
  82. scale: Optional[Union[Tuple[Number, Number], ndarray]] = None,
  83. ) -> Tuple[ndarray, ndarray, ndarray]:
  84. """Applies a wrapaffine to the input image based on the specified center, scale.
  85. Args:
  86. img (ndarray): The input image as a NumPy ndarray.
  87. center (Optional[Union[Tuple[Number, Number], ndarray]], optional): Center of the bounding box (x, y)
  88. scale (Optional[Union[Tuple[Number, Number], ndarray]], optional): Scale of the bounding box
  89. wrt [width, height].
  90. Returns:
  91. Tuple[ndarray, ndarray, ndarray]: The transformed image,
  92. the center used for the transformation, and the scale used for the transformation.
  93. """
  94. rot = 0
  95. imshape = np.array(img.shape[:2][::-1])
  96. if isinstance(center, Sequence):
  97. center = np.array(center)
  98. if isinstance(scale, Sequence):
  99. scale = np.array(scale)
  100. center = center if center is not None else imshape / 2.0
  101. scale = scale if scale is not None else imshape
  102. if self.use_udp:
  103. trans = get_warp_matrix(
  104. rot,
  105. center * 2.0,
  106. [self.input_size[0] - 1.0, self.input_size[1] - 1.0],
  107. scale,
  108. )
  109. img = cv2.warpAffine(
  110. img,
  111. trans,
  112. (int(self.input_size[0]), int(self.input_size[1])),
  113. flags=cv2.INTER_LINEAR,
  114. )
  115. else:
  116. trans = get_affine_transform(center, scale, rot, self.input_size)
  117. img = cv2.warpAffine(
  118. img,
  119. trans,
  120. (int(self.input_size[0]), int(self.input_size[1])),
  121. flags=cv2.INTER_LINEAR,
  122. )
  123. return img, center, scale
  124. def __call__(self, datas: List[dict]) -> List[dict]:
  125. for data in datas:
  126. ori_img = data["img"]
  127. if "ori_img" not in data:
  128. data["ori_img"] = ori_img
  129. if "ori_img_size" not in data:
  130. data["ori_img_size"] = [ori_img.shape[1], ori_img.shape[0]]
  131. img, center, scale = self.apply(
  132. ori_img, data.get("center", None), data.get("scale", None)
  133. )
  134. data["img"] = img
  135. data["center"] = center
  136. data["scale"] = scale
  137. img_size = [img.shape[1], img.shape[0]]
  138. data["img_size"] = img_size # [size_w, size_h]
  139. return datas
  140. def affine_transform(pt: ndarray, t: ndarray):
  141. """Apply an affine transformation to a 2D point.
  142. Args:
  143. pt (numpy.ndarray): A 2D point represented as a 2-element array.
  144. t (numpy.ndarray): A 3x3 affine transformation matrix.
  145. Returns:
  146. numpy.ndarray: The transformed 2D point.
  147. """
  148. new_pt = np.array([pt[0], pt[1], 1.0]).T
  149. new_pt = np.dot(t, new_pt)
  150. return new_pt[:2]
  151. def transform_preds(
  152. coords: ndarray,
  153. center: Tuple[float, float],
  154. scale: Tuple[float, float],
  155. output_size: Tuple[int, int],
  156. ) -> ndarray:
  157. """Transform coordinates to the target space using an affine transformation.
  158. Args:
  159. coords (numpy.ndarray): Original coordinates, shape (N, 2).
  160. center (tuple): Center point for the transformation.
  161. scale (tuple): Scale factor for the transformation.
  162. output_size (tuple): Size of the output space.
  163. Returns:
  164. numpy.ndarray: Transformed coordinates, shape (N, 2).
  165. """
  166. target_coords = np.zeros(coords.shape)
  167. trans = get_affine_transform(center, scale, 0, output_size, inv=1)
  168. for p in range(coords.shape[0]):
  169. target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
  170. return target_coords
  171. @benchmark.timeit
  172. @class_requires_deps("opencv-contrib-python")
  173. class KptPostProcess:
  174. """Save Result Transform"""
  175. def __init__(self, use_dark=True):
  176. self.use_dark = use_dark
  177. def apply(self, heatmap: ndarray, center: ndarray, scale: ndarray) -> Kpts:
  178. """apply"""
  179. # TODO: add batch support
  180. heatmap, center, scale = heatmap[None, ...], center[None, ...], scale[None, ...]
  181. preds, maxvals = self.get_final_preds(heatmap, center, scale)
  182. keypoints, scores = np.concatenate((preds, maxvals), axis=-1), np.mean(
  183. maxvals.squeeze(-1), axis=1
  184. )
  185. return [
  186. {"keypoints": kpt, "kpt_score": score}
  187. for kpt, score in zip(keypoints, scores)
  188. ]
  189. def __call__(self, batch_outputs: List[dict], datas: List[dict]) -> List[Kpts]:
  190. """Apply the post-processing to a batch of outputs.
  191. Args:
  192. batch_outputs (List[dict]): The list of detection outputs.
  193. datas (List[dict]): The list of input data.
  194. Returns:
  195. List[dict]: The list of post-processed keypoints.
  196. """
  197. return [
  198. self.apply(output["heatmap"], data["center"], data["scale"])
  199. for data, output in zip(datas, batch_outputs)
  200. ]
  201. def get_final_preds(
  202. self, heatmaps: ndarray, center: ndarray, scale: ndarray, kernelsize: int = 3
  203. ):
  204. """the highest heatvalue location with a quarter offset in the
  205. direction from the highest response to the second highest response.
  206. Args:
  207. heatmaps (numpy.ndarray): The predicted heatmaps
  208. center (numpy.ndarray): The boxes center
  209. scale (numpy.ndarray): The scale factor
  210. Returns:
  211. preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
  212. maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints
  213. """
  214. coords, maxvals = self.get_max_preds(heatmaps)
  215. heatmap_height = heatmaps.shape[2]
  216. heatmap_width = heatmaps.shape[3]
  217. if self.use_dark:
  218. coords = self.dark_postprocess(heatmaps, coords, kernelsize)
  219. else:
  220. for n in range(coords.shape[0]):
  221. for p in range(coords.shape[1]):
  222. hm = heatmaps[n][p]
  223. px = int(math.floor(coords[n][p][0] + 0.5))
  224. py = int(math.floor(coords[n][p][1] + 0.5))
  225. if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
  226. diff = np.array(
  227. [
  228. hm[py][px + 1] - hm[py][px - 1],
  229. hm[py + 1][px] - hm[py - 1][px],
  230. ]
  231. )
  232. coords[n][p] += np.sign(diff) * 0.25
  233. preds = coords.copy()
  234. # Transform back
  235. for i in range(coords.shape[0]):
  236. preds[i] = transform_preds(
  237. coords[i], center[i], scale[i], [heatmap_width, heatmap_height]
  238. )
  239. return preds, maxvals
  240. def get_max_preds(self, heatmaps: ndarray) -> Tuple[ndarray, ndarray]:
  241. """get predictions from score maps
  242. Args:
  243. heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
  244. Returns:
  245. preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
  246. maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints
  247. """
  248. assert isinstance(heatmaps, np.ndarray), "heatmaps should be numpy.ndarray"
  249. assert heatmaps.ndim == 4, "batch_images should be 4-ndim"
  250. batch_size = heatmaps.shape[0]
  251. num_joints = heatmaps.shape[1]
  252. width = heatmaps.shape[3]
  253. heatmaps_reshaped = heatmaps.reshape((batch_size, num_joints, -1))
  254. idx = np.argmax(heatmaps_reshaped, 2)
  255. maxvals = np.amax(heatmaps_reshaped, 2)
  256. maxvals = maxvals.reshape((batch_size, num_joints, 1))
  257. idx = idx.reshape((batch_size, num_joints, 1))
  258. preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
  259. preds[:, :, 0] = (preds[:, :, 0]) % width
  260. preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
  261. pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
  262. pred_mask = pred_mask.astype(np.float32)
  263. preds *= pred_mask
  264. return preds, maxvals
  265. def gaussian_blur(self, heatmap: ndarray, kernel: int) -> ndarray:
  266. border = (kernel - 1) // 2
  267. batch_size = heatmap.shape[0]
  268. num_joints = heatmap.shape[1]
  269. height = heatmap.shape[2]
  270. width = heatmap.shape[3]
  271. for i in range(batch_size):
  272. for j in range(num_joints):
  273. origin_max = np.max(heatmap[i, j])
  274. dr = np.zeros((height + 2 * border, width + 2 * border))
  275. dr[border:-border, border:-border] = heatmap[i, j].copy()
  276. dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)
  277. heatmap[i, j] = dr[border:-border, border:-border].copy()
  278. heatmap[i, j] *= origin_max / np.max(heatmap[i, j])
  279. return heatmap
  280. def dark_parse(self, hm: ndarray, coord: ndarray):
  281. heatmap_height = hm.shape[0]
  282. heatmap_width = hm.shape[1]
  283. px = int(coord[0])
  284. py = int(coord[1])
  285. if 1 < px < heatmap_width - 2 and 1 < py < heatmap_height - 2:
  286. dx = 0.5 * (hm[py][px + 1] - hm[py][px - 1])
  287. dy = 0.5 * (hm[py + 1][px] - hm[py - 1][px])
  288. dxx = 0.25 * (hm[py][px + 2] - 2 * hm[py][px] + hm[py][px - 2])
  289. dxy = 0.25 * (
  290. hm[py + 1][px + 1]
  291. - hm[py - 1][px + 1]
  292. - hm[py + 1][px - 1]
  293. + hm[py - 1][px - 1]
  294. )
  295. dyy = 0.25 * (hm[py + 2 * 1][px] - 2 * hm[py][px] + hm[py - 2 * 1][px])
  296. derivative = np.matrix([[dx], [dy]])
  297. hessian = np.matrix([[dxx, dxy], [dxy, dyy]])
  298. if dxx * dyy - dxy**2 != 0:
  299. hessianinv = hessian.I
  300. offset = -hessianinv * derivative
  301. offset = np.squeeze(np.array(offset.T), axis=0)
  302. coord += offset
  303. return coord
  304. def dark_postprocess(
  305. self, hm: ndarray, coords: ndarray, kernelsize: int
  306. ) -> ndarray:
  307. """
  308. refer to https://github.com/ilovepose/DarkPose/lib/core/inference.py
  309. """
  310. hm = self.gaussian_blur(hm, kernelsize)
  311. hm = np.maximum(hm, 1e-10)
  312. hm = np.log(hm)
  313. for n in range(coords.shape[0]):
  314. for p in range(coords.shape[1]):
  315. coords[n, p] = self.dark_parse(hm[n][p], coords[n][p])
  316. return coords