processors.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from typing import List, Sequence, Tuple, Union, Optional
  16. import cv2
  17. import numpy as np
  18. from numpy import ndarray
  19. from ..object_detection.processors import get_affine_transform
  20. from ...utils.benchmark import benchmark
  21. Number = Union[int, float]
  22. Kpts = List[dict]
  23. def get_warp_matrix(
  24. theta: float, size_input: ndarray, size_dst: ndarray, size_target: ndarray
  25. ) -> ndarray:
  26. """This code is based on
  27. https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
  28. Calculate the transformation matrix under the constraint of unbiased.
  29. Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
  30. Data Processing for Human Pose Estimation (CVPR 2020).
  31. Args:
  32. theta (float): Rotation angle in degrees.
  33. size_input (np.ndarray): Size of input image [w, h].
  34. size_dst (np.ndarray): Size of output image [w, h].
  35. size_target (np.ndarray): Size of ROI in input plane [w, h].
  36. Returns:
  37. matrix (np.ndarray): A matrix for transformation.
  38. """
  39. theta = np.deg2rad(theta)
  40. matrix = np.zeros((2, 3), dtype=np.float32)
  41. scale_x = size_dst[0] / size_target[0]
  42. scale_y = size_dst[1] / size_target[1]
  43. matrix[0, 0] = np.cos(theta) * scale_x
  44. matrix[0, 1] = -np.sin(theta) * scale_x
  45. matrix[0, 2] = scale_x * (
  46. -0.5 * size_input[0] * np.cos(theta)
  47. + 0.5 * size_input[1] * np.sin(theta)
  48. + 0.5 * size_target[0]
  49. )
  50. matrix[1, 0] = np.sin(theta) * scale_y
  51. matrix[1, 1] = np.cos(theta) * scale_y
  52. matrix[1, 2] = scale_y * (
  53. -0.5 * size_input[0] * np.sin(theta)
  54. - 0.5 * size_input[1] * np.cos(theta)
  55. + 0.5 * size_target[1]
  56. )
  57. return matrix
  58. @benchmark.timeit
  59. class TopDownAffine:
  60. """refer to https://github.com/open-mmlab/mmpose/blob/71ec36ebd63c475ab589afc817868e749a61491f/mmpose/datasets/transforms/topdown_transforms.py#L13
  61. Get the bbox image as the model input by affine transform.
  62. Args:
  63. input_size (Tuple[int, int]): The input image size of the model in
  64. [w, h]. The bbox region will be cropped and resize to `input_size`
  65. use_udp (bool): Whether use unbiased data processing. See
  66. `UDP (CVPR 2020)`_ for details. Defaults to ``False``
  67. .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524
  68. """
  69. def __init__(self, input_size: Tuple[int, int], use_udp: bool = False):
  70. assert (
  71. all([isinstance(i, int) for i in input_size]) and len(input_size) == 2
  72. ), f"Invalid input_size {input_size}"
  73. self.input_size = input_size
  74. self.use_udp = use_udp
  75. def apply(
  76. self,
  77. img: ndarray,
  78. center: Optional[Union[Tuple[Number, Number], ndarray]] = None,
  79. scale: Optional[Union[Tuple[Number, Number], ndarray]] = None,
  80. ) -> Tuple[ndarray, ndarray, ndarray]:
  81. """Applies a wrapaffine to the input image based on the specified center, scale.
  82. Args:
  83. img (ndarray): The input image as a NumPy ndarray.
  84. center (Optional[Union[Tuple[Number, Number], ndarray]], optional): Center of the bounding box (x, y)
  85. scale (Optional[Union[Tuple[Number, Number], ndarray]], optional): Scale of the bounding box
  86. wrt [width, height].
  87. Returns:
  88. Tuple[ndarray, ndarray, ndarray]: The transformed image,
  89. the center used for the transformation, and the scale used for the transformation.
  90. """
  91. rot = 0
  92. imshape = np.array(img.shape[:2][::-1])
  93. if isinstance(center, Sequence):
  94. center = np.array(center)
  95. if isinstance(scale, Sequence):
  96. scale = np.array(scale)
  97. center = center if center is not None else imshape / 2.0
  98. scale = scale if scale is not None else imshape
  99. if self.use_udp:
  100. trans = get_warp_matrix(
  101. rot,
  102. center * 2.0,
  103. [self.input_size[0] - 1.0, self.input_size[1] - 1.0],
  104. scale,
  105. )
  106. img = cv2.warpAffine(
  107. img,
  108. trans,
  109. (int(self.input_size[0]), int(self.input_size[1])),
  110. flags=cv2.INTER_LINEAR,
  111. )
  112. else:
  113. trans = get_affine_transform(center, scale, rot, self.input_size)
  114. img = cv2.warpAffine(
  115. img,
  116. trans,
  117. (int(self.input_size[0]), int(self.input_size[1])),
  118. flags=cv2.INTER_LINEAR,
  119. )
  120. return img, center, scale
  121. def __call__(self, datas: List[dict]) -> List[dict]:
  122. for data in datas:
  123. ori_img = data["img"]
  124. if "ori_img" not in data:
  125. data["ori_img"] = ori_img
  126. if "ori_img_size" not in data:
  127. data["ori_img_size"] = [ori_img.shape[1], ori_img.shape[0]]
  128. img, center, scale = self.apply(
  129. ori_img, data.get("center", None), data.get("scale", None)
  130. )
  131. data["img"] = img
  132. data["center"] = center
  133. data["scale"] = scale
  134. img_size = [img.shape[1], img.shape[0]]
  135. data["img_size"] = img_size # [size_w, size_h]
  136. return datas
  137. def affine_transform(pt: ndarray, t: ndarray):
  138. """Apply an affine transformation to a 2D point.
  139. Args:
  140. pt (numpy.ndarray): A 2D point represented as a 2-element array.
  141. t (numpy.ndarray): A 3x3 affine transformation matrix.
  142. Returns:
  143. numpy.ndarray: The transformed 2D point.
  144. """
  145. new_pt = np.array([pt[0], pt[1], 1.0]).T
  146. new_pt = np.dot(t, new_pt)
  147. return new_pt[:2]
  148. def transform_preds(
  149. coords: ndarray,
  150. center: Tuple[float, float],
  151. scale: Tuple[float, float],
  152. output_size: Tuple[int, int],
  153. ) -> ndarray:
  154. """Transform coordinates to the target space using an affine transformation.
  155. Args:
  156. coords (numpy.ndarray): Original coordinates, shape (N, 2).
  157. center (tuple): Center point for the transformation.
  158. scale (tuple): Scale factor for the transformation.
  159. output_size (tuple): Size of the output space.
  160. Returns:
  161. numpy.ndarray: Transformed coordinates, shape (N, 2).
  162. """
  163. target_coords = np.zeros(coords.shape)
  164. trans = get_affine_transform(center, scale, 0, output_size, inv=1)
  165. for p in range(coords.shape[0]):
  166. target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
  167. return target_coords
  168. @benchmark.timeit
  169. class KptPostProcess:
  170. """Save Result Transform"""
  171. def __init__(self, use_dark=True):
  172. self.use_dark = use_dark
  173. def apply(self, heatmap: ndarray, center: ndarray, scale: ndarray) -> Kpts:
  174. """apply"""
  175. # TODO: add batch support
  176. heatmap, center, scale = heatmap[None, ...], center[None, ...], scale[None, ...]
  177. preds, maxvals = self.get_final_preds(heatmap, center, scale)
  178. keypoints, scores = np.concatenate((preds, maxvals), axis=-1), np.mean(
  179. maxvals.squeeze(-1), axis=1
  180. )
  181. return [
  182. {"keypoints": kpt, "kpt_score": score}
  183. for kpt, score in zip(keypoints, scores)
  184. ]
  185. def __call__(self, batch_outputs: List[dict], datas: List[dict]) -> List[Kpts]:
  186. """Apply the post-processing to a batch of outputs.
  187. Args:
  188. batch_outputs (List[dict]): The list of detection outputs.
  189. datas (List[dict]): The list of input data.
  190. Returns:
  191. List[dict]: The list of post-processed keypoints.
  192. """
  193. return [
  194. self.apply(output["heatmap"], data["center"], data["scale"])
  195. for data, output in zip(datas, batch_outputs)
  196. ]
  197. def get_final_preds(
  198. self, heatmaps: ndarray, center: ndarray, scale: ndarray, kernelsize: int = 3
  199. ):
  200. """the highest heatvalue location with a quarter offset in the
  201. direction from the highest response to the second highest response.
  202. Args:
  203. heatmaps (numpy.ndarray): The predicted heatmaps
  204. center (numpy.ndarray): The boxes center
  205. scale (numpy.ndarray): The scale factor
  206. Returns:
  207. preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
  208. maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints
  209. """
  210. coords, maxvals = self.get_max_preds(heatmaps)
  211. heatmap_height = heatmaps.shape[2]
  212. heatmap_width = heatmaps.shape[3]
  213. if self.use_dark:
  214. coords = self.dark_postprocess(heatmaps, coords, kernelsize)
  215. else:
  216. for n in range(coords.shape[0]):
  217. for p in range(coords.shape[1]):
  218. hm = heatmaps[n][p]
  219. px = int(math.floor(coords[n][p][0] + 0.5))
  220. py = int(math.floor(coords[n][p][1] + 0.5))
  221. if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
  222. diff = np.array(
  223. [
  224. hm[py][px + 1] - hm[py][px - 1],
  225. hm[py + 1][px] - hm[py - 1][px],
  226. ]
  227. )
  228. coords[n][p] += np.sign(diff) * 0.25
  229. preds = coords.copy()
  230. # Transform back
  231. for i in range(coords.shape[0]):
  232. preds[i] = transform_preds(
  233. coords[i], center[i], scale[i], [heatmap_width, heatmap_height]
  234. )
  235. return preds, maxvals
  236. def get_max_preds(self, heatmaps: ndarray) -> Tuple[ndarray, ndarray]:
  237. """get predictions from score maps
  238. Args:
  239. heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
  240. Returns:
  241. preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
  242. maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints
  243. """
  244. assert isinstance(heatmaps, np.ndarray), "heatmaps should be numpy.ndarray"
  245. assert heatmaps.ndim == 4, "batch_images should be 4-ndim"
  246. batch_size = heatmaps.shape[0]
  247. num_joints = heatmaps.shape[1]
  248. width = heatmaps.shape[3]
  249. heatmaps_reshaped = heatmaps.reshape((batch_size, num_joints, -1))
  250. idx = np.argmax(heatmaps_reshaped, 2)
  251. maxvals = np.amax(heatmaps_reshaped, 2)
  252. maxvals = maxvals.reshape((batch_size, num_joints, 1))
  253. idx = idx.reshape((batch_size, num_joints, 1))
  254. preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
  255. preds[:, :, 0] = (preds[:, :, 0]) % width
  256. preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
  257. pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
  258. pred_mask = pred_mask.astype(np.float32)
  259. preds *= pred_mask
  260. return preds, maxvals
  261. def gaussian_blur(self, heatmap: ndarray, kernel: int) -> ndarray:
  262. border = (kernel - 1) // 2
  263. batch_size = heatmap.shape[0]
  264. num_joints = heatmap.shape[1]
  265. height = heatmap.shape[2]
  266. width = heatmap.shape[3]
  267. for i in range(batch_size):
  268. for j in range(num_joints):
  269. origin_max = np.max(heatmap[i, j])
  270. dr = np.zeros((height + 2 * border, width + 2 * border))
  271. dr[border:-border, border:-border] = heatmap[i, j].copy()
  272. dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)
  273. heatmap[i, j] = dr[border:-border, border:-border].copy()
  274. heatmap[i, j] *= origin_max / np.max(heatmap[i, j])
  275. return heatmap
  276. def dark_parse(self, hm: ndarray, coord: ndarray):
  277. heatmap_height = hm.shape[0]
  278. heatmap_width = hm.shape[1]
  279. px = int(coord[0])
  280. py = int(coord[1])
  281. if 1 < px < heatmap_width - 2 and 1 < py < heatmap_height - 2:
  282. dx = 0.5 * (hm[py][px + 1] - hm[py][px - 1])
  283. dy = 0.5 * (hm[py + 1][px] - hm[py - 1][px])
  284. dxx = 0.25 * (hm[py][px + 2] - 2 * hm[py][px] + hm[py][px - 2])
  285. dxy = 0.25 * (
  286. hm[py + 1][px + 1]
  287. - hm[py - 1][px + 1]
  288. - hm[py + 1][px - 1]
  289. + hm[py - 1][px - 1]
  290. )
  291. dyy = 0.25 * (hm[py + 2 * 1][px] - 2 * hm[py][px] + hm[py - 2 * 1][px])
  292. derivative = np.matrix([[dx], [dy]])
  293. hessian = np.matrix([[dxx, dxy], [dxy, dyy]])
  294. if dxx * dyy - dxy**2 != 0:
  295. hessianinv = hessian.I
  296. offset = -hessianinv * derivative
  297. offset = np.squeeze(np.array(offset.T), axis=0)
  298. coord += offset
  299. return coord
  300. def dark_postprocess(
  301. self, hm: ndarray, coords: ndarray, kernelsize: int
  302. ) -> ndarray:
  303. """
  304. refer to https://github.com/ilovepose/DarkPose/lib/core/inference.py
  305. """
  306. hm = self.gaussian_blur(hm, kernelsize)
  307. hm = np.maximum(hm, 1e-10)
  308. hm = np.log(hm)
  309. for n in range(coords.shape[0]):
  310. for p in range(coords.shape[1]):
  311. coords[n, p] = self.dark_parse(hm[n][p], coords[n][p])
  312. return coords