detection.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import copy
  16. import collections
  17. import math
  18. def eval_detection(
  19. model,
  20. data_dir,
  21. ann_file,
  22. conf_threshold=None,
  23. nms_iou_threshold=None,
  24. plot=False,
  25. batch_size=1,
  26. ):
  27. from .utils import CocoDetection
  28. from .utils import COCOMetric
  29. import cv2
  30. from tqdm import trange
  31. import time
  32. if conf_threshold is not None or nms_iou_threshold is not None:
  33. assert (
  34. conf_threshold is not None and nms_iou_threshold is not None
  35. ), "The conf_threshold and nms_iou_threshold should be setted at the same time"
  36. assert isinstance(
  37. conf_threshold, (float, int)
  38. ), "The conf_threshold:{} need to be int or float".format(conf_threshold)
  39. assert isinstance(
  40. nms_iou_threshold, (float, int)
  41. ), "The nms_iou_threshold:{} need to be int or float".format(nms_iou_threshold)
  42. eval_dataset = CocoDetection(data_dir=data_dir, ann_file=ann_file, shuffle=False)
  43. all_image_info = eval_dataset.file_list
  44. image_num = eval_dataset.num_samples
  45. eval_dataset.data_fields = {
  46. "im_id",
  47. "image_shape",
  48. "image",
  49. "gt_bbox",
  50. "gt_class",
  51. "is_crowd",
  52. }
  53. eval_metric = COCOMetric(
  54. coco_gt=copy.deepcopy(eval_dataset.coco_gt), classwise=False
  55. )
  56. scores = collections.OrderedDict()
  57. twenty_percent_image_num = math.ceil(image_num * 0.2)
  58. start_time = 0
  59. end_time = 0
  60. average_inference_time = 0
  61. im_list = list()
  62. im_id_list = list()
  63. for image_info, i in zip(
  64. all_image_info, trange(image_num, desc="Inference Progress")
  65. ):
  66. if i == twenty_percent_image_num:
  67. start_time = time.time()
  68. im = cv2.imread(image_info["image"])
  69. im_id = image_info["im_id"]
  70. if batch_size == 1:
  71. if conf_threshold is None and nms_iou_threshold is None:
  72. result = model.predict(im.copy())
  73. else:
  74. result = model.predict(im, conf_threshold, nms_iou_threshold)
  75. pred = {
  76. "bbox": [
  77. [c] + [s] + b
  78. for b, s, c in zip(result.boxes, result.scores, result.label_ids)
  79. ],
  80. "bbox_num": len(result.boxes),
  81. "im_id": im_id,
  82. }
  83. eval_metric.update(im_id, pred)
  84. else:
  85. im_list.append(im)
  86. im_id_list.append(im_id)
  87. # If the batch_size is not satisfied, the remaining pictures are formed into a batch
  88. if (i + 1) % batch_size != 0 and i != image_num - 1:
  89. continue
  90. if conf_threshold is None and nms_iou_threshold is None:
  91. results = model.batch_predict(im_list)
  92. else:
  93. model.postprocessor.conf_threshold = conf_threshold
  94. model.postprocessor.nms_threshold = nms_iou_threshold
  95. results = model.batch_predict(im_list)
  96. for k in range(len(im_list)):
  97. pred = {
  98. "bbox": [
  99. [c] + [s] + b
  100. for b, s, c in zip(
  101. results[k].boxes, results[k].scores, results[k].label_ids
  102. )
  103. ],
  104. "bbox_num": len(results[k].boxes),
  105. "im_id": im_id_list[k],
  106. }
  107. eval_metric.update(im_id_list[k], pred)
  108. im_list.clear()
  109. im_id_list.clear()
  110. if i == image_num - 1:
  111. end_time = time.time()
  112. average_inference_time = round(
  113. (end_time - start_time) / (image_num - twenty_percent_image_num), 4
  114. )
  115. eval_metric.accumulate()
  116. eval_details = eval_metric.details
  117. scores.update(eval_metric.get())
  118. scores.update({"average_inference_time(s)": average_inference_time})
  119. eval_metric.reset()
  120. return scores