tensorrt_engine.h 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include <iostream>
  16. #include <fstream>
  17. #include <map>
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include "glog/logging.h"
  22. #include "NvInfer.h"
  23. #include "NvInferRuntime.h"
  24. #include "NvOnnxConfig.h"
  25. #include "NvOnnxParser.h"
  26. #include "model_deploy/common/include/output_struct.h"
  27. #include "model_deploy/engine/include/engine.h"
  28. #include "model_deploy/common/include/base_model.h"
  29. #include "model_deploy/engine/include/tensorrt_buffers.h"
  30. namespace PaddleDeploy {
  31. using Severity = nvinfer1::ILogger::Severity;
  32. struct InferDeleter {
  33. template <typename T> void operator()(T *obj) const {
  34. if (obj) {
  35. obj->destroy();
  36. }
  37. }
  38. };
  39. // A logger for create TensorRT infer builder.
  40. class NaiveLogger : public nvinfer1::ILogger {
  41. public:
  42. explicit NaiveLogger(Severity severity = Severity::kWARNING)
  43. : mReportableSeverity(severity) {}
  44. void log(nvinfer1::ILogger::Severity severity, const char *msg) override {
  45. switch (severity) {
  46. case Severity::kINFO:
  47. LOG(INFO) << msg;
  48. break;
  49. case Severity::kWARNING:
  50. LOG(WARNING) << msg;
  51. break;
  52. case Severity::kINTERNAL_ERROR:
  53. std::cout << "kINTERNAL_ERROR:" << msg << std::endl;
  54. break;
  55. case Severity::kERROR:
  56. LOG(ERROR) << msg;
  57. break;
  58. case Severity::kVERBOSE:
  59. // std::cout << "kVERBOSE:" << msg << std::endl;
  60. break;
  61. default:
  62. // std::cout << "default:" << msg << std::endl;
  63. break;
  64. }
  65. }
  66. static NaiveLogger &Global() {
  67. static NaiveLogger *x = new NaiveLogger;
  68. return *x;
  69. }
  70. ~NaiveLogger() override {}
  71. Severity mReportableSeverity;
  72. };
  73. class TensorRTInferenceEngine : public InferEngine {
  74. template <typename T>
  75. using InferUniquePtr = std::unique_ptr<T, InferDeleter>;
  76. public:
  77. bool Init(const InferenceConfig& engine_config);
  78. bool Infer(const std::vector<DataBlob>& input_blobs,
  79. std::vector<DataBlob>* output_blobs);
  80. std::shared_ptr<nvinfer1::ICudaEngine> engine_{nullptr};
  81. std::shared_ptr<nvinfer1::IExecutionContext> context_;
  82. NaiveLogger logger_;
  83. private:
  84. void FeedInput(const std::vector<DataBlob>& input_blobs,
  85. const TensorRT::BufferManager& buffers);
  86. bool SaveEngine(const nvinfer1::ICudaEngine& engine,
  87. const std::string& fileName);
  88. nvinfer1::ICudaEngine* LoadEngine(const std::string& engine,
  89. int DLACore = -1);
  90. void ParseONNXModel(const std::string& model_dir);
  91. YAML::Node yaml_config_;
  92. };
  93. } // namespace PaddleDeploy