mot_metrics.py 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import copy
  19. import sys
  20. import math
  21. from collections import defaultdict
  22. import numpy as np
  23. import paddle
  24. import paddle.nn.functional as F
  25. from paddlex.ppdet.modeling.bbox_utils import bbox_iou_np_expand
  26. from .map_utils import ap_per_class
  27. from .metrics import Metric
  28. from .munkres import Munkres
  29. from paddlex.ppdet.utils.logger import setup_logger
  30. logger = setup_logger(__name__)
  31. __all__ = ['MOTEvaluator', 'MOTMetric', 'JDEDetMetric', 'KITTIMOTMetric']
  32. def read_mot_results(filename, is_gt=False, is_ignore=False):
  33. valid_labels = {1}
  34. ignore_labels = {2, 7, 8, 12}
  35. results_dict = dict()
  36. if os.path.isfile(filename):
  37. with open(filename, 'r') as f:
  38. for line in f.readlines():
  39. linelist = line.split(',')
  40. if len(linelist) < 7:
  41. continue
  42. fid = int(linelist[0])
  43. if fid < 1:
  44. continue
  45. results_dict.setdefault(fid, list())
  46. box_size = float(linelist[4]) * float(linelist[5])
  47. if is_gt:
  48. if 'MOT16-' in filename or 'MOT17-' in filename or 'MOT15-' in filename or 'MOT20-' in filename:
  49. label = int(float(linelist[7]))
  50. mark = int(float(linelist[6]))
  51. if mark == 0 or label not in valid_labels:
  52. continue
  53. score = 1
  54. elif is_ignore:
  55. if 'MOT16-' in filename or 'MOT17-' in filename or 'MOT15-' in filename or 'MOT20-' in filename:
  56. label = int(float(linelist[7]))
  57. vis_ratio = float(linelist[8])
  58. if label not in ignore_labels and vis_ratio >= 0:
  59. continue
  60. else:
  61. continue
  62. score = 1
  63. else:
  64. score = float(linelist[6])
  65. tlwh = tuple(map(float, linelist[2:6]))
  66. target_id = int(linelist[1])
  67. results_dict[fid].append((tlwh, target_id, score))
  68. return results_dict
  69. """
  70. MOT dataset label list, see in https://motchallenge.net
  71. labels={'ped', ... % 1
  72. 'person_on_vhcl', ... % 2
  73. 'car', ... % 3
  74. 'bicycle', ... % 4
  75. 'mbike', ... % 5
  76. 'non_mot_vhcl', ... % 6
  77. 'static_person', ... % 7
  78. 'distractor', ... % 8
  79. 'occluder', ... % 9
  80. 'occluder_on_grnd', ... % 10
  81. 'occluder_full', ... % 11
  82. 'reflection', ... % 12
  83. 'crowd' ... % 13
  84. };
  85. """
  86. def unzip_objs(objs):
  87. if len(objs) > 0:
  88. tlwhs, ids, scores = zip(*objs)
  89. else:
  90. tlwhs, ids, scores = [], [], []
  91. tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
  92. return tlwhs, ids, scores
  93. class MOTEvaluator(object):
  94. def __init__(self, data_root, seq_name, data_type):
  95. self.data_root = data_root
  96. self.seq_name = seq_name
  97. self.data_type = data_type
  98. self.load_annotations()
  99. self.reset_accumulator()
  100. def load_annotations(self):
  101. assert self.data_type == 'mot'
  102. gt_filename = os.path.join(self.data_root, self.seq_name, 'gt',
  103. 'gt.txt')
  104. self.gt_frame_dict = read_mot_results(gt_filename, is_gt=True)
  105. self.gt_ignore_frame_dict = read_mot_results(
  106. gt_filename, is_ignore=True)
  107. def reset_accumulator(self):
  108. import motmetrics as mm
  109. mm.lap.default_solver = 'lap'
  110. self.acc = mm.MOTAccumulator(auto_id=True)
  111. def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False):
  112. import motmetrics as mm
  113. mm.lap.default_solver = 'lap'
  114. # results
  115. trk_tlwhs = np.copy(trk_tlwhs)
  116. trk_ids = np.copy(trk_ids)
  117. # gts
  118. gt_objs = self.gt_frame_dict.get(frame_id, [])
  119. gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2]
  120. # ignore boxes
  121. ignore_objs = self.gt_ignore_frame_dict.get(frame_id, [])
  122. ignore_tlwhs = unzip_objs(ignore_objs)[0]
  123. # remove ignored results
  124. keep = np.ones(len(trk_tlwhs), dtype=bool)
  125. iou_distance = mm.distances.iou_matrix(
  126. ignore_tlwhs, trk_tlwhs, max_iou=0.5)
  127. if len(iou_distance) > 0:
  128. match_is, match_js = mm.lap.linear_sum_assignment(iou_distance)
  129. match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js])
  130. match_ious = iou_distance[match_is, match_js]
  131. match_js = np.asarray(match_js, dtype=int)
  132. match_js = match_js[np.logical_not(np.isnan(match_ious))]
  133. keep[match_js] = False
  134. trk_tlwhs = trk_tlwhs[keep]
  135. trk_ids = trk_ids[keep]
  136. # get distance matrix
  137. iou_distance = mm.distances.iou_matrix(
  138. gt_tlwhs, trk_tlwhs, max_iou=0.5)
  139. # acc
  140. self.acc.update(gt_ids, trk_ids, iou_distance)
  141. if rtn_events and iou_distance.size > 0 and hasattr(self.acc,
  142. 'last_mot_events'):
  143. events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics
  144. else:
  145. events = None
  146. return events
  147. def eval_file(self, filename):
  148. self.reset_accumulator()
  149. result_frame_dict = read_mot_results(filename, is_gt=False)
  150. frames = sorted(list(set(result_frame_dict.keys())))
  151. for frame_id in frames:
  152. trk_objs = result_frame_dict.get(frame_id, [])
  153. trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2]
  154. self.eval_frame(frame_id, trk_tlwhs, trk_ids, rtn_events=False)
  155. return self.acc
  156. @staticmethod
  157. def get_summary(accs,
  158. names,
  159. metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1',
  160. 'precision', 'recall')):
  161. import motmetrics as mm
  162. mm.lap.default_solver = 'lap'
  163. names = copy.deepcopy(names)
  164. if metrics is None:
  165. metrics = mm.metrics.motchallenge_metrics
  166. metrics = copy.deepcopy(metrics)
  167. mh = mm.metrics.create()
  168. summary = mh.compute_many(
  169. accs, metrics=metrics, names=names, generate_overall=True)
  170. return summary
  171. @staticmethod
  172. def save_summary(summary, filename):
  173. import pandas as pd
  174. writer = pd.ExcelWriter(filename)
  175. summary.to_excel(writer)
  176. writer.save()
  177. class MOTMetric(Metric):
  178. def __init__(self, save_summary=False):
  179. self.save_summary = save_summary
  180. self.MOTEvaluator = MOTEvaluator
  181. self.result_root = None
  182. self.reset()
  183. def reset(self):
  184. self.accs = []
  185. self.seqs = []
  186. def update(self, data_root, seq, data_type, result_root, result_filename):
  187. evaluator = self.MOTEvaluator(data_root, seq, data_type)
  188. self.accs.append(evaluator.eval_file(result_filename))
  189. self.seqs.append(seq)
  190. self.result_root = result_root
  191. def accumulate(self):
  192. import motmetrics as mm
  193. import openpyxl
  194. metrics = mm.metrics.motchallenge_metrics
  195. mh = mm.metrics.create()
  196. summary = self.MOTEvaluator.get_summary(self.accs, self.seqs, metrics)
  197. self.strsummary = mm.io.render_summary(
  198. summary,
  199. formatters=mh.formatters,
  200. namemap=mm.io.motchallenge_metric_names)
  201. if self.save_summary:
  202. self.MOTEvaluator.save_summary(
  203. summary, os.path.join(self.result_root, 'summary.xlsx'))
  204. def log(self):
  205. print(self.strsummary)
  206. def get_results(self):
  207. return self.strsummary
  208. class JDEDetMetric(Metric):
  209. # Note this detection AP metric is different from COCOMetric or VOCMetric,
  210. # and the bboxes coordinates are not scaled to the original image
  211. def __init__(self, overlap_thresh=0.5):
  212. self.overlap_thresh = overlap_thresh
  213. self.reset()
  214. def reset(self):
  215. self.AP_accum = np.zeros(1)
  216. self.AP_accum_count = np.zeros(1)
  217. def update(self, inputs, outputs):
  218. bboxes = outputs['bbox'][:, 2:].numpy()
  219. scores = outputs['bbox'][:, 1].numpy()
  220. labels = outputs['bbox'][:, 0].numpy()
  221. bbox_lengths = outputs['bbox_num'].numpy()
  222. if bboxes.shape[0] == 1 and bboxes.sum() == 0.0:
  223. return
  224. gt_boxes = inputs['gt_bbox'].numpy()[0]
  225. gt_labels = inputs['gt_class'].numpy()[0]
  226. if gt_labels.shape[0] == 0:
  227. return
  228. correct = []
  229. detected = []
  230. for i in range(bboxes.shape[0]):
  231. obj_pred = 0
  232. pred_bbox = bboxes[i].reshape(1, 4)
  233. # Compute iou with target boxes
  234. iou = bbox_iou_np_expand(pred_bbox, gt_boxes, x1y1x2y2=True)[0]
  235. # Extract index of largest overlap
  236. best_i = np.argmax(iou)
  237. # If overlap exceeds threshold and classification is correct mark as correct
  238. if iou[best_i] > self.overlap_thresh and obj_pred == gt_labels[
  239. best_i] and best_i not in detected:
  240. correct.append(1)
  241. detected.append(best_i)
  242. else:
  243. correct.append(0)
  244. # Compute Average Precision (AP) per class
  245. target_cls = list(gt_labels.T[0])
  246. AP, AP_class, R, P = ap_per_class(
  247. tp=correct,
  248. conf=scores,
  249. pred_cls=np.zeros_like(scores),
  250. target_cls=target_cls)
  251. self.AP_accum_count += np.bincount(AP_class, minlength=1)
  252. self.AP_accum += np.bincount(AP_class, minlength=1, weights=AP)
  253. def accumulate(self):
  254. logger.info("Accumulating evaluatation results...")
  255. self.map_stat = self.AP_accum[0] / (self.AP_accum_count[0] + 1E-16)
  256. def log(self):
  257. map_stat = 100. * self.map_stat
  258. logger.info("mAP({:.2f}) = {:.2f}%".format(self.overlap_thresh,
  259. map_stat))
  260. def get_results(self):
  261. return self.map_stat
  262. """
  263. Following code is borrow from https://github.com/xingyizhou/CenterTrack/blob/master/src/tools/eval_kitti_track/evaluate_tracking.py
  264. """
  265. class tData:
  266. """
  267. Utility class to load data.
  268. """
  269. def __init__(self,frame=-1,obj_type="unset",truncation=-1,occlusion=-1,\
  270. obs_angle=-10,x1=-1,y1=-1,x2=-1,y2=-1,w=-1,h=-1,l=-1,\
  271. X=-1000,Y=-1000,Z=-1000,yaw=-10,score=-1000,track_id=-1):
  272. """
  273. Constructor, initializes the object given the parameters.
  274. """
  275. self.frame = frame
  276. self.track_id = track_id
  277. self.obj_type = obj_type
  278. self.truncation = truncation
  279. self.occlusion = occlusion
  280. self.obs_angle = obs_angle
  281. self.x1 = x1
  282. self.y1 = y1
  283. self.x2 = x2
  284. self.y2 = y2
  285. self.w = w
  286. self.h = h
  287. self.l = l
  288. self.X = X
  289. self.Y = Y
  290. self.Z = Z
  291. self.yaw = yaw
  292. self.score = score
  293. self.ignored = False
  294. self.valid = False
  295. self.tracker = -1
  296. def __str__(self):
  297. attrs = vars(self)
  298. return '\n'.join("%s: %s" % item for item in attrs.items())
  299. class KITTIEvaluation(object):
  300. """ KITTI tracking statistics (CLEAR MOT, id-switches, fragments, ML/PT/MT, precision/recall)
  301. MOTA - Multi-object tracking accuracy in [0,100]
  302. MOTP - Multi-object tracking precision in [0,100] (3D) / [td,100] (2D)
  303. MOTAL - Multi-object tracking accuracy in [0,100] with log10(id-switches)
  304. id-switches - number of id switches
  305. fragments - number of fragmentations
  306. MT, PT, ML - number of mostly tracked, partially tracked and mostly lost trajectories
  307. recall - recall = percentage of detected targets
  308. precision - precision = percentage of correctly detected targets
  309. FAR - number of false alarms per frame
  310. falsepositives - number of false positives (FP)
  311. missed - number of missed targets (FN)
  312. """
  313. def __init__(self, result_path, gt_path, min_overlap=0.5, max_truncation = 0,\
  314. min_height = 25, max_occlusion = 2, cls="car",\
  315. n_frames=[], seqs=[], n_sequences=0):
  316. # get number of sequences and
  317. # get number of frames per sequence from test mapping
  318. # (created while extracting the benchmark)
  319. self.gt_path = os.path.join(gt_path, "label_02")
  320. self.n_frames = n_frames
  321. self.sequence_name = seqs
  322. self.n_sequences = n_sequences
  323. self.cls = cls # class to evaluate, i.e. pedestrian or car
  324. self.result_path = result_path
  325. # statistics and numbers for evaluation
  326. self.n_gt = 0 # number of ground truth detections minus ignored false negatives and true positives
  327. self.n_igt = 0 # number of ignored ground truth detections
  328. self.n_gts = [
  329. ] # number of ground truth detections minus ignored false negatives and true positives PER SEQUENCE
  330. self.n_igts = [
  331. ] # number of ground ignored truth detections PER SEQUENCE
  332. self.n_gt_trajectories = 0
  333. self.n_gt_seq = []
  334. self.n_tr = 0 # number of tracker detections minus ignored tracker detections
  335. self.n_trs = [
  336. ] # number of tracker detections minus ignored tracker detections PER SEQUENCE
  337. self.n_itr = 0 # number of ignored tracker detections
  338. self.n_itrs = [] # number of ignored tracker detections PER SEQUENCE
  339. self.n_igttr = 0 # number of ignored ground truth detections where the corresponding associated tracker detection is also ignored
  340. self.n_tr_trajectories = 0
  341. self.n_tr_seq = []
  342. self.MOTA = 0
  343. self.MOTP = 0
  344. self.MOTAL = 0
  345. self.MODA = 0
  346. self.MODP = 0
  347. self.MODP_t = []
  348. self.recall = 0
  349. self.precision = 0
  350. self.F1 = 0
  351. self.FAR = 0
  352. self.total_cost = 0
  353. self.itp = 0 # number of ignored true positives
  354. self.itps = [] # number of ignored true positives PER SEQUENCE
  355. self.tp = 0 # number of true positives including ignored true positives!
  356. self.tps = [
  357. ] # number of true positives including ignored true positives PER SEQUENCE
  358. self.fn = 0 # number of false negatives WITHOUT ignored false negatives
  359. self.fns = [
  360. ] # number of false negatives WITHOUT ignored false negatives PER SEQUENCE
  361. self.ifn = 0 # number of ignored false negatives
  362. self.ifns = [] # number of ignored false negatives PER SEQUENCE
  363. self.fp = 0 # number of false positives
  364. # a bit tricky, the number of ignored false negatives and ignored true positives
  365. # is subtracted, but if both tracker detection and ground truth detection
  366. # are ignored this number is added again to avoid double counting
  367. self.fps = [] # above PER SEQUENCE
  368. self.mme = 0
  369. self.fragments = 0
  370. self.id_switches = 0
  371. self.MT = 0
  372. self.PT = 0
  373. self.ML = 0
  374. self.min_overlap = min_overlap # minimum bounding box overlap for 3rd party metrics
  375. self.max_truncation = max_truncation # maximum truncation of an object for evaluation
  376. self.max_occlusion = max_occlusion # maximum occlusion of an object for evaluation
  377. self.min_height = min_height # minimum height of an object for evaluation
  378. self.n_sample_points = 500
  379. # this should be enough to hold all groundtruth trajectories
  380. # is expanded if necessary and reduced in any case
  381. self.gt_trajectories = [[] for x in range(self.n_sequences)]
  382. self.ign_trajectories = [[] for x in range(self.n_sequences)]
  383. def loadGroundtruth(self):
  384. try:
  385. self._loadData(
  386. self.gt_path, cls=self.cls, loading_groundtruth=True)
  387. except IOError:
  388. return False
  389. return True
  390. def loadTracker(self):
  391. try:
  392. if not self._loadData(
  393. self.result_path, cls=self.cls, loading_groundtruth=False):
  394. return False
  395. except IOError:
  396. return False
  397. return True
  398. def _loadData(self,
  399. root_dir,
  400. cls,
  401. min_score=-1000,
  402. loading_groundtruth=False):
  403. """
  404. Generic loader for ground truth and tracking data.
  405. Use loadGroundtruth() or loadTracker() to load this data.
  406. Loads detections in KITTI format from textfiles.
  407. """
  408. # construct objectDetections object to hold detection data
  409. t_data = tData()
  410. data = []
  411. eval_2d = True
  412. eval_3d = True
  413. seq_data = []
  414. n_trajectories = 0
  415. n_trajectories_seq = []
  416. for seq, s_name in enumerate(self.sequence_name):
  417. i = 0
  418. filename = os.path.join(root_dir, "%s.txt" % s_name)
  419. f = open(filename, "r")
  420. f_data = [
  421. [] for x in range(self.n_frames[seq])
  422. ] # current set has only 1059 entries, sufficient length is checked anyway
  423. ids = []
  424. n_in_seq = 0
  425. id_frame_cache = []
  426. for line in f:
  427. # KITTI tracking benchmark data format:
  428. # (frame,tracklet_id,objectType,truncation,occlusion,alpha,x1,y1,x2,y2,h,w,l,X,Y,Z,ry)
  429. line = line.strip()
  430. fields = line.split(" ")
  431. # classes that should be loaded (ignored neighboring classes)
  432. if "car" in cls.lower():
  433. classes = ["car", "van"]
  434. elif "pedestrian" in cls.lower():
  435. classes = ["pedestrian", "person_sitting"]
  436. else:
  437. classes = [cls.lower()]
  438. classes += ["dontcare"]
  439. if not any([s for s in classes if s in fields[2].lower()]):
  440. continue
  441. # get fields from table
  442. t_data.frame = int(float(fields[0])) # frame
  443. t_data.track_id = int(float(fields[1])) # id
  444. t_data.obj_type = fields[
  445. 2].lower() # object type [car, pedestrian, cyclist, ...]
  446. t_data.truncation = int(
  447. float(fields[3])) # truncation [-1,0,1,2]
  448. t_data.occlusion = int(
  449. float(fields[4])) # occlusion [-1,0,1,2]
  450. t_data.obs_angle = float(fields[5]) # observation angle [rad]
  451. t_data.x1 = float(fields[6]) # left [px]
  452. t_data.y1 = float(fields[7]) # top [px]
  453. t_data.x2 = float(fields[8]) # right [px]
  454. t_data.y2 = float(fields[9]) # bottom [px]
  455. t_data.h = float(fields[10]) # height [m]
  456. t_data.w = float(fields[11]) # width [m]
  457. t_data.l = float(fields[12]) # length [m]
  458. t_data.X = float(fields[13]) # X [m]
  459. t_data.Y = float(fields[14]) # Y [m]
  460. t_data.Z = float(fields[15]) # Z [m]
  461. t_data.yaw = float(fields[16]) # yaw angle [rad]
  462. if not loading_groundtruth:
  463. if len(fields) == 17:
  464. t_data.score = -1
  465. elif len(fields) == 18:
  466. t_data.score = float(fields[17]) # detection score
  467. else:
  468. logger.info("file is not in KITTI format")
  469. return
  470. # do not consider objects marked as invalid
  471. if t_data.track_id is -1 and t_data.obj_type != "dontcare":
  472. continue
  473. idx = t_data.frame
  474. # check if length for frame data is sufficient
  475. if idx >= len(f_data):
  476. print("extend f_data", idx, len(f_data))
  477. f_data += [[] for x in range(max(500, idx - len(f_data)))]
  478. try:
  479. id_frame = (t_data.frame, t_data.track_id)
  480. if id_frame in id_frame_cache and not loading_groundtruth:
  481. logger.info(
  482. "track ids are not unique for sequence %d: frame %d"
  483. % (seq, t_data.frame))
  484. logger.info(
  485. "track id %d occured at least twice for this frame"
  486. % t_data.track_id)
  487. logger.info("Exiting...")
  488. #continue # this allows to evaluate non-unique result files
  489. return False
  490. id_frame_cache.append(id_frame)
  491. f_data[t_data.frame].append(copy.copy(t_data))
  492. except:
  493. print(len(f_data), idx)
  494. raise
  495. if t_data.track_id not in ids and t_data.obj_type != "dontcare":
  496. ids.append(t_data.track_id)
  497. n_trajectories += 1
  498. n_in_seq += 1
  499. # check if uploaded data provides information for 2D and 3D evaluation
  500. if not loading_groundtruth and eval_2d is True and (
  501. t_data.x1 == -1 or t_data.x2 == -1 or
  502. t_data.y1 == -1 or t_data.y2 == -1):
  503. eval_2d = False
  504. if not loading_groundtruth and eval_3d is True and (
  505. t_data.X == -1000 or t_data.Y == -1000 or
  506. t_data.Z == -1000):
  507. eval_3d = False
  508. # only add existing frames
  509. n_trajectories_seq.append(n_in_seq)
  510. seq_data.append(f_data)
  511. f.close()
  512. if not loading_groundtruth:
  513. self.tracker = seq_data
  514. self.n_tr_trajectories = n_trajectories
  515. self.eval_2d = eval_2d
  516. self.eval_3d = eval_3d
  517. self.n_tr_seq = n_trajectories_seq
  518. if self.n_tr_trajectories == 0:
  519. return False
  520. else:
  521. # split ground truth and DontCare areas
  522. self.dcareas = []
  523. self.groundtruth = []
  524. for seq_idx in range(len(seq_data)):
  525. seq_gt = seq_data[seq_idx]
  526. s_g, s_dc = [], []
  527. for f in range(len(seq_gt)):
  528. all_gt = seq_gt[f]
  529. g, dc = [], []
  530. for gg in all_gt:
  531. if gg.obj_type == "dontcare":
  532. dc.append(gg)
  533. else:
  534. g.append(gg)
  535. s_g.append(g)
  536. s_dc.append(dc)
  537. self.dcareas.append(s_dc)
  538. self.groundtruth.append(s_g)
  539. self.n_gt_seq = n_trajectories_seq
  540. self.n_gt_trajectories = n_trajectories
  541. return True
  542. def boxoverlap(self, a, b, criterion="union"):
  543. """
  544. boxoverlap computes intersection over union for bbox a and b in KITTI format.
  545. If the criterion is 'union', overlap = (a inter b) / a union b).
  546. If the criterion is 'a', overlap = (a inter b) / a, where b should be a dontcare area.
  547. """
  548. x1 = max(a.x1, b.x1)
  549. y1 = max(a.y1, b.y1)
  550. x2 = min(a.x2, b.x2)
  551. y2 = min(a.y2, b.y2)
  552. w = x2 - x1
  553. h = y2 - y1
  554. if w <= 0. or h <= 0.:
  555. return 0.
  556. inter = w * h
  557. aarea = (a.x2 - a.x1) * (a.y2 - a.y1)
  558. barea = (b.x2 - b.x1) * (b.y2 - b.y1)
  559. # intersection over union overlap
  560. if criterion.lower() == "union":
  561. o = inter / float(aarea + barea - inter)
  562. elif criterion.lower() == "a":
  563. o = float(inter) / float(aarea)
  564. else:
  565. raise TypeError("Unkown type for criterion")
  566. return o
  567. def compute3rdPartyMetrics(self):
  568. """
  569. Computes the metrics defined in
  570. - Stiefelhagen 2008: Evaluating Multiple Object Tracking Performance: The CLEAR MOT Metrics
  571. MOTA, MOTAL, MOTP
  572. - Nevatia 2008: Global Data Association for Multi-Object Tracking Using Network Flows
  573. MT/PT/ML
  574. """
  575. # construct Munkres object for Hungarian Method association
  576. hm = Munkres()
  577. max_cost = 1e9
  578. # go through all frames and associate ground truth and tracker results
  579. # groundtruth and tracker contain lists for every single frame containing lists of KITTI format detections
  580. fr, ids = 0, 0
  581. for seq_idx in range(len(self.groundtruth)):
  582. seq_gt = self.groundtruth[seq_idx]
  583. seq_dc = self.dcareas[seq_idx] # don't care areas
  584. seq_tracker = self.tracker[seq_idx]
  585. seq_trajectories = defaultdict(list)
  586. seq_ignored = defaultdict(list)
  587. # statistics over the current sequence, check the corresponding
  588. # variable comments in __init__ to get their meaning
  589. seqtp = 0
  590. seqitp = 0
  591. seqfn = 0
  592. seqifn = 0
  593. seqfp = 0
  594. seqigt = 0
  595. seqitr = 0
  596. last_ids = [[], []]
  597. n_gts = 0
  598. n_trs = 0
  599. for f in range(len(seq_gt)):
  600. g = seq_gt[f]
  601. dc = seq_dc[f]
  602. t = seq_tracker[f]
  603. # counting total number of ground truth and tracker objects
  604. self.n_gt += len(g)
  605. self.n_tr += len(t)
  606. n_gts += len(g)
  607. n_trs += len(t)
  608. # use hungarian method to associate, using boxoverlap 0..1 as cost
  609. # build cost matrix
  610. cost_matrix = []
  611. this_ids = [[], []]
  612. for gg in g:
  613. # save current ids
  614. this_ids[0].append(gg.track_id)
  615. this_ids[1].append(-1)
  616. gg.tracker = -1
  617. gg.id_switch = 0
  618. gg.fragmentation = 0
  619. cost_row = []
  620. for tt in t:
  621. # overlap == 1 is cost ==0
  622. c = 1 - self.boxoverlap(gg, tt)
  623. # gating for boxoverlap
  624. if c <= self.min_overlap:
  625. cost_row.append(c)
  626. else:
  627. cost_row.append(max_cost) # = 1e9
  628. cost_matrix.append(cost_row)
  629. # all ground truth trajectories are initially not associated
  630. # extend groundtruth trajectories lists (merge lists)
  631. seq_trajectories[gg.track_id].append(-1)
  632. seq_ignored[gg.track_id].append(False)
  633. if len(g) is 0:
  634. cost_matrix = [[]]
  635. # associate
  636. association_matrix = hm.compute(cost_matrix)
  637. # tmp variables for sanity checks and MODP computation
  638. tmptp = 0
  639. tmpfp = 0
  640. tmpfn = 0
  641. tmpc = 0 # this will sum up the overlaps for all true positives
  642. tmpcs = [0] * len(
  643. g) # this will save the overlaps for all true positives
  644. # the reason is that some true positives might be ignored
  645. # later such that the corrsponding overlaps can
  646. # be subtracted from tmpc for MODP computation
  647. # mapping for tracker ids and ground truth ids
  648. for row, col in association_matrix:
  649. # apply gating on boxoverlap
  650. c = cost_matrix[row][col]
  651. if c < max_cost:
  652. g[row].tracker = t[col].track_id
  653. this_ids[1][row] = t[col].track_id
  654. t[col].valid = True
  655. g[row].distance = c
  656. self.total_cost += 1 - c
  657. tmpc += 1 - c
  658. tmpcs[row] = 1 - c
  659. seq_trajectories[g[row].track_id][-1] = t[col].track_id
  660. # true positives are only valid associations
  661. self.tp += 1
  662. tmptp += 1
  663. else:
  664. g[row].tracker = -1
  665. self.fn += 1
  666. tmpfn += 1
  667. # associate tracker and DontCare areas
  668. # ignore tracker in neighboring classes
  669. nignoredtracker = 0 # number of ignored tracker detections
  670. ignoredtrackers = dict() # will associate the track_id with -1
  671. # if it is not ignored and 1 if it is
  672. # ignored;
  673. # this is used to avoid double counting ignored
  674. # cases, see the next loop
  675. for tt in t:
  676. ignoredtrackers[tt.track_id] = -1
  677. # ignore detection if it belongs to a neighboring class or is
  678. # smaller or equal to the minimum height
  679. tt_height = abs(tt.y1 - tt.y2)
  680. if ((self.cls == "car" and tt.obj_type == "van") or
  681. (self.cls == "pedestrian" and
  682. tt.obj_type == "person_sitting") or
  683. tt_height <= self.min_height) and not tt.valid:
  684. nignoredtracker += 1
  685. tt.ignored = True
  686. ignoredtrackers[tt.track_id] = 1
  687. continue
  688. for d in dc:
  689. overlap = self.boxoverlap(tt, d, "a")
  690. if overlap > 0.5 and not tt.valid:
  691. tt.ignored = True
  692. nignoredtracker += 1
  693. ignoredtrackers[tt.track_id] = 1
  694. break
  695. # check for ignored FN/TP (truncation or neighboring object class)
  696. ignoredfn = 0 # the number of ignored false negatives
  697. nignoredtp = 0 # the number of ignored true positives
  698. nignoredpairs = 0 # the number of ignored pairs, i.e. a true positive
  699. # which is ignored but where the associated tracker
  700. # detection has already been ignored
  701. gi = 0
  702. for gg in g:
  703. if gg.tracker < 0:
  704. if gg.occlusion>self.max_occlusion or gg.truncation>self.max_truncation\
  705. or (self.cls=="car" and gg.obj_type=="van") or (self.cls=="pedestrian" and gg.obj_type=="person_sitting"):
  706. seq_ignored[gg.track_id][-1] = True
  707. gg.ignored = True
  708. ignoredfn += 1
  709. elif gg.tracker >= 0:
  710. if gg.occlusion>self.max_occlusion or gg.truncation>self.max_truncation\
  711. or (self.cls=="car" and gg.obj_type=="van") or (self.cls=="pedestrian" and gg.obj_type=="person_sitting"):
  712. seq_ignored[gg.track_id][-1] = True
  713. gg.ignored = True
  714. nignoredtp += 1
  715. # if the associated tracker detection is already ignored,
  716. # we want to avoid double counting ignored detections
  717. if ignoredtrackers[gg.tracker] > 0:
  718. nignoredpairs += 1
  719. # for computing MODP, the overlaps from ignored detections
  720. # are subtracted
  721. tmpc -= tmpcs[gi]
  722. gi += 1
  723. # the below might be confusion, check the comments in __init__
  724. # to see what the individual statistics represent
  725. # correct TP by number of ignored TP due to truncation
  726. # ignored TP are shown as tracked in visualization
  727. tmptp -= nignoredtp
  728. # count the number of ignored true positives
  729. self.itp += nignoredtp
  730. # adjust the number of ground truth objects considered
  731. self.n_gt -= (ignoredfn + nignoredtp)
  732. # count the number of ignored ground truth objects
  733. self.n_igt += ignoredfn + nignoredtp
  734. # count the number of ignored tracker objects
  735. self.n_itr += nignoredtracker
  736. # count the number of ignored pairs, i.e. associated tracker and
  737. # ground truth objects that are both ignored
  738. self.n_igttr += nignoredpairs
  739. # false negatives = associated gt bboxes exceding association threshold + non-associated gt bboxes
  740. tmpfn += len(g) - len(association_matrix) - ignoredfn
  741. self.fn += len(g) - len(association_matrix) - ignoredfn
  742. self.ifn += ignoredfn
  743. # false positives = tracker bboxes - associated tracker bboxes
  744. # mismatches (mme_t)
  745. tmpfp += len(
  746. t) - tmptp - nignoredtracker - nignoredtp + nignoredpairs
  747. self.fp += len(
  748. t) - tmptp - nignoredtracker - nignoredtp + nignoredpairs
  749. # update sequence data
  750. seqtp += tmptp
  751. seqitp += nignoredtp
  752. seqfp += tmpfp
  753. seqfn += tmpfn
  754. seqifn += ignoredfn
  755. seqigt += ignoredfn + nignoredtp
  756. seqitr += nignoredtracker
  757. # sanity checks
  758. # - the number of true positives minues ignored true positives
  759. # should be greater or equal to 0
  760. # - the number of false negatives should be greater or equal to 0
  761. # - the number of false positives needs to be greater or equal to 0
  762. # otherwise ignored detections might be counted double
  763. # - the number of counted true positives (plus ignored ones)
  764. # and the number of counted false negatives (plus ignored ones)
  765. # should match the total number of ground truth objects
  766. # - the number of counted true positives (plus ignored ones)
  767. # and the number of counted false positives
  768. # plus the number of ignored tracker detections should
  769. # match the total number of tracker detections; note that
  770. # nignoredpairs is subtracted here to avoid double counting
  771. # of ignored detection sin nignoredtp and nignoredtracker
  772. if tmptp < 0:
  773. print(tmptp, nignoredtp)
  774. raise NameError("Something went wrong! TP is negative")
  775. if tmpfn < 0:
  776. print(tmpfn,
  777. len(g),
  778. len(association_matrix), ignoredfn, nignoredpairs)
  779. raise NameError("Something went wrong! FN is negative")
  780. if tmpfp < 0:
  781. print(tmpfp,
  782. len(t), tmptp, nignoredtracker, nignoredtp,
  783. nignoredpairs)
  784. raise NameError("Something went wrong! FP is negative")
  785. if tmptp + tmpfn is not len(g) - ignoredfn - nignoredtp:
  786. print("seqidx", seq_idx)
  787. print("frame ", f)
  788. print("TP ", tmptp)
  789. print("FN ", tmpfn)
  790. print("FP ", tmpfp)
  791. print("nGT ", len(g))
  792. print("nAss ", len(association_matrix))
  793. print("ign GT", ignoredfn)
  794. print("ign TP", nignoredtp)
  795. raise NameError(
  796. "Something went wrong! nGroundtruth is not TP+FN")
  797. if tmptp + tmpfp + nignoredtp + nignoredtracker - nignoredpairs is not len(
  798. t):
  799. print(seq_idx, f, len(t), tmptp, tmpfp)
  800. print(len(association_matrix), association_matrix)
  801. raise NameError(
  802. "Something went wrong! nTracker is not TP+FP")
  803. # check for id switches or fragmentations
  804. for i, tt in enumerate(this_ids[0]):
  805. if tt in last_ids[0]:
  806. idx = last_ids[0].index(tt)
  807. tid = this_ids[1][i]
  808. lid = last_ids[1][idx]
  809. if tid != lid and lid != -1 and tid != -1:
  810. if g[i].truncation < self.max_truncation:
  811. g[i].id_switch = 1
  812. ids += 1
  813. if tid != lid and lid != -1:
  814. if g[i].truncation < self.max_truncation:
  815. g[i].fragmentation = 1
  816. fr += 1
  817. # save current index
  818. last_ids = this_ids
  819. # compute MOTP_t
  820. MODP_t = 1
  821. if tmptp != 0:
  822. MODP_t = tmpc / float(tmptp)
  823. self.MODP_t.append(MODP_t)
  824. # remove empty lists for current gt trajectories
  825. self.gt_trajectories[seq_idx] = seq_trajectories
  826. self.ign_trajectories[seq_idx] = seq_ignored
  827. # gather statistics for "per sequence" statistics.
  828. self.n_gts.append(n_gts)
  829. self.n_trs.append(n_trs)
  830. self.tps.append(seqtp)
  831. self.itps.append(seqitp)
  832. self.fps.append(seqfp)
  833. self.fns.append(seqfn)
  834. self.ifns.append(seqifn)
  835. self.n_igts.append(seqigt)
  836. self.n_itrs.append(seqitr)
  837. # compute MT/PT/ML, fragments, idswitches for all groundtruth trajectories
  838. n_ignored_tr_total = 0
  839. for seq_idx, (
  840. seq_trajectories, seq_ignored
  841. ) in enumerate(zip(self.gt_trajectories, self.ign_trajectories)):
  842. if len(seq_trajectories) == 0:
  843. continue
  844. tmpMT, tmpML, tmpPT, tmpId_switches, tmpFragments = [0] * 5
  845. n_ignored_tr = 0
  846. for g, ign_g in zip(seq_trajectories.values(),
  847. seq_ignored.values()):
  848. # all frames of this gt trajectory are ignored
  849. if all(ign_g):
  850. n_ignored_tr += 1
  851. n_ignored_tr_total += 1
  852. continue
  853. # all frames of this gt trajectory are not assigned to any detections
  854. if all([this == -1 for this in g]):
  855. tmpML += 1
  856. self.ML += 1
  857. continue
  858. # compute tracked frames in trajectory
  859. last_id = g[0]
  860. # first detection (necessary to be in gt_trajectories) is always tracked
  861. tracked = 1 if g[0] >= 0 else 0
  862. lgt = 0 if ign_g[0] else 1
  863. for f in range(1, len(g)):
  864. if ign_g[f]:
  865. last_id = -1
  866. continue
  867. lgt += 1
  868. if last_id != g[f] and last_id != -1 and g[f] != -1 and g[
  869. f - 1] != -1:
  870. tmpId_switches += 1
  871. self.id_switches += 1
  872. if f < len(g) - 1 and g[f - 1] != g[
  873. f] and last_id != -1 and g[f] != -1 and g[f +
  874. 1] != -1:
  875. tmpFragments += 1
  876. self.fragments += 1
  877. if g[f] != -1:
  878. tracked += 1
  879. last_id = g[f]
  880. # handle last frame; tracked state is handled in for loop (g[f]!=-1)
  881. if len(g) > 1 and g[f - 1] != g[f] and last_id != -1 and g[
  882. f] != -1 and not ign_g[f]:
  883. tmpFragments += 1
  884. self.fragments += 1
  885. # compute MT/PT/ML
  886. tracking_ratio = tracked / float(len(g) - sum(ign_g))
  887. if tracking_ratio > 0.8:
  888. tmpMT += 1
  889. self.MT += 1
  890. elif tracking_ratio < 0.2:
  891. tmpML += 1
  892. self.ML += 1
  893. else: # 0.2 <= tracking_ratio <= 0.8
  894. tmpPT += 1
  895. self.PT += 1
  896. if (self.n_gt_trajectories - n_ignored_tr_total) == 0:
  897. self.MT = 0.
  898. self.PT = 0.
  899. self.ML = 0.
  900. else:
  901. self.MT /= float(self.n_gt_trajectories - n_ignored_tr_total)
  902. self.PT /= float(self.n_gt_trajectories - n_ignored_tr_total)
  903. self.ML /= float(self.n_gt_trajectories - n_ignored_tr_total)
  904. # precision/recall etc.
  905. if (self.fp + self.tp) == 0 or (self.tp + self.fn) == 0:
  906. self.recall = 0.
  907. self.precision = 0.
  908. else:
  909. self.recall = self.tp / float(self.tp + self.fn)
  910. self.precision = self.tp / float(self.fp + self.tp)
  911. if (self.recall + self.precision) == 0:
  912. self.F1 = 0.
  913. else:
  914. self.F1 = 2. * (self.precision * self.recall) / (
  915. self.precision + self.recall)
  916. if sum(self.n_frames) == 0:
  917. self.FAR = "n/a"
  918. else:
  919. self.FAR = self.fp / float(sum(self.n_frames))
  920. # compute CLEARMOT
  921. if self.n_gt == 0:
  922. self.MOTA = -float("inf")
  923. self.MODA = -float("inf")
  924. else:
  925. self.MOTA = 1 - (self.fn + self.fp + self.id_switches
  926. ) / float(self.n_gt)
  927. self.MODA = 1 - (self.fn + self.fp) / float(self.n_gt)
  928. if self.tp == 0:
  929. self.MOTP = float("inf")
  930. else:
  931. self.MOTP = self.total_cost / float(self.tp)
  932. if self.n_gt != 0:
  933. if self.id_switches == 0:
  934. self.MOTAL = 1 - (self.fn + self.fp + self.id_switches
  935. ) / float(self.n_gt)
  936. else:
  937. self.MOTAL = 1 - (self.fn + self.fp +
  938. math.log10(self.id_switches)
  939. ) / float(self.n_gt)
  940. else:
  941. self.MOTAL = -float("inf")
  942. if sum(self.n_frames) == 0:
  943. self.MODP = "n/a"
  944. else:
  945. self.MODP = sum(self.MODP_t) / float(sum(self.n_frames))
  946. return True
  947. def createSummary(self):
  948. summary = ""
  949. summary += "tracking evaluation summary".center(80, "=") + "\n"
  950. summary += self.printEntry("Multiple Object Tracking Accuracy (MOTA)",
  951. self.MOTA) + "\n"
  952. summary += self.printEntry("Multiple Object Tracking Precision (MOTP)",
  953. self.MOTP) + "\n"
  954. summary += self.printEntry("Multiple Object Tracking Accuracy (MOTAL)",
  955. self.MOTAL) + "\n"
  956. summary += self.printEntry("Multiple Object Detection Accuracy (MODA)",
  957. self.MODA) + "\n"
  958. summary += self.printEntry(
  959. "Multiple Object Detection Precision (MODP)", self.MODP) + "\n"
  960. summary += "\n"
  961. summary += self.printEntry("Recall", self.recall) + "\n"
  962. summary += self.printEntry("Precision", self.precision) + "\n"
  963. summary += self.printEntry("F1", self.F1) + "\n"
  964. summary += self.printEntry("False Alarm Rate", self.FAR) + "\n"
  965. summary += "\n"
  966. summary += self.printEntry("Mostly Tracked", self.MT) + "\n"
  967. summary += self.printEntry("Partly Tracked", self.PT) + "\n"
  968. summary += self.printEntry("Mostly Lost", self.ML) + "\n"
  969. summary += "\n"
  970. summary += self.printEntry("True Positives", self.tp) + "\n"
  971. #summary += self.printEntry("True Positives per Sequence", self.tps) + "\n"
  972. summary += self.printEntry("Ignored True Positives", self.itp) + "\n"
  973. #summary += self.printEntry("Ignored True Positives per Sequence", self.itps) + "\n"
  974. summary += self.printEntry("False Positives", self.fp) + "\n"
  975. #summary += self.printEntry("False Positives per Sequence", self.fps) + "\n"
  976. summary += self.printEntry("False Negatives", self.fn) + "\n"
  977. #summary += self.printEntry("False Negatives per Sequence", self.fns) + "\n"
  978. summary += self.printEntry("ID-switches", self.id_switches) + "\n"
  979. self.fp = self.fp / self.n_gt
  980. self.fn = self.fn / self.n_gt
  981. self.id_switches = self.id_switches / self.n_gt
  982. summary += self.printEntry("False Positives Ratio", self.fp) + "\n"
  983. #summary += self.printEntry("False Positives per Sequence", self.fps) + "\n"
  984. summary += self.printEntry("False Negatives Ratio", self.fn) + "\n"
  985. #summary += self.printEntry("False Negatives per Sequence", self.fns) + "\n"
  986. summary += self.printEntry("Ignored False Negatives Ratio",
  987. self.ifn) + "\n"
  988. #summary += self.printEntry("Ignored False Negatives per Sequence", self.ifns) + "\n"
  989. summary += self.printEntry("Missed Targets", self.fn) + "\n"
  990. summary += self.printEntry("ID-switches", self.id_switches) + "\n"
  991. summary += self.printEntry("Fragmentations", self.fragments) + "\n"
  992. summary += "\n"
  993. summary += self.printEntry("Ground Truth Objects (Total)", self.n_gt +
  994. self.n_igt) + "\n"
  995. #summary += self.printEntry("Ground Truth Objects (Total) per Sequence", self.n_gts) + "\n"
  996. summary += self.printEntry("Ignored Ground Truth Objects",
  997. self.n_igt) + "\n"
  998. #summary += self.printEntry("Ignored Ground Truth Objects per Sequence", self.n_igts) + "\n"
  999. summary += self.printEntry("Ground Truth Trajectories",
  1000. self.n_gt_trajectories) + "\n"
  1001. summary += "\n"
  1002. summary += self.printEntry("Tracker Objects (Total)", self.n_tr) + "\n"
  1003. #summary += self.printEntry("Tracker Objects (Total) per Sequence", self.n_trs) + "\n"
  1004. summary += self.printEntry("Ignored Tracker Objects",
  1005. self.n_itr) + "\n"
  1006. #summary += self.printEntry("Ignored Tracker Objects per Sequence", self.n_itrs) + "\n"
  1007. summary += self.printEntry("Tracker Trajectories",
  1008. self.n_tr_trajectories) + "\n"
  1009. #summary += "\n"
  1010. #summary += self.printEntry("Ignored Tracker Objects with Associated Ignored Ground Truth Objects", self.n_igttr) + "\n"
  1011. summary += "=" * 80
  1012. return summary
  1013. def printEntry(self, key, val, width=(70, 10)):
  1014. """
  1015. Pretty print an entry in a table fashion.
  1016. """
  1017. s_out = key.ljust(width[0])
  1018. if type(val) == int:
  1019. s = "%%%dd" % width[1]
  1020. s_out += s % val
  1021. elif type(val) == float:
  1022. s = "%%%df" % (width[1])
  1023. s_out += s % val
  1024. else:
  1025. s_out += ("%s" % val).rjust(width[1])
  1026. return s_out
  1027. def saveToStats(self, save_summary):
  1028. """
  1029. Save the statistics in a whitespace separate file.
  1030. """
  1031. summary = self.createSummary()
  1032. if save_summary:
  1033. filename = os.path.join(self.result_path,
  1034. "summary_%s.txt" % self.cls)
  1035. dump = open(filename, "w+")
  1036. dump.write(summary)
  1037. dump.close()
  1038. return summary
  1039. class KITTIMOTMetric(Metric):
  1040. def __init__(self, save_summary=True):
  1041. self.save_summary = save_summary
  1042. self.MOTEvaluator = KITTIEvaluation
  1043. self.result_root = None
  1044. self.reset()
  1045. def reset(self):
  1046. self.seqs = []
  1047. self.n_sequences = 0
  1048. self.n_frames = []
  1049. self.strsummary = ''
  1050. def update(self, data_root, seq, data_type, result_root, result_filename):
  1051. assert data_type == 'kitti', "data_type should 'kitti'"
  1052. self.result_root = result_root
  1053. self.gt_path = data_root
  1054. gt_path = '{}/label_02/{}.txt'.format(data_root, seq)
  1055. gt = open(gt_path, "r")
  1056. max_frame = 0
  1057. for line in gt:
  1058. line = line.strip()
  1059. line_list = line.split(" ")
  1060. if int(line_list[0]) > max_frame:
  1061. max_frame = int(line_list[0])
  1062. rs = open(result_filename, "r")
  1063. for line in rs:
  1064. line = line.strip()
  1065. line_list = line.split(" ")
  1066. if int(line_list[0]) > max_frame:
  1067. max_frame = int(line_list[0])
  1068. gt.close()
  1069. rs.close()
  1070. self.n_frames.append(max_frame + 1)
  1071. self.seqs.append(seq)
  1072. self.n_sequences += 1
  1073. def accumulate(self):
  1074. logger.info("Processing Result for KITTI Tracking Benchmark")
  1075. e = self.MOTEvaluator(result_path=self.result_root, gt_path=self.gt_path,\
  1076. n_frames=self.n_frames, seqs=self.seqs, n_sequences=self.n_sequences)
  1077. try:
  1078. if not e.loadTracker():
  1079. return
  1080. logger.info("Loading Results - Success")
  1081. logger.info("Evaluate Object Class: %s" % c.upper())
  1082. except:
  1083. logger.info("Caught exception while loading result data.")
  1084. if not e.loadGroundtruth():
  1085. raise ValueError("Ground truth not found.")
  1086. logger.info("Loading Groundtruth - Success")
  1087. # sanity checks
  1088. if len(e.groundtruth) is not len(e.tracker):
  1089. logger.info(
  1090. "The uploaded data does not provide results for every sequence."
  1091. )
  1092. return False
  1093. logger.info("Loaded %d Sequences." % len(e.groundtruth))
  1094. logger.info("Start Evaluation...")
  1095. if e.compute3rdPartyMetrics():
  1096. self.strsummary = e.saveToStats(self.save_summary)
  1097. else:
  1098. logger.info(
  1099. "There seem to be no true positives or false positives at all in the submitted data."
  1100. )
  1101. def log(self):
  1102. print(self.strsummary)
  1103. def get_results(self):
  1104. return self.strsummary