classification.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import yaml
  16. import os.path as osp
  17. import numpy as np
  18. from sklearn.metrics import confusion_matrix, roc_curve, auc
  19. class Evaluator(object):
  20. def __init__(self, model_path, topk=5):
  21. with open(osp.join(model_path, "model.yml")) as f:
  22. model_info = yaml.load(f.read(), Loader=yaml.Loader)
  23. with open(osp.join(model_path, 'eval_details.json'), 'r') as f:
  24. eval_details = json.load(f)
  25. self.topk = topk
  26. self.labels = model_info['_Attributes']['labels']
  27. self.true_labels = np.array(eval_details['true_labels'])
  28. self.pred_scores = np.array(eval_details['pred_scores'])
  29. label_ids_list = list(range(len(self.labels)))
  30. self.no_appear_label_ids = set(label_ids_list) - set(
  31. self.true_labels.tolist())
  32. def cal_confusion_matrix(self):
  33. '''计算混淆矩阵。
  34. '''
  35. pred_labels = np.argsort(self.pred_scores)[:, -1:].flatten()
  36. cm = confusion_matrix(
  37. self.true_labels.tolist(),
  38. pred_labels.tolist(),
  39. labels=list(range(len(self.labels))))
  40. return cm
  41. def cal_precision_recall_F1(self):
  42. '''计算precision、recall、F1。
  43. '''
  44. out = {}
  45. out_avg = {}
  46. out_avg['precision'] = 0.0
  47. out_avg['recall'] = 0.0
  48. out_avg['F1'] = 0.0
  49. pred_labels = np.argsort(self.pred_scores)[:, -1:].flatten()
  50. for label_id in range(len(self.labels)):
  51. out[self.labels[label_id]] = {}
  52. if label_id in self.no_appear_label_ids:
  53. out[self.labels[label_id]]['precision'] = -1.0
  54. out[self.labels[label_id]]['recall'] = -1.0
  55. out[self.labels[label_id]]['F1'] = -1.0
  56. continue
  57. pred_index = np.where(pred_labels == label_id)[0].tolist()
  58. tp = np.sum(
  59. self.true_labels[pred_index] == pred_labels[pred_index])
  60. tp_fp = len(pred_index)
  61. tp_fn = len(np.where(self.true_labels == label_id)[0].tolist())
  62. out[self.labels[label_id]]['precision'] = tp * 1.0 / tp_fp
  63. out[self.labels[label_id]]['recall'] = tp * 1.0 / tp_fn
  64. out[self.labels[label_id]]['F1'] = 2 * tp * 1.0 / (tp_fp + tp_fn)
  65. ratio = tp_fn * 1.0 / self.true_labels.shape[0]
  66. out_avg['precision'] += out[self.labels[label_id]][
  67. 'precision'] * ratio
  68. out_avg['recall'] += out[self.labels[label_id]]['recall'] * ratio
  69. out_avg['F1'] += out[self.labels[label_id]]['F1'] * ratio
  70. return out, out_avg
  71. def cal_auc(self):
  72. '''计算AUC。
  73. '''
  74. out = {}
  75. for label_id in range(len(self.labels)):
  76. part_pred_scores = self.pred_scores[:, label_id:label_id + 1]
  77. part_pred_scores = part_pred_scores.flatten()
  78. fpr, tpr, thresholds = roc_curve(
  79. self.true_labels, part_pred_scores, pos_label=label_id)
  80. label_auc = auc(fpr, tpr)
  81. if label_id in self.no_appear_label_ids:
  82. out[self.labels[label_id]] = -1.0
  83. continue
  84. out[self.labels[label_id]] = label_auc
  85. return out
  86. def cal_accuracy(self):
  87. '''计算Accuracy。
  88. '''
  89. out = {}
  90. k = min(self.topk, len(self.labels))
  91. pred_top1_label = np.argsort(self.pred_scores)[:, -1]
  92. pred_topk_label = np.argsort(self.pred_scores)[:, -k:]
  93. acc1 = sum(pred_top1_label == self.true_labels) / len(self.true_labels)
  94. acck = sum([
  95. np.isin(x, y) for x, y in zip(self.true_labels, pred_topk_label)
  96. ]) / len(self.true_labels)
  97. out['acc1'] = acc1
  98. out['acck'] = acck
  99. out['k'] = k
  100. return out
  101. def generate_report(self):
  102. '''生成评估报告。
  103. '''
  104. report = dict()
  105. report['Confusion_Matrix'] = self.cal_confusion_matrix()
  106. report['PRF1_average'] = {}
  107. report['PRF1'], report['PRF1_average'][
  108. 'over_all'] = self.cal_precision_recall_F1()
  109. auc = self.cal_auc()
  110. for k, v in auc.items():
  111. report['PRF1'][k]['auc'] = v
  112. acc = self.cal_accuracy()
  113. report["Acc1"] = acc["acc1"]
  114. report["Acck"] = acc["acck"]
  115. report["topk"] = acc["k"]
  116. report['label_list'] = self.labels
  117. return report