topk_eval.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. import json
  13. import argparse
  14. from paddle import nn
  15. import paddle
  16. from ....utils import logging
  17. def parse_args():
  18. """Parse all arguments """
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument(
  21. '--prediction_json_path', type=str, default='./pre_res.json')
  22. parser.add_argument('--gt_val_path', type=str, default='./val.txt')
  23. parser.add_argument('--image_dir', type=str)
  24. parser.add_argument('--num_classes', type=int)
  25. args = parser.parse_args()
  26. return args
  27. class AvgMetrics(nn.Layer):
  28. """ Average metrics """
  29. def __init__(self):
  30. super().__init__()
  31. self.avg_meters = {}
  32. @property
  33. def avg(self):
  34. """ Return average value of each metric """
  35. if self.avg_meters:
  36. for metric_key in self.avg_meters:
  37. return self.avg_meters[metric_key].avg
  38. @property
  39. def avg_info(self):
  40. """ Return a formatted string of average values and names """
  41. return ", ".join(
  42. [self.avg_meters[key].avg_info for key in self.avg_meters])
  43. class TopkAcc(AvgMetrics):
  44. """ Top-k accuracy metric """
  45. def __init__(self, topk=(1, 5)):
  46. super().__init__()
  47. assert isinstance(topk, (int, list, tuple))
  48. if isinstance(topk, int):
  49. topk = [topk]
  50. self.topk = topk
  51. self.warned = False
  52. def forward(self, x, label):
  53. """ forward function """
  54. if isinstance(x, dict):
  55. x = x["logits"]
  56. output_dims = x.shape[-1]
  57. metric_dict = dict()
  58. for idx, k in enumerate(self.topk):
  59. if output_dims < k:
  60. if not self.warned:
  61. msg = f"The output dims({output_dims}) is less than k({k}), so the Top-{k} metric is meaningless."
  62. logging.info(msg)
  63. self.warned = True
  64. metric_dict[f"top{k}"] = 1
  65. else:
  66. metric_dict[f"top{k}"] = paddle.metric.accuracy(
  67. x, label, k=k).item()
  68. return metric_dict
  69. def prase_pt_info(pt_info, num_classes):
  70. """ Parse prediction information to probability vector """
  71. pre_list = [0.0] * num_classes
  72. for idx, val in zip(pt_info["class_ids"], pt_info["scores"]):
  73. pre_list[idx] = val
  74. return pre_list
  75. def main(args):
  76. """ main function """
  77. with open(args.prediction_json_path, 'r') as fp:
  78. predication_result = json.load(fp)
  79. gt_info = {}
  80. pred = []
  81. label = []
  82. for line in open(args.gt_val_path):
  83. img_file, gt_label = line.strip().split(" ")
  84. img_file = img_file.split('/')[-1]
  85. gt_info[img_file] = int(gt_label)
  86. for pt_info in predication_result:
  87. img_file = os.path.relpath(pt_info['file_name'], args.image_dir)
  88. pred.append(prase_pt_info(pt_info, args.num_classes))
  89. label.append([gt_info[img_file]])
  90. metric_dict = TopkAcc()(paddle.to_tensor(pred), paddle.to_tensor(label))
  91. logging.info(metric_dict)
  92. if __name__ == "__main__":
  93. args = parse_args()
  94. main(args)