tracker.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import cv2
  19. import glob
  20. import paddle
  21. import numpy as np
  22. from paddlex.ppdet.core.workspace import create
  23. from paddlex.ppdet.utils.checkpoint import load_weight, load_pretrain_weight
  24. from paddlex.ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_box
  25. from paddlex.ppdet.modeling.mot.utils import Timer, load_det_results
  26. from paddlex.ppdet.modeling.mot import visualization as mot_vis
  27. from paddlex.ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric
  28. from paddlex.ppdet.utils import stats
  29. from .callbacks import Callback, ComposeCallback
  30. from paddlex.ppdet.utils.logger import setup_logger
  31. logger = setup_logger(__name__)
  32. __all__ = ['Tracker']
  33. class Tracker(object):
  34. def __init__(self, cfg, mode='eval'):
  35. self.cfg = cfg
  36. assert mode.lower() in ['test', 'eval'], \
  37. "mode should be 'test' or 'eval'"
  38. self.mode = mode.lower()
  39. self.optimizer = None
  40. # build MOT data loader
  41. self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())]
  42. # build model
  43. self.model = create(cfg.architecture)
  44. self.status = {}
  45. self.start_epoch = 0
  46. # initial default callbacks
  47. self._init_callbacks()
  48. # initial default metrics
  49. self._init_metrics()
  50. self._reset_metrics()
  51. def _init_callbacks(self):
  52. self._callbacks = []
  53. self._compose_callback = None
  54. def _init_metrics(self):
  55. if self.mode in ['test']:
  56. self._metrics = []
  57. return
  58. if self.cfg.metric == 'MOT':
  59. self._metrics = [MOTMetric(), ]
  60. elif self.cfg.metric == 'KITTI':
  61. self._metrics = [KITTIMOTMetric(), ]
  62. else:
  63. logger.warning("Metric not support for metric type {}".format(
  64. self.cfg.metric))
  65. self._metrics = []
  66. def _reset_metrics(self):
  67. for metric in self._metrics:
  68. metric.reset()
  69. def register_callbacks(self, callbacks):
  70. callbacks = [h for h in list(callbacks) if h is not None]
  71. for c in callbacks:
  72. assert isinstance(c, Callback), \
  73. "metrics shoule be instances of subclass of Metric"
  74. self._callbacks.extend(callbacks)
  75. self._compose_callback = ComposeCallback(self._callbacks)
  76. def register_metrics(self, metrics):
  77. metrics = [m for m in list(metrics) if m is not None]
  78. for m in metrics:
  79. assert isinstance(m, Metric), \
  80. "metrics shoule be instances of subclass of Metric"
  81. self._metrics.extend(metrics)
  82. def load_weights_jde(self, weights):
  83. load_weight(self.model, weights, self.optimizer)
  84. def load_weights_sde(self, det_weights, reid_weights):
  85. if self.model.detector:
  86. load_weight(self.model.detector, det_weights)
  87. load_weight(self.model.reid, reid_weights)
  88. else:
  89. load_weight(self.model.reid, reid_weights, self.optimizer)
  90. def _eval_seq_jde(self,
  91. dataloader,
  92. save_dir=None,
  93. show_image=False,
  94. frame_rate=30,
  95. draw_threshold=0):
  96. if save_dir:
  97. if not os.path.exists(save_dir): os.makedirs(save_dir)
  98. tracker = self.model.tracker
  99. tracker.max_time_lost = int(frame_rate / 30.0 * tracker.track_buffer)
  100. timer = Timer()
  101. results = []
  102. frame_id = 0
  103. self.status['mode'] = 'track'
  104. self.model.eval()
  105. for step_id, data in enumerate(dataloader):
  106. self.status['step_id'] = step_id
  107. if frame_id % 40 == 0:
  108. logger.info('Processing frame {} ({:.2f} fps)'.format(
  109. frame_id, 1. / max(1e-5, timer.average_time)))
  110. # forward
  111. timer.tic()
  112. pred_dets, pred_embs = self.model(data)
  113. online_targets = self.model.tracker.update(pred_dets, pred_embs)
  114. online_tlwhs, online_ids = [], []
  115. online_scores = []
  116. for t in online_targets:
  117. tlwh = t.tlwh
  118. tid = t.track_id
  119. tscore = t.score
  120. if tscore < draw_threshold: continue
  121. vertical = tlwh[2] / tlwh[3] > 1.6
  122. if tlwh[2] * tlwh[3] > tracker.min_box_area and not vertical:
  123. online_tlwhs.append(tlwh)
  124. online_ids.append(tid)
  125. online_scores.append(tscore)
  126. timer.toc()
  127. # save results
  128. results.append(
  129. (frame_id + 1, online_tlwhs, online_scores, online_ids))
  130. self.save_results(data, frame_id, online_ids, online_tlwhs,
  131. online_scores, timer.average_time, show_image,
  132. save_dir)
  133. frame_id += 1
  134. return results, frame_id, timer.average_time, timer.calls
  135. def _eval_seq_sde(self,
  136. dataloader,
  137. save_dir=None,
  138. show_image=False,
  139. frame_rate=30,
  140. det_file='',
  141. draw_threshold=0):
  142. if save_dir:
  143. if not os.path.exists(save_dir): os.makedirs(save_dir)
  144. tracker = self.model.tracker
  145. use_detector = False if not self.model.detector else True
  146. timer = Timer()
  147. results = []
  148. frame_id = 0
  149. self.status['mode'] = 'track'
  150. self.model.eval()
  151. self.model.reid.eval()
  152. if not use_detector:
  153. dets_list = load_det_results(det_file, len(dataloader))
  154. logger.info('Finish loading detection results file {}.'.format(
  155. det_file))
  156. for step_id, data in enumerate(dataloader):
  157. self.status['step_id'] = step_id
  158. if frame_id % 40 == 0:
  159. logger.info('Processing frame {} ({:.2f} fps)'.format(
  160. frame_id, 1. / max(1e-5, timer.average_time)))
  161. ori_image = data['ori_image']
  162. input_shape = data['image'].shape[2:]
  163. im_shape = data['im_shape']
  164. scale_factor = data['scale_factor']
  165. timer.tic()
  166. if not use_detector:
  167. dets = dets_list[frame_id]
  168. bbox_tlwh = paddle.to_tensor(dets['bbox'], dtype='float32')
  169. pred_scores = paddle.to_tensor(dets['score'], dtype='float32')
  170. if pred_scores < draw_threshold: continue
  171. if bbox_tlwh.shape[0] > 0:
  172. pred_bboxes = paddle.concat(
  173. (bbox_tlwh[:, 0:2],
  174. bbox_tlwh[:, 2:4] + bbox_tlwh[:, 0:2]),
  175. axis=1)
  176. else:
  177. pred_bboxes = []
  178. pred_scores = []
  179. else:
  180. outs = self.model.detector(data)
  181. if outs['bbox_num'] > 0:
  182. pred_bboxes = scale_coords(outs['bbox'][:, 2:],
  183. input_shape, im_shape,
  184. scale_factor)
  185. pred_scores = outs['bbox'][:, 1:2]
  186. else:
  187. pred_bboxes = []
  188. pred_scores = []
  189. pred_bboxes = clip_box(pred_bboxes, input_shape, im_shape,
  190. scale_factor)
  191. bbox_tlwh = paddle.concat(
  192. (pred_bboxes[:, 0:2],
  193. pred_bboxes[:, 2:4] - pred_bboxes[:, 0:2] + 1),
  194. axis=1)
  195. crops, pred_scores = get_crops(
  196. pred_bboxes, ori_image, pred_scores, w=64, h=192)
  197. crops = paddle.to_tensor(crops)
  198. pred_scores = paddle.to_tensor(pred_scores)
  199. data.update({'crops': crops})
  200. features = self.model(data)
  201. features = features.numpy()
  202. detections = [
  203. Detection(tlwh, score, feat)
  204. for tlwh, score, feat in zip(bbox_tlwh, pred_scores, features)
  205. ]
  206. self.model.tracker.predict()
  207. online_targets = self.model.tracker.update(detections)
  208. online_tlwhs = []
  209. online_scores = []
  210. online_ids = []
  211. for track in online_targets:
  212. if not track.is_confirmed() or track.time_since_update > 1:
  213. continue
  214. online_tlwhs.append(track.to_tlwh())
  215. online_scores.append(1.0)
  216. online_ids.append(track.track_id)
  217. timer.toc()
  218. # save results
  219. results.append(
  220. (frame_id + 1, online_tlwhs, online_scores, online_ids))
  221. self.save_results(data, frame_id, online_ids, online_tlwhs,
  222. online_scores, timer.average_time, show_image,
  223. save_dir)
  224. frame_id += 1
  225. return results, frame_id, timer.average_time, timer.calls
  226. def mot_evaluate(self,
  227. data_root,
  228. seqs,
  229. output_dir,
  230. data_type='mot',
  231. model_type='JDE',
  232. save_images=False,
  233. save_videos=False,
  234. show_image=False,
  235. det_results_dir=''):
  236. if not os.path.exists(output_dir): os.makedirs(output_dir)
  237. result_root = os.path.join(output_dir, 'mot_results')
  238. if not os.path.exists(result_root): os.makedirs(result_root)
  239. assert data_type in ['mot', 'kitti'], \
  240. "data_type should be 'mot' or 'kitti'"
  241. assert model_type in ['JDE', 'DeepSORT', 'FairMOT'], \
  242. "model_type should be 'JDE', 'DeepSORT' or 'FairMOT'"
  243. # run tracking
  244. n_frame = 0
  245. timer_avgs, timer_calls = [], []
  246. for seq in seqs:
  247. if not os.path.isdir(os.path.join(data_root, seq)):
  248. continue
  249. infer_dir = os.path.join(data_root, seq, 'img1')
  250. seqinfo = os.path.join(data_root, seq, 'seqinfo.ini')
  251. if not os.path.exists(seqinfo) or not os.path.exists(
  252. infer_dir) or not os.path.isdir(infer_dir):
  253. continue
  254. save_dir = os.path.join(
  255. output_dir, 'mot_outputs',
  256. seq) if save_images or save_videos else None
  257. logger.info('start seq: {}'.format(seq))
  258. images = self.get_infer_images(infer_dir)
  259. self.dataset.set_images(images)
  260. dataloader = create('EvalMOTReader')(self.dataset, 0)
  261. result_filename = os.path.join(result_root, '{}.txt'.format(seq))
  262. meta_info = open(seqinfo).read()
  263. frame_rate = int(meta_info[meta_info.find('frameRate') + 10:
  264. meta_info.find('\nseqLength')])
  265. with paddle.no_grad():
  266. if model_type in ['JDE', 'FairMOT']:
  267. results, nf, ta, tc = self._eval_seq_jde(
  268. dataloader,
  269. save_dir=save_dir,
  270. show_image=show_image,
  271. frame_rate=frame_rate)
  272. elif model_type in ['DeepSORT']:
  273. results, nf, ta, tc = self._eval_seq_sde(
  274. dataloader,
  275. save_dir=save_dir,
  276. show_image=show_image,
  277. frame_rate=frame_rate,
  278. det_file=os.path.join(det_results_dir,
  279. '{}.txt'.format(seq)))
  280. else:
  281. raise ValueError(model_type)
  282. self.write_mot_results(result_filename, results, data_type)
  283. n_frame += nf
  284. timer_avgs.append(ta)
  285. timer_calls.append(tc)
  286. if save_videos:
  287. output_video_path = os.path.join(save_dir, '..',
  288. '{}_vis.mp4'.format(seq))
  289. cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(
  290. save_dir, output_video_path)
  291. os.system(cmd_str)
  292. logger.info('Save video in {}.'.format(output_video_path))
  293. logger.info('Evaluate seq: {}'.format(seq))
  294. # update metrics
  295. for metric in self._metrics:
  296. metric.update(data_root, seq, data_type, result_root,
  297. result_filename)
  298. timer_avgs = np.asarray(timer_avgs)
  299. timer_calls = np.asarray(timer_calls)
  300. all_time = np.dot(timer_avgs, timer_calls)
  301. avg_time = all_time / np.sum(timer_calls)
  302. logger.info('Time elapsed: {:.2f} seconds, FPS: {:.2f}'.format(
  303. all_time, 1.0 / avg_time))
  304. # accumulate metric to log out
  305. for metric in self._metrics:
  306. metric.accumulate()
  307. metric.log()
  308. # reset metric states for metric may performed multiple times
  309. self._reset_metrics()
  310. def get_infer_images(self, infer_dir):
  311. assert infer_dir is None or os.path.isdir(infer_dir), \
  312. "{} is not a directory".format(infer_dir)
  313. images = set()
  314. assert os.path.isdir(infer_dir), \
  315. "infer_dir {} is not a directory".format(infer_dir)
  316. exts = ['jpg', 'jpeg', 'png', 'bmp']
  317. exts += [ext.upper() for ext in exts]
  318. for ext in exts:
  319. images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
  320. images = list(images)
  321. images.sort()
  322. assert len(images) > 0, "no image found in {}".format(infer_dir)
  323. logger.info("Found {} inference images in total.".format(len(images)))
  324. return images
  325. def mot_predict(self,
  326. video_file,
  327. frame_rate,
  328. image_dir,
  329. output_dir,
  330. data_type='mot',
  331. model_type='JDE',
  332. save_images=False,
  333. save_videos=True,
  334. show_image=False,
  335. det_results_dir='',
  336. draw_threshold=0.5):
  337. assert video_file is not None or image_dir is not None, \
  338. "--video_file or --image_dir should be set."
  339. assert video_file is None or os.path.isfile(video_file), \
  340. "{} is not a file".format(video_file)
  341. assert image_dir is None or os.path.isdir(image_dir), \
  342. "{} is not a directory".format(image_dir)
  343. if not os.path.exists(output_dir): os.makedirs(output_dir)
  344. result_root = os.path.join(output_dir, 'mot_results')
  345. if not os.path.exists(result_root): os.makedirs(result_root)
  346. assert data_type in ['mot', 'kitti'], \
  347. "data_type should be 'mot' or 'kitti'"
  348. assert model_type in ['JDE', 'DeepSORT', 'FairMOT'], \
  349. "model_type should be 'JDE', 'DeepSORT' or 'FairMOT'"
  350. # run tracking
  351. if video_file:
  352. seq = video_file.split('/')[-1].split('.')[0]
  353. self.dataset.set_video(video_file, frame_rate)
  354. logger.info('Starting tracking video {}'.format(video_file))
  355. elif image_dir:
  356. seq = image_dir.split('/')[-1].split('.')[0]
  357. images = [
  358. '{}/{}'.format(image_dir, x) for x in os.listdir(image_dir)
  359. ]
  360. images.sort()
  361. self.dataset.set_images(images)
  362. logger.info('Starting tracking folder {}, found {} images'.format(
  363. image_dir, len(images)))
  364. else:
  365. raise ValueError('--video_file or --image_dir should be set.')
  366. save_dir = os.path.join(output_dir, 'mot_outputs',
  367. seq) if save_images or save_videos else None
  368. dataloader = create('TestMOTReader')(self.dataset, 0)
  369. result_filename = os.path.join(result_root, '{}.txt'.format(seq))
  370. if frame_rate == -1:
  371. frame_rate = self.dataset.frame_rate
  372. with paddle.no_grad():
  373. if model_type in ['JDE', 'FairMOT']:
  374. results, nf, ta, tc = self._eval_seq_jde(
  375. dataloader,
  376. save_dir=save_dir,
  377. show_image=show_image,
  378. frame_rate=frame_rate,
  379. draw_threshold=draw_threshold)
  380. elif model_type in ['DeepSORT']:
  381. results, nf, ta, tc = self._eval_seq_sde(
  382. dataloader,
  383. save_dir=save_dir,
  384. show_image=show_image,
  385. frame_rate=frame_rate,
  386. det_file=os.path.join(det_results_dir,
  387. '{}.txt'.format(seq)),
  388. draw_threshold=draw_threshold)
  389. else:
  390. raise ValueError(model_type)
  391. self.write_mot_results(result_filename, results, data_type)
  392. if save_videos:
  393. output_video_path = os.path.join(save_dir, '..',
  394. '{}_vis.mp4'.format(seq))
  395. cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(
  396. save_dir, output_video_path)
  397. os.system(cmd_str)
  398. logger.info('Save video in {}'.format(output_video_path))
  399. def write_mot_results(self, filename, results, data_type='mot'):
  400. if data_type in ['mot', 'mcmot', 'lab']:
  401. save_format = '{frame},{id},{x1},{y1},{w},{h},{score},-1,-1,-1\n'
  402. elif data_type == 'kitti':
  403. save_format = '{frame} {id} car 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
  404. else:
  405. raise ValueError(data_type)
  406. with open(filename, 'w') as f:
  407. for frame_id, tlwhs, tscores, track_ids in results:
  408. if data_type == 'kitti':
  409. frame_id -= 1
  410. for tlwh, score, track_id in zip(tlwhs, tscores, track_ids):
  411. if track_id < 0:
  412. continue
  413. x1, y1, w, h = tlwh
  414. x2, y2 = x1 + w, y1 + h
  415. line = save_format.format(
  416. frame=frame_id,
  417. id=track_id,
  418. x1=x1,
  419. y1=y1,
  420. x2=x2,
  421. y2=y2,
  422. w=w,
  423. h=h,
  424. score=score)
  425. f.write(line)
  426. logger.info('MOT results save in {}'.format(filename))
  427. def save_results(self, data, frame_id, online_ids, online_tlwhs,
  428. online_scores, average_time, show_image, save_dir):
  429. if show_image or save_dir is not None:
  430. assert 'ori_image' in data
  431. img0 = data['ori_image'].numpy()[0]
  432. online_im = mot_vis.plot_tracking(
  433. img0,
  434. online_tlwhs,
  435. online_ids,
  436. online_scores,
  437. frame_id=frame_id,
  438. fps=1. / average_time)
  439. if show_image:
  440. cv2.imshow('online_im', online_im)
  441. if save_dir is not None:
  442. cv2.imwrite(
  443. os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
  444. online_im)