segmentation.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # copytrue (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import yaml
  16. import os.path as osp
  17. import numpy as np
  18. class Evaluator(object):
  19. def __init__(self, model_path):
  20. with open(osp.join(model_path, "model.yml")) as f:
  21. model_info = yaml.load(f.read(), Loader=yaml.Loader)
  22. self.labels = model_info['_Attributes']['labels']
  23. with open(osp.join(model_path, 'eval_details.json'), 'r') as f:
  24. eval_details = json.load(f)
  25. self.confusion_matrix = np.array(eval_details['confusion_matrix'])
  26. self.num_classes = len(self.confusion_matrix)
  27. def cal_iou(self):
  28. '''计算IoU。
  29. '''
  30. category_iou = []
  31. mean_iou = 0
  32. vji = np.sum(self.confusion_matrix, axis=1)
  33. vij = np.sum(self.confusion_matrix, axis=0)
  34. for c in range(self.num_classes):
  35. total = vji[c] + vij[c] - self.confusion_matrix[c][c]
  36. if total == 0:
  37. iou = 0
  38. else:
  39. iou = float(self.confusion_matrix[c][c]) / total
  40. mean_iou += iou
  41. category_iou.append(iou)
  42. mean_iou = float(mean_iou) / float(self.num_classes)
  43. return np.array(category_iou), mean_iou
  44. def cal_acc(self):
  45. '''计算Acc。
  46. '''
  47. total = self.confusion_matrix.sum()
  48. total_tp = 0
  49. for c in range(self.num_classes):
  50. total_tp += self.confusion_matrix[c][c]
  51. if total == 0:
  52. mean_acc = 0
  53. else:
  54. mean_acc = float(total_tp) / total
  55. vij = np.sum(self.confusion_matrix, axis=0)
  56. category_acc = []
  57. for c in range(self.num_classes):
  58. if vij[c] == 0:
  59. acc = 0
  60. else:
  61. acc = self.confusion_matrix[c][c] / float(vij[c])
  62. category_acc.append(acc)
  63. return np.array(category_acc), mean_acc
  64. def cal_confusion_matrix(self):
  65. '''计算混淆矩阵。
  66. '''
  67. return self.confusion_matrix
  68. def cal_precision_recall(self):
  69. '''计算precision、recall.
  70. '''
  71. self.precision_recall = dict()
  72. for i in range(len(self.labels)):
  73. label_name = self.labels[i]
  74. if np.isclose(np.sum(self.confusion_matrix[i, :]), 0, atol=1e-6):
  75. recall = -1
  76. else:
  77. total_gt = np.sum(self.confusion_matrix[i, :]) + 1e-06
  78. recall = self.confusion_matrix[i, i] / total_gt
  79. if np.isclose(np.sum(self.confusion_matrix[:, i]), 0, atol=1e-6):
  80. precision = -1
  81. else:
  82. total_pred = np.sum(self.confusion_matrix[:, i]) + 1e-06
  83. precision = self.confusion_matrix[i, i] / total_pred
  84. self.precision_recall[label_name] = {
  85. 'precision': precision,
  86. 'recall': recall
  87. }
  88. return self.precision_recall
  89. def generate_report(self):
  90. '''生成评估报告。
  91. '''
  92. category_iou, mean_iou = self.cal_iou()
  93. category_acc, mean_acc = self.cal_acc()
  94. category_iou_dict = {}
  95. for i in range(len(category_iou)):
  96. category_iou_dict[self.labels[i]] = category_iou[i]
  97. report = dict()
  98. report['Confusion_Matrix'] = self.cal_confusion_matrix()
  99. report['Mean_IoU'] = mean_iou
  100. report['Mean_Acc'] = mean_acc
  101. report['PRIoU'] = self.cal_precision_recall()
  102. for key in report['PRIoU']:
  103. report['PRIoU'][key]["iou"] = category_iou_dict[key]
  104. report['label_list'] = self.labels
  105. return report