trt_backend.cc 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/runtime/backends/tensorrt/trt_backend.h"
  15. #include <cstring>
  16. #include <unordered_map>
  17. #include "NvInferRuntime.h"
  18. #include "ultra_infer/function/cuda_cast.h"
  19. #include "ultra_infer/utils/utils.h"
  20. #ifdef ENABLE_PADDLE2ONNX
  21. #include "paddle2onnx/converter.h"
  22. #endif
  23. namespace ultra_infer {
  24. FDTrtLogger *FDTrtLogger::logger = nullptr;
  25. // Check if the model can build tensorrt engine now
  26. // If the model has dynamic input shape, it will require defined shape
  27. // information We can set the shape range information by function
  28. // SetTrtInputShape() But if the shape range is not defined, then the engine
  29. // cannot build, in this case, The engine will build once there's data feeded,
  30. // and the shape range will be updated
  31. bool CanBuildEngine(
  32. const std::map<std::string, ShapeRangeInfo> &shape_range_info) {
  33. for (auto iter = shape_range_info.begin(); iter != shape_range_info.end();
  34. ++iter) {
  35. bool is_full_static = true;
  36. for (size_t i = 0; i < iter->second.shape.size(); ++i) {
  37. if (iter->second.shape[i] < 0) {
  38. is_full_static = false;
  39. break;
  40. }
  41. }
  42. if (is_full_static) {
  43. continue;
  44. }
  45. for (size_t i = 0; i < iter->second.shape.size(); ++i) {
  46. if (iter->second.min[i] < 0 || iter->second.max[i] < 0) {
  47. return false;
  48. }
  49. }
  50. }
  51. return true;
  52. }
  53. bool TrtBackend::LoadTrtCache(const std::string &trt_engine_file) {
  54. cudaSetDevice(option_.gpu_id);
  55. std::string engine_buffer;
  56. if (!ReadBinaryFromFile(trt_engine_file, &engine_buffer)) {
  57. FDERROR << "Failed to load TensorRT Engine from " << trt_engine_file << "."
  58. << std::endl;
  59. return false;
  60. }
  61. FDUniquePtr<nvinfer1::IRuntime> runtime{
  62. nvinfer1::createInferRuntime(*FDTrtLogger::Get())};
  63. if (!runtime) {
  64. FDERROR << "Failed to call createInferRuntime()." << std::endl;
  65. return false;
  66. }
  67. engine_ = std::shared_ptr<nvinfer1::ICudaEngine>(
  68. runtime->deserializeCudaEngine(engine_buffer.data(),
  69. engine_buffer.size()),
  70. FDInferDeleter());
  71. if (!engine_) {
  72. FDERROR << "Failed to call deserializeCudaEngine()." << std::endl;
  73. return false;
  74. }
  75. context_ = std::shared_ptr<nvinfer1::IExecutionContext>(
  76. engine_->createExecutionContext());
  77. GetInputOutputInfo();
  78. for (int32_t i = 0; i < engine_->getNbBindings(); ++i) {
  79. if (!engine_->bindingIsInput(i)) {
  80. continue;
  81. }
  82. auto min = ToVec(engine_->getProfileDimensions(
  83. i, 0, nvinfer1::OptProfileSelector::kMAX));
  84. auto max = ToVec(engine_->getProfileDimensions(
  85. i, 0, nvinfer1::OptProfileSelector::kMIN));
  86. auto name = std::string(engine_->getBindingName(i));
  87. auto iter = shape_range_info_.find(name);
  88. if (iter == shape_range_info_.end()) {
  89. FDERROR << "There's no input named '" << name << "' in loaded model."
  90. << std::endl;
  91. return false;
  92. }
  93. iter->second.Update(min);
  94. iter->second.Update(max);
  95. }
  96. FDINFO << "Build TensorRT Engine from cache file: " << trt_engine_file
  97. << " with shape range information as below," << std::endl;
  98. for (const auto &item : shape_range_info_) {
  99. FDINFO << item.second << std::endl;
  100. }
  101. return true;
  102. }
  103. bool TrtBackend::Init(const RuntimeOption &runtime_option) {
  104. auto trt_option = runtime_option.trt_option;
  105. trt_option.model_file = runtime_option.model_file;
  106. trt_option.params_file = runtime_option.params_file;
  107. trt_option.model_format = runtime_option.model_format;
  108. trt_option.gpu_id = runtime_option.device_id;
  109. trt_option.enable_pinned_memory = runtime_option.enable_pinned_memory;
  110. trt_option.external_stream_ = runtime_option.external_stream_;
  111. if (runtime_option.device != Device::GPU) {
  112. FDERROR << "TrtBackend only supports Device::GPU, but now it's "
  113. << runtime_option.device << "." << std::endl;
  114. return false;
  115. }
  116. if (runtime_option.model_format != ModelFormat::PADDLE &&
  117. runtime_option.model_format != ModelFormat::ONNX) {
  118. FDERROR
  119. << "TrtBackend only supports model format PADDLE/ONNX, but now it's "
  120. << runtime_option.model_format << "." << std::endl;
  121. return false;
  122. }
  123. if (runtime_option.model_format == ModelFormat::PADDLE) {
  124. if (runtime_option.model_from_memory_) {
  125. return InitFromPaddle(runtime_option.model_file,
  126. runtime_option.params_file, trt_option);
  127. } else {
  128. std::string model_buffer;
  129. std::string params_buffer;
  130. FDASSERT(ReadBinaryFromFile(runtime_option.model_file, &model_buffer),
  131. "Failed to read model file %s.",
  132. runtime_option.model_file.c_str());
  133. FDASSERT(ReadBinaryFromFile(runtime_option.params_file, &params_buffer),
  134. "Failed to read parameters file %s.",
  135. runtime_option.params_file.c_str());
  136. return InitFromPaddle(model_buffer, params_buffer, trt_option);
  137. }
  138. } else {
  139. if (runtime_option.model_from_memory_) {
  140. return InitFromOnnx(runtime_option.model_file, trt_option);
  141. } else {
  142. std::string model_buffer;
  143. FDASSERT(ReadBinaryFromFile(runtime_option.model_file, &model_buffer),
  144. "Failed to read model file %s.",
  145. runtime_option.model_file.c_str());
  146. return InitFromOnnx(model_buffer, trt_option);
  147. }
  148. }
  149. return true;
  150. }
  151. bool TrtBackend::InitFromPaddle(const std::string &model_buffer,
  152. const std::string &params_buffer,
  153. const TrtBackendOption &option, bool verbose) {
  154. if (initialized_) {
  155. FDERROR << "TrtBackend is already initlized, cannot initialize again."
  156. << std::endl;
  157. return false;
  158. }
  159. option_ = option;
  160. #ifdef ENABLE_PADDLE2ONNX
  161. std::vector<paddle2onnx::CustomOp> ops;
  162. ops.resize(1);
  163. strcpy(ops[0].op_name, "pool2d");
  164. strcpy(ops[0].export_op_name, "AdaptivePool2d");
  165. char *model_content_ptr;
  166. int model_content_size = 0;
  167. char *calibration_cache_ptr;
  168. int calibration_cache_size = 0;
  169. if (!paddle2onnx::Export(model_buffer.c_str(), model_buffer.size(),
  170. params_buffer.c_str(), params_buffer.size(),
  171. &model_content_ptr, &model_content_size, 11, true,
  172. verbose, true, true, true, ops.data(), 1, "tensorrt",
  173. &calibration_cache_ptr, &calibration_cache_size, "",
  174. &save_external_)) {
  175. FDERROR << "Error occurred while export PaddlePaddle to ONNX format."
  176. << std::endl;
  177. return false;
  178. }
  179. std::string onnx_model_proto(model_content_ptr,
  180. model_content_ptr + model_content_size);
  181. delete[] model_content_ptr;
  182. model_content_ptr = nullptr;
  183. if (calibration_cache_size) {
  184. std::string calibration_str(calibration_cache_ptr,
  185. calibration_cache_ptr + calibration_cache_size);
  186. calibration_str_ = calibration_str;
  187. delete[] calibration_cache_ptr;
  188. }
  189. if (save_external_) {
  190. model_file_name_ = "model.onnx";
  191. std::fstream f(model_file_name_, std::ios::out);
  192. FDASSERT(f.is_open(), "Can not open file: %s to save model.",
  193. model_file_name_.c_str());
  194. f << onnx_model_proto;
  195. f.close();
  196. }
  197. return InitFromOnnx(onnx_model_proto, option);
  198. #else
  199. FDERROR << "Didn't compile with PaddlePaddle frontend, you can try to "
  200. "call `InitFromOnnx` instead."
  201. << std::endl;
  202. return false;
  203. #endif
  204. }
  205. bool TrtBackend::InitFromOnnx(const std::string &model_buffer,
  206. const TrtBackendOption &option) {
  207. if (initialized_) {
  208. FDERROR << "TrtBackend is already initlized, cannot initialize again."
  209. << std::endl;
  210. return false;
  211. }
  212. option_ = option;
  213. cudaSetDevice(option_.gpu_id);
  214. std::string onnx_content = model_buffer;
  215. // This part of code will record the original outputs order
  216. // because the converted tensorrt network may exist wrong order of outputs
  217. outputs_order_.clear();
  218. auto onnx_reader =
  219. paddle2onnx::OnnxReader(onnx_content.c_str(), onnx_content.size());
  220. for (int i = 0; i < onnx_reader.num_outputs; ++i) {
  221. std::string name(onnx_reader.outputs[i].name);
  222. outputs_order_[name] = i;
  223. }
  224. shape_range_info_.clear();
  225. inputs_desc_.clear();
  226. outputs_desc_.clear();
  227. inputs_desc_.resize(onnx_reader.num_inputs);
  228. outputs_desc_.resize(onnx_reader.num_outputs);
  229. for (int i = 0; i < onnx_reader.num_inputs; ++i) {
  230. std::string name(onnx_reader.inputs[i].name);
  231. std::vector<int64_t> shape(onnx_reader.inputs[i].shape,
  232. onnx_reader.inputs[i].shape +
  233. onnx_reader.inputs[i].rank);
  234. inputs_desc_[i].name = name;
  235. inputs_desc_[i].shape.assign(shape.begin(), shape.end());
  236. inputs_desc_[i].dtype = ReaderDtypeToTrtDtype(onnx_reader.inputs[i].dtype);
  237. inputs_desc_[i].original_dtype =
  238. ReaderDtypeToFDDtype(onnx_reader.inputs[i].dtype);
  239. auto info = ShapeRangeInfo(shape);
  240. info.name = name;
  241. auto iter_min = option.min_shape.find(name);
  242. auto iter_max = option.max_shape.find(name);
  243. auto iter_opt = option.opt_shape.find(name);
  244. if (iter_min != option.min_shape.end()) {
  245. info.min.assign(iter_min->second.begin(), iter_min->second.end());
  246. info.max.assign(iter_max->second.begin(), iter_max->second.end());
  247. info.opt.assign(iter_opt->second.begin(), iter_opt->second.end());
  248. }
  249. shape_range_info_.insert(std::make_pair(name, info));
  250. }
  251. for (int i = 0; i < onnx_reader.num_outputs; ++i) {
  252. std::string name(onnx_reader.outputs[i].name);
  253. std::vector<int64_t> shape(onnx_reader.outputs[i].shape,
  254. onnx_reader.outputs[i].shape +
  255. onnx_reader.outputs[i].rank);
  256. outputs_desc_[i].name = name;
  257. outputs_desc_[i].shape.assign(shape.begin(), shape.end());
  258. outputs_desc_[i].dtype =
  259. ReaderDtypeToTrtDtype(onnx_reader.outputs[i].dtype);
  260. outputs_desc_[i].original_dtype =
  261. ReaderDtypeToFDDtype(onnx_reader.outputs[i].dtype);
  262. }
  263. if (option_.external_stream_) {
  264. stream_ = reinterpret_cast<cudaStream_t>(option_.external_stream_);
  265. } else {
  266. FDASSERT(cudaStreamCreate(&stream_) == 0,
  267. "[ERROR] Error occurs while calling cudaStreamCreate().");
  268. }
  269. if (save_external_) {
  270. onnx_content.clear();
  271. onnx_content = model_file_name_;
  272. }
  273. if (!CreateTrtEngineFromOnnx(onnx_content)) {
  274. FDERROR << "Failed to create tensorrt engine." << std::endl;
  275. return false;
  276. }
  277. initialized_ = true;
  278. return true;
  279. }
  280. int TrtBackend::ShapeRangeInfoUpdated(const std::vector<FDTensor> &inputs) {
  281. bool need_update_engine = false;
  282. for (size_t i = 0; i < inputs.size(); ++i) {
  283. auto iter = shape_range_info_.find(inputs[i].name);
  284. if (iter == shape_range_info_.end()) {
  285. FDERROR << "There's no input named '" << inputs[i].name
  286. << "' in loaded model." << std::endl;
  287. }
  288. if (iter->second.Update(inputs[i].shape) == 1) {
  289. need_update_engine = true;
  290. }
  291. }
  292. return need_update_engine;
  293. }
  294. bool TrtBackend::Infer(std::vector<FDTensor> &inputs,
  295. std::vector<FDTensor> *outputs, bool copy_to_fd) {
  296. if (inputs.size() != NumInputs()) {
  297. FDERROR << "Require " << NumInputs() << "inputs, but get " << inputs.size()
  298. << "." << std::endl;
  299. return false;
  300. }
  301. if (ShapeRangeInfoUpdated(inputs)) {
  302. // meet new shape output of predefined max/min shape
  303. // rebuild the tensorrt engine
  304. FDWARNING
  305. << "TensorRT engine will be rebuilt once shape range information "
  306. "changed, this may take lots of time, you can set a proper shape "
  307. "range before loading model to avoid rebuilding process. refer "
  308. "https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/en/"
  309. "faq/"
  310. "tensorrt_tricks.md for more details."
  311. << std::endl;
  312. BuildTrtEngine();
  313. }
  314. RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
  315. cudaSetDevice(option_.gpu_id);
  316. SetInputs(inputs);
  317. AllocateOutputsBuffer(outputs, copy_to_fd);
  318. RUNTIME_PROFILE_LOOP_BEGIN(1)
  319. if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) {
  320. FDERROR << "Failed to Infer with TensorRT." << std::endl;
  321. return false;
  322. }
  323. RUNTIME_PROFILE_LOOP_END
  324. for (size_t i = 0; i < outputs->size(); ++i) {
  325. // if the final output tensor's dtype is different from the model output
  326. // tensor's dtype, then we need cast the data to the final output's dtype
  327. auto model_output_dtype =
  328. GetFDDataType(outputs_device_buffer_[(*outputs)[i].name].dtype());
  329. if ((*outputs)[i].dtype != model_output_dtype) {
  330. FDTensor output_tensor;
  331. output_tensor.SetExternalData(
  332. (*outputs)[i].shape, model_output_dtype,
  333. outputs_device_buffer_[(*outputs)[i].name].data(), Device::GPU);
  334. casted_output_tensors_[(*outputs)[i].name].Resize(
  335. (*outputs)[i].shape, (*outputs)[i].dtype, (*outputs)[i].name,
  336. Device::GPU);
  337. function::CudaCast(output_tensor,
  338. &casted_output_tensors_[(*outputs)[i].name], stream_);
  339. if (!copy_to_fd) {
  340. (*outputs)[i].SetExternalData(
  341. (*outputs)[i].shape, model_output_dtype,
  342. casted_output_tensors_[(*outputs)[i].name].MutableData(),
  343. Device::GPU, option_.gpu_id);
  344. }
  345. } else {
  346. casted_output_tensors_[(*outputs)[i].name].SetExternalData(
  347. (*outputs)[i].shape, model_output_dtype,
  348. outputs_device_buffer_[(*outputs)[i].name].data(), Device::GPU);
  349. }
  350. }
  351. if (copy_to_fd) {
  352. for (size_t i = 0; i < outputs->size(); ++i) {
  353. FDASSERT(
  354. cudaMemcpyAsync((*outputs)[i].Data(),
  355. casted_output_tensors_[(*outputs)[i].name].Data(),
  356. (*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost,
  357. stream_) == 0,
  358. "[ERROR] Error occurs while copy memory from GPU to CPU.");
  359. }
  360. FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess,
  361. "[ERROR] Error occurs while sync cuda stream.");
  362. }
  363. RUNTIME_PROFILE_LOOP_H2D_D2H_END
  364. return true;
  365. }
  366. void TrtBackend::GetInputOutputInfo() {
  367. // Read the original dtypes from inputs_desc_ and outputs_desc_
  368. std::unordered_map<std::string, FDDataType> inputs_original_dtype_map;
  369. std::unordered_map<std::string, FDDataType> outputs_original_dtype_map;
  370. for (size_t i = 0; i < inputs_desc_.size(); ++i) {
  371. inputs_original_dtype_map[inputs_desc_[i].name] =
  372. inputs_desc_[i].original_dtype;
  373. }
  374. for (size_t i = 0; i < outputs_desc_.size(); ++i) {
  375. outputs_original_dtype_map[outputs_desc_[i].name] =
  376. outputs_desc_[i].original_dtype;
  377. }
  378. // Re-read the tensor infos from TRT model and write into inputs_desc_ and
  379. // outputs_desc_
  380. std::vector<TrtValueInfo>().swap(inputs_desc_);
  381. std::vector<TrtValueInfo>().swap(outputs_desc_);
  382. inputs_desc_.clear();
  383. outputs_desc_.clear();
  384. auto num_binds = engine_->getNbBindings();
  385. for (auto i = 0; i < num_binds; ++i) {
  386. std::string name = std::string(engine_->getBindingName(i));
  387. auto shape = ToVec(engine_->getBindingDimensions(i));
  388. auto dtype = engine_->getBindingDataType(i);
  389. if (engine_->bindingIsInput(i)) {
  390. auto original_dtype = inputs_original_dtype_map.count(name)
  391. ? inputs_original_dtype_map[name]
  392. : GetFDDataType(dtype);
  393. inputs_desc_.emplace_back(
  394. TrtValueInfo{name, shape, dtype, original_dtype});
  395. inputs_device_buffer_[name] = FDDeviceBuffer(dtype);
  396. } else {
  397. auto original_dtype = outputs_original_dtype_map.count(name)
  398. ? outputs_original_dtype_map[name]
  399. : GetFDDataType(dtype);
  400. outputs_desc_.emplace_back(
  401. TrtValueInfo{name, shape, dtype, original_dtype});
  402. outputs_device_buffer_[name] = FDDeviceBuffer(dtype);
  403. casted_output_tensors_[name] = FDTensor();
  404. }
  405. io_name_index_[name] = i;
  406. }
  407. bindings_.resize(num_binds);
  408. }
  409. void TrtBackend::SetInputs(const std::vector<FDTensor> &inputs) {
  410. for (const auto &item : inputs) {
  411. // auto idx = engine_->getBindingIndex(item.name.c_str());
  412. auto iter = io_name_index_.find(item.name);
  413. FDASSERT(iter != io_name_index_.end(),
  414. "TRTBackend SetInputs not find name:%s", item.name.c_str());
  415. auto idx = iter->second;
  416. std::vector<int> shape(item.shape.begin(), item.shape.end());
  417. auto dims = ToDims(shape);
  418. context_->setBindingDimensions(idx, dims);
  419. if (item.device == Device::GPU) {
  420. if (item.dtype == FDDataType::INT64) {
  421. inputs_device_buffer_[item.name].resize(dims);
  422. FDTensor input_tensor;
  423. input_tensor.SetExternalData(item.shape, FDDataType::INT32,
  424. inputs_device_buffer_[item.name].data(),
  425. Device::GPU);
  426. function::CudaCast(item, &input_tensor, stream_);
  427. } else {
  428. // no copy
  429. inputs_device_buffer_[item.name].SetExternalData(dims, item.Data());
  430. }
  431. } else {
  432. // Allocate input buffer memory
  433. inputs_device_buffer_[item.name].resize(dims);
  434. // copy from cpu to gpu
  435. if (item.dtype == FDDataType::INT64) {
  436. int64_t *data = static_cast<int64_t *>(const_cast<void *>(item.Data()));
  437. std::vector<int32_t> casted_data(data, data + item.Numel());
  438. // FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
  439. // static_cast<void*>(casted_data.data()),
  440. // item.Nbytes() / 2, cudaMemcpyHostToDevice,
  441. // stream_) == 0,
  442. // "Error occurs while copy memory from CPU to GPU.");
  443. // WARN: For cudaMemcpyHostToDevice direction, cudaMemcpyAsync need
  444. // page-locked host memory to avoid any overlap to occur. The
  445. // page-locked feature need by cudaMemcpyAsync may not guarantee by
  446. // FDTensor now. Reference:
  447. // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creation-and-destruction
  448. FDASSERT(cudaMemcpy(inputs_device_buffer_[item.name].data(),
  449. static_cast<void *>(casted_data.data()),
  450. item.Nbytes() / 2, cudaMemcpyHostToDevice) == 0,
  451. "Error occurs while copy memory from CPU to GPU.");
  452. } else {
  453. // FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
  454. // item.Data(), item.Nbytes(),
  455. // cudaMemcpyHostToDevice, stream_) == 0,
  456. // "Error occurs while copy memory from CPU to GPU.");
  457. // WARN: For cudaMemcpyHostToDevice direction, cudaMemcpyAsync need
  458. // page-locked host memory to avoid any overlap to occur. The
  459. // page-locked feature need by cudaMemcpyAsync may not guarantee by
  460. // FDTensor now. Reference:
  461. // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creation-and-destruction
  462. FDASSERT(cudaMemcpy(inputs_device_buffer_[item.name].data(),
  463. item.Data(), item.Nbytes(),
  464. cudaMemcpyHostToDevice) == 0,
  465. "Error occurs while copy memory from CPU to GPU.");
  466. }
  467. }
  468. // binding input buffer
  469. bindings_[idx] = inputs_device_buffer_[item.name].data();
  470. }
  471. }
  472. void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor> *outputs,
  473. bool copy_to_fd) {
  474. if (outputs->size() != outputs_desc_.size()) {
  475. outputs->resize(outputs_desc_.size());
  476. }
  477. for (size_t i = 0; i < outputs_desc_.size(); ++i) {
  478. // auto idx = engine_->getBindingIndex(outputs_desc_[i].name.c_str());
  479. auto idx_iter = io_name_index_.find(outputs_desc_[i].name);
  480. FDASSERT(idx_iter != io_name_index_.end(),
  481. "TRTBackend Outputs not find name:%s",
  482. outputs_desc_[i].name.c_str());
  483. auto idx = idx_iter->second;
  484. auto output_dims = context_->getBindingDimensions(idx);
  485. // find the original index of output
  486. auto iter = outputs_order_.find(outputs_desc_[i].name);
  487. FDASSERT(
  488. iter != outputs_order_.end(),
  489. "Cannot find output: %s of tensorrt network from the original model.",
  490. outputs_desc_[i].name.c_str());
  491. auto ori_idx = iter->second;
  492. // Allocate output buffer memory
  493. outputs_device_buffer_[outputs_desc_[i].name].resize(output_dims);
  494. // binding output buffer
  495. bindings_[idx] = outputs_device_buffer_[outputs_desc_[i].name].data();
  496. // set user's outputs info
  497. std::vector<int64_t> shape(output_dims.d,
  498. output_dims.d + output_dims.nbDims);
  499. if (copy_to_fd) {
  500. (*outputs)[ori_idx].is_pinned_memory = option_.enable_pinned_memory;
  501. (*outputs)[ori_idx].Resize(shape, outputs_desc_[i].original_dtype,
  502. outputs_desc_[i].name);
  503. } else {
  504. (*outputs)[ori_idx].name = outputs_desc_[i].name;
  505. (*outputs)[ori_idx].SetExternalData(
  506. shape, outputs_desc_[i].original_dtype, bindings_[idx], Device::GPU,
  507. option_.gpu_id);
  508. }
  509. }
  510. }
  511. bool TrtBackend::BuildTrtEngine() {
  512. if (option_.enable_log_info) {
  513. FDTrtLogger::Get()->SetLog(true, true);
  514. }
  515. auto config =
  516. FDUniquePtr<nvinfer1::IBuilderConfig>(builder_->createBuilderConfig());
  517. if (!config) {
  518. FDERROR << "Failed to call createBuilderConfig()." << std::endl;
  519. return false;
  520. }
  521. if (option_.enable_fp16) {
  522. if (!builder_->platformHasFastFp16()) {
  523. FDWARNING << "Detected FP16 is not supported in the current GPU, "
  524. "will use FP32 instead."
  525. << std::endl;
  526. } else {
  527. FDINFO << "[TrtBackend] Use FP16 to inference." << std::endl;
  528. config->setFlag(nvinfer1::BuilderFlag::kFP16);
  529. }
  530. }
  531. FDINFO << "Start to building TensorRT Engine..." << std::endl;
  532. if (context_) {
  533. context_.reset();
  534. engine_.reset();
  535. }
  536. if (option_.max_batch_size >= 1) {
  537. builder_->setMaxBatchSize(option_.max_batch_size);
  538. }
  539. config->setMaxWorkspaceSize(option_.max_workspace_size);
  540. auto profile = builder_->createOptimizationProfile();
  541. for (const auto &item : shape_range_info_) {
  542. FDASSERT(
  543. profile->setDimensions(item.first.c_str(),
  544. nvinfer1::OptProfileSelector::kMIN,
  545. ToDims(item.second.min)),
  546. "[TrtBackend] Failed to set min_shape for input: %s in TrtBackend.",
  547. item.first.c_str());
  548. FDASSERT(
  549. profile->setDimensions(item.first.c_str(),
  550. nvinfer1::OptProfileSelector::kMAX,
  551. ToDims(item.second.max)),
  552. "[TrtBackend] Failed to set max_shape for input: %s in TrtBackend.",
  553. item.first.c_str());
  554. if (item.second.opt.size() == 0) {
  555. FDASSERT(
  556. profile->setDimensions(item.first.c_str(),
  557. nvinfer1::OptProfileSelector::kOPT,
  558. ToDims(item.second.max)),
  559. "[TrtBackend] Failed to set opt_shape for input: %s in TrtBackend.",
  560. item.first.c_str());
  561. } else {
  562. FDASSERT(
  563. item.second.opt.size() == item.second.shape.size(),
  564. "Require the dimension of opt in shape range information equal to "
  565. "dimension of input: %s in this model, but now it's %zu != %zu.",
  566. item.first.c_str(), item.second.opt.size(), item.second.shape.size());
  567. FDASSERT(
  568. profile->setDimensions(item.first.c_str(),
  569. nvinfer1::OptProfileSelector::kOPT,
  570. ToDims(item.second.opt)),
  571. "[TrtBackend] Failed to set opt_shape for input: %s in TrtBackend.",
  572. item.first.c_str());
  573. }
  574. }
  575. config->addOptimizationProfile(profile);
  576. if (calibration_str_.size()) {
  577. if (!builder_->platformHasFastInt8()) {
  578. FDWARNING << "Detected INT8 is not supported in the current GPU, "
  579. "will use FP32 instead."
  580. << std::endl;
  581. } else {
  582. FDINFO << "[TrtBackend] Use INT8 to inference." << std::endl;
  583. config->setFlag(nvinfer1::BuilderFlag::kINT8);
  584. Int8EntropyCalibrator2 *calibrator =
  585. new Int8EntropyCalibrator2(calibration_str_);
  586. config->setInt8Calibrator(calibrator);
  587. }
  588. }
  589. FDUniquePtr<nvinfer1::IHostMemory> plan{
  590. builder_->buildSerializedNetwork(*network_, *config)};
  591. if (!plan) {
  592. FDERROR << "Failed to call buildSerializedNetwork()." << std::endl;
  593. return false;
  594. }
  595. FDUniquePtr<nvinfer1::IRuntime> runtime{
  596. nvinfer1::createInferRuntime(*FDTrtLogger::Get())};
  597. if (!runtime) {
  598. FDERROR << "Failed to call createInferRuntime()." << std::endl;
  599. return false;
  600. }
  601. engine_ = std::shared_ptr<nvinfer1::ICudaEngine>(
  602. runtime->deserializeCudaEngine(plan->data(), plan->size()),
  603. FDInferDeleter());
  604. if (!engine_) {
  605. FDERROR << "Failed to call deserializeCudaEngine()." << std::endl;
  606. return false;
  607. }
  608. context_ = std::shared_ptr<nvinfer1::IExecutionContext>(
  609. engine_->createExecutionContext());
  610. GetInputOutputInfo();
  611. FDINFO << "TensorRT Engine is built successfully." << std::endl;
  612. if (option_.serialize_file != "") {
  613. FDINFO << "Serialize TensorRTEngine to local file "
  614. << option_.serialize_file << "." << std::endl;
  615. std::ofstream engine_file(option_.serialize_file.c_str(),
  616. std::ios::binary | std::ios::out);
  617. if (!engine_file) {
  618. FDERROR << "Failed to open " << option_.serialize_file << " to write."
  619. << std::endl;
  620. return false;
  621. }
  622. engine_file.write(static_cast<char *>(plan->data()), plan->size());
  623. engine_file.close();
  624. FDINFO << "TensorRTEngine is serialized to local file "
  625. << option_.serialize_file
  626. << ", we can load this model from the serialized engine "
  627. "directly next time."
  628. << std::endl;
  629. }
  630. return true;
  631. }
  632. bool TrtBackend::CreateTrtEngineFromOnnx(const std::string &onnx_model_buffer) {
  633. const auto explicitBatch =
  634. 1U << static_cast<uint32_t>(
  635. nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
  636. builder_ = FDUniquePtr<nvinfer1::IBuilder>(
  637. nvinfer1::createInferBuilder(*FDTrtLogger::Get()));
  638. if (!builder_) {
  639. FDERROR << "Failed to call createInferBuilder()." << std::endl;
  640. return false;
  641. }
  642. network_ = FDUniquePtr<nvinfer1::INetworkDefinition>(
  643. builder_->createNetworkV2(explicitBatch));
  644. if (!network_) {
  645. FDERROR << "Failed to call createNetworkV2()." << std::endl;
  646. return false;
  647. }
  648. parser_ = FDUniquePtr<nvonnxparser::IParser>(
  649. nvonnxparser::createParser(*network_, *FDTrtLogger::Get()));
  650. if (!parser_) {
  651. FDERROR << "Failed to call createParser()." << std::endl;
  652. return false;
  653. }
  654. bool model_parser;
  655. if (save_external_) {
  656. model_parser = !parser_->parseFromFile(onnx_model_buffer.c_str(), 0);
  657. } else {
  658. model_parser =
  659. !parser_->parse(onnx_model_buffer.data(), onnx_model_buffer.size());
  660. }
  661. if (model_parser) {
  662. FDERROR << "Failed to parse ONNX model by TensorRT." << std::endl;
  663. return false;
  664. }
  665. if (option_.serialize_file != "") {
  666. std::ifstream fin(option_.serialize_file, std::ios::binary | std::ios::in);
  667. if (fin) {
  668. FDINFO << "Detect serialized TensorRT Engine file in "
  669. << option_.serialize_file << ", will load it directly."
  670. << std::endl;
  671. fin.close();
  672. // clear memory buffer of the temporary member
  673. std::string().swap(onnx_model_buffer_);
  674. return LoadTrtCache(option_.serialize_file);
  675. }
  676. }
  677. if (!CanBuildEngine(shape_range_info_)) {
  678. onnx_model_buffer_ = onnx_model_buffer;
  679. FDWARNING << "Cannot build engine right now, because there's dynamic input "
  680. "shape exists, list as below,"
  681. << std::endl;
  682. for (int i = 0; i < NumInputs(); ++i) {
  683. FDWARNING << "Input " << i << ": " << GetInputInfo(i) << std::endl;
  684. }
  685. FDWARNING
  686. << "UltraInfer will build the engine while inference with input data, "
  687. "and will also collect the input shape range information. You "
  688. "should be noticed that UltraInfer will rebuild the engine while "
  689. "new input shape is out of the collected shape range, this may "
  690. "bring some time consuming problem, refer "
  691. "https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/en/"
  692. "faq/"
  693. "tensorrt_tricks.md for more details."
  694. << std::endl;
  695. initialized_ = true;
  696. return true;
  697. }
  698. if (!BuildTrtEngine()) {
  699. FDERROR << "Failed to build tensorrt engine." << std::endl;
  700. }
  701. // clear memory buffer of the temporary member
  702. std::string().swap(onnx_model_buffer_);
  703. return true;
  704. }
  705. TensorInfo TrtBackend::GetInputInfo(int index) {
  706. FDASSERT(index < NumInputs(),
  707. "The index: %d should less than the number of inputs: %d.", index,
  708. NumInputs());
  709. TensorInfo info;
  710. info.name = inputs_desc_[index].name;
  711. info.shape.assign(inputs_desc_[index].shape.begin(),
  712. inputs_desc_[index].shape.end());
  713. info.dtype = inputs_desc_[index].original_dtype;
  714. return info;
  715. }
  716. std::vector<TensorInfo> TrtBackend::GetInputInfos() {
  717. std::vector<TensorInfo> infos;
  718. for (auto i = 0; i < inputs_desc_.size(); i++) {
  719. infos.emplace_back(GetInputInfo(i));
  720. }
  721. return infos;
  722. }
  723. TensorInfo TrtBackend::GetOutputInfo(int index) {
  724. FDASSERT(index < NumOutputs(),
  725. "The index: %d should less than the number of outputs: %d.", index,
  726. NumOutputs());
  727. TensorInfo info;
  728. info.name = outputs_desc_[index].name;
  729. info.shape.assign(outputs_desc_[index].shape.begin(),
  730. outputs_desc_[index].shape.end());
  731. info.dtype = outputs_desc_[index].original_dtype;
  732. return info;
  733. }
  734. std::vector<TensorInfo> TrtBackend::GetOutputInfos() {
  735. std::vector<TensorInfo> infos;
  736. for (auto i = 0; i < outputs_desc_.size(); i++) {
  737. infos.emplace_back(GetOutputInfo(i));
  738. }
  739. return infos;
  740. }
  741. std::unique_ptr<BaseBackend> TrtBackend::Clone(RuntimeOption &runtime_option,
  742. void *stream, int device_id) {
  743. std::unique_ptr<BaseBackend> new_backend = utils::make_unique<TrtBackend>();
  744. auto casted_backend = dynamic_cast<TrtBackend *>(new_backend.get());
  745. if (device_id > 0 && device_id != option_.gpu_id) {
  746. auto clone_option = option_;
  747. clone_option.gpu_id = device_id;
  748. clone_option.external_stream_ = stream;
  749. if (runtime_option.model_from_memory_) {
  750. FDASSERT(casted_backend->InitFromPaddle(runtime_option.model_file,
  751. runtime_option.params_file,
  752. clone_option),
  753. "Clone model from Paddle failed while initialize TrtBackend.");
  754. } else {
  755. if (option_.model_format == ModelFormat::ONNX) {
  756. std::string model_buffer = "";
  757. FDASSERT(
  758. ReadBinaryFromFile(clone_option.model_file, &model_buffer),
  759. "Fail to read binary from model file while cloning TrtBackend");
  760. FDASSERT(casted_backend->InitFromOnnx(model_buffer, clone_option),
  761. "Clone model from ONNX failed while initialize TrtBackend.");
  762. } else {
  763. std::string model_buffer = "";
  764. std::string params_buffer = "";
  765. FDASSERT(
  766. ReadBinaryFromFile(clone_option.model_file, &model_buffer),
  767. "Fail to read binary from model file while cloning TrtBackend");
  768. FDASSERT(
  769. ReadBinaryFromFile(clone_option.params_file, &params_buffer),
  770. "Fail to read binary from parameter file while cloning TrtBackend");
  771. FDASSERT(casted_backend->InitFromPaddle(model_buffer, params_buffer,
  772. clone_option),
  773. "Clone model from Paddle failed while initialize TrtBackend.");
  774. }
  775. }
  776. FDWARNING << "The target device id:" << device_id
  777. << " is different from current device id:" << option_.gpu_id
  778. << ", cannot share memory with current engine." << std::endl;
  779. return new_backend;
  780. }
  781. cudaSetDevice(option_.gpu_id);
  782. casted_backend->option_.gpu_id = option_.gpu_id;
  783. if (stream) {
  784. casted_backend->stream_ = reinterpret_cast<cudaStream_t>(stream);
  785. } else {
  786. FDASSERT(cudaStreamCreate(&casted_backend->stream_) == 0,
  787. "[ERROR] Error occurs while clone calling cudaStreamCreate().");
  788. }
  789. casted_backend->inputs_desc_.assign(inputs_desc_.begin(), inputs_desc_.end());
  790. casted_backend->outputs_desc_.assign(outputs_desc_.begin(),
  791. outputs_desc_.end());
  792. casted_backend->outputs_order_.insert(outputs_order_.begin(),
  793. outputs_order_.end());
  794. casted_backend->shape_range_info_.insert(shape_range_info_.begin(),
  795. shape_range_info_.end());
  796. casted_backend->engine_ = engine_;
  797. casted_backend->context_ = std::shared_ptr<nvinfer1::IExecutionContext>(
  798. casted_backend->engine_->createExecutionContext());
  799. casted_backend->GetInputOutputInfo();
  800. FDINFO << "TRTBackend clone finish." << std::endl;
  801. return new_backend;
  802. }
  803. } // namespace ultra_infer