classify.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import os
  16. import time
  17. import collections
  18. def topk_accuracy(topk_list, label_list):
  19. match_array = np.logical_or.reduce(topk_list == label_list, axis=1)
  20. topk_acc_score = match_array.sum() / match_array.shape[0]
  21. return topk_acc_score
  22. def eval_classify(model, image_file_path, label_file_path, topk=5):
  23. from tqdm import trange
  24. import cv2
  25. import math
  26. result_list = []
  27. label_list = []
  28. image_label_dict = {}
  29. assert os.path.isdir(
  30. image_file_path
  31. ), "The image_file_path:{} is not a directory.".format(image_file_path)
  32. assert os.path.isfile(
  33. label_file_path
  34. ), "The label_file_path:{} is not a file.".format(label_file_path)
  35. assert isinstance(topk, int), "The tok:{} is not int type".format(topk)
  36. with open(label_file_path, "r") as file:
  37. lines = file.readlines()
  38. for line in lines:
  39. items = line.strip().split()
  40. image_name = items[0]
  41. label = items[1]
  42. image_label_dict[image_name] = int(label)
  43. images_num = len(image_label_dict)
  44. twenty_percent_images_num = math.ceil(images_num * 0.2)
  45. start_time = 0
  46. end_time = 0
  47. average_inference_time = 0
  48. scores = collections.OrderedDict()
  49. for (image, label), i in zip(
  50. image_label_dict.items(), trange(images_num, desc="Inference Progress")
  51. ):
  52. if i == twenty_percent_images_num:
  53. start_time = time.time()
  54. label_list.append([label])
  55. image_path = os.path.join(image_file_path, image)
  56. im = cv2.imread(image_path)
  57. result = model.predict(im, topk)
  58. result_list.append(result.label_ids)
  59. if i == images_num - 1:
  60. end_time = time.time()
  61. average_inference_time = round(
  62. (end_time - start_time) / (images_num - twenty_percent_images_num), 4
  63. )
  64. topk_acc_score = topk_accuracy(np.array(result_list), np.array(label_list))
  65. if topk == 1:
  66. scores.update({"topk1": topk_acc_score})
  67. scores.update({"topk1_average_inference_time(s)": average_inference_time})
  68. elif topk == 5:
  69. scores.update({"topk5": topk_acc_score})
  70. scores.update({"topk5_average_inference_time(s)": average_inference_time})
  71. return scores