compare.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # coding: utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # 环境变量配置,用于控制是否使用GPU
  16. # 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
  17. import argparse
  18. import os
  19. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  20. import os.path as osp
  21. import cv2
  22. import re
  23. import xml.etree.ElementTree as ET
  24. import paddlex as pdx
  25. def parse_xml_file(xml_file):
  26. tree = ET.parse(xml_file)
  27. pattern = re.compile('<object>', re.IGNORECASE)
  28. obj_match = pattern.findall(str(ET.tostringlist(tree.getroot())))
  29. if len(obj_match) == 0:
  30. return False
  31. obj_tag = obj_match[0][1:-1]
  32. objs = tree.findall(obj_tag)
  33. pattern = re.compile('<size>', re.IGNORECASE)
  34. size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1]
  35. size_element = tree.find(size_tag)
  36. pattern = re.compile('<width>', re.IGNORECASE)
  37. width_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
  38. im_w = float(size_element.find(width_tag).text)
  39. pattern = re.compile('<height>', re.IGNORECASE)
  40. height_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
  41. im_h = float(size_element.find(height_tag).text)
  42. gt_bbox = []
  43. gt_class = []
  44. for i, obj in enumerate(objs):
  45. pattern = re.compile('<name>', re.IGNORECASE)
  46. name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
  47. cname = obj.find(name_tag).text.strip()
  48. gt_class.append(cname)
  49. pattern = re.compile('<difficult>', re.IGNORECASE)
  50. diff_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
  51. try:
  52. _difficult = int(obj.find(diff_tag).text)
  53. except Exception:
  54. _difficult = 0
  55. pattern = re.compile('<bndbox>', re.IGNORECASE)
  56. box_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
  57. box_element = obj.find(box_tag)
  58. pattern = re.compile('<xmin>', re.IGNORECASE)
  59. xmin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
  60. x1 = float(box_element.find(xmin_tag).text)
  61. pattern = re.compile('<ymin>', re.IGNORECASE)
  62. ymin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
  63. y1 = float(box_element.find(ymin_tag).text)
  64. pattern = re.compile('<xmax>', re.IGNORECASE)
  65. xmax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
  66. x2 = float(box_element.find(xmax_tag).text)
  67. pattern = re.compile('<ymax>', re.IGNORECASE)
  68. ymax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
  69. y2 = float(box_element.find(ymax_tag).text)
  70. x1 = max(0, x1)
  71. y1 = max(0, y1)
  72. if im_w > 0.5 and im_h > 0.5:
  73. x2 = min(im_w - 1, x2)
  74. y2 = min(im_h - 1, y2)
  75. gt_bbox.append([x1, y1, x2, y2])
  76. gts = []
  77. for bbox, name in zip(gt_bbox, gt_class):
  78. x1, y1, x2, y2 = bbox
  79. w = x2 - x1 + 1
  80. h = y2 - y1 + 1
  81. gt = {
  82. 'category_id': 0,
  83. 'category': name,
  84. 'bbox': [x1, y1, w, h],
  85. 'score': 1
  86. }
  87. gts.append(gt)
  88. return gts
  89. if __name__ == '__main__':
  90. parser = argparse.ArgumentParser(description=__doc__)
  91. parser.add_argument(
  92. "--model_dir",
  93. default="./output/faster_rcnn_r50_vd_dcn/best_model/",
  94. type=str,
  95. help="The model directory path.")
  96. parser.add_argument(
  97. "--dataset_dir",
  98. default="./aluminum_inspection",
  99. type=str,
  100. help="The VOC-format dataset directory path.")
  101. parser.add_argument(
  102. "--save_dir",
  103. default="./visualize/compare",
  104. type=str,
  105. help="The directory path of result.")
  106. parser.add_argument(
  107. "--score_threshold",
  108. default=0.1,
  109. type=float,
  110. help="The predicted bbox whose score is lower than score_threshold is filtered."
  111. )
  112. args = parser.parse_args()
  113. if not os.path.exists(args.save_dir):
  114. os.makedirs(args.save_dir)
  115. file_list = osp.join(args.dataset_dir, 'val_list.txt')
  116. model = pdx.load_model(args.model_dir)
  117. with open(file_list, 'r') as fr:
  118. while True:
  119. line = fr.readline()
  120. if not line:
  121. break
  122. img_file, xml_file = [osp.join(args.dataset_dir, x) \
  123. for x in line.strip().split()[:2]]
  124. if not osp.exists(img_file):
  125. continue
  126. if not osp.exists(xml_file):
  127. continue
  128. res = model.predict(img_file)
  129. gts = parse_xml_file(xml_file)
  130. det_vis = pdx.det.visualize(
  131. img_file, res, threshold=args.score_threshold, save_dir=None)
  132. if gts == False:
  133. gts = cv2.imread(img_file)
  134. else:
  135. gt_vis = pdx.det.visualize(
  136. img_file,
  137. gts,
  138. threshold=args.score_threshold,
  139. save_dir=None)
  140. vis = cv2.hconcat([gt_vis, det_vis])
  141. cv2.imwrite(
  142. os.path.join(args.save_dir, os.path.split(img_file)[-1]), vis)
  143. print('The comparison has been made for {}'.format(img_file))
  144. print(
  145. "The visualized ground-truths and predictions are saved in {}. Ground-truth is on the left, prediciton is on the right".
  146. format(args.save_dir))