utils.cc 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/runtime/backends/ort/utils.h"
  15. #include "ultra_infer/utils/utils.h"
  16. namespace ultra_infer {
  17. ONNXTensorElementDataType GetOrtDtype(const FDDataType &fd_dtype) {
  18. if (fd_dtype == FDDataType::FP32) {
  19. return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
  20. } else if (fd_dtype == FDDataType::FP64) {
  21. return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
  22. } else if (fd_dtype == FDDataType::INT32) {
  23. return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
  24. } else if (fd_dtype == FDDataType::INT64) {
  25. return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
  26. } else if (fd_dtype == FDDataType::UINT8) {
  27. return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
  28. } else if (fd_dtype == FDDataType::INT8) {
  29. return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
  30. } else if (fd_dtype == FDDataType::FP16) {
  31. return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
  32. }
  33. FDERROR << "Unrecognized fastdeply data type:" << Str(fd_dtype) << "."
  34. << std::endl;
  35. return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
  36. }
  37. FDDataType GetFdDtype(const ONNXTensorElementDataType &ort_dtype) {
  38. if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
  39. return FDDataType::FP32;
  40. } else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
  41. return FDDataType::FP64;
  42. } else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
  43. return FDDataType::INT32;
  44. } else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
  45. return FDDataType::INT64;
  46. } else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
  47. return FDDataType::FP16;
  48. } else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) {
  49. return FDDataType::UINT8;
  50. } else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
  51. return FDDataType::INT8;
  52. }
  53. FDERROR << "Unrecognized ort data type:" << ort_dtype << "." << std::endl;
  54. return FDDataType::FP32;
  55. }
  56. Ort::Value CreateOrtValue(FDTensor &tensor, bool is_backend_cuda) {
  57. FDASSERT(tensor.device == Device::GPU || tensor.device == Device::CPU,
  58. "Only support tensor which device is CPU or GPU for OrtBackend.");
  59. if (tensor.device == Device::GPU && is_backend_cuda) {
  60. Ort::MemoryInfo memory_info("Cuda", OrtDeviceAllocator, 0,
  61. OrtMemTypeDefault);
  62. auto ort_value = Ort::Value::CreateTensor(
  63. memory_info, tensor.MutableData(), tensor.Nbytes(), tensor.shape.data(),
  64. tensor.shape.size(), GetOrtDtype(tensor.dtype));
  65. return ort_value;
  66. }
  67. Ort::MemoryInfo memory_info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
  68. auto ort_value = Ort::Value::CreateTensor(
  69. memory_info, tensor.Data(), tensor.Nbytes(), tensor.shape.data(),
  70. tensor.shape.size(), GetOrtDtype(tensor.dtype));
  71. return ort_value;
  72. }
  73. } // namespace ultra_infer