deepsort_tracker.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is borrow from https://github.com/nwojke/deep_sort/blob/master/deep_sort/tracker.py
  16. """
  17. import numpy as np
  18. from ..matching.deepsort_matching import NearestNeighborDistanceMetric
  19. from ..matching.deepsort_matching import iou_cost, min_cost_matching, matching_cascade, gate_cost_matrix
  20. from .base_sde_tracker import Track
  21. from paddlex.ppdet.core.workspace import register, serializable
  22. from paddlex.ppdet.utils.logger import setup_logger
  23. logger = setup_logger(__name__)
  24. __all__ = ['DeepSORTTracker']
  25. @register
  26. @serializable
  27. class DeepSORTTracker(object):
  28. __inject__ = ['motion']
  29. """
  30. DeepSORT tracker
  31. Args:
  32. img_size (list): input image size, [h, w]
  33. budget (int): If not None, fix samples per class to at most this number.
  34. Removes the oldest samples when the budget is reached.
  35. max_age (int): maximum number of missed misses before a track is deleted
  36. n_init (float): Number of frames that a track remains in initialization
  37. phase. Number of consecutive detections before the track is confirmed.
  38. The track state is set to `Deleted` if a miss occurs within the first
  39. `n_init` frames.
  40. metric_type (str): either "euclidean" or "cosine", the distance metric
  41. used for measurement to track association.
  42. matching_threshold (float): samples with larger distance are
  43. considered an invalid match.
  44. max_iou_distance (float): max iou distance threshold
  45. motion (object): KalmanFilter instance
  46. """
  47. def __init__(self,
  48. img_size=[608, 1088],
  49. budget=100,
  50. max_age=30,
  51. n_init=3,
  52. metric_type='cosine',
  53. matching_threshold=0.2,
  54. max_iou_distance=0.7,
  55. motion='KalmanFilter'):
  56. self.img_size = img_size
  57. self.max_age = max_age
  58. self.n_init = n_init
  59. self.metric = NearestNeighborDistanceMetric(metric_type,
  60. matching_threshold, budget)
  61. self.max_iou_distance = max_iou_distance
  62. self.motion = motion
  63. self.tracks = []
  64. self._next_id = 1
  65. def predict(self):
  66. """
  67. Propagate track state distributions one time step forward.
  68. This function should be called once every time step, before `update`.
  69. """
  70. for track in self.tracks:
  71. track.predict(self.motion)
  72. def update(self, detections):
  73. """
  74. Perform measurement update and track management.
  75. Args:
  76. detections (list): List[ppdet.modeling.mot.utils.Detection]
  77. A list of detections at the current time step.
  78. """
  79. # Run matching cascade.
  80. matches, unmatched_tracks, unmatched_detections = \
  81. self._match(detections)
  82. # Update track set.
  83. for track_idx, detection_idx in matches:
  84. self.tracks[track_idx].update(self.motion,
  85. detections[detection_idx])
  86. for track_idx in unmatched_tracks:
  87. self.tracks[track_idx].mark_missed()
  88. for detection_idx in unmatched_detections:
  89. self._initiate_track(detections[detection_idx])
  90. self.tracks = [t for t in self.tracks if not t.is_deleted()]
  91. # Update distance metric.
  92. active_targets = [t.track_id for t in self.tracks if t.is_confirmed()]
  93. features, targets = [], []
  94. for track in self.tracks:
  95. if not track.is_confirmed():
  96. continue
  97. features += track.features
  98. targets += [track.track_id for _ in track.features]
  99. track.features = []
  100. self.metric.partial_fit(
  101. np.asarray(features), np.asarray(targets), active_targets)
  102. output_stracks = self.tracks
  103. return output_stracks
  104. def _match(self, detections):
  105. def gated_metric(tracks, dets, track_indices, detection_indices):
  106. features = np.array([dets[i].feature for i in detection_indices])
  107. targets = np.array([tracks[i].track_id for i in track_indices])
  108. cost_matrix = self.metric.distance(features, targets)
  109. cost_matrix = gate_cost_matrix(self.motion, cost_matrix, tracks,
  110. dets, track_indices,
  111. detection_indices)
  112. return cost_matrix
  113. # Split track set into confirmed and unconfirmed tracks.
  114. confirmed_tracks = [
  115. i for i, t in enumerate(self.tracks) if t.is_confirmed()
  116. ]
  117. unconfirmed_tracks = [
  118. i for i, t in enumerate(self.tracks) if not t.is_confirmed()
  119. ]
  120. # Associate confirmed tracks using appearance features.
  121. matches_a, unmatched_tracks_a, unmatched_detections = \
  122. matching_cascade(
  123. gated_metric, self.metric.matching_threshold, self.max_age,
  124. self.tracks, detections, confirmed_tracks)
  125. # Associate remaining tracks together with unconfirmed tracks using IOU.
  126. iou_track_candidates = unconfirmed_tracks + [
  127. k for k in unmatched_tracks_a
  128. if self.tracks[k].time_since_update == 1
  129. ]
  130. unmatched_tracks_a = [
  131. k for k in unmatched_tracks_a
  132. if self.tracks[k].time_since_update != 1
  133. ]
  134. matches_b, unmatched_tracks_b, unmatched_detections = \
  135. min_cost_matching(
  136. iou_cost, self.max_iou_distance, self.tracks,
  137. detections, iou_track_candidates, unmatched_detections)
  138. matches = matches_a + matches_b
  139. unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))
  140. return matches, unmatched_tracks, unmatched_detections
  141. def _initiate_track(self, detection):
  142. mean, covariance = self.motion.initiate(detection.to_xyah())
  143. self.tracks.append(
  144. Track(mean, covariance, self._next_id, self.n_init, self.max_age,
  145. detection.feature))
  146. self._next_id += 1