model.cc 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/vision/tracking/pptracking/model.h"
  15. #include "ultra_infer/vision/tracking/pptracking/letter_box_resize.h"
  16. #include "yaml-cpp/yaml.h"
  17. namespace ultra_infer {
  18. namespace vision {
  19. namespace tracking {
  20. PPTracking::PPTracking(const std::string &model_file,
  21. const std::string &params_file,
  22. const std::string &config_file,
  23. const RuntimeOption &custom_option,
  24. const ModelFormat &model_format) {
  25. config_file_ = config_file;
  26. valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
  27. valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
  28. runtime_option = custom_option;
  29. runtime_option.model_format = model_format;
  30. runtime_option.model_file = model_file;
  31. runtime_option.params_file = params_file;
  32. initialized = Initialize();
  33. }
  34. bool PPTracking::BuildPreprocessPipelineFromConfig() {
  35. processors_.clear();
  36. YAML::Node cfg;
  37. try {
  38. cfg = YAML::LoadFile(config_file_);
  39. } catch (YAML::BadFile &e) {
  40. FDERROR << "Failed to load yaml file " << config_file_
  41. << ", maybe you should check this file." << std::endl;
  42. return false;
  43. }
  44. // Get draw_threshold for visualization
  45. if (cfg["draw_threshold"].IsDefined()) {
  46. draw_threshold_ = cfg["draw_threshold"].as<float>();
  47. } else {
  48. FDERROR << "Please set draw_threshold." << std::endl;
  49. return false;
  50. }
  51. // Get config for tracker
  52. if (cfg["tracker"].IsDefined()) {
  53. if (cfg["tracker"]["conf_thres"].IsDefined()) {
  54. conf_thresh_ = cfg["tracker"]["conf_thres"].as<float>();
  55. } else {
  56. std::cerr << "Please set conf_thres in tracker." << std::endl;
  57. return false;
  58. }
  59. if (cfg["tracker"]["min_box_area"].IsDefined()) {
  60. min_box_area_ = cfg["tracker"]["min_box_area"].as<float>();
  61. }
  62. if (cfg["tracker"]["tracked_thresh"].IsDefined()) {
  63. tracked_thresh_ = cfg["tracker"]["tracked_thresh"].as<float>();
  64. }
  65. }
  66. processors_.push_back(std::make_shared<BGR2RGB>());
  67. for (const auto &op : cfg["Preprocess"]) {
  68. std::string op_name = op["type"].as<std::string>();
  69. if (op_name == "Resize") {
  70. bool keep_ratio = op["keep_ratio"].as<bool>();
  71. auto target_size = op["target_size"].as<std::vector<int>>();
  72. int interp = op["interp"].as<int>();
  73. FDASSERT(target_size.size() == 2,
  74. "Require size of target_size be 2, but now it's %lu.",
  75. target_size.size());
  76. if (!keep_ratio) {
  77. int width = target_size[1];
  78. int height = target_size[0];
  79. processors_.push_back(
  80. std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
  81. } else {
  82. int min_target_size = std::min(target_size[0], target_size[1]);
  83. int max_target_size = std::max(target_size[0], target_size[1]);
  84. std::vector<int> max_size;
  85. if (max_target_size > 0) {
  86. max_size.push_back(max_target_size);
  87. max_size.push_back(max_target_size);
  88. }
  89. processors_.push_back(std::make_shared<ResizeByShort>(
  90. min_target_size, interp, true, max_size));
  91. }
  92. } else if (op_name == "LetterBoxResize") {
  93. auto target_size = op["target_size"].as<std::vector<int>>();
  94. FDASSERT(target_size.size() == 2,
  95. "Require size of target_size be 2, but now it's %lu.",
  96. target_size.size());
  97. std::vector<float> color{127.0f, 127.0f, 127.0f};
  98. if (op["fill_value"].IsDefined()) {
  99. color = op["fill_value"].as<std::vector<float>>();
  100. }
  101. processors_.push_back(
  102. std::make_shared<LetterBoxResize>(target_size, color));
  103. } else if (op_name == "NormalizeImage") {
  104. auto mean = op["mean"].as<std::vector<float>>();
  105. auto std = op["std"].as<std::vector<float>>();
  106. bool is_scale = true;
  107. if (op["is_scale"]) {
  108. is_scale = op["is_scale"].as<bool>();
  109. }
  110. std::string norm_type = "mean_std";
  111. if (op["norm_type"]) {
  112. norm_type = op["norm_type"].as<std::string>();
  113. }
  114. if (norm_type != "mean_std") {
  115. std::fill(mean.begin(), mean.end(), 0.0);
  116. std::fill(std.begin(), std.end(), 1.0);
  117. }
  118. processors_.push_back(std::make_shared<Normalize>(mean, std, is_scale));
  119. } else if (op_name == "Permute") {
  120. // Do nothing, do permute as the last operation
  121. continue;
  122. // processors_.push_back(std::make_shared<HWC2CHW>());
  123. } else if (op_name == "Pad") {
  124. auto size = op["size"].as<std::vector<int>>();
  125. auto value = op["fill_value"].as<std::vector<float>>();
  126. processors_.push_back(std::make_shared<Cast>("float"));
  127. processors_.push_back(
  128. std::make_shared<PadToSize>(size[1], size[0], value));
  129. } else if (op_name == "PadStride") {
  130. auto stride = op["stride"].as<int>();
  131. processors_.push_back(
  132. std::make_shared<StridePad>(stride, std::vector<float>(3, 0)));
  133. } else {
  134. FDERROR << "Unexcepted preprocess operator: " << op_name << "."
  135. << std::endl;
  136. return false;
  137. }
  138. }
  139. processors_.push_back(std::make_shared<HWC2CHW>());
  140. FuseTransforms(&processors_);
  141. return true;
  142. }
  143. bool PPTracking::Initialize() {
  144. if (!BuildPreprocessPipelineFromConfig()) {
  145. FDERROR << "Failed to build preprocess pipeline from configuration file."
  146. << std::endl;
  147. return false;
  148. }
  149. if (!InitRuntime()) {
  150. FDERROR << "Failed to initialize ultra_infer backend." << std::endl;
  151. return false;
  152. }
  153. // create JDETracker instance
  154. jdeTracker_ = std::unique_ptr<JDETracker>(new JDETracker);
  155. return true;
  156. }
  157. bool PPTracking::Predict(cv::Mat *img, MOTResult *result) {
  158. Mat mat(*img);
  159. std::vector<FDTensor> input_tensors;
  160. if (!Preprocess(&mat, &input_tensors)) {
  161. FDERROR << "Failed to preprocess input image." << std::endl;
  162. return false;
  163. }
  164. std::vector<FDTensor> output_tensors;
  165. if (!Infer(input_tensors, &output_tensors)) {
  166. FDERROR << "Failed to inference." << std::endl;
  167. return false;
  168. }
  169. if (!Postprocess(output_tensors, result)) {
  170. FDERROR << "Failed to post process." << std::endl;
  171. return false;
  172. }
  173. return true;
  174. }
  175. bool PPTracking::Preprocess(Mat *mat, std::vector<FDTensor> *outputs) {
  176. int origin_w = mat->Width();
  177. int origin_h = mat->Height();
  178. for (size_t i = 0; i < processors_.size(); ++i) {
  179. if (!(*(processors_[i].get()))(mat)) {
  180. FDERROR << "Failed to process image data in " << processors_[i]->Name()
  181. << "." << std::endl;
  182. return false;
  183. }
  184. }
  185. // LetterBoxResize(mat);
  186. // Normalize::Run(mat,mean_,scale_,is_scale_);
  187. // HWC2CHW::Run(mat);
  188. Cast::Run(mat, "float");
  189. outputs->resize(3);
  190. // image_shape
  191. (*outputs)[0].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(0).name);
  192. float *shape = static_cast<float *>((*outputs)[0].MutableData());
  193. shape[0] = mat->Height();
  194. shape[1] = mat->Width();
  195. // image
  196. (*outputs)[1].name = InputInfoOfRuntime(1).name;
  197. mat->ShareWithTensor(&((*outputs)[1]));
  198. (*outputs)[1].ExpandDim(0);
  199. // scale
  200. (*outputs)[2].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(2).name);
  201. float *scale = static_cast<float *>((*outputs)[2].MutableData());
  202. scale[0] = mat->Height() * 1.0 / origin_h;
  203. scale[1] = mat->Width() * 1.0 / origin_w;
  204. return true;
  205. }
  206. void FilterDets(const float conf_thresh, const cv::Mat &dets,
  207. std::vector<int> *index) {
  208. for (int i = 0; i < dets.rows; ++i) {
  209. float score = *dets.ptr<float>(i, 4);
  210. if (score > conf_thresh) {
  211. index->push_back(i);
  212. }
  213. }
  214. }
  215. bool PPTracking::Postprocess(std::vector<FDTensor> &infer_result,
  216. MOTResult *result) {
  217. auto bbox_shape = infer_result[0].shape;
  218. auto bbox_data = static_cast<float *>(infer_result[0].Data());
  219. auto emb_shape = infer_result[1].shape;
  220. auto emb_data = static_cast<float *>(infer_result[1].Data());
  221. cv::Mat dets(bbox_shape[0], 6, CV_32FC1, bbox_data);
  222. cv::Mat emb(bbox_shape[0], emb_shape[1], CV_32FC1, emb_data);
  223. result->Clear();
  224. std::vector<Track> tracks;
  225. std::vector<int> valid;
  226. FilterDets(conf_thresh_, dets, &valid);
  227. cv::Mat new_dets, new_emb;
  228. for (int i = 0; i < valid.size(); ++i) {
  229. new_dets.push_back(dets.row(valid[i]));
  230. new_emb.push_back(emb.row(valid[i]));
  231. }
  232. jdeTracker_->update(new_dets, new_emb, &tracks);
  233. if (tracks.size() == 0) {
  234. std::array<int, 4> box = {
  235. int(*dets.ptr<float>(0, 0)), int(*dets.ptr<float>(0, 1)),
  236. int(*dets.ptr<float>(0, 2)), int(*dets.ptr<float>(0, 3))};
  237. result->boxes.push_back(box);
  238. result->ids.push_back(1);
  239. result->scores.push_back(*dets.ptr<float>(0, 4));
  240. } else {
  241. std::vector<Track>::iterator titer;
  242. for (titer = tracks.begin(); titer != tracks.end(); ++titer) {
  243. if (titer->score < tracked_thresh_) {
  244. continue;
  245. } else {
  246. float w = titer->ltrb[2] - titer->ltrb[0];
  247. float h = titer->ltrb[3] - titer->ltrb[1];
  248. bool vertical = w / h > 1.6;
  249. float area = w * h;
  250. if (area > min_box_area_ && !vertical) {
  251. std::array<int, 4> box = {int(titer->ltrb[0]), int(titer->ltrb[1]),
  252. int(titer->ltrb[2]), int(titer->ltrb[3])};
  253. result->boxes.push_back(box);
  254. result->ids.push_back(titer->id);
  255. result->scores.push_back(titer->score);
  256. }
  257. }
  258. }
  259. }
  260. if (!is_record_trail_)
  261. return true;
  262. int nums = result->boxes.size();
  263. for (int i = 0; i < nums; i++) {
  264. float center_x = (result->boxes[i][0] + result->boxes[i][2]) / 2;
  265. float center_y = (result->boxes[i][1] + result->boxes[i][3]) / 2;
  266. int id = result->ids[i];
  267. recorder_->Add(id, {int(center_x), int(center_y)});
  268. }
  269. return true;
  270. }
  271. void PPTracking::BindRecorder(TrailRecorder *recorder) {
  272. recorder_ = recorder;
  273. is_record_trail_ = true;
  274. }
  275. void PPTracking::UnbindRecorder() {
  276. is_record_trail_ = false;
  277. std::map<int, std::vector<std::array<int, 2>>>::iterator iter;
  278. for (iter = recorder_->records.begin(); iter != recorder_->records.end();
  279. iter++) {
  280. iter->second.clear();
  281. iter->second.shrink_to_fit();
  282. }
  283. recorder_->records.clear();
  284. std::map<int, std::vector<std::array<int, 2>>>().swap(recorder_->records);
  285. recorder_ = nullptr;
  286. }
  287. } // namespace tracking
  288. } // namespace vision
  289. } // namespace ultra_infer