compare.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # coding: utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # 环境变量配置,用于控制是否使用GPU
  16. # 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
  17. import os
  18. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  19. import os.path as osp
  20. import cv2
  21. import re
  22. import xml.etree.ElementTree as ET
  23. import paddlex as pdx
  24. data_dir = 'aluminum_inspection/'
  25. file_list = 'aluminum_inspection/val_list.txt'
  26. model_dir = 'output/faster_rcnn_r50_vd_dcn/best_model/'
  27. save_dir = './visualize/compare'
  28. # 设置置信度阈值
  29. score_threshold = 0.1
  30. if not os.path.exists(save_dir):
  31. os.makedirs(save_dir)
  32. model = pdx.load_model(model_dir)
  33. with open(file_list, 'r') as fr:
  34. while True:
  35. line = fr.readline()
  36. if not line:
  37. break
  38. img_file, xml_file = [osp.join(data_dir, x) \
  39. for x in line.strip().split()[:2]]
  40. if not osp.exists(img_file):
  41. continue
  42. if not osp.exists(xml_file):
  43. continue
  44. res = model.predict(img_file)
  45. det_vis = pdx.det.visualize(
  46. img_file, res, threshold=score_threshold, save_dir=None)
  47. tree = ET.parse(xml_file)
  48. pattern = re.compile('<object>', re.IGNORECASE)
  49. obj_match = pattern.findall(str(ET.tostringlist(tree.getroot())))
  50. if len(obj_match) == 0:
  51. continue
  52. obj_tag = obj_match[0][1:-1]
  53. objs = tree.findall(obj_tag)
  54. pattern = re.compile('<size>', re.IGNORECASE)
  55. size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:
  56. -1]
  57. size_element = tree.find(size_tag)
  58. pattern = re.compile('<width>', re.IGNORECASE)
  59. width_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:
  60. -1]
  61. im_w = float(size_element.find(width_tag).text)
  62. pattern = re.compile('<height>', re.IGNORECASE)
  63. height_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:
  64. -1]
  65. im_h = float(size_element.find(height_tag).text)
  66. gt_bbox = []
  67. gt_class = []
  68. for i, obj in enumerate(objs):
  69. pattern = re.compile('<name>', re.IGNORECASE)
  70. name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
  71. cname = obj.find(name_tag).text.strip()
  72. gt_class.append(cname)
  73. pattern = re.compile('<difficult>', re.IGNORECASE)
  74. diff_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
  75. try:
  76. _difficult = int(obj.find(diff_tag).text)
  77. except Exception:
  78. _difficult = 0
  79. pattern = re.compile('<bndbox>', re.IGNORECASE)
  80. box_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
  81. box_element = obj.find(box_tag)
  82. pattern = re.compile('<xmin>', re.IGNORECASE)
  83. xmin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][
  84. 1:-1]
  85. x1 = float(box_element.find(xmin_tag).text)
  86. pattern = re.compile('<ymin>', re.IGNORECASE)
  87. ymin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][
  88. 1:-1]
  89. y1 = float(box_element.find(ymin_tag).text)
  90. pattern = re.compile('<xmax>', re.IGNORECASE)
  91. xmax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][
  92. 1:-1]
  93. x2 = float(box_element.find(xmax_tag).text)
  94. pattern = re.compile('<ymax>', re.IGNORECASE)
  95. ymax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][
  96. 1:-1]
  97. y2 = float(box_element.find(ymax_tag).text)
  98. x1 = max(0, x1)
  99. y1 = max(0, y1)
  100. if im_w > 0.5 and im_h > 0.5:
  101. x2 = min(im_w - 1, x2)
  102. y2 = min(im_h - 1, y2)
  103. gt_bbox.append([x1, y1, x2, y2])
  104. gts = []
  105. for bbox, name in zip(gt_bbox, gt_class):
  106. x1, y1, x2, y2 = bbox
  107. w = x2 - x1 + 1
  108. h = y2 - y1 + 1
  109. gt = {
  110. 'category_id': 0,
  111. 'category': name,
  112. 'bbox': [x1, y1, w, h],
  113. 'score': 1
  114. }
  115. gts.append(gt)
  116. gt_vis = pdx.det.visualize(
  117. img_file, gts, threshold=score_threshold, save_dir=None)
  118. vis = cv2.hconcat([gt_vis, det_vis])
  119. cv2.imwrite(os.path.join(save_dir, os.path.split(img_file)[-1]), vis)
  120. print('The comparison has been made for {}'.format(img_file))
  121. print(
  122. "The visualized ground-truths and predictions are saved in {}. Ground-truth is on the left, prediciton is on the right".
  123. format(save_dir))