base_jde_tracker.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/multitracker.py
  16. """
  17. import numpy as np
  18. from collections import defaultdict
  19. from collections import deque, OrderedDict
  20. from ..matching import jde_matching as matching
  21. from paddlex.ppdet.core.workspace import register, serializable
  22. import warnings
  23. warnings.filterwarnings("ignore")
  24. __all__ = [
  25. 'TrackState',
  26. 'BaseTrack',
  27. 'STrack',
  28. 'joint_stracks',
  29. 'sub_stracks',
  30. 'remove_duplicate_stracks',
  31. ]
  32. class TrackState(object):
  33. New = 0
  34. Tracked = 1
  35. Lost = 2
  36. Removed = 3
  37. @register
  38. @serializable
  39. class BaseTrack(object):
  40. _count_dict = defaultdict(int) # support single class and multi classes
  41. track_id = 0
  42. is_activated = False
  43. state = TrackState.New
  44. history = OrderedDict()
  45. features = []
  46. curr_feature = None
  47. score = 0
  48. start_frame = 0
  49. frame_id = 0
  50. time_since_update = 0
  51. # multi-camera
  52. location = (np.inf, np.inf)
  53. @property
  54. def end_frame(self):
  55. return self.frame_id
  56. @staticmethod
  57. def next_id(cls_id):
  58. BaseTrack._count_dict[cls_id] += 1
  59. return BaseTrack._count_dict[cls_id]
  60. # @even: reset track id
  61. @staticmethod
  62. def init_count(num_classes):
  63. """
  64. Initiate _count for all object classes
  65. :param num_classes:
  66. """
  67. for cls_id in range(num_classes):
  68. BaseTrack._count_dict[cls_id] = 0
  69. @staticmethod
  70. def reset_track_count(cls_id):
  71. BaseTrack._count_dict[cls_id] = 0
  72. def activate(self, *args):
  73. raise NotImplementedError
  74. def predict(self):
  75. raise NotImplementedError
  76. def update(self, *args, **kwargs):
  77. raise NotImplementedError
  78. def mark_lost(self):
  79. self.state = TrackState.Lost
  80. def mark_removed(self):
  81. self.state = TrackState.Removed
  82. @register
  83. @serializable
  84. class STrack(BaseTrack):
  85. def __init__(self, tlwh, score, cls_id, buff_size=30, temp_feat=None):
  86. # wait activate
  87. self._tlwh = np.asarray(tlwh, dtype=np.float)
  88. self.score = score
  89. self.cls_id = cls_id
  90. self.track_len = 0
  91. self.kalman_filter = None
  92. self.mean, self.covariance = None, None
  93. self.is_activated = False
  94. self.use_reid = True if temp_feat is not None else False
  95. if self.use_reid:
  96. self.smooth_feat = None
  97. self.update_features(temp_feat)
  98. self.features = deque([], maxlen=buff_size)
  99. self.alpha = 0.9
  100. def update_features(self, feat):
  101. # L2 normalizing, this function has no use for BYTETracker
  102. feat /= np.linalg.norm(feat)
  103. self.curr_feat = feat
  104. if self.smooth_feat is None:
  105. self.smooth_feat = feat
  106. else:
  107. self.smooth_feat = self.alpha * self.smooth_feat + (
  108. 1.0 - self.alpha) * feat
  109. self.features.append(feat)
  110. self.smooth_feat /= np.linalg.norm(self.smooth_feat)
  111. def predict(self):
  112. mean_state = self.mean.copy()
  113. if self.state != TrackState.Tracked:
  114. mean_state[7] = 0
  115. self.mean, self.covariance = self.kalman_filter.predict(
  116. mean_state, self.covariance)
  117. @staticmethod
  118. def multi_predict(tracks, kalman_filter):
  119. if len(tracks) > 0:
  120. multi_mean = np.asarray([track.mean.copy() for track in tracks])
  121. multi_covariance = np.asarray(
  122. [track.covariance for track in tracks])
  123. for i, st in enumerate(tracks):
  124. if st.state != TrackState.Tracked:
  125. multi_mean[i][7] = 0
  126. multi_mean, multi_covariance = kalman_filter.multi_predict(
  127. multi_mean, multi_covariance)
  128. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  129. tracks[i].mean = mean
  130. tracks[i].covariance = cov
  131. def reset_track_id(self):
  132. self.reset_track_count(self.cls_id)
  133. def activate(self, kalman_filter, frame_id):
  134. """Start a new track"""
  135. self.kalman_filter = kalman_filter
  136. # update track id for the object class
  137. self.track_id = self.next_id(self.cls_id)
  138. self.mean, self.covariance = self.kalman_filter.initiate(
  139. self.tlwh_to_xyah(self._tlwh))
  140. self.track_len = 0
  141. self.state = TrackState.Tracked # set flag 'tracked'
  142. if frame_id == 1: # to record the first frame's detection result
  143. self.is_activated = True
  144. self.frame_id = frame_id
  145. self.start_frame = frame_id
  146. def re_activate(self, new_track, frame_id, new_id=False):
  147. self.mean, self.covariance = self.kalman_filter.update(
  148. self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh))
  149. if self.use_reid:
  150. self.update_features(new_track.curr_feat)
  151. self.track_len = 0
  152. self.state = TrackState.Tracked
  153. self.is_activated = True
  154. self.frame_id = frame_id
  155. if new_id: # update track id for the object class
  156. self.track_id = self.next_id(self.cls_id)
  157. def update(self, new_track, frame_id, update_feature=True):
  158. self.frame_id = frame_id
  159. self.track_len += 1
  160. new_tlwh = new_track.tlwh
  161. self.mean, self.covariance = self.kalman_filter.update(
  162. self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
  163. self.state = TrackState.Tracked # set flag 'tracked'
  164. self.is_activated = True # set flag 'activated'
  165. self.score = new_track.score
  166. if update_feature and self.use_reid:
  167. self.update_features(new_track.curr_feat)
  168. @property
  169. def tlwh(self):
  170. """Get current position in bounding box format `(top left x, top left y,
  171. width, height)`.
  172. """
  173. if self.mean is None:
  174. return self._tlwh.copy()
  175. ret = self.mean[:4].copy()
  176. ret[2] *= ret[3]
  177. ret[:2] -= ret[2:] / 2
  178. return ret
  179. @property
  180. def tlbr(self):
  181. """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
  182. `(top left, bottom right)`.
  183. """
  184. ret = self.tlwh.copy()
  185. ret[2:] += ret[:2]
  186. return ret
  187. @staticmethod
  188. def tlwh_to_xyah(tlwh):
  189. """Convert bounding box to format `(center x, center y, aspect ratio,
  190. height)`, where the aspect ratio is `width / height`.
  191. """
  192. ret = np.asarray(tlwh).copy()
  193. ret[:2] += ret[2:] / 2
  194. ret[2] /= ret[3]
  195. return ret
  196. def to_xyah(self):
  197. return self.tlwh_to_xyah(self.tlwh)
  198. @staticmethod
  199. def tlbr_to_tlwh(tlbr):
  200. ret = np.asarray(tlbr).copy()
  201. ret[2:] -= ret[:2]
  202. return ret
  203. @staticmethod
  204. def tlwh_to_tlbr(tlwh):
  205. ret = np.asarray(tlwh).copy()
  206. ret[2:] += ret[:2]
  207. return ret
  208. def __repr__(self):
  209. return 'OT_({}-{})_({}-{})'.format(self.cls_id, self.track_id,
  210. self.start_frame, self.end_frame)
  211. def joint_stracks(tlista, tlistb):
  212. exists = {}
  213. res = []
  214. for t in tlista:
  215. exists[t.track_id] = 1
  216. res.append(t)
  217. for t in tlistb:
  218. tid = t.track_id
  219. if not exists.get(tid, 0):
  220. exists[tid] = 1
  221. res.append(t)
  222. return res
  223. def sub_stracks(tlista, tlistb):
  224. stracks = {}
  225. for t in tlista:
  226. stracks[t.track_id] = t
  227. for t in tlistb:
  228. tid = t.track_id
  229. if stracks.get(tid, 0):
  230. del stracks[tid]
  231. return list(stracks.values())
  232. def remove_duplicate_stracks(stracksa, stracksb):
  233. pdist = matching.iou_distance(stracksa, stracksb)
  234. pairs = np.where(pdist < 0.15)
  235. dupa, dupb = list(), list()
  236. for p, q in zip(*pairs):
  237. timep = stracksa[p].frame_id - stracksa[p].start_frame
  238. timeq = stracksb[q].frame_id - stracksb[q].start_frame
  239. if timep > timeq:
  240. dupb.append(q)
  241. else:
  242. dupa.append(p)
  243. resa = [t for i, t in enumerate(stracksa) if not i in dupa]
  244. resb = [t for i, t in enumerate(stracksb) if not i in dupb]
  245. return resa, resb