centerpoint.cc 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/vision/perception/paddle3d/centerpoint/centerpoint.h"
  15. namespace ultra_infer {
  16. namespace vision {
  17. namespace perception {
  18. Centerpoint::Centerpoint(const std::string &model_file,
  19. const std::string &params_file,
  20. const std::string &config_file,
  21. const RuntimeOption &custom_option,
  22. const ModelFormat &model_format)
  23. : preprocessor_(config_file) {
  24. valid_gpu_backends = {Backend::PDINFER};
  25. runtime_option = custom_option;
  26. runtime_option.model_format = model_format;
  27. runtime_option.model_file = model_file;
  28. runtime_option.params_file = params_file;
  29. initialized = Initialize();
  30. }
  31. bool Centerpoint::Initialize() {
  32. if (!InitRuntime()) {
  33. FDERROR << "Failed to initialize ultra_infer backend." << std::endl;
  34. return false;
  35. }
  36. return true;
  37. }
  38. bool Centerpoint::Predict(const std::string point_dir,
  39. PerceptionResult *result) {
  40. std::vector<PerceptionResult> results;
  41. if (!BatchPredict({point_dir}, &results)) {
  42. return false;
  43. }
  44. if (results.size()) {
  45. *result = std::move(results[0]);
  46. }
  47. return true;
  48. }
  49. bool Centerpoint::BatchPredict(std::vector<std::string> points_dir,
  50. std::vector<PerceptionResult> *results) {
  51. int64_t num_point_dim = 5;
  52. int with_timelag = 0;
  53. if (!preprocessor_.Run(points_dir, num_point_dim, with_timelag,
  54. reused_input_tensors_)) {
  55. FDERROR << "Failed to preprocess the input image." << std::endl;
  56. return false;
  57. }
  58. results->resize(reused_input_tensors_.size());
  59. for (int index = 0; index < reused_input_tensors_.size(); ++index) {
  60. std::vector<FDTensor> input_tensor;
  61. input_tensor.push_back(reused_input_tensors_[index]);
  62. input_tensor[0].name = InputInfoOfRuntime(0).name;
  63. if (!Infer(input_tensor, &reused_output_tensors_)) {
  64. FDERROR << "Failed to inference by runtime." << std::endl;
  65. return false;
  66. }
  67. (*results)[index].Clear();
  68. (*results)[index].Reserve(reused_output_tensors_[0].shape[0]);
  69. if (!postprocessor_.Run(reused_output_tensors_, &((*results)[index]))) {
  70. FDERROR << "Failed to postprocess the inference results by runtime."
  71. << std::endl;
  72. return false;
  73. }
  74. }
  75. return true;
  76. }
  77. } // namespace perception
  78. } // namespace vision
  79. } // namespace ultra_infer