triton_engine.cpp 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "model_deploy/engine/include/triton_engine.h"
  15. namespace nic = nvidia::inferenceserver::client;
  16. #define FAIL_IF_ERR(X, MSG) \
  17. { \
  18. nic::Error err = (X); \
  19. if (!err.IsOk()) { \
  20. std::cerr << "error: " << (MSG) << ": " << err << std::endl; \
  21. exit(1); \
  22. } \
  23. }
  24. namespace PaddleDeploy {
  25. std::string DtypeToString(int64_t dtype) {
  26. if (dtype == 0) {
  27. return "FP32";
  28. } else if (dtype == 1) {
  29. return "INT64";
  30. } else if (dtype == 2) {
  31. return "INT32";
  32. } else if (dtype == 3) {
  33. return "UINT8";
  34. }
  35. }
  36. int DtypeToInt(std::string dtype) {
  37. if (dtype == "FP32") {
  38. return 0;
  39. } else if (dtype == "INT64") {
  40. return 1;
  41. } else if (dtype == "INT32") {
  42. return 2;
  43. } else if (dtype == "UINT8") {
  44. return 3;
  45. }
  46. }
  47. bool Model::TritonEngineInit(const TritonEngineConfig& engine_config) {
  48. infer_engine_ = std::make_shared<TritonInferenceEngine>();
  49. InferenceConfig config("triton");
  50. *(config.triton_config) = engine_config;
  51. infer_engine_->Init(config);
  52. }
  53. void TritonInferenceEngine::ParseConfigs(
  54. const TritonEngineConfig& configs) {
  55. options_.model_name_ = configs.model_name_;
  56. options_.model_version_ = configs.model_version_;
  57. options_.request_id_ = configs.request_id_;
  58. options_.sequence_id_ = configs.sequence_id_;
  59. options_.sequence_start_ = configs.sequence_start_;
  60. options_.sequence_end_ = configs.sequence_end_;
  61. options_.priority_ = configs.priority_;
  62. options_.server_timeout_ = configs.server_timeout_;
  63. options_.client_timeout_ = configs.client_timeout_;
  64. }
  65. bool TritonInferenceEngine::Init(const InferenceConfig& configs) {
  66. const TritonEngineConfig& triton_configs = *(configs.triton_config);
  67. ParseConfigs(triton_configs);
  68. FAIL_IF_ERR(nic::InferenceServerHttpClient::Create(&client_,
  69. triton_configs.url_,
  70. triton_configs.verbose_),
  71. "error: unable to create client for inference.")
  72. return true;
  73. }
  74. nic::Error TritonInferenceEngine::GetModelMetaData(
  75. rapidjson::Document* model_metadata) {
  76. std::string model_metadata_str;
  77. FAIL_IF_ERR(client_->ModelMetadata(&model_metadata_str,
  78. options_.model_name_,
  79. options_.model_version_),
  80. "error: failed to get model metadata.");
  81. model_metadata->Parse(model_metadata_str.c_str(), model_metadata_str.size());
  82. if (model_metadata->HasParseError()) {
  83. return nic::Error(
  84. "failed to parse JSON at" +
  85. std::to_string(model_metadata->GetErrorOffset()) + ": " +
  86. std::string(GetParseError_En(model_metadata->GetParseError())));
  87. }
  88. return nic::Error::Success;
  89. }
  90. void TritonInferenceEngine::CreateInput(
  91. const std::vector<DataBlob>& input_blobs,
  92. std::vector<nic::InferInput* >* inputs) {
  93. for (int i = 0; i < input_blobs.size(); i++) {
  94. nic::InferInput* input;
  95. std::vector<int64_t> input_shape(input_blobs[i].shape.begin(),
  96. input_blobs[i].shape.end());
  97. nic::InferInput::Create(&input, input_blobs[i].name,
  98. input_shape,
  99. DtypeToString(input_blobs[i].dtype));
  100. FAIL_IF_ERR(input->AppendRaw(
  101. reinterpret_cast<const uint8_t *>(&input_blobs[i].data[0]),
  102. input_blobs[i].data.size()),
  103. "error: unable to set data for INPUT.");
  104. inputs->push_back(input);
  105. }
  106. }
  107. void TritonInferenceEngine::CreateOutput(
  108. const rapidjson::Document& model_metadata,
  109. std::vector<const nic::InferRequestedOutput* >* outputs) {
  110. const auto &output_itr = model_metadata.FindMember("outputs");
  111. for (rapidjson::Value::ConstValueIterator itr = output_itr->value.Begin();
  112. itr != output_itr->value.End(); ++itr) {
  113. auto output_name = (*itr)["name"].GetString();
  114. nic::InferRequestedOutput* output;
  115. nic::InferRequestedOutput::Create(&output, output_name);
  116. outputs->push_back(std::move(output));
  117. }
  118. }
  119. bool TritonInferenceEngine::Infer(const std::vector<DataBlob>& input_blobs,
  120. std::vector<DataBlob>* output_blobs) {
  121. rapidjson::Document model_metadata;
  122. GetModelMetaData(&model_metadata);
  123. std::vector<nic::InferInput* > inputs;
  124. CreateInput(input_blobs, &inputs);
  125. std::vector<const nic::InferRequestedOutput* > outputs;
  126. CreateOutput(model_metadata, &outputs);
  127. nic::InferResult* results;
  128. client_->Infer(&results, options_, inputs, outputs, headers_, query_params_);
  129. for (const auto output : outputs) {
  130. std::string output_name = output->Name();
  131. DataBlob output_blob;
  132. output_blob.name = output_name;
  133. std::vector<int64_t> output_shape;
  134. results->Shape(output_name, &output_shape);
  135. for (auto shape : output_shape) {
  136. output_blob.shape.push_back(static_cast<int>(shape));
  137. }
  138. std::string output_dtype;
  139. results->Datatype(output_name, &output_dtype);
  140. output_blob.dtype = DtypeToInt(output_dtype);
  141. // TODO(my_username): set output.lod when batch_size >1;
  142. int size = std::accumulate(output_blob.shape.begin(),
  143. output_blob.shape.end(), 1, std::multiplies<int>());
  144. size_t output_byte_size;
  145. uint8_t* output_data;
  146. results->RawData(output_blob.name, (const uint8_t**)&output_data,
  147. &output_byte_size);
  148. if (output_blob.dtype == 0) {
  149. output_blob.data.resize(size * sizeof(float));
  150. memcpy(output_blob.data.data(), output_data, size * sizeof(float));
  151. } else if (output_blob.dtype == 1) {
  152. output_blob.data.resize(size * sizeof(int64_t));
  153. memcpy(output_blob.data.data(), output_data, size * sizeof(int64_t));
  154. } else if (output_blob.dtype == 2) {
  155. output_blob.data.resize(size * sizeof(int));
  156. memcpy(output_blob.data.data(), output_data, size * sizeof(int));
  157. } else if (output_blob.dtype == 3) {
  158. output_blob.data.resize(size * sizeof(uint8_t));
  159. memcpy(output_blob.data.data(), output_data, size * sizeof(uint8_t));
  160. }
  161. output_blobs->push_back(std::move(output_blob));
  162. }
  163. return true;
  164. }
  165. } // namespace PaddleDeploy