segmentation.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # copytrue (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import yaml
  16. import os.path as osp
  17. import numpy as np
  18. from paddlex.utils import get_encoding
  19. class Evaluator(object):
  20. def __init__(self, model_path):
  21. model_yml = osp.join(model_path, "model.yml")
  22. with open(model_yml, encoding=get_encoding(model_yml)) as f:
  23. model_info = yaml.load(f.read(), Loader=yaml.Loader)
  24. self.labels = model_info['_Attributes']['labels']
  25. eval_details_file = osp.join(model_path, 'eval_details.json')
  26. with open(
  27. eval_details_file, 'r',
  28. encoding=get_encoding(eval_details_file)) as f:
  29. eval_details = json.load(f)
  30. self.confusion_matrix = np.array(eval_details['confusion_matrix'])
  31. self.num_classes = len(self.confusion_matrix)
  32. def cal_iou(self):
  33. '''计算IoU。
  34. '''
  35. category_iou = []
  36. mean_iou = 0
  37. vji = np.sum(self.confusion_matrix, axis=1)
  38. vij = np.sum(self.confusion_matrix, axis=0)
  39. for c in range(self.num_classes):
  40. total = vji[c] + vij[c] - self.confusion_matrix[c][c]
  41. if total == 0:
  42. iou = 0
  43. else:
  44. iou = float(self.confusion_matrix[c][c]) / total
  45. mean_iou += iou
  46. category_iou.append(iou)
  47. mean_iou = float(mean_iou) / float(self.num_classes)
  48. return np.array(category_iou), mean_iou
  49. def cal_acc(self):
  50. '''计算Acc。
  51. '''
  52. total = self.confusion_matrix.sum()
  53. total_tp = 0
  54. for c in range(self.num_classes):
  55. total_tp += self.confusion_matrix[c][c]
  56. if total == 0:
  57. mean_acc = 0
  58. else:
  59. mean_acc = float(total_tp) / total
  60. vij = np.sum(self.confusion_matrix, axis=0)
  61. category_acc = []
  62. for c in range(self.num_classes):
  63. if vij[c] == 0:
  64. acc = 0
  65. else:
  66. acc = self.confusion_matrix[c][c] / float(vij[c])
  67. category_acc.append(acc)
  68. return np.array(category_acc), mean_acc
  69. def cal_confusion_matrix(self):
  70. '''计算混淆矩阵。
  71. '''
  72. return self.confusion_matrix
  73. def cal_precision_recall(self):
  74. '''计算precision、recall.
  75. '''
  76. self.precision_recall = dict()
  77. for i in range(len(self.labels)):
  78. label_name = self.labels[i]
  79. if np.isclose(np.sum(self.confusion_matrix[i, :]), 0, atol=1e-6):
  80. recall = -1
  81. else:
  82. total_gt = np.sum(self.confusion_matrix[i, :]) + 1e-06
  83. recall = self.confusion_matrix[i, i] / total_gt
  84. if np.isclose(np.sum(self.confusion_matrix[:, i]), 0, atol=1e-6):
  85. precision = -1
  86. else:
  87. total_pred = np.sum(self.confusion_matrix[:, i]) + 1e-06
  88. precision = self.confusion_matrix[i, i] / total_pred
  89. self.precision_recall[label_name] = {
  90. 'precision': precision,
  91. 'recall': recall
  92. }
  93. return self.precision_recall
  94. def generate_report(self):
  95. '''生成评估报告。
  96. '''
  97. category_iou, mean_iou = self.cal_iou()
  98. category_acc, mean_acc = self.cal_acc()
  99. category_iou_dict = {}
  100. for i in range(len(category_iou)):
  101. category_iou_dict[self.labels[i]] = category_iou[i]
  102. report = dict()
  103. report['Confusion_Matrix'] = self.cal_confusion_matrix()
  104. report['Mean_IoU'] = mean_iou
  105. report['Mean_Acc'] = mean_acc
  106. report['PRIoU'] = self.cal_precision_recall()
  107. for key in report['PRIoU']:
  108. report['PRIoU'][key]["iou"] = category_iou_dict[key]
  109. report['label_list'] = self.labels
  110. return report