mot_eval_utils.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. import copy
  17. import motmetrics as mm
  18. mm.lap.default_solver = 'lap'
  19. __all__ = [
  20. 'read_mot_results',
  21. 'unzip_objs',
  22. 'MOTEvaluator',
  23. ]
  24. def read_mot_results(filename, is_gt=False, is_ignore=False):
  25. valid_labels = {1}
  26. ignore_labels = {2, 7, 8, 12}
  27. results_dict = dict()
  28. if os.path.isfile(filename):
  29. with open(filename, 'r') as f:
  30. for line in f.readlines():
  31. linelist = line.split(',')
  32. if len(linelist) < 7:
  33. continue
  34. fid = int(linelist[0])
  35. if fid < 1:
  36. continue
  37. results_dict.setdefault(fid, list())
  38. box_size = float(linelist[4]) * float(linelist[5])
  39. if is_gt:
  40. if 'MOT16-' in filename or 'MOT17-' in filename:
  41. label = int(float(linelist[7]))
  42. mark = int(float(linelist[6]))
  43. if mark == 0 or label not in valid_labels:
  44. continue
  45. score = 1
  46. elif is_ignore:
  47. if 'MOT16-' in filename or 'MOT17-' in filename:
  48. label = int(float(linelist[7]))
  49. vis_ratio = float(linelist[8])
  50. if label not in ignore_labels and vis_ratio >= 0:
  51. continue
  52. else:
  53. continue
  54. score = 1
  55. else:
  56. score = float(linelist[6])
  57. tlwh = tuple(map(float, linelist[2:6]))
  58. target_id = int(linelist[1])
  59. results_dict[fid].append((tlwh, target_id, score))
  60. return results_dict
  61. """
  62. labels={'ped', ... % 1
  63. 'person_on_vhcl', ... % 2
  64. 'car', ... % 3
  65. 'bicycle', ... % 4
  66. 'mbike', ... % 5
  67. 'non_mot_vhcl', ... % 6
  68. 'static_person', ... % 7
  69. 'distractor', ... % 8
  70. 'occluder', ... % 9
  71. 'occluder_on_grnd', ... % 10
  72. 'occluder_full', ... % 11
  73. 'reflection', ... % 12
  74. 'crowd' ... % 13
  75. };
  76. """
  77. def unzip_objs(objs):
  78. if len(objs) > 0:
  79. tlwhs, ids, scores = zip(*objs)
  80. else:
  81. tlwhs, ids, scores = [], [], []
  82. tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
  83. return tlwhs, ids, scores
  84. class MOTEvaluator(object):
  85. def __init__(self, data_root, seq_name, data_type):
  86. self.data_root = data_root
  87. self.seq_name = seq_name
  88. self.data_type = data_type
  89. self.load_annotations()
  90. self.reset_accumulator()
  91. def load_annotations(self):
  92. assert self.data_type == 'mot'
  93. gt_filename = os.path.join(self.data_root, self.seq_name, 'gt',
  94. 'gt.txt')
  95. self.gt_frame_dict = read_mot_results(gt_filename, is_gt=True)
  96. self.gt_ignore_frame_dict = read_mot_results(
  97. gt_filename, is_ignore=True)
  98. def reset_accumulator(self):
  99. self.acc = mm.MOTAccumulator(auto_id=True)
  100. def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False):
  101. # results
  102. trk_tlwhs = np.copy(trk_tlwhs)
  103. trk_ids = np.copy(trk_ids)
  104. # gts
  105. gt_objs = self.gt_frame_dict.get(frame_id, [])
  106. gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2]
  107. # ignore boxes
  108. ignore_objs = self.gt_ignore_frame_dict.get(frame_id, [])
  109. ignore_tlwhs = unzip_objs(ignore_objs)[0]
  110. # remove ignored results
  111. keep = np.ones(len(trk_tlwhs), dtype=bool)
  112. iou_distance = mm.distances.iou_matrix(
  113. ignore_tlwhs, trk_tlwhs, max_iou=0.5)
  114. if len(iou_distance) > 0:
  115. match_is, match_js = mm.lap.linear_sum_assignment(iou_distance)
  116. match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js])
  117. match_ious = iou_distance[match_is, match_js]
  118. match_js = np.asarray(match_js, dtype=int)
  119. match_js = match_js[np.logical_not(np.isnan(match_ious))]
  120. keep[match_js] = False
  121. trk_tlwhs = trk_tlwhs[keep]
  122. trk_ids = trk_ids[keep]
  123. # get distance matrix
  124. iou_distance = mm.distances.iou_matrix(
  125. gt_tlwhs, trk_tlwhs, max_iou=0.5)
  126. # acc
  127. self.acc.update(gt_ids, trk_ids, iou_distance)
  128. if rtn_events and iou_distance.size > 0 and hasattr(self.acc,
  129. 'last_mot_events'):
  130. events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics
  131. else:
  132. events = None
  133. return events
  134. def eval_file(self, filename):
  135. self.reset_accumulator()
  136. result_frame_dict = read_mot_results(filename, is_gt=False)
  137. frames = sorted(list(set(result_frame_dict.keys())))
  138. for frame_id in frames:
  139. trk_objs = result_frame_dict.get(frame_id, [])
  140. trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2]
  141. self.eval_frame(frame_id, trk_tlwhs, trk_ids, rtn_events=False)
  142. return self.acc
  143. @staticmethod
  144. def get_summary(accs,
  145. names,
  146. metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1',
  147. 'precision', 'recall')):
  148. names = copy.deepcopy(names)
  149. if metrics is None:
  150. metrics = mm.metrics.motchallenge_metrics
  151. metrics = copy.deepcopy(metrics)
  152. mh = mm.metrics.create()
  153. summary = mh.compute_many(
  154. accs, metrics=metrics, names=names, generate_overall=True)
  155. return summary
  156. @staticmethod
  157. def save_summary(summary, filename):
  158. import pandas as pd
  159. writer = pd.ExcelWriter(filename)
  160. summary.to_excel(writer)
  161. writer.save()