cal_tp_fp.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # coding: utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # 环境变量配置,用于控制是否使用GPU
  16. # 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
  17. import os
  18. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  19. import os.path as osp
  20. import numpy as np
  21. import matplotlib
  22. matplotlib.use('Agg')
  23. import matplotlib.pyplot as plt
  24. import paddlex as pdx
  25. data_dir = 'aluminum_inspection/'
  26. positive_file_list = 'aluminum_inspection/val_list.txt'
  27. negative_dir = 'aluminum_inspection/val_wu_xia_ci'
  28. model_dir = 'output/faster_rcnn_r50_vd_dcn/best_model/'
  29. save_dir = 'visualize/faster_rcnn_r50_vd_dcn'
  30. if not osp.exists(save_dir):
  31. os.makedirs(save_dir)
  32. tp = np.zeros((101, 1))
  33. fp = np.zeros((101, 1))
  34. # 导入模型
  35. model = pdx.load_model(model_dir)
  36. # 计算图片级召回率
  37. print(
  38. "Begin to calculate image-level recall rate of positive images. Please wait for a moment..."
  39. )
  40. positive_num = 0
  41. with open(positive_file_list, 'r') as fr:
  42. while True:
  43. line = fr.readline()
  44. if not line:
  45. break
  46. img_file, xml_file = [osp.join(data_dir, x) \
  47. for x in line.strip().split()[:2]]
  48. if not osp.exists(img_file):
  49. continue
  50. if not osp.exists(xml_file):
  51. continue
  52. positive_num += 1
  53. results = model.predict(img_file)
  54. scores = list()
  55. for res in results:
  56. scores.append(res['score'])
  57. if len(scores) > 0:
  58. tp[0:int(np.round(max(scores) / 0.01)), 0] += 1
  59. tp = tp / positive_num
  60. # 计算图片级误检率
  61. print(
  62. "Begin to calculate image-level false-positive rate of background images. Please wait for a moment..."
  63. )
  64. negative_num = 0
  65. for file in os.listdir(negative_dir):
  66. file = osp.join(negative_dir, file)
  67. results = model.predict(file)
  68. negative_num += 1
  69. scores = list()
  70. for res in results:
  71. scores.append(res['score'])
  72. if len(scores) > 0:
  73. fp[0:int(np.round(max(scores) / 0.01)), 0] += 1
  74. fp = fp / negative_num
  75. # 保存结果
  76. tp_fp_list_file = osp.join(save_dir, 'tp_fp_list.txt')
  77. with open(tp_fp_list_file, 'w') as f:
  78. f.write("| score | recall rate | false-positive rate |\n")
  79. f.write("| -- | -- | -- |\n")
  80. for i in range(100):
  81. f.write("| {:2f} | {:2f} | {:2f} |\n".format(0.01 * i, tp[i, 0], fp[
  82. i, 0]))
  83. print("The numerical score-recall_rate-false_positive_rate is saved as {}".
  84. format(tp_fp_list_file))
  85. plt.subplot(1, 2, 1)
  86. plt.title("image-level false_positive-recall")
  87. plt.xlabel("recall")
  88. plt.ylabel("false_positive")
  89. plt.xlim(0, 1)
  90. plt.ylim(0, 1)
  91. plt.grid(linestyle='--', linewidth=1)
  92. plt.plot([0, 1], [0, 1], 'r--', linewidth=1)
  93. my_x_ticks = np.arange(0, 1, 0.1)
  94. my_y_ticks = np.arange(0, 1, 0.1)
  95. plt.xticks(my_x_ticks, fontsize=5)
  96. plt.yticks(my_y_ticks, fontsize=5)
  97. plt.plot(tp, fp, color='b', label="image level", linewidth=1)
  98. plt.legend(loc="lower left", fontsize=5)
  99. plt.subplot(1, 2, 2)
  100. plt.title("score-recall")
  101. plt.xlabel('recall')
  102. plt.ylabel('score')
  103. plt.xlim(0, 1)
  104. plt.ylim(0, 1)
  105. plt.grid(linestyle='--', linewidth=1)
  106. plt.xticks(my_x_ticks, fontsize=5)
  107. plt.yticks(my_y_ticks, fontsize=5)
  108. plt.plot(
  109. tp, np.arange(0, 1.01, 0.01), color='b', label="image level", linewidth=1)
  110. plt.legend(loc="lower left", fontsize=5)
  111. tp_fp_chart_file = os.path.join(save_dir, "image-level_tp_fp.png")
  112. plt.savefig(tp_fp_chart_file, dpi=800)
  113. plt.close()
  114. print("The diagrammatic score-recall_rate-false_positive_rate is saved as {}".
  115. format(tp_fp_chart_file))