processors.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from typing import List, Sequence, Tuple, Union, Optional
  16. import cv2
  17. import numpy as np
  18. from numpy import ndarray
  19. from ..object_detection.processors import get_affine_transform
  20. Number = Union[int, float]
  21. Kpts = List[dict]
  22. def get_warp_matrix(
  23. theta: float, size_input: ndarray, size_dst: ndarray, size_target: ndarray
  24. ) -> ndarray:
  25. """This code is based on
  26. https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
  27. Calculate the transformation matrix under the constraint of unbiased.
  28. Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
  29. Data Processing for Human Pose Estimation (CVPR 2020).
  30. Args:
  31. theta (float): Rotation angle in degrees.
  32. size_input (np.ndarray): Size of input image [w, h].
  33. size_dst (np.ndarray): Size of output image [w, h].
  34. size_target (np.ndarray): Size of ROI in input plane [w, h].
  35. Returns:
  36. matrix (np.ndarray): A matrix for transformation.
  37. """
  38. theta = np.deg2rad(theta)
  39. matrix = np.zeros((2, 3), dtype=np.float32)
  40. scale_x = size_dst[0] / size_target[0]
  41. scale_y = size_dst[1] / size_target[1]
  42. matrix[0, 0] = np.cos(theta) * scale_x
  43. matrix[0, 1] = -np.sin(theta) * scale_x
  44. matrix[0, 2] = scale_x * (
  45. -0.5 * size_input[0] * np.cos(theta)
  46. + 0.5 * size_input[1] * np.sin(theta)
  47. + 0.5 * size_target[0]
  48. )
  49. matrix[1, 0] = np.sin(theta) * scale_y
  50. matrix[1, 1] = np.cos(theta) * scale_y
  51. matrix[1, 2] = scale_y * (
  52. -0.5 * size_input[0] * np.sin(theta)
  53. - 0.5 * size_input[1] * np.cos(theta)
  54. + 0.5 * size_target[1]
  55. )
  56. return matrix
  57. class TopDownAffine:
  58. """refer to https://github.com/open-mmlab/mmpose/blob/71ec36ebd63c475ab589afc817868e749a61491f/mmpose/datasets/transforms/topdown_transforms.py#L13
  59. Get the bbox image as the model input by affine transform.
  60. Args:
  61. input_size (Tuple[int, int]): The input image size of the model in
  62. [w, h]. The bbox region will be cropped and resize to `input_size`
  63. use_udp (bool): Whether use unbiased data processing. See
  64. `UDP (CVPR 2020)`_ for details. Defaults to ``False``
  65. .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524
  66. """
  67. def __init__(self, input_size: Tuple[int, int], use_udp: bool = False):
  68. assert (
  69. all([isinstance(i, int) for i in input_size]) and len(input_size) == 2
  70. ), f"Invalid input_size {input_size}"
  71. self.input_size = input_size
  72. self.use_udp = use_udp
  73. def apply(
  74. self,
  75. img: ndarray,
  76. center: Optional[Union[Tuple[Number, Number], ndarray]] = None,
  77. scale: Optional[Union[Tuple[Number, Number], ndarray]] = None,
  78. ) -> Tuple[ndarray, ndarray, ndarray]:
  79. """Applies a wrapaffine to the input image based on the specified center, scale.
  80. Args:
  81. img (ndarray): The input image as a NumPy ndarray.
  82. center (Optional[Union[Tuple[Number, Number], ndarray]], optional): Center of the bounding box (x, y)
  83. scale (Optional[Union[Tuple[Number, Number], ndarray]], optional): Scale of the bounding box
  84. wrt [width, height].
  85. Returns:
  86. Tuple[ndarray, ndarray, ndarray]: The transformed image,
  87. the center used for the transformation, and the scale used for the transformation.
  88. """
  89. rot = 0
  90. imshape = np.array(img.shape[:2][::-1])
  91. if isinstance(center, Sequence):
  92. center = np.array(center)
  93. if isinstance(scale, Sequence):
  94. scale = np.array(scale)
  95. center = center if center is not None else imshape / 2.0
  96. scale = scale if scale is not None else imshape
  97. if self.use_udp:
  98. trans = get_warp_matrix(
  99. rot,
  100. center * 2.0,
  101. [self.input_size[0] - 1.0, self.input_size[1] - 1.0],
  102. scale,
  103. )
  104. img = cv2.warpAffine(
  105. img,
  106. trans,
  107. (int(self.input_size[0]), int(self.input_size[1])),
  108. flags=cv2.INTER_LINEAR,
  109. )
  110. else:
  111. trans = get_affine_transform(center, scale, rot, self.input_size)
  112. img = cv2.warpAffine(
  113. img,
  114. trans,
  115. (int(self.input_size[0]), int(self.input_size[1])),
  116. flags=cv2.INTER_LINEAR,
  117. )
  118. return img, center, scale
  119. def __call__(self, datas: List[dict]) -> List[dict]:
  120. for data in datas:
  121. ori_img = data["img"]
  122. if "ori_img" not in data:
  123. data["ori_img"] = ori_img
  124. if "ori_img_size" not in data:
  125. data["ori_img_size"] = [ori_img.shape[1], ori_img.shape[0]]
  126. img, center, scale = self.apply(
  127. ori_img, data.get("center", None), data.get("scale", None)
  128. )
  129. data["img"] = img
  130. data["center"] = center
  131. data["scale"] = scale
  132. img_size = [img.shape[1], img.shape[0]]
  133. data["img_size"] = img_size # [size_w, size_h]
  134. return datas
  135. def affine_transform(pt: ndarray, t: ndarray):
  136. """Apply an affine transformation to a 2D point.
  137. Args:
  138. pt (numpy.ndarray): A 2D point represented as a 2-element array.
  139. t (numpy.ndarray): A 3x3 affine transformation matrix.
  140. Returns:
  141. numpy.ndarray: The transformed 2D point.
  142. """
  143. new_pt = np.array([pt[0], pt[1], 1.0]).T
  144. new_pt = np.dot(t, new_pt)
  145. return new_pt[:2]
  146. def transform_preds(
  147. coords: ndarray,
  148. center: Tuple[float, float],
  149. scale: Tuple[float, float],
  150. output_size: Tuple[int, int],
  151. ) -> ndarray:
  152. """Transform coordinates to the target space using an affine transformation.
  153. Args:
  154. coords (numpy.ndarray): Original coordinates, shape (N, 2).
  155. center (tuple): Center point for the transformation.
  156. scale (tuple): Scale factor for the transformation.
  157. output_size (tuple): Size of the output space.
  158. Returns:
  159. numpy.ndarray: Transformed coordinates, shape (N, 2).
  160. """
  161. target_coords = np.zeros(coords.shape)
  162. trans = get_affine_transform(center, scale, 0, output_size, inv=1)
  163. for p in range(coords.shape[0]):
  164. target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
  165. return target_coords
  166. class KptPostProcess:
  167. """Save Result Transform"""
  168. def __init__(self, use_dark=True):
  169. self.use_dark = use_dark
  170. def apply(self, heatmap: ndarray, center: ndarray, scale: ndarray) -> Kpts:
  171. """apply"""
  172. # TODO: add batch support
  173. heatmap, center, scale = heatmap[None, ...], center[None, ...], scale[None, ...]
  174. preds, maxvals = self.get_final_preds(heatmap, center, scale)
  175. keypoints, scores = np.concatenate((preds, maxvals), axis=-1), np.mean(
  176. maxvals.squeeze(-1), axis=1
  177. )
  178. return [
  179. {"keypoints": kpt, "kpt_score": score}
  180. for kpt, score in zip(keypoints, scores)
  181. ]
  182. def __call__(self, batch_outputs: List[dict], datas: List[dict]) -> List[Kpts]:
  183. """Apply the post-processing to a batch of outputs.
  184. Args:
  185. batch_outputs (List[dict]): The list of detection outputs.
  186. datas (List[dict]): The list of input data.
  187. Returns:
  188. List[dict]: The list of post-processed keypoints.
  189. """
  190. return [
  191. self.apply(output["heatmap"], data["center"], data["scale"])
  192. for data, output in zip(datas, batch_outputs)
  193. ]
  194. def get_final_preds(
  195. self, heatmaps: ndarray, center: ndarray, scale: ndarray, kernelsize: int = 3
  196. ):
  197. """the highest heatvalue location with a quarter offset in the
  198. direction from the highest response to the second highest response.
  199. Args:
  200. heatmaps (numpy.ndarray): The predicted heatmaps
  201. center (numpy.ndarray): The boxes center
  202. scale (numpy.ndarray): The scale factor
  203. Returns:
  204. preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
  205. maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints
  206. """
  207. coords, maxvals = self.get_max_preds(heatmaps)
  208. heatmap_height = heatmaps.shape[2]
  209. heatmap_width = heatmaps.shape[3]
  210. if self.use_dark:
  211. coords = self.dark_postprocess(heatmaps, coords, kernelsize)
  212. else:
  213. for n in range(coords.shape[0]):
  214. for p in range(coords.shape[1]):
  215. hm = heatmaps[n][p]
  216. px = int(math.floor(coords[n][p][0] + 0.5))
  217. py = int(math.floor(coords[n][p][1] + 0.5))
  218. if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
  219. diff = np.array(
  220. [
  221. hm[py][px + 1] - hm[py][px - 1],
  222. hm[py + 1][px] - hm[py - 1][px],
  223. ]
  224. )
  225. coords[n][p] += np.sign(diff) * 0.25
  226. preds = coords.copy()
  227. # Transform back
  228. for i in range(coords.shape[0]):
  229. preds[i] = transform_preds(
  230. coords[i], center[i], scale[i], [heatmap_width, heatmap_height]
  231. )
  232. return preds, maxvals
  233. def get_max_preds(self, heatmaps: ndarray) -> Tuple[ndarray, ndarray]:
  234. """get predictions from score maps
  235. Args:
  236. heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
  237. Returns:
  238. preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
  239. maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints
  240. """
  241. assert isinstance(heatmaps, np.ndarray), "heatmaps should be numpy.ndarray"
  242. assert heatmaps.ndim == 4, "batch_images should be 4-ndim"
  243. batch_size = heatmaps.shape[0]
  244. num_joints = heatmaps.shape[1]
  245. width = heatmaps.shape[3]
  246. heatmaps_reshaped = heatmaps.reshape((batch_size, num_joints, -1))
  247. idx = np.argmax(heatmaps_reshaped, 2)
  248. maxvals = np.amax(heatmaps_reshaped, 2)
  249. maxvals = maxvals.reshape((batch_size, num_joints, 1))
  250. idx = idx.reshape((batch_size, num_joints, 1))
  251. preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
  252. preds[:, :, 0] = (preds[:, :, 0]) % width
  253. preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
  254. pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
  255. pred_mask = pred_mask.astype(np.float32)
  256. preds *= pred_mask
  257. return preds, maxvals
  258. def gaussian_blur(self, heatmap: ndarray, kernel: int) -> ndarray:
  259. border = (kernel - 1) // 2
  260. batch_size = heatmap.shape[0]
  261. num_joints = heatmap.shape[1]
  262. height = heatmap.shape[2]
  263. width = heatmap.shape[3]
  264. for i in range(batch_size):
  265. for j in range(num_joints):
  266. origin_max = np.max(heatmap[i, j])
  267. dr = np.zeros((height + 2 * border, width + 2 * border))
  268. dr[border:-border, border:-border] = heatmap[i, j].copy()
  269. dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)
  270. heatmap[i, j] = dr[border:-border, border:-border].copy()
  271. heatmap[i, j] *= origin_max / np.max(heatmap[i, j])
  272. return heatmap
  273. def dark_parse(self, hm: ndarray, coord: ndarray):
  274. heatmap_height = hm.shape[0]
  275. heatmap_width = hm.shape[1]
  276. px = int(coord[0])
  277. py = int(coord[1])
  278. if 1 < px < heatmap_width - 2 and 1 < py < heatmap_height - 2:
  279. dx = 0.5 * (hm[py][px + 1] - hm[py][px - 1])
  280. dy = 0.5 * (hm[py + 1][px] - hm[py - 1][px])
  281. dxx = 0.25 * (hm[py][px + 2] - 2 * hm[py][px] + hm[py][px - 2])
  282. dxy = 0.25 * (
  283. hm[py + 1][px + 1]
  284. - hm[py - 1][px + 1]
  285. - hm[py + 1][px - 1]
  286. + hm[py - 1][px - 1]
  287. )
  288. dyy = 0.25 * (hm[py + 2 * 1][px] - 2 * hm[py][px] + hm[py - 2 * 1][px])
  289. derivative = np.matrix([[dx], [dy]])
  290. hessian = np.matrix([[dxx, dxy], [dxy, dyy]])
  291. if dxx * dyy - dxy**2 != 0:
  292. hessianinv = hessian.I
  293. offset = -hessianinv * derivative
  294. offset = np.squeeze(np.array(offset.T), axis=0)
  295. coord += offset
  296. return coord
  297. def dark_postprocess(
  298. self, hm: ndarray, coords: ndarray, kernelsize: int
  299. ) -> ndarray:
  300. """
  301. refer to https://github.com/ilovepose/DarkPose/lib/core/inference.py
  302. """
  303. hm = self.gaussian_blur(hm, kernelsize)
  304. hm = np.maximum(hm, 1e-10)
  305. hm = np.log(hm)
  306. for n in range(coords.shape[0]):
  307. for p in range(coords.shape[1]):
  308. coords[n, p] = self.dark_parse(hm[n][p], coords[n][p])
  309. return coords