keypoint_metrics.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import json
  16. from collections import defaultdict, OrderedDict
  17. import numpy as np
  18. from pycocotools.coco import COCO
  19. from pycocotools.cocoeval import COCOeval
  20. from ..modeling.keypoint_utils import oks_nms
  21. from scipy.io import loadmat, savemat
  22. __all__ = ['KeyPointTopDownCOCOEval', 'KeyPointTopDownMPIIEval']
  23. class KeyPointTopDownCOCOEval(object):
  24. '''
  25. Adapted from
  26. https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
  27. Copyright (c) Microsoft, under the MIT License.
  28. '''
  29. def __init__(self,
  30. anno_file,
  31. num_samples,
  32. num_joints,
  33. output_eval,
  34. iou_type='keypoints',
  35. in_vis_thre=0.2,
  36. oks_thre=0.9):
  37. super(KeyPointTopDownCOCOEval, self).__init__()
  38. self.coco = COCO(anno_file)
  39. self.num_samples = num_samples
  40. self.num_joints = num_joints
  41. self.iou_type = iou_type
  42. self.in_vis_thre = in_vis_thre
  43. self.oks_thre = oks_thre
  44. self.output_eval = output_eval
  45. self.res_file = os.path.join(output_eval, "keypoints_results.json")
  46. self.reset()
  47. def reset(self):
  48. self.results = {
  49. 'all_preds': np.zeros(
  50. (self.num_samples, self.num_joints, 3), dtype=np.float32),
  51. 'all_boxes': np.zeros((self.num_samples, 6)),
  52. 'image_path': []
  53. }
  54. self.eval_results = {}
  55. self.idx = 0
  56. def update(self, inputs, outputs):
  57. kpts, _ = outputs['keypoint'][0]
  58. num_images = inputs['image'].shape[0]
  59. self.results['all_preds'][self.idx:self.idx + num_images, :, 0:
  60. 3] = kpts[:, :, 0:3]
  61. self.results['all_boxes'][self.idx:self.idx + num_images, 0:
  62. 2] = inputs['center'].numpy()[:, 0:2]
  63. self.results['all_boxes'][self.idx:self.idx + num_images, 2:
  64. 4] = inputs['scale'].numpy()[:, 0:2]
  65. self.results['all_boxes'][self.idx:self.idx + num_images, 4] = np.prod(
  66. inputs['scale'].numpy() * 200, 1)
  67. self.results['all_boxes'][self.idx:self.idx + num_images,
  68. 5] = np.squeeze(inputs['score'].numpy())
  69. self.results['image_path'].extend(inputs['im_id'].numpy())
  70. self.idx += num_images
  71. def _write_coco_keypoint_results(self, keypoints):
  72. data_pack = [{
  73. 'cat_id': 1,
  74. 'cls': 'person',
  75. 'ann_type': 'keypoints',
  76. 'keypoints': keypoints
  77. }]
  78. results = self._coco_keypoint_results_one_category_kernel(data_pack[0])
  79. if not os.path.exists(self.output_eval):
  80. os.makedirs(self.output_eval)
  81. with open(self.res_file, 'w') as f:
  82. json.dump(results, f, sort_keys=True, indent=4)
  83. try:
  84. json.load(open(self.res_file))
  85. except Exception:
  86. content = []
  87. with open(self.res_file, 'r') as f:
  88. for line in f:
  89. content.append(line)
  90. content[-1] = ']'
  91. with open(self.res_file, 'w') as f:
  92. for c in content:
  93. f.write(c)
  94. def _coco_keypoint_results_one_category_kernel(self, data_pack):
  95. cat_id = data_pack['cat_id']
  96. keypoints = data_pack['keypoints']
  97. cat_results = []
  98. for img_kpts in keypoints:
  99. if len(img_kpts) == 0:
  100. continue
  101. _key_points = np.array(
  102. [img_kpts[k]['keypoints'] for k in range(len(img_kpts))])
  103. _key_points = _key_points.reshape(_key_points.shape[0], -1)
  104. result = [{
  105. 'image_id': img_kpts[k]['image'],
  106. 'category_id': cat_id,
  107. 'keypoints': _key_points[k].tolist(),
  108. 'score': img_kpts[k]['score'],
  109. 'center': list(img_kpts[k]['center']),
  110. 'scale': list(img_kpts[k]['scale'])
  111. } for k in range(len(img_kpts))]
  112. cat_results.extend(result)
  113. return cat_results
  114. def get_final_results(self, preds, all_boxes, img_path):
  115. _kpts = []
  116. for idx, kpt in enumerate(preds):
  117. _kpts.append({
  118. 'keypoints': kpt,
  119. 'center': all_boxes[idx][0:2],
  120. 'scale': all_boxes[idx][2:4],
  121. 'area': all_boxes[idx][4],
  122. 'score': all_boxes[idx][5],
  123. 'image': int(img_path[idx])
  124. })
  125. # image x person x (keypoints)
  126. kpts = defaultdict(list)
  127. for kpt in _kpts:
  128. kpts[kpt['image']].append(kpt)
  129. # rescoring and oks nms
  130. num_joints = preds.shape[1]
  131. in_vis_thre = self.in_vis_thre
  132. oks_thre = self.oks_thre
  133. oks_nmsed_kpts = []
  134. for img in kpts.keys():
  135. img_kpts = kpts[img]
  136. for n_p in img_kpts:
  137. box_score = n_p['score']
  138. kpt_score = 0
  139. valid_num = 0
  140. for n_jt in range(0, num_joints):
  141. t_s = n_p['keypoints'][n_jt][2]
  142. if t_s > in_vis_thre:
  143. kpt_score = kpt_score + t_s
  144. valid_num = valid_num + 1
  145. if valid_num != 0:
  146. kpt_score = kpt_score / valid_num
  147. # rescoring
  148. n_p['score'] = kpt_score * box_score
  149. keep = oks_nms([img_kpts[i] for i in range(len(img_kpts))],
  150. oks_thre)
  151. if len(keep) == 0:
  152. oks_nmsed_kpts.append(img_kpts)
  153. else:
  154. oks_nmsed_kpts.append([img_kpts[_keep] for _keep in keep])
  155. self._write_coco_keypoint_results(oks_nmsed_kpts)
  156. def accumulate(self):
  157. self.get_final_results(self.results['all_preds'],
  158. self.results['all_boxes'],
  159. self.results['image_path'])
  160. coco_dt = self.coco.loadRes(self.res_file)
  161. coco_eval = COCOeval(self.coco, coco_dt, 'keypoints')
  162. coco_eval.params.useSegm = None
  163. coco_eval.evaluate()
  164. coco_eval.accumulate()
  165. coco_eval.summarize()
  166. keypoint_stats = []
  167. for ind in range(len(coco_eval.stats)):
  168. keypoint_stats.append((coco_eval.stats[ind]))
  169. self.eval_results['keypoint'] = keypoint_stats
  170. def log(self):
  171. stats_names = [
  172. 'AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5',
  173. 'AR .75', 'AR (M)', 'AR (L)'
  174. ]
  175. num_values = len(stats_names)
  176. print(' '.join(['| {}'.format(name) for name in stats_names]) + ' |')
  177. print('|---' * (num_values + 1) + '|')
  178. print(' '.join([
  179. '| {:.3f}'.format(value) for value in self.eval_results['keypoint']
  180. ]) + ' |')
  181. def get_results(self):
  182. return self.eval_results
  183. class KeyPointTopDownMPIIEval(object):
  184. def __init__(self,
  185. anno_file,
  186. num_samples,
  187. num_joints,
  188. output_eval,
  189. oks_thre=0.9):
  190. super(KeyPointTopDownMPIIEval, self).__init__()
  191. self.ann_file = anno_file
  192. self.reset()
  193. def reset(self):
  194. self.results = []
  195. self.eval_results = {}
  196. self.idx = 0
  197. def update(self, inputs, outputs):
  198. kpts, _ = outputs['keypoint'][0]
  199. num_images = inputs['image'].shape[0]
  200. results = {}
  201. results['preds'] = kpts[:, :, 0:3]
  202. results['boxes'] = np.zeros((num_images, 6))
  203. results['boxes'][:, 0:2] = inputs['center'].numpy()[:, 0:2]
  204. results['boxes'][:, 2:4] = inputs['scale'].numpy()[:, 0:2]
  205. results['boxes'][:, 4] = np.prod(inputs['scale'].numpy() * 200, 1)
  206. results['boxes'][:, 5] = np.squeeze(inputs['score'].numpy())
  207. results['image_path'] = inputs['image_file']
  208. self.results.append(results)
  209. def accumulate(self):
  210. self.eval_results = self.evaluate(self.results)
  211. def log(self):
  212. for item, value in self.eval_results.items():
  213. print("{} : {}".format(item, value))
  214. def get_results(self):
  215. return self.eval_results
  216. def evaluate(self, outputs, savepath=None):
  217. """Evaluate PCKh for MPII dataset. Adapted from
  218. https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
  219. Copyright (c) Microsoft, under the MIT License.
  220. Args:
  221. outputs(list(preds, boxes)):
  222. * preds (np.ndarray[N,K,3]): The first two dimensions are
  223. coordinates, score is the third dimension of the array.
  224. * boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
  225. , scale[1],area, score]
  226. Returns:
  227. dict: PCKh for each joint
  228. """
  229. kpts = []
  230. for output in outputs:
  231. preds = output['preds']
  232. batch_size = preds.shape[0]
  233. for i in range(batch_size):
  234. kpts.append({'keypoints': preds[i]})
  235. preds = np.stack([kpt['keypoints'] for kpt in kpts])
  236. # convert 0-based index to 1-based index,
  237. # and get the first two dimensions.
  238. preds = preds[..., :2] + 1.0
  239. if savepath is not None:
  240. pred_file = os.path.join(savepath, 'pred.mat')
  241. savemat(pred_file, mdict={'preds': preds})
  242. SC_BIAS = 0.6
  243. threshold = 0.5
  244. gt_file = os.path.join(
  245. os.path.dirname(self.ann_file), 'mpii_gt_val.mat')
  246. gt_dict = loadmat(gt_file)
  247. dataset_joints = gt_dict['dataset_joints']
  248. jnt_missing = gt_dict['jnt_missing']
  249. pos_gt_src = gt_dict['pos_gt_src']
  250. headboxes_src = gt_dict['headboxes_src']
  251. pos_pred_src = np.transpose(preds, [1, 2, 0])
  252. head = np.where(dataset_joints == 'head')[1][0]
  253. lsho = np.where(dataset_joints == 'lsho')[1][0]
  254. lelb = np.where(dataset_joints == 'lelb')[1][0]
  255. lwri = np.where(dataset_joints == 'lwri')[1][0]
  256. lhip = np.where(dataset_joints == 'lhip')[1][0]
  257. lkne = np.where(dataset_joints == 'lkne')[1][0]
  258. lank = np.where(dataset_joints == 'lank')[1][0]
  259. rsho = np.where(dataset_joints == 'rsho')[1][0]
  260. relb = np.where(dataset_joints == 'relb')[1][0]
  261. rwri = np.where(dataset_joints == 'rwri')[1][0]
  262. rkne = np.where(dataset_joints == 'rkne')[1][0]
  263. rank = np.where(dataset_joints == 'rank')[1][0]
  264. rhip = np.where(dataset_joints == 'rhip')[1][0]
  265. jnt_visible = 1 - jnt_missing
  266. uv_error = pos_pred_src - pos_gt_src
  267. uv_err = np.linalg.norm(uv_error, axis=1)
  268. headsizes = headboxes_src[1, :, :] - headboxes_src[0, :, :]
  269. headsizes = np.linalg.norm(headsizes, axis=0)
  270. headsizes *= SC_BIAS
  271. scale = headsizes * np.ones((len(uv_err), 1), dtype=np.float32)
  272. scaled_uv_err = uv_err / scale
  273. scaled_uv_err = scaled_uv_err * jnt_visible
  274. jnt_count = np.sum(jnt_visible, axis=1)
  275. less_than_threshold = (scaled_uv_err <= threshold) * jnt_visible
  276. PCKh = 100. * np.sum(less_than_threshold, axis=1) / jnt_count
  277. # save
  278. rng = np.arange(0, 0.5 + 0.01, 0.01)
  279. pckAll = np.zeros((len(rng), 16), dtype=np.float32)
  280. for r, threshold in enumerate(rng):
  281. less_than_threshold = (scaled_uv_err <= threshold) * jnt_visible
  282. pckAll[r, :] = 100. * np.sum(less_than_threshold,
  283. axis=1) / jnt_count
  284. PCKh = np.ma.array(PCKh, mask=False)
  285. PCKh.mask[6:8] = True
  286. jnt_count = np.ma.array(jnt_count, mask=False)
  287. jnt_count.mask[6:8] = True
  288. jnt_ratio = jnt_count / np.sum(jnt_count).astype(np.float64)
  289. name_value = [ #noqa
  290. ('Head', PCKh[head]),
  291. ('Shoulder', 0.5 * (PCKh[lsho] + PCKh[rsho])),
  292. ('Elbow', 0.5 * (PCKh[lelb] + PCKh[relb])),
  293. ('Wrist', 0.5 * (PCKh[lwri] + PCKh[rwri])),
  294. ('Hip', 0.5 * (PCKh[lhip] + PCKh[rhip])),
  295. ('Knee', 0.5 * (PCKh[lkne] + PCKh[rkne])),
  296. ('Ankle', 0.5 * (PCKh[lank] + PCKh[rank])),
  297. ('PCKh', np.sum(PCKh * jnt_ratio)),
  298. ('PCKh@0.1', np.sum(pckAll[11, :] * jnt_ratio))
  299. ]
  300. name_value = OrderedDict(name_value)
  301. return name_value
  302. def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
  303. """sort kpts and remove the repeated ones."""
  304. kpts = sorted(kpts, key=lambda x: x[key])
  305. num = len(kpts)
  306. for i in range(num - 1, 0, -1):
  307. if kpts[i][key] == kpts[i - 1][key]:
  308. del kpts[i]
  309. return kpts