topk_eval.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import json
  16. import argparse
  17. from paddle import nn
  18. import paddle
  19. from ....utils import logging
  20. def parse_args():
  21. """Parse all arguments """
  22. parser = argparse.ArgumentParser()
  23. parser.add_argument(
  24. '--prediction_json_path', type=str, default='./pre_res.json')
  25. parser.add_argument('--gt_val_path', type=str, default='./val.txt')
  26. parser.add_argument('--image_dir', type=str)
  27. parser.add_argument('--num_classes', type=int)
  28. args = parser.parse_args()
  29. return args
  30. class AvgMetrics(nn.Layer):
  31. """ Average metrics """
  32. def __init__(self):
  33. super().__init__()
  34. self.avg_meters = {}
  35. @property
  36. def avg(self):
  37. """ Return average value of each metric """
  38. if self.avg_meters:
  39. for metric_key in self.avg_meters:
  40. return self.avg_meters[metric_key].avg
  41. @property
  42. def avg_info(self):
  43. """ Return a formatted string of average values and names """
  44. return ", ".join(
  45. [self.avg_meters[key].avg_info for key in self.avg_meters])
  46. class TopkAcc(AvgMetrics):
  47. """ Top-k accuracy metric """
  48. def __init__(self, topk=(1, 5)):
  49. super().__init__()
  50. assert isinstance(topk, (int, list, tuple))
  51. if isinstance(topk, int):
  52. topk = [topk]
  53. self.topk = topk
  54. self.warned = False
  55. def forward(self, x, label):
  56. """ forward function """
  57. if isinstance(x, dict):
  58. x = x["logits"]
  59. output_dims = x.shape[-1]
  60. metric_dict = dict()
  61. for idx, k in enumerate(self.topk):
  62. if output_dims < k:
  63. if not self.warned:
  64. msg = f"The output dims({output_dims}) is less than k({k}), so the Top-{k} metric is meaningless."
  65. logging.info(msg)
  66. self.warned = True
  67. metric_dict[f"top{k}"] = 1
  68. else:
  69. metric_dict[f"top{k}"] = paddle.metric.accuracy(
  70. x, label, k=k).item()
  71. return metric_dict
  72. def prase_pt_info(pt_info, num_classes):
  73. """ Parse prediction information to probability vector """
  74. pre_list = [0.0] * num_classes
  75. for idx, val in zip(pt_info["class_ids"], pt_info["scores"]):
  76. pre_list[idx] = val
  77. return pre_list
  78. def main(args):
  79. """ main function """
  80. with open(args.prediction_json_path, 'r') as fp:
  81. predication_result = json.load(fp)
  82. gt_info = {}
  83. pred = []
  84. label = []
  85. for line in open(args.gt_val_path):
  86. img_file, gt_label = line.strip().split(" ")
  87. img_file = img_file.split('/')[-1]
  88. gt_info[img_file] = int(gt_label)
  89. for pt_info in predication_result:
  90. img_file = os.path.relpath(pt_info['file_name'], args.image_dir)
  91. pred.append(prase_pt_info(pt_info, args.num_classes))
  92. label.append([gt_info[img_file]])
  93. metric_dict = TopkAcc()(paddle.to_tensor(pred), paddle.to_tensor(label))
  94. logging.info(metric_dict)
  95. if __name__ == "__main__":
  96. args = parse_args()
  97. main(args)