classification.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import yaml
  16. import os.path as osp
  17. import numpy as np
  18. from sklearn.metrics import confusion_matrix, roc_curve, auc
  19. from paddlex.utils import get_encoding
  20. class Evaluator(object):
  21. def __init__(self, model_path, topk=5):
  22. model_yml = osp.join(model_path, "model.yml")
  23. with open(model_yml, encoding=get_encoding(model_yml)) as f:
  24. model_info = yaml.load(f.read(), Loader=yaml.Loader)
  25. eval_details_file = osp.join(model_path, 'eval_details.json')
  26. with open(
  27. eval_details_file, 'r',
  28. encoding=get_encoding(eval_details_file)) as f:
  29. eval_details = json.load(f)
  30. self.topk = topk
  31. self.labels = model_info['_Attributes']['labels']
  32. self.true_labels = np.array(eval_details['true_labels'])
  33. self.pred_scores = np.array(eval_details['pred_scores'])
  34. label_ids_list = list(range(len(self.labels)))
  35. self.no_appear_label_ids = set(label_ids_list) - set(
  36. self.true_labels.tolist())
  37. def cal_confusion_matrix(self):
  38. '''计算混淆矩阵。
  39. '''
  40. pred_labels = np.argsort(self.pred_scores)[:, -1:].flatten()
  41. cm = confusion_matrix(
  42. self.true_labels.tolist(),
  43. pred_labels.tolist(),
  44. labels=list(range(len(self.labels))))
  45. return cm
  46. def cal_precision_recall_F1(self):
  47. '''计算precision、recall、F1。
  48. '''
  49. out = {}
  50. out_avg = {}
  51. out_avg['precision'] = 0.0
  52. out_avg['recall'] = 0.0
  53. out_avg['F1'] = 0.0
  54. pred_labels = np.argsort(self.pred_scores)[:, -1:].flatten()
  55. for label_id in range(len(self.labels)):
  56. out[self.labels[label_id]] = {}
  57. if label_id in self.no_appear_label_ids:
  58. out[self.labels[label_id]]['precision'] = -1.0
  59. out[self.labels[label_id]]['recall'] = -1.0
  60. out[self.labels[label_id]]['F1'] = -1.0
  61. continue
  62. pred_index = np.where(pred_labels == label_id)[0].tolist()
  63. tp = np.sum(
  64. self.true_labels[pred_index] == pred_labels[pred_index])
  65. tp_fp = len(pred_index)
  66. tp_fn = len(np.where(self.true_labels == label_id)[0].tolist())
  67. out[self.labels[label_id]]['precision'] = tp * 1.0 / tp_fp
  68. out[self.labels[label_id]]['recall'] = tp * 1.0 / tp_fn
  69. out[self.labels[label_id]]['F1'] = 2 * tp * 1.0 / (tp_fp + tp_fn)
  70. ratio = tp_fn * 1.0 / self.true_labels.shape[0]
  71. out_avg['precision'] += out[self.labels[label_id]][
  72. 'precision'] * ratio
  73. out_avg['recall'] += out[self.labels[label_id]]['recall'] * ratio
  74. out_avg['F1'] += out[self.labels[label_id]]['F1'] * ratio
  75. return out, out_avg
  76. def cal_auc(self):
  77. '''计算AUC。
  78. '''
  79. out = {}
  80. for label_id in range(len(self.labels)):
  81. part_pred_scores = self.pred_scores[:, label_id:label_id + 1]
  82. part_pred_scores = part_pred_scores.flatten()
  83. fpr, tpr, thresholds = roc_curve(
  84. self.true_labels, part_pred_scores, pos_label=label_id)
  85. label_auc = auc(fpr, tpr)
  86. if label_id in self.no_appear_label_ids:
  87. out[self.labels[label_id]] = -1.0
  88. continue
  89. out[self.labels[label_id]] = label_auc
  90. return out
  91. def cal_accuracy(self):
  92. '''计算Accuracy。
  93. '''
  94. out = {}
  95. k = min(self.topk, len(self.labels))
  96. pred_top1_label = np.argsort(self.pred_scores)[:, -1]
  97. pred_topk_label = np.argsort(self.pred_scores)[:, -k:]
  98. acc1 = sum(pred_top1_label == self.true_labels) / len(self.true_labels)
  99. acck = sum([
  100. np.isin(x, y) for x, y in zip(self.true_labels, pred_topk_label)
  101. ]) / len(self.true_labels)
  102. out['acc1'] = acc1
  103. out['acck'] = acck
  104. out['k'] = k
  105. return out
  106. def generate_report(self):
  107. '''生成评估报告。
  108. '''
  109. report = dict()
  110. report['Confusion_Matrix'] = self.cal_confusion_matrix()
  111. report['PRF1_average'] = {}
  112. report['PRF1'], report['PRF1_average'][
  113. 'over_all'] = self.cal_precision_recall_F1()
  114. auc = self.cal_auc()
  115. for k, v in auc.items():
  116. report['PRF1'][k]['auc'] = v
  117. acc = self.cal_accuracy()
  118. report["Acc1"] = acc["acc1"]
  119. report["Acck"] = acc["acck"]
  120. report["topk"] = acc["k"]
  121. report['label_list'] = self.labels
  122. return report