utils.h 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include <cuda_runtime_api.h>
  16. #include <algorithm>
  17. #include <iostream>
  18. #include <map>
  19. #include <memory>
  20. #include <numeric>
  21. #include <string>
  22. #include <vector>
  23. #include "NvInfer.h"
  24. #include "ultra_infer/core/allocate.h"
  25. #include "ultra_infer/core/fd_tensor.h"
  26. #include "ultra_infer/utils/utils.h"
  27. namespace ultra_infer {
  28. struct FDInferDeleter {
  29. template <typename T> void operator()(T *obj) const {
  30. if (obj) {
  31. delete obj;
  32. // obj->destroy();
  33. }
  34. }
  35. };
  36. template <typename T> using FDUniquePtr = std::unique_ptr<T, FDInferDeleter>;
  37. int64_t Volume(const nvinfer1::Dims &d);
  38. nvinfer1::Dims ToDims(const std::vector<int> &vec);
  39. nvinfer1::Dims ToDims(const std::vector<int64_t> &vec);
  40. size_t TrtDataTypeSize(const nvinfer1::DataType &dtype);
  41. FDDataType GetFDDataType(const nvinfer1::DataType &dtype);
  42. nvinfer1::DataType ReaderDtypeToTrtDtype(int reader_dtype);
  43. FDDataType ReaderDtypeToFDDtype(int reader_dtype);
  44. std::vector<int> ToVec(const nvinfer1::Dims &dim);
  45. template <typename T>
  46. std::ostream &operator<<(std::ostream &out, const std::vector<T> &vec) {
  47. out << "[";
  48. for (size_t i = 0; i < vec.size(); ++i) {
  49. if (i != vec.size() - 1) {
  50. out << vec[i] << ", ";
  51. } else {
  52. out << vec[i] << "]";
  53. }
  54. }
  55. return out;
  56. }
  57. template <typename AllocFunc, typename FreeFunc> class FDGenericBuffer {
  58. public:
  59. //!
  60. //! \brief Construct an empty buffer.
  61. //!
  62. explicit FDGenericBuffer(nvinfer1::DataType type = nvinfer1::DataType::kFLOAT)
  63. : mSize(0), mCapacity(0), mType(type), mBuffer(nullptr),
  64. mExternal_buffer(nullptr) {}
  65. //!
  66. //! \brief Construct a buffer with the specified allocation size in bytes.
  67. //!
  68. FDGenericBuffer(size_t size, nvinfer1::DataType type)
  69. : mSize(size), mCapacity(size), mType(type) {
  70. if (!allocFn(&mBuffer, this->nbBytes())) {
  71. throw std::bad_alloc();
  72. }
  73. }
  74. //!
  75. //! \brief This use to skip memory copy step.
  76. //!
  77. FDGenericBuffer(size_t size, nvinfer1::DataType type, void *buffer)
  78. : mSize(size), mCapacity(size), mType(type) {
  79. mExternal_buffer = buffer;
  80. }
  81. FDGenericBuffer(FDGenericBuffer &&buf)
  82. : mSize(buf.mSize), mCapacity(buf.mCapacity), mType(buf.mType),
  83. mBuffer(buf.mBuffer) {
  84. buf.mSize = 0;
  85. buf.mCapacity = 0;
  86. buf.mType = nvinfer1::DataType::kFLOAT;
  87. buf.mBuffer = nullptr;
  88. }
  89. FDGenericBuffer &operator=(FDGenericBuffer &&buf) {
  90. if (this != &buf) {
  91. freeFn(mBuffer);
  92. mSize = buf.mSize;
  93. mCapacity = buf.mCapacity;
  94. mType = buf.mType;
  95. mBuffer = buf.mBuffer;
  96. // Reset buf.
  97. buf.mSize = 0;
  98. buf.mCapacity = 0;
  99. buf.mBuffer = nullptr;
  100. }
  101. return *this;
  102. }
  103. //!
  104. //! \brief Returns pointer to underlying array.
  105. //!
  106. void *data() {
  107. if (mExternal_buffer != nullptr)
  108. return mExternal_buffer;
  109. return mBuffer;
  110. }
  111. //!
  112. //! \brief Returns pointer to underlying array.
  113. //!
  114. const void *data() const {
  115. if (mExternal_buffer != nullptr)
  116. return mExternal_buffer;
  117. return mBuffer;
  118. }
  119. //!
  120. //! \brief Returns the size (in number of elements) of the buffer.
  121. //!
  122. size_t size() const { return mSize; }
  123. //!
  124. //! \brief Returns the size (in bytes) of the buffer.
  125. //!
  126. size_t nbBytes() const { return this->size() * TrtDataTypeSize(mType); }
  127. //!
  128. //! \brief Returns the dtype of the buffer.
  129. //!
  130. nvinfer1::DataType dtype() const { return mType; }
  131. //!
  132. //! \brief Set user memory buffer for TRT Buffer
  133. //!
  134. void SetExternalData(size_t size, nvinfer1::DataType type, void *buffer) {
  135. mSize = mCapacity = size;
  136. mType = type;
  137. mExternal_buffer = const_cast<void *>(buffer);
  138. }
  139. //!
  140. //! \brief Set user memory buffer for TRT Buffer
  141. //!
  142. void SetExternalData(const nvinfer1::Dims &dims, const void *buffer) {
  143. mSize = mCapacity = Volume(dims);
  144. mExternal_buffer = const_cast<void *>(buffer);
  145. }
  146. //!
  147. //! \brief Resizes the buffer. This is a no-op if the new size is smaller than
  148. //! or equal to the current capacity.
  149. //!
  150. void resize(size_t newSize) {
  151. mExternal_buffer = nullptr;
  152. mSize = newSize;
  153. if (mCapacity < newSize) {
  154. freeFn(mBuffer);
  155. if (!allocFn(&mBuffer, this->nbBytes())) {
  156. throw std::bad_alloc{};
  157. }
  158. mCapacity = newSize;
  159. }
  160. }
  161. //!
  162. //! \brief Overload of resize that accepts Dims
  163. //!
  164. void resize(const nvinfer1::Dims &dims) { return this->resize(Volume(dims)); }
  165. ~FDGenericBuffer() {
  166. mExternal_buffer = nullptr;
  167. freeFn(mBuffer);
  168. }
  169. private:
  170. size_t mSize{0}, mCapacity{0};
  171. nvinfer1::DataType mType;
  172. void *mBuffer;
  173. void *mExternal_buffer;
  174. AllocFunc allocFn;
  175. FreeFunc freeFn;
  176. };
  177. using FDDeviceBuffer = FDGenericBuffer<FDDeviceAllocator, FDDeviceFree>;
  178. using FDDeviceHostBuffer =
  179. FDGenericBuffer<FDDeviceHostAllocator, FDDeviceHostFree>;
  180. class FDTrtLogger : public nvinfer1::ILogger {
  181. public:
  182. static FDTrtLogger *logger;
  183. static FDTrtLogger *Get() {
  184. if (logger != nullptr) {
  185. return logger;
  186. }
  187. logger = new FDTrtLogger();
  188. return logger;
  189. }
  190. void SetLog(bool enable_info = false, bool enable_warning = false) {
  191. enable_info_ = enable_info;
  192. enable_warning_ = enable_warning;
  193. }
  194. void log(nvinfer1::ILogger::Severity severity,
  195. const char *msg) noexcept override {
  196. if (severity == nvinfer1::ILogger::Severity::kINFO) {
  197. if (enable_info_) {
  198. FDINFO << msg << std::endl;
  199. }
  200. } else if (severity == nvinfer1::ILogger::Severity::kWARNING) {
  201. if (enable_warning_) {
  202. FDWARNING << msg << std::endl;
  203. }
  204. } else if (severity == nvinfer1::ILogger::Severity::kERROR) {
  205. FDERROR << msg << std::endl;
  206. } else if (severity == nvinfer1::ILogger::Severity::kINTERNAL_ERROR) {
  207. FDASSERT(false, "%s", msg);
  208. }
  209. }
  210. private:
  211. bool enable_info_ = false;
  212. bool enable_warning_ = false;
  213. };
  214. struct ShapeRangeInfo {
  215. explicit ShapeRangeInfo(const std::vector<int64_t> &new_shape) {
  216. shape.assign(new_shape.begin(), new_shape.end());
  217. min.resize(new_shape.size());
  218. max.resize(new_shape.size());
  219. is_static.resize(new_shape.size());
  220. for (size_t i = 0; i < new_shape.size(); ++i) {
  221. if (new_shape[i] > 0) {
  222. min[i] = new_shape[i];
  223. max[i] = new_shape[i];
  224. is_static[i] = 1;
  225. } else {
  226. min[i] = -1;
  227. max[i] = -1;
  228. is_static[i] = 0;
  229. }
  230. }
  231. }
  232. std::string name;
  233. std::vector<int64_t> shape;
  234. std::vector<int64_t> min;
  235. std::vector<int64_t> max;
  236. std::vector<int64_t> opt;
  237. std::vector<int8_t> is_static;
  238. // return
  239. // -1: new shape is inillegal
  240. // 0 : new shape is able to inference
  241. // 1 : new shape is out of range, need to update engine
  242. int Update(const std::vector<int64_t> &new_shape);
  243. int Update(const std::vector<int> &new_shape) {
  244. std::vector<int64_t> new_shape_int64(new_shape.begin(), new_shape.end());
  245. return Update(new_shape_int64);
  246. }
  247. friend std::ostream &operator<<(std::ostream &out,
  248. const ShapeRangeInfo &info) {
  249. out << "Input name: " << info.name << ", shape=" << info.shape
  250. << ", min=" << info.min << ", max=" << info.max << std::endl;
  251. return out;
  252. }
  253. };
  254. } // namespace ultra_infer