detection.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import collections
  16. import math
  17. def eval_detection(
  18. model,
  19. data_dir,
  20. ann_file,
  21. conf_threshold=None,
  22. nms_iou_threshold=None,
  23. plot=False,
  24. batch_size=1,
  25. ):
  26. from .utils import CocoDetection
  27. from .utils import COCOMetric
  28. import cv2
  29. from tqdm import trange
  30. import time
  31. if conf_threshold is not None or nms_iou_threshold is not None:
  32. assert (
  33. conf_threshold is not None and nms_iou_threshold is not None
  34. ), "The conf_threshold and nms_iou_threshold should be set at the same time"
  35. assert isinstance(
  36. conf_threshold, (float, int)
  37. ), "The conf_threshold:{} need to be int or float".format(conf_threshold)
  38. assert isinstance(
  39. nms_iou_threshold, (float, int)
  40. ), "The nms_iou_threshold:{} need to be int or float".format(nms_iou_threshold)
  41. eval_dataset = CocoDetection(data_dir=data_dir, ann_file=ann_file, shuffle=False)
  42. all_image_info = eval_dataset.file_list
  43. image_num = eval_dataset.num_samples
  44. eval_dataset.data_fields = {
  45. "im_id",
  46. "image_shape",
  47. "image",
  48. "gt_bbox",
  49. "gt_class",
  50. "is_crowd",
  51. }
  52. eval_metric = COCOMetric(
  53. coco_gt=copy.deepcopy(eval_dataset.coco_gt), classwise=False
  54. )
  55. scores = collections.OrderedDict()
  56. twenty_percent_image_num = math.ceil(image_num * 0.2)
  57. start_time = 0
  58. end_time = 0
  59. average_inference_time = 0
  60. im_list = list()
  61. im_id_list = list()
  62. for image_info, i in zip(
  63. all_image_info, trange(image_num, desc="Inference Progress")
  64. ):
  65. if i == twenty_percent_image_num:
  66. start_time = time.time()
  67. im = cv2.imread(image_info["image"])
  68. im_id = image_info["im_id"]
  69. if batch_size == 1:
  70. if conf_threshold is None and nms_iou_threshold is None:
  71. result = model.predict(im.copy())
  72. else:
  73. result = model.predict(im, conf_threshold, nms_iou_threshold)
  74. pred = {
  75. "bbox": [
  76. [c] + [s] + b
  77. for b, s, c in zip(result.boxes, result.scores, result.label_ids)
  78. ],
  79. "bbox_num": len(result.boxes),
  80. "im_id": im_id,
  81. }
  82. eval_metric.update(im_id, pred)
  83. else:
  84. im_list.append(im)
  85. im_id_list.append(im_id)
  86. # If the batch_size is not satisfied, the remaining pictures are formed into a batch
  87. if (i + 1) % batch_size != 0 and i != image_num - 1:
  88. continue
  89. if conf_threshold is None and nms_iou_threshold is None:
  90. results = model.batch_predict(im_list)
  91. else:
  92. model.postprocessor.conf_threshold = conf_threshold
  93. model.postprocessor.nms_threshold = nms_iou_threshold
  94. results = model.batch_predict(im_list)
  95. for k in range(len(im_list)):
  96. pred = {
  97. "bbox": [
  98. [c] + [s] + b
  99. for b, s, c in zip(
  100. results[k].boxes, results[k].scores, results[k].label_ids
  101. )
  102. ],
  103. "bbox_num": len(results[k].boxes),
  104. "im_id": im_id_list[k],
  105. }
  106. eval_metric.update(im_id_list[k], pred)
  107. im_list.clear()
  108. im_id_list.clear()
  109. if i == image_num - 1:
  110. end_time = time.time()
  111. average_inference_time = round(
  112. (end_time - start_time) / (image_num - twenty_percent_image_num), 4
  113. )
  114. eval_metric.accumulate()
  115. eval_details = eval_metric.details
  116. scores.update(eval_metric.get())
  117. scores.update({"average_inference_time(s)": average_inference_time})
  118. eval_metric.reset()
  119. return scores