segmentation.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from tqdm import trange
  15. import numpy as np
  16. import collections
  17. import os
  18. import math
  19. import time
  20. def eval_segmentation(model, data_dir, batch_size=1):
  21. import cv2
  22. from .utils import Cityscapes
  23. from .utils import f1_score, calculate_area, mean_iou, accuracy, kappa
  24. assert os.path.isdir(data_dir), "The image_file_path:{} is not a directory.".format(
  25. data_dir
  26. )
  27. eval_dataset = Cityscapes(dataset_root=data_dir, mode="val")
  28. file_list = eval_dataset.file_list
  29. image_num = eval_dataset.num_samples
  30. num_classes = eval_dataset.num_classes
  31. intersect_area_all = 0
  32. pred_area_all = 0
  33. label_area_all = 0
  34. conf_mat_all = []
  35. twenty_percent_image_num = math.ceil(image_num * 0.2)
  36. start_time = 0
  37. end_time = 0
  38. average_inference_time = 0
  39. im_list = []
  40. label_list = []
  41. for image_label_path, i in zip(
  42. file_list, trange(image_num, desc="Inference Progress")
  43. ):
  44. if i == twenty_percent_image_num:
  45. start_time = time.time()
  46. im = cv2.imread(image_label_path[0])
  47. label = cv2.imread(image_label_path[1], cv2.IMREAD_GRAYSCALE)
  48. label_list.append(label)
  49. if batch_size == 1:
  50. result = model.predict(im)
  51. results = [result]
  52. else:
  53. im_list.append(im)
  54. # If the batch_size is not satisfied, the remaining pictures are formed into a batch
  55. if (i + 1) % batch_size != 0 and i != image_num - 1:
  56. continue
  57. results = model.batch_predict(im_list)
  58. if i == image_num - 1:
  59. end_time = time.time()
  60. average_inference_time = round(
  61. (end_time - start_time) / (image_num - twenty_percent_image_num), 4
  62. )
  63. for result, label in zip(results, label_list):
  64. pred = np.array(result.label_map).reshape(result.shape[0], result.shape[1])
  65. intersect_area, pred_area, label_area = calculate_area(
  66. pred, label, num_classes
  67. )
  68. intersect_area_all = intersect_area_all + intersect_area
  69. pred_area_all = pred_area_all + pred_area
  70. label_area_all = label_area_all + label_area
  71. im_list.clear()
  72. label_list.clear()
  73. class_iou, miou = mean_iou(intersect_area_all, pred_area_all, label_area_all)
  74. class_acc, oacc = accuracy(intersect_area_all, pred_area_all)
  75. kappa_res = kappa(intersect_area_all, pred_area_all, label_area_all)
  76. category_f1score = f1_score(intersect_area_all, pred_area_all, label_area_all)
  77. eval_metrics = collections.OrderedDict(
  78. zip(
  79. [
  80. "miou",
  81. "category_iou",
  82. "oacc",
  83. "category_acc",
  84. "kappa",
  85. "category_F1-score",
  86. "average_inference_time(s)",
  87. ],
  88. [
  89. miou,
  90. class_iou,
  91. oacc,
  92. class_acc,
  93. kappa_res,
  94. category_f1score,
  95. average_inference_time,
  96. ],
  97. )
  98. )
  99. return eval_metrics