trajectory.h 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // The code is based on:
  15. // https://github.com/CnybTseng/JDE/blob/master/platforms/common/trajectory.h
  16. // The copyright of CnybTseng/JDE is as follows:
  17. // MIT License
  18. #pragma once
  19. #include "opencv2/video/tracking.hpp"
  20. #include "ultra_infer/ultra_infer_model.h"
  21. #include <opencv2/core/core.hpp>
  22. #include <opencv2/highgui/highgui.hpp>
  23. #include <opencv2/imgproc/imgproc.hpp>
  24. #include <vector>
  25. namespace ultra_infer {
  26. namespace vision {
  27. namespace tracking {
  28. typedef enum { New = 0, Tracked = 1, Lost = 2, Removed = 3 } TrajectoryState;
  29. class Trajectory;
  30. typedef std::vector<Trajectory> TrajectoryPool;
  31. typedef std::vector<Trajectory>::iterator TrajectoryPoolIterator;
  32. typedef std::vector<Trajectory *> TrajectoryPtrPool;
  33. typedef std::vector<Trajectory *>::iterator TrajectoryPtrPoolIterator;
  34. class ULTRAINFER_DECL TKalmanFilter : public cv::KalmanFilter {
  35. public:
  36. TKalmanFilter(void);
  37. virtual ~TKalmanFilter(void) {}
  38. virtual void init(const cv::Mat &measurement);
  39. virtual const cv::Mat &predict();
  40. virtual const cv::Mat &correct(const cv::Mat &measurement);
  41. virtual void project(cv::Mat *mean, cv::Mat *covariance) const;
  42. private:
  43. float std_weight_position;
  44. float std_weight_velocity;
  45. };
  46. inline TKalmanFilter::TKalmanFilter(void) : cv::KalmanFilter(8, 4) {
  47. cv::KalmanFilter::transitionMatrix = cv::Mat::eye(8, 8, CV_32F);
  48. for (int i = 0; i < 4; ++i)
  49. cv::KalmanFilter::transitionMatrix.at<float>(i, i + 4) = 1;
  50. cv::KalmanFilter::measurementMatrix = cv::Mat::eye(4, 8, CV_32F);
  51. std_weight_position = 1 / 20.f;
  52. std_weight_velocity = 1 / 160.f;
  53. }
  54. class ULTRAINFER_DECL Trajectory : public TKalmanFilter {
  55. public:
  56. Trajectory();
  57. Trajectory(const cv::Vec4f &ltrb, float score, const cv::Mat &embedding);
  58. Trajectory(const Trajectory &other);
  59. Trajectory &operator=(const Trajectory &rhs);
  60. virtual ~Trajectory(void) {}
  61. int next_id(int &nt);
  62. virtual const cv::Mat &predict(void);
  63. virtual void update(Trajectory *traj, int timestamp,
  64. bool update_embedding = true);
  65. virtual void activate(int &cnt, int timestamp);
  66. virtual void reactivate(Trajectory *traj, int &cnt, int timestamp,
  67. bool newid = false);
  68. virtual void mark_lost(void);
  69. virtual void mark_removed(void);
  70. friend TrajectoryPool operator+(const TrajectoryPool &a,
  71. const TrajectoryPool &b);
  72. friend TrajectoryPool operator+(const TrajectoryPool &a,
  73. const TrajectoryPtrPool &b);
  74. friend TrajectoryPool &operator+=(TrajectoryPool &a, // NOLINT
  75. const TrajectoryPtrPool &b);
  76. friend TrajectoryPool operator-(const TrajectoryPool &a,
  77. const TrajectoryPool &b);
  78. friend TrajectoryPool &operator-=(TrajectoryPool &a, // NOLINT
  79. const TrajectoryPool &b);
  80. friend TrajectoryPtrPool operator+(const TrajectoryPtrPool &a,
  81. const TrajectoryPtrPool &b);
  82. friend TrajectoryPtrPool operator+(const TrajectoryPtrPool &a,
  83. TrajectoryPool *b);
  84. friend TrajectoryPtrPool operator-(const TrajectoryPtrPool &a,
  85. const TrajectoryPtrPool &b);
  86. friend cv::Mat embedding_distance(const TrajectoryPool &a,
  87. const TrajectoryPool &b);
  88. friend cv::Mat embedding_distance(const TrajectoryPtrPool &a,
  89. const TrajectoryPtrPool &b);
  90. friend cv::Mat embedding_distance(const TrajectoryPtrPool &a,
  91. const TrajectoryPool &b);
  92. friend cv::Mat mahalanobis_distance(const TrajectoryPool &a,
  93. const TrajectoryPool &b);
  94. friend cv::Mat mahalanobis_distance(const TrajectoryPtrPool &a,
  95. const TrajectoryPtrPool &b);
  96. friend cv::Mat mahalanobis_distance(const TrajectoryPtrPool &a,
  97. const TrajectoryPool &b);
  98. friend cv::Mat iou_distance(const TrajectoryPool &a, const TrajectoryPool &b);
  99. friend cv::Mat iou_distance(const TrajectoryPtrPool &a,
  100. const TrajectoryPtrPool &b);
  101. friend cv::Mat iou_distance(const TrajectoryPtrPool &a,
  102. const TrajectoryPool &b);
  103. private:
  104. void update_embedding(const cv::Mat &embedding);
  105. public:
  106. TrajectoryState state;
  107. cv::Vec4f ltrb;
  108. cv::Mat smooth_embedding;
  109. int id;
  110. bool is_activated;
  111. int timestamp;
  112. int starttime;
  113. float score;
  114. private:
  115. // int count=0;
  116. cv::Vec4f xyah;
  117. cv::Mat current_embedding;
  118. float eta;
  119. int length;
  120. };
  121. inline cv::Vec4f ltrb2xyah(const cv::Vec4f &ltrb) {
  122. cv::Vec4f xyah;
  123. xyah[0] = (ltrb[0] + ltrb[2]) * 0.5f;
  124. xyah[1] = (ltrb[1] + ltrb[3]) * 0.5f;
  125. xyah[3] = ltrb[3] - ltrb[1];
  126. xyah[2] = (ltrb[2] - ltrb[0]) / xyah[3];
  127. return xyah;
  128. }
  129. inline Trajectory::Trajectory()
  130. : state(New), ltrb(cv::Vec4f()), smooth_embedding(cv::Mat()), id(0),
  131. is_activated(false), timestamp(0), starttime(0), score(0), eta(0.9),
  132. length(0) {}
  133. inline Trajectory::Trajectory(const cv::Vec4f &ltrb_, float score_,
  134. const cv::Mat &embedding)
  135. : state(New), ltrb(ltrb_), smooth_embedding(cv::Mat()), id(0),
  136. is_activated(false), timestamp(0), starttime(0), score(score_), eta(0.9),
  137. length(0) {
  138. xyah = ltrb2xyah(ltrb);
  139. update_embedding(embedding);
  140. }
  141. inline Trajectory::Trajectory(const Trajectory &other)
  142. : state(other.state), ltrb(other.ltrb), id(other.id),
  143. is_activated(other.is_activated), timestamp(other.timestamp),
  144. starttime(other.starttime), xyah(other.xyah), score(other.score),
  145. eta(other.eta), length(other.length) {
  146. other.smooth_embedding.copyTo(smooth_embedding);
  147. other.current_embedding.copyTo(current_embedding);
  148. // copy state in KalmanFilter
  149. other.statePre.copyTo(cv::KalmanFilter::statePre);
  150. other.statePost.copyTo(cv::KalmanFilter::statePost);
  151. other.errorCovPre.copyTo(cv::KalmanFilter::errorCovPre);
  152. other.errorCovPost.copyTo(cv::KalmanFilter::errorCovPost);
  153. }
  154. inline Trajectory &Trajectory::operator=(const Trajectory &rhs) {
  155. this->state = rhs.state;
  156. this->ltrb = rhs.ltrb;
  157. rhs.smooth_embedding.copyTo(this->smooth_embedding);
  158. this->id = rhs.id;
  159. this->is_activated = rhs.is_activated;
  160. this->timestamp = rhs.timestamp;
  161. this->starttime = rhs.starttime;
  162. this->xyah = rhs.xyah;
  163. this->score = rhs.score;
  164. rhs.current_embedding.copyTo(this->current_embedding);
  165. this->eta = rhs.eta;
  166. this->length = rhs.length;
  167. // copy state in KalmanFilter
  168. rhs.statePre.copyTo(cv::KalmanFilter::statePre);
  169. rhs.statePost.copyTo(cv::KalmanFilter::statePost);
  170. rhs.errorCovPre.copyTo(cv::KalmanFilter::errorCovPre);
  171. rhs.errorCovPost.copyTo(cv::KalmanFilter::errorCovPost);
  172. return *this;
  173. }
  174. inline int Trajectory::next_id(int &cnt) {
  175. ++cnt;
  176. return cnt;
  177. }
  178. inline void Trajectory::mark_lost(void) { state = Lost; }
  179. inline void Trajectory::mark_removed(void) { state = Removed; }
  180. } // namespace tracking
  181. } // namespace vision
  182. } // namespace ultra_infer