tracker.cc 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // The code is based on:
  15. // https://github.com/CnybTseng/JDE/blob/master/platforms/common/jdetracker.cpp
  16. // The copyright of CnybTseng/JDE is as follows:
  17. // MIT License
  18. #include <algorithm>
  19. #include <limits.h>
  20. #include <map>
  21. #include <stdio.h>
  22. #include "ultra_infer/vision/tracking/pptracking/lapjv.h"
  23. #include "ultra_infer/vision/tracking/pptracking/tracker.h"
  24. #define mat2vec4f(m) \
  25. cv::Vec4f(*m.ptr<float>(0, 0), *m.ptr<float>(0, 1), *m.ptr<float>(0, 2), \
  26. *m.ptr<float>(0, 3))
  27. namespace ultra_infer {
  28. namespace vision {
  29. namespace tracking {
  30. static std::map<int, float> chi2inv95 = {
  31. {1, 3.841459f}, {2, 5.991465f}, {3, 7.814728f},
  32. {4, 9.487729f}, {5, 11.070498f}, {6, 12.591587f},
  33. {7, 14.067140f}, {8, 15.507313f}, {9, 16.918978f}};
  34. JDETracker::JDETracker()
  35. : timestamp(0), max_lost_time(30), lambda(0.98f), det_thresh(0.3f) {}
  36. bool JDETracker::update(const cv::Mat &dets, const cv::Mat &emb,
  37. std::vector<Track> *tracks) {
  38. ++timestamp;
  39. TrajectoryPool candidates(dets.rows);
  40. for (int i = 0; i < dets.rows; ++i) {
  41. float score = *dets.ptr<float>(i, 1);
  42. const cv::Mat &ltrb_ = dets(cv::Rect(2, i, 4, 1));
  43. cv::Vec4f ltrb = mat2vec4f(ltrb_);
  44. const cv::Mat &embedding = emb(cv::Rect(0, i, emb.cols, 1));
  45. candidates[i] = Trajectory(ltrb, score, embedding);
  46. }
  47. TrajectoryPtrPool tracked_trajectories;
  48. TrajectoryPtrPool unconfirmed_trajectories;
  49. for (size_t i = 0; i < this->tracked_trajectories.size(); ++i) {
  50. if (this->tracked_trajectories[i].is_activated)
  51. tracked_trajectories.push_back(&this->tracked_trajectories[i]);
  52. else
  53. unconfirmed_trajectories.push_back(&this->tracked_trajectories[i]);
  54. }
  55. TrajectoryPtrPool trajectory_pool =
  56. tracked_trajectories + &(this->lost_trajectories);
  57. for (size_t i = 0; i < trajectory_pool.size(); ++i)
  58. trajectory_pool[i]->predict();
  59. Match matches;
  60. std::vector<int> mismatch_row;
  61. std::vector<int> mismatch_col;
  62. cv::Mat cost = motion_distance(trajectory_pool, candidates);
  63. linear_assignment(cost, 0.7f, &matches, &mismatch_row, &mismatch_col);
  64. MatchIterator miter;
  65. TrajectoryPtrPool activated_trajectories;
  66. TrajectoryPtrPool retrieved_trajectories;
  67. for (miter = matches.begin(); miter != matches.end(); miter++) {
  68. Trajectory *pt = trajectory_pool[miter->first];
  69. Trajectory &ct = candidates[miter->second];
  70. if (pt->state == Tracked) {
  71. pt->update(&ct, timestamp);
  72. activated_trajectories.push_back(pt);
  73. } else {
  74. pt->reactivate(&ct, count, timestamp);
  75. retrieved_trajectories.push_back(pt);
  76. }
  77. }
  78. TrajectoryPtrPool next_candidates(mismatch_col.size());
  79. for (size_t i = 0; i < mismatch_col.size(); ++i)
  80. next_candidates[i] = &candidates[mismatch_col[i]];
  81. TrajectoryPtrPool next_trajectory_pool;
  82. for (size_t i = 0; i < mismatch_row.size(); ++i) {
  83. int j = mismatch_row[i];
  84. if (trajectory_pool[j]->state == Tracked)
  85. next_trajectory_pool.push_back(trajectory_pool[j]);
  86. }
  87. cost = iou_distance(next_trajectory_pool, next_candidates);
  88. linear_assignment(cost, 0.5f, &matches, &mismatch_row, &mismatch_col);
  89. for (miter = matches.begin(); miter != matches.end(); miter++) {
  90. Trajectory *pt = next_trajectory_pool[miter->first];
  91. Trajectory *ct = next_candidates[miter->second];
  92. if (pt->state == Tracked) {
  93. pt->update(ct, timestamp);
  94. activated_trajectories.push_back(pt);
  95. } else {
  96. pt->reactivate(ct, count, timestamp);
  97. retrieved_trajectories.push_back(pt);
  98. }
  99. }
  100. TrajectoryPtrPool lost_trajectories;
  101. for (size_t i = 0; i < mismatch_row.size(); ++i) {
  102. Trajectory *pt = next_trajectory_pool[mismatch_row[i]];
  103. if (pt->state != Lost) {
  104. pt->mark_lost();
  105. lost_trajectories.push_back(pt);
  106. }
  107. }
  108. TrajectoryPtrPool nnext_candidates(mismatch_col.size());
  109. for (size_t i = 0; i < mismatch_col.size(); ++i)
  110. nnext_candidates[i] = next_candidates[mismatch_col[i]];
  111. cost = iou_distance(unconfirmed_trajectories, nnext_candidates);
  112. linear_assignment(cost, 0.7f, &matches, &mismatch_row, &mismatch_col);
  113. for (miter = matches.begin(); miter != matches.end(); miter++) {
  114. unconfirmed_trajectories[miter->first]->update(
  115. nnext_candidates[miter->second], timestamp);
  116. activated_trajectories.push_back(unconfirmed_trajectories[miter->first]);
  117. }
  118. TrajectoryPtrPool removed_trajectories;
  119. for (size_t i = 0; i < mismatch_row.size(); ++i) {
  120. unconfirmed_trajectories[mismatch_row[i]]->mark_removed();
  121. removed_trajectories.push_back(unconfirmed_trajectories[mismatch_row[i]]);
  122. }
  123. for (size_t i = 0; i < mismatch_col.size(); ++i) {
  124. if (nnext_candidates[mismatch_col[i]]->score < det_thresh)
  125. continue;
  126. nnext_candidates[mismatch_col[i]]->activate(count, timestamp);
  127. activated_trajectories.push_back(nnext_candidates[mismatch_col[i]]);
  128. }
  129. for (size_t i = 0; i < this->lost_trajectories.size(); ++i) {
  130. Trajectory &lt = this->lost_trajectories[i];
  131. if (timestamp - lt.timestamp > max_lost_time) {
  132. lt.mark_removed();
  133. removed_trajectories.push_back(&lt);
  134. }
  135. }
  136. TrajectoryPoolIterator piter;
  137. for (piter = this->tracked_trajectories.begin();
  138. piter != this->tracked_trajectories.end();) {
  139. if (piter->state != Tracked)
  140. piter = this->tracked_trajectories.erase(piter);
  141. else
  142. ++piter;
  143. }
  144. this->tracked_trajectories += activated_trajectories;
  145. this->tracked_trajectories += retrieved_trajectories;
  146. this->lost_trajectories -= this->tracked_trajectories;
  147. this->lost_trajectories += lost_trajectories;
  148. this->lost_trajectories -= this->removed_trajectories;
  149. this->removed_trajectories += removed_trajectories;
  150. remove_duplicate_trajectory(&this->tracked_trajectories,
  151. &this->lost_trajectories);
  152. tracks->clear();
  153. for (size_t i = 0; i < this->tracked_trajectories.size(); ++i) {
  154. if (this->tracked_trajectories[i].is_activated) {
  155. Track track = {this->tracked_trajectories[i].id,
  156. this->tracked_trajectories[i].score,
  157. this->tracked_trajectories[i].ltrb};
  158. tracks->push_back(track);
  159. }
  160. }
  161. return 0;
  162. }
  163. cv::Mat JDETracker::motion_distance(const TrajectoryPtrPool &a,
  164. const TrajectoryPool &b) {
  165. if (0 == a.size() || 0 == b.size())
  166. return cv::Mat(a.size(), b.size(), CV_32F);
  167. cv::Mat edists = embedding_distance(a, b);
  168. cv::Mat mdists = mahalanobis_distance(a, b);
  169. cv::Mat fdists = lambda * edists + (1 - lambda) * mdists;
  170. const float gate_thresh = chi2inv95[4];
  171. for (int i = 0; i < fdists.rows; ++i) {
  172. for (int j = 0; j < fdists.cols; ++j) {
  173. if (*mdists.ptr<float>(i, j) > gate_thresh)
  174. *fdists.ptr<float>(i, j) = FLT_MAX;
  175. }
  176. }
  177. return fdists;
  178. }
  179. void JDETracker::linear_assignment(const cv::Mat &cost, float cost_limit,
  180. Match *matches,
  181. std::vector<int> *mismatch_row,
  182. std::vector<int> *mismatch_col) {
  183. matches->clear();
  184. mismatch_row->clear();
  185. mismatch_col->clear();
  186. if (cost.empty()) {
  187. for (int i = 0; i < cost.rows; ++i)
  188. mismatch_row->push_back(i);
  189. for (int i = 0; i < cost.cols; ++i)
  190. mismatch_col->push_back(i);
  191. return;
  192. }
  193. float opt = 0;
  194. cv::Mat x(cost.rows, 1, CV_32S);
  195. cv::Mat y(cost.cols, 1, CV_32S);
  196. lapjv_internal(cost, true, cost_limit, reinterpret_cast<int *>(x.data),
  197. reinterpret_cast<int *>(y.data));
  198. for (int i = 0; i < x.rows; ++i) {
  199. int j = *x.ptr<int>(i);
  200. if (j >= 0)
  201. matches->insert({i, j});
  202. else
  203. mismatch_row->push_back(i);
  204. }
  205. for (int i = 0; i < y.rows; ++i) {
  206. int j = *y.ptr<int>(i);
  207. if (j < 0)
  208. mismatch_col->push_back(i);
  209. }
  210. return;
  211. }
  212. void JDETracker::remove_duplicate_trajectory(TrajectoryPool *a,
  213. TrajectoryPool *b,
  214. float iou_thresh) {
  215. if (a->size() == 0 || b->size() == 0)
  216. return;
  217. cv::Mat dist = iou_distance(*a, *b);
  218. cv::Mat mask = dist < iou_thresh;
  219. std::vector<cv::Point> idx;
  220. cv::findNonZero(mask, idx);
  221. std::vector<int> da;
  222. std::vector<int> db;
  223. for (size_t i = 0; i < idx.size(); ++i) {
  224. int ta = (*a)[idx[i].y].timestamp - (*a)[idx[i].y].starttime;
  225. int tb = (*b)[idx[i].x].timestamp - (*b)[idx[i].x].starttime;
  226. if (ta > tb)
  227. db.push_back(idx[i].x);
  228. else
  229. da.push_back(idx[i].y);
  230. }
  231. int id = 0;
  232. TrajectoryPoolIterator piter;
  233. for (piter = a->begin(); piter != a->end();) {
  234. std::vector<int>::iterator iter = find(da.begin(), da.end(), id++);
  235. if (iter != da.end())
  236. piter = a->erase(piter);
  237. else
  238. ++piter;
  239. }
  240. id = 0;
  241. for (piter = b->begin(); piter != b->end();) {
  242. std::vector<int>::iterator iter = find(db.begin(), db.end(), id++);
  243. if (iter != db.end())
  244. piter = b->erase(piter);
  245. else
  246. ++piter;
  247. }
  248. }
  249. } // namespace tracking
  250. } // namespace vision
  251. } // namespace ultra_infer