keypoint_operators.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # function:
  15. # operators to process sample,
  16. # eg: decode/resize/crop image
  17. from __future__ import absolute_import
  18. try:
  19. from collections.abc import Sequence
  20. except Exception:
  21. from collections import Sequence
  22. import cv2
  23. import numpy as np
  24. import math
  25. import copy
  26. import os
  27. from ...modeling.keypoint_utils import get_affine_mat_kernel, warp_affine_joints, get_affine_transform, affine_transform
  28. from paddlex.ppdet.core.workspace import serializable
  29. from paddlex.ppdet.utils.logger import setup_logger
  30. logger = setup_logger(__name__)
  31. registered_ops = []
  32. __all__ = [
  33. 'RandomAffine', 'KeyPointFlip', 'TagGenerate', 'ToHeatmaps',
  34. 'NormalizePermute', 'EvalAffine', 'RandomFlipHalfBodyTransform',
  35. 'TopDownAffine', 'ToHeatmapsTopDown', 'TopDownEvalAffine'
  36. ]
  37. def register_keypointop(cls):
  38. return serializable(cls)
  39. @register_keypointop
  40. class KeyPointFlip(object):
  41. """Get the fliped image by flip_prob. flip the coords also
  42. the left coords and right coords should exchange while flip, for the right keypoint will be left keypoint after image fliped
  43. Args:
  44. flip_permutation (list[17]): the left-right exchange order list corresponding to [0,1,2,...,16]
  45. hmsize (list[2]): output heatmap's shape list of different scale outputs of higherhrnet
  46. flip_prob (float): the ratio whether to flip the image
  47. records(dict): the dict contained the image, mask and coords
  48. Returns:
  49. records(dict): contain the image, mask and coords after tranformed
  50. """
  51. def __init__(self, flip_permutation, hmsize, flip_prob=0.5):
  52. super(KeyPointFlip, self).__init__()
  53. assert isinstance(flip_permutation, Sequence)
  54. self.flip_permutation = flip_permutation
  55. self.flip_prob = flip_prob
  56. self.hmsize = hmsize
  57. def __call__(self, records):
  58. image = records['image']
  59. kpts_lst = records['joints']
  60. mask_lst = records['mask']
  61. flip = np.random.random() < self.flip_prob
  62. if flip:
  63. image = image[:, ::-1]
  64. for idx, hmsize in enumerate(self.hmsize):
  65. if len(mask_lst) > idx:
  66. mask_lst[idx] = mask_lst[idx][:, ::-1]
  67. if kpts_lst[idx].ndim == 3:
  68. kpts_lst[idx] = kpts_lst[idx][:, self.flip_permutation]
  69. else:
  70. kpts_lst[idx] = kpts_lst[idx][self.flip_permutation]
  71. kpts_lst[idx][..., 0] = hmsize - kpts_lst[idx][..., 0]
  72. kpts_lst[idx] = kpts_lst[idx].astype(np.int64)
  73. kpts_lst[idx][kpts_lst[idx][..., 0] >= hmsize, 2] = 0
  74. kpts_lst[idx][kpts_lst[idx][..., 1] >= hmsize, 2] = 0
  75. kpts_lst[idx][kpts_lst[idx][..., 0] < 0, 2] = 0
  76. kpts_lst[idx][kpts_lst[idx][..., 1] < 0, 2] = 0
  77. records['image'] = image
  78. records['joints'] = kpts_lst
  79. records['mask'] = mask_lst
  80. return records
  81. def get_warp_matrix(theta, size_input, size_dst, size_target):
  82. """Calculate the transformation matrix under the constraint of unbiased.
  83. Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
  84. Data Processing for Human Pose Estimation (CVPR 2020).
  85. Args:
  86. theta (float): Rotation angle in degrees.
  87. size_input (np.ndarray): Size of input image [w, h].
  88. size_dst (np.ndarray): Size of output image [w, h].
  89. size_target (np.ndarray): Size of ROI in input plane [w, h].
  90. Returns:
  91. matrix (np.ndarray): A matrix for transformation.
  92. """
  93. theta = np.deg2rad(theta)
  94. matrix = np.zeros((2, 3), dtype=np.float32)
  95. scale_x = size_dst[0] / size_target[0]
  96. scale_y = size_dst[1] / size_target[1]
  97. matrix[0, 0] = math.cos(theta) * scale_x
  98. matrix[0, 1] = -math.sin(theta) * scale_x
  99. matrix[0, 2] = scale_x * (
  100. -0.5 * size_input[0] * math.cos(theta) + 0.5 * size_input[1] *
  101. math.sin(theta) + 0.5 * size_target[0])
  102. matrix[1, 0] = math.sin(theta) * scale_y
  103. matrix[1, 1] = math.cos(theta) * scale_y
  104. matrix[1, 2] = scale_y * (
  105. -0.5 * size_input[0] * math.sin(theta) - 0.5 * size_input[1] *
  106. math.cos(theta) + 0.5 * size_target[1])
  107. return matrix
  108. @register_keypointop
  109. class RandomAffine(object):
  110. """apply affine transform to image, mask and coords
  111. to achieve the rotate, scale and shift effect for training image
  112. Args:
  113. max_degree (float): the max abslute rotate degree to apply, transform range is [-max_degree, max_degree]
  114. max_scale (list[2]): the scale range to apply, transform range is [min, max]
  115. max_shift (float): the max abslute shift ratio to apply, transform range is [-max_shift*imagesize, max_shift*imagesize]
  116. hmsize (list[2]): output heatmap's shape list of different scale outputs of higherhrnet
  117. trainsize (int): the standard length used to train, the 'scale_type' of [h,w] will be resize to trainsize for standard
  118. scale_type (str): the length of [h,w] to used for trainsize, chosed between 'short' and 'long'
  119. records(dict): the dict contained the image, mask and coords
  120. Returns:
  121. records(dict): contain the image, mask and coords after tranformed
  122. """
  123. def __init__(self,
  124. max_degree=30,
  125. scale=[0.75, 1.5],
  126. max_shift=0.2,
  127. hmsize=[128, 256],
  128. trainsize=512,
  129. scale_type='short'):
  130. super(RandomAffine, self).__init__()
  131. self.max_degree = max_degree
  132. self.min_scale = scale[0]
  133. self.max_scale = scale[1]
  134. self.max_shift = max_shift
  135. self.hmsize = hmsize
  136. self.trainsize = trainsize
  137. self.scale_type = scale_type
  138. def _get_affine_matrix(self, center, scale, res, rot=0):
  139. """Generate transformation matrix."""
  140. h = scale
  141. t = np.zeros((3, 3), dtype=np.float32)
  142. t[0, 0] = float(res[1]) / h
  143. t[1, 1] = float(res[0]) / h
  144. t[0, 2] = res[1] * (-float(center[0]) / h + .5)
  145. t[1, 2] = res[0] * (-float(center[1]) / h + .5)
  146. t[2, 2] = 1
  147. if rot != 0:
  148. rot = -rot # To match direction of rotation from cropping
  149. rot_mat = np.zeros((3, 3), dtype=np.float32)
  150. rot_rad = rot * np.pi / 180
  151. sn, cs = np.sin(rot_rad), np.cos(rot_rad)
  152. rot_mat[0, :2] = [cs, -sn]
  153. rot_mat[1, :2] = [sn, cs]
  154. rot_mat[2, 2] = 1
  155. # Need to rotate around center
  156. t_mat = np.eye(3)
  157. t_mat[0, 2] = -res[1] / 2
  158. t_mat[1, 2] = -res[0] / 2
  159. t_inv = t_mat.copy()
  160. t_inv[:2, 2] *= -1
  161. t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
  162. return t
  163. def __call__(self, records):
  164. image = records['image']
  165. keypoints = records['joints']
  166. heatmap_mask = records['mask']
  167. degree = (np.random.random() * 2 - 1) * self.max_degree
  168. shape = np.array(image.shape[:2][::-1])
  169. center = center = np.array((np.array(shape) / 2))
  170. aug_scale = np.random.random() * (self.max_scale - self.min_scale
  171. ) + self.min_scale
  172. if self.scale_type == 'long':
  173. scale = max(shape[0], shape[1]) / 1.0
  174. elif self.scale_type == 'short':
  175. scale = min(shape[0], shape[1]) / 1.0
  176. else:
  177. raise ValueError('Unknown scale type: {}'.format(self.scale_type))
  178. roi_size = aug_scale * scale
  179. dx = int(0)
  180. dy = int(0)
  181. if self.max_shift > 0:
  182. dx = np.random.randint(-self.max_shift * roi_size,
  183. self.max_shift * roi_size)
  184. dy = np.random.randint(-self.max_shift * roi_size,
  185. self.max_shift * roi_size)
  186. center += np.array([dx, dy])
  187. input_size = 2 * center
  188. keypoints[..., :2] *= shape
  189. heatmap_mask *= 255
  190. kpts_lst = []
  191. mask_lst = []
  192. image_affine_mat = self._get_affine_matrix(
  193. center, roi_size, (self.trainsize, self.trainsize), degree)[:2]
  194. image = cv2.warpAffine(
  195. image,
  196. image_affine_mat, (self.trainsize, self.trainsize),
  197. flags=cv2.INTER_LINEAR)
  198. for hmsize in self.hmsize:
  199. kpts = copy.deepcopy(keypoints)
  200. mask_affine_mat = self._get_affine_matrix(
  201. center, roi_size, (hmsize, hmsize), degree)[:2]
  202. if heatmap_mask is not None:
  203. mask = cv2.warpAffine(heatmap_mask, mask_affine_mat,
  204. (hmsize, hmsize))
  205. mask = ((mask / 255) > 0.5).astype(np.float32)
  206. kpts[..., 0:2] = warp_affine_joints(kpts[..., 0:2].copy(),
  207. mask_affine_mat)
  208. kpts[np.trunc(kpts[..., 0]) >= hmsize, 2] = 0
  209. kpts[np.trunc(kpts[..., 1]) >= hmsize, 2] = 0
  210. kpts[np.trunc(kpts[..., 0]) < 0, 2] = 0
  211. kpts[np.trunc(kpts[..., 1]) < 0, 2] = 0
  212. kpts_lst.append(kpts)
  213. mask_lst.append(mask)
  214. records['image'] = image
  215. records['joints'] = kpts_lst
  216. records['mask'] = mask_lst
  217. return records
  218. @register_keypointop
  219. class EvalAffine(object):
  220. """apply affine transform to image
  221. resize the short of [h,w] to standard size for eval
  222. Args:
  223. size (int): the standard length used to train, the 'short' of [h,w] will be resize to trainsize for standard
  224. records(dict): the dict contained the image, mask and coords
  225. Returns:
  226. records(dict): contain the image, mask and coords after tranformed
  227. """
  228. def __init__(self, size, stride=64):
  229. super(EvalAffine, self).__init__()
  230. self.size = size
  231. self.stride = stride
  232. def __call__(self, records):
  233. image = records['image']
  234. mask = records['mask'] if 'mask' in records else None
  235. s = self.size
  236. h, w, _ = image.shape
  237. trans, size_resized = get_affine_mat_kernel(h, w, s, inv=False)
  238. image_resized = cv2.warpAffine(image, trans, size_resized)
  239. if mask is not None:
  240. mask = cv2.warpAffine(mask, trans, size_resized)
  241. records['mask'] = mask
  242. if 'joints' in records:
  243. del records['joints']
  244. records['image'] = image_resized
  245. return records
  246. @register_keypointop
  247. class NormalizePermute(object):
  248. def __init__(self,
  249. mean=[123.675, 116.28, 103.53],
  250. std=[58.395, 57.120, 57.375],
  251. is_scale=True):
  252. super(NormalizePermute, self).__init__()
  253. self.mean = mean
  254. self.std = std
  255. self.is_scale = is_scale
  256. def __call__(self, records):
  257. image = records['image']
  258. image = image.astype(np.float32)
  259. if self.is_scale:
  260. image /= 255.
  261. image = image.transpose((2, 0, 1))
  262. mean = np.array(self.mean, dtype=np.float32)
  263. std = np.array(self.std, dtype=np.float32)
  264. invstd = 1. / std
  265. for v, m, s in zip(image, mean, invstd):
  266. v.__isub__(m).__imul__(s)
  267. records['image'] = image
  268. return records
  269. @register_keypointop
  270. class TagGenerate(object):
  271. """record gt coords for aeloss to sample coords value in tagmaps
  272. Args:
  273. num_joints (int): the keypoint numbers of dataset to train
  274. num_people (int): maxmum people to support for sample aeloss
  275. records(dict): the dict contained the image, mask and coords
  276. Returns:
  277. records(dict): contain the gt coords used in tagmap
  278. """
  279. def __init__(self, num_joints, max_people=30):
  280. super(TagGenerate, self).__init__()
  281. self.max_people = max_people
  282. self.num_joints = num_joints
  283. def __call__(self, records):
  284. kpts_lst = records['joints']
  285. kpts = kpts_lst[0]
  286. tagmap = np.zeros(
  287. (self.max_people, self.num_joints, 4), dtype=np.int64)
  288. inds = np.where(kpts[..., 2] > 0)
  289. p, j = inds[0], inds[1]
  290. visible = kpts[inds]
  291. # tagmap is [p, j, 3], where last dim is j, y, x
  292. tagmap[p, j, 0] = j
  293. tagmap[p, j, 1] = visible[..., 1] # y
  294. tagmap[p, j, 2] = visible[..., 0] # x
  295. tagmap[p, j, 3] = 1
  296. records['tagmap'] = tagmap
  297. del records['joints']
  298. return records
  299. @register_keypointop
  300. class ToHeatmaps(object):
  301. """to generate the gaussin heatmaps of keypoint for heatmap loss
  302. Args:
  303. num_joints (int): the keypoint numbers of dataset to train
  304. hmsize (list[2]): output heatmap's shape list of different scale outputs of higherhrnet
  305. sigma (float): the std of gaussin kernel genereted
  306. records(dict): the dict contained the image, mask and coords
  307. Returns:
  308. records(dict): contain the heatmaps used to heatmaploss
  309. """
  310. def __init__(self, num_joints, hmsize, sigma=None):
  311. super(ToHeatmaps, self).__init__()
  312. self.num_joints = num_joints
  313. self.hmsize = np.array(hmsize)
  314. if sigma is None:
  315. sigma = hmsize[0] // 64
  316. self.sigma = sigma
  317. r = 6 * sigma + 3
  318. x = np.arange(0, r, 1, np.float32)
  319. y = x[:, None]
  320. x0, y0 = 3 * sigma + 1, 3 * sigma + 1
  321. self.gaussian = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
  322. def __call__(self, records):
  323. kpts_lst = records['joints']
  324. mask_lst = records['mask']
  325. for idx, hmsize in enumerate(self.hmsize):
  326. mask = mask_lst[idx]
  327. kpts = kpts_lst[idx]
  328. heatmaps = np.zeros((self.num_joints, hmsize, hmsize))
  329. inds = np.where(kpts[..., 2] > 0)
  330. visible = kpts[inds].astype(np.int64)[..., :2]
  331. ul = np.round(visible - 3 * self.sigma - 1)
  332. br = np.round(visible + 3 * self.sigma + 2)
  333. sul = np.maximum(0, -ul)
  334. sbr = np.minimum(hmsize, br) - ul
  335. dul = np.clip(ul, 0, hmsize - 1)
  336. dbr = np.clip(br, 0, hmsize)
  337. for i in range(len(visible)):
  338. dx1, dy1 = dul[i]
  339. dx2, dy2 = dbr[i]
  340. sx1, sy1 = sul[i]
  341. sx2, sy2 = sbr[i]
  342. heatmaps[inds[1][i], dy1:dy2, dx1:dx2] = np.maximum(
  343. self.gaussian[sy1:sy2, sx1:sx2],
  344. heatmaps[inds[1][i], dy1:dy2, dx1:dx2])
  345. records['heatmap_gt{}x'.format(idx + 1)] = heatmaps
  346. records['mask_{}x'.format(idx + 1)] = mask
  347. del records['mask']
  348. return records
  349. @register_keypointop
  350. class RandomFlipHalfBodyTransform(object):
  351. """apply data augment to image and coords
  352. to achieve the flip, scale, rotate and half body transform effect for training image
  353. Args:
  354. trainsize (list):[w, h], Image target size
  355. upper_body_ids (list): The upper body joint ids
  356. flip_pairs (list): The left-right joints exchange order list
  357. pixel_std (int): The pixel std of the scale
  358. scale (float): The scale factor to transform the image
  359. rot (int): The rotate factor to transform the image
  360. num_joints_half_body (int): The joints threshold of the half body transform
  361. prob_half_body (float): The threshold of the half body transform
  362. flip (bool): Whether to flip the image
  363. Returns:
  364. records(dict): contain the image and coords after tranformed
  365. """
  366. def __init__(self,
  367. trainsize,
  368. upper_body_ids,
  369. flip_pairs,
  370. pixel_std,
  371. scale=0.35,
  372. rot=40,
  373. num_joints_half_body=8,
  374. prob_half_body=0.3,
  375. flip=True,
  376. rot_prob=0.6):
  377. super(RandomFlipHalfBodyTransform, self).__init__()
  378. self.trainsize = trainsize
  379. self.upper_body_ids = upper_body_ids
  380. self.flip_pairs = flip_pairs
  381. self.pixel_std = pixel_std
  382. self.scale = scale
  383. self.rot = rot
  384. self.num_joints_half_body = num_joints_half_body
  385. self.prob_half_body = prob_half_body
  386. self.flip = flip
  387. self.aspect_ratio = trainsize[0] * 1.0 / trainsize[1]
  388. self.rot_prob = rot_prob
  389. def halfbody_transform(self, joints, joints_vis):
  390. upper_joints = []
  391. lower_joints = []
  392. for joint_id in range(joints.shape[0]):
  393. if joints_vis[joint_id][0] > 0:
  394. if joint_id in self.upper_body_ids:
  395. upper_joints.append(joints[joint_id])
  396. else:
  397. lower_joints.append(joints[joint_id])
  398. if np.random.randn() < 0.5 and len(upper_joints) > 2:
  399. selected_joints = upper_joints
  400. else:
  401. selected_joints = lower_joints if len(
  402. lower_joints) > 2 else upper_joints
  403. if len(selected_joints) < 2:
  404. return None, None
  405. selected_joints = np.array(selected_joints, dtype=np.float32)
  406. center = selected_joints.mean(axis=0)[:2]
  407. left_top = np.amin(selected_joints, axis=0)
  408. right_bottom = np.amax(selected_joints, axis=0)
  409. w = right_bottom[0] - left_top[0]
  410. h = right_bottom[1] - left_top[1]
  411. if w > self.aspect_ratio * h:
  412. h = w * 1.0 / self.aspect_ratio
  413. elif w < self.aspect_ratio * h:
  414. w = h * self.aspect_ratio
  415. scale = np.array(
  416. [w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
  417. dtype=np.float32)
  418. scale = scale * 1.5
  419. return center, scale
  420. def flip_joints(self, joints, joints_vis, width, matched_parts):
  421. joints[:, 0] = width - joints[:, 0] - 1
  422. for pair in matched_parts:
  423. joints[pair[0], :], joints[pair[1], :] = \
  424. joints[pair[1], :], joints[pair[0], :].copy()
  425. joints_vis[pair[0], :], joints_vis[pair[1], :] = \
  426. joints_vis[pair[1], :], joints_vis[pair[0], :].copy()
  427. return joints * joints_vis, joints_vis
  428. def __call__(self, records):
  429. image = records['image']
  430. joints = records['joints']
  431. joints_vis = records['joints_vis']
  432. c = records['center']
  433. s = records['scale']
  434. r = 0
  435. if (np.sum(joints_vis[:, 0]) > self.num_joints_half_body and
  436. np.random.rand() < self.prob_half_body):
  437. c_half_body, s_half_body = self.halfbody_transform(joints,
  438. joints_vis)
  439. if c_half_body is not None and s_half_body is not None:
  440. c, s = c_half_body, s_half_body
  441. sf = self.scale
  442. rf = self.rot
  443. s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
  444. r = np.clip(np.random.randn() * rf, -rf * 2,
  445. rf * 2) if np.random.random() <= self.rot_prob else 0
  446. if self.flip and np.random.random() <= 0.5:
  447. image = image[:, ::-1, :]
  448. joints, joints_vis = self.flip_joints(
  449. joints, joints_vis, image.shape[1], self.flip_pairs)
  450. c[0] = image.shape[1] - c[0] - 1
  451. records['image'] = image
  452. records['joints'] = joints
  453. records['joints_vis'] = joints_vis
  454. records['center'] = c
  455. records['scale'] = s
  456. records['rotate'] = r
  457. return records
  458. @register_keypointop
  459. class TopDownAffine(object):
  460. """apply affine transform to image and coords
  461. Args:
  462. trainsize (list): [w, h], the standard size used to train
  463. records(dict): the dict contained the image and coords
  464. Returns:
  465. records (dict): contain the image and coords after tranformed
  466. """
  467. def __init__(self, trainsize):
  468. self.trainsize = trainsize
  469. def __call__(self, records):
  470. image = records['image']
  471. joints = records['joints']
  472. joints_vis = records['joints_vis']
  473. rot = records['rotate'] if "rotate" in records else 0
  474. trans = get_affine_transform(records['center'], records['scale'] * 200,
  475. rot, self.trainsize)
  476. image = cv2.warpAffine(
  477. image,
  478. trans, (int(self.trainsize[0]), int(self.trainsize[1])),
  479. flags=cv2.INTER_LINEAR)
  480. for i in range(joints.shape[0]):
  481. if joints_vis[i, 0] > 0.0:
  482. joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
  483. records['image'] = image
  484. records['joints'] = joints
  485. return records
  486. @register_keypointop
  487. class TopDownEvalAffine(object):
  488. """apply affine transform to image and coords
  489. Args:
  490. trainsize (list): [w, h], the standard size used to train
  491. records(dict): the dict contained the image and coords
  492. Returns:
  493. records (dict): contain the image and coords after tranformed
  494. """
  495. def __init__(self, trainsize):
  496. self.trainsize = trainsize
  497. def __call__(self, records):
  498. image = records['image']
  499. rot = 0
  500. imshape = records['im_shape'][::-1]
  501. center = imshape / 2.
  502. scale = imshape
  503. trans = get_affine_transform(center, scale, rot, self.trainsize)
  504. image = cv2.warpAffine(
  505. image,
  506. trans, (int(self.trainsize[0]), int(self.trainsize[1])),
  507. flags=cv2.INTER_LINEAR)
  508. records['image'] = image
  509. return records
  510. @register_keypointop
  511. class ToHeatmapsTopDown(object):
  512. """to generate the gaussin heatmaps of keypoint for heatmap loss
  513. Args:
  514. hmsize (list): [w, h] output heatmap's size
  515. sigma (float): the std of gaussin kernel genereted
  516. records(dict): the dict contained the image and coords
  517. Returns:
  518. records (dict): contain the heatmaps used to heatmaploss
  519. """
  520. def __init__(self, hmsize, sigma):
  521. super(ToHeatmapsTopDown, self).__init__()
  522. self.hmsize = np.array(hmsize)
  523. self.sigma = sigma
  524. def __call__(self, records):
  525. joints = records['joints']
  526. joints_vis = records['joints_vis']
  527. num_joints = joints.shape[0]
  528. image_size = np.array(
  529. [records['image'].shape[1], records['image'].shape[0]])
  530. target_weight = np.ones((num_joints, 1), dtype=np.float32)
  531. target_weight[:, 0] = joints_vis[:, 0]
  532. target = np.zeros(
  533. (num_joints, self.hmsize[1], self.hmsize[0]), dtype=np.float32)
  534. tmp_size = self.sigma * 3
  535. for joint_id in range(num_joints):
  536. feat_stride = image_size / self.hmsize
  537. mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
  538. mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
  539. # Check that any part of the gaussian is in-bounds
  540. ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
  541. br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
  542. if ul[0] >= self.hmsize[0] or ul[1] >= self.hmsize[1] or br[
  543. 0] < 0 or br[1] < 0:
  544. # If not, just return the image as is
  545. target_weight[joint_id] = 0
  546. continue
  547. # # Generate gaussian
  548. size = 2 * tmp_size + 1
  549. x = np.arange(0, size, 1, np.float32)
  550. y = x[:, np.newaxis]
  551. x0 = y0 = size // 2
  552. # The gaussian is not normalized, we want the center value to equal 1
  553. g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * self.sigma**2))
  554. # Usable gaussian range
  555. g_x = max(0, -ul[0]), min(br[0], self.hmsize[0]) - ul[0]
  556. g_y = max(0, -ul[1]), min(br[1], self.hmsize[1]) - ul[1]
  557. # Image range
  558. img_x = max(0, ul[0]), min(br[0], self.hmsize[0])
  559. img_y = max(0, ul[1]), min(br[1], self.hmsize[1])
  560. v = target_weight[joint_id]
  561. if v > 0.5:
  562. target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[
  563. 0]:g_y[1], g_x[0]:g_x[1]]
  564. records['target'] = target
  565. records['target_weight'] = target_weight
  566. del records['joints'], records['joints_vis']
  567. return records