voxelize_op.cu 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #if defined(PADDLEINFERENCE_API_COMPAT_2_4_x)
  15. #include "paddle/include/experimental/ext_all.h"
  16. #elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x)
  17. #include "paddle/include/paddle/extension.h"
  18. #else
  19. #include "paddle/extension.h"
  20. #endif
  21. namespace ultra_infer {
  22. namespace paddle_custom_ops {
  23. #define CHECK_INPUT_CUDA(x) \
  24. PD_CHECK(x.is_gpu() || x.is_gpu_pinned(), #x " must be a GPU Tensor.")
  25. #define CUDA_KERNEL_LOOP(i, n) \
  26. for (auto i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
  27. i += blockDim.x * gridDim.x)
  28. template <typename T, typename T_int>
  29. __global__ void init_num_point_grid(
  30. const T *points, const float point_cloud_range_x_min,
  31. const float point_cloud_range_y_min, const float point_cloud_range_z_min,
  32. const float voxel_size_x, const float voxel_size_y,
  33. const float voxel_size_z, const int grid_size_x, const int grid_size_y,
  34. const int grid_size_z, const int64_t num_points, const int num_point_dim,
  35. T_int *num_points_in_grid, int *points_valid) {
  36. int64_t point_idx = blockIdx.x * blockDim.x + threadIdx.x;
  37. if (point_idx > num_points || point_idx == num_points) {
  38. return;
  39. }
  40. int coord_x =
  41. floor((points[point_idx * num_point_dim + 0] - point_cloud_range_x_min) /
  42. voxel_size_x);
  43. int coord_y =
  44. floor((points[point_idx * num_point_dim + 1] - point_cloud_range_y_min) /
  45. voxel_size_y);
  46. int coord_z =
  47. floor((points[point_idx * num_point_dim + 2] - point_cloud_range_z_min) /
  48. voxel_size_z);
  49. if (coord_x < 0 || coord_x > grid_size_x || coord_x == grid_size_x) {
  50. return;
  51. }
  52. if (coord_y < 0 || coord_y > grid_size_y || coord_y == grid_size_y) {
  53. return;
  54. }
  55. if (coord_z < 0 || coord_z > grid_size_z || coord_z == grid_size_z) {
  56. return;
  57. }
  58. int grid_idx =
  59. coord_z * grid_size_y * grid_size_x + coord_y * grid_size_x + coord_x;
  60. num_points_in_grid[grid_idx] = 0;
  61. points_valid[grid_idx] = num_points;
  62. }
  63. template <typename T, typename T_int>
  64. __global__ void map_point_to_grid_kernel(
  65. const T *points, const float point_cloud_range_x_min,
  66. const float point_cloud_range_y_min, const float point_cloud_range_z_min,
  67. const float voxel_size_x, const float voxel_size_y,
  68. const float voxel_size_z, const int grid_size_x, const int grid_size_y,
  69. const int grid_size_z, const int64_t num_points, const int num_point_dim,
  70. const int max_num_points_in_voxel, T_int *points_to_grid_idx,
  71. T_int *points_to_num_idx, T_int *num_points_in_grid, int *points_valid) {
  72. int64_t point_idx = blockIdx.x * blockDim.x + threadIdx.x;
  73. if (point_idx > num_points || point_idx == num_points) {
  74. return;
  75. }
  76. int coord_x =
  77. floor((points[point_idx * num_point_dim + 0] - point_cloud_range_x_min) /
  78. voxel_size_x);
  79. int coord_y =
  80. floor((points[point_idx * num_point_dim + 1] - point_cloud_range_y_min) /
  81. voxel_size_y);
  82. int coord_z =
  83. floor((points[point_idx * num_point_dim + 2] - point_cloud_range_z_min) /
  84. voxel_size_z);
  85. if (coord_x < 0 || coord_x > grid_size_x || coord_x == grid_size_x) {
  86. return;
  87. }
  88. if (coord_y < 0 || coord_y > grid_size_y || coord_y == grid_size_y) {
  89. return;
  90. }
  91. if (coord_z < 0 || coord_z > grid_size_z || coord_z == grid_size_z) {
  92. return;
  93. }
  94. int grid_idx =
  95. coord_z * grid_size_y * grid_size_x + coord_y * grid_size_x + coord_x;
  96. T_int num = atomicAdd(num_points_in_grid + grid_idx, 1);
  97. if (num < max_num_points_in_voxel) {
  98. points_to_num_idx[point_idx] = num;
  99. points_to_grid_idx[point_idx] = grid_idx;
  100. atomicMin(points_valid + grid_idx, static_cast<int>(point_idx));
  101. }
  102. }
  103. template <typename T_int>
  104. __global__ void update_points_flag(const int *points_valid,
  105. const T_int *points_to_grid_idx,
  106. const int num_points, int *points_flag) {
  107. int tid = threadIdx.x + blockIdx.x * blockDim.x;
  108. for (int i = tid; i < num_points; i += gridDim.x * blockDim.x) {
  109. T_int grid_idx = points_to_grid_idx[i];
  110. if (grid_idx >= 0) {
  111. int id = points_valid[grid_idx];
  112. if (id != num_points && id == i) {
  113. points_flag[i] = 1;
  114. }
  115. }
  116. }
  117. }
  118. template <typename T_int>
  119. __global__ void
  120. get_voxel_idx_kernel(const int *points_flag, const T_int *points_to_grid_idx,
  121. const int *points_flag_prefix_sum, const int num_points,
  122. const int max_voxels, T_int *num_voxels,
  123. T_int *grid_idx_to_voxel_idx) {
  124. int tid = threadIdx.x + blockIdx.x * blockDim.x;
  125. for (int i = tid; i < num_points; i += gridDim.x * blockDim.x) {
  126. if (points_flag[i] == 1) {
  127. T_int grid_idx = points_to_grid_idx[i];
  128. int num = points_flag_prefix_sum[i];
  129. if (num < max_voxels) {
  130. grid_idx_to_voxel_idx[grid_idx] = num;
  131. }
  132. }
  133. if (i == num_points - 1) {
  134. int num = points_flag_prefix_sum[i] + points_flag[i];
  135. if (num < max_voxels) {
  136. num_voxels[0] = num;
  137. } else {
  138. num_voxels[0] = max_voxels;
  139. }
  140. }
  141. }
  142. }
  143. template <typename T>
  144. __global__ void init_voxels_kernel(const int64_t num, T *voxels) {
  145. int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  146. if (idx > num || idx == num) {
  147. return;
  148. }
  149. voxels[idx] = static_cast<T>(0);
  150. }
  151. template <typename T, typename T_int>
  152. __global__ void
  153. assign_voxels_kernel(const T *points, const T_int *points_to_grid_idx,
  154. const T_int *points_to_num_idx,
  155. const T_int *grid_idx_to_voxel_idx,
  156. const int64_t num_points, const int num_point_dim,
  157. const int max_num_points_in_voxel, T *voxels) {
  158. int64_t point_idx = blockIdx.x * blockDim.x + threadIdx.x;
  159. if (point_idx > num_points || point_idx == num_points) {
  160. return;
  161. }
  162. T_int grid_idx = points_to_grid_idx[point_idx];
  163. T_int num_idx = points_to_num_idx[point_idx];
  164. if (grid_idx > -1 && num_idx > -1) {
  165. T_int voxel_idx = grid_idx_to_voxel_idx[grid_idx];
  166. if (voxel_idx > -1) {
  167. for (int64_t i = 0; i < num_point_dim; ++i) {
  168. voxels[voxel_idx * max_num_points_in_voxel * num_point_dim +
  169. num_idx * num_point_dim + i] =
  170. points[point_idx * num_point_dim + i];
  171. }
  172. }
  173. }
  174. }
  175. template <typename T, typename T_int>
  176. __global__ void
  177. assign_coords_kernel(const T_int *grid_idx_to_voxel_idx,
  178. const T_int *num_points_in_grid, const int num_grids,
  179. const int grid_size_x, const int grid_size_y,
  180. const int grid_size_z, const int max_num_points_in_voxel,
  181. T *coords, T *num_points_per_voxel) {
  182. int64_t grid_idx = blockIdx.x * blockDim.x + threadIdx.x;
  183. if (grid_idx > num_grids || grid_idx == num_grids) {
  184. return;
  185. }
  186. T_int voxel_idx = grid_idx_to_voxel_idx[grid_idx];
  187. if (voxel_idx > -1) {
  188. T_int coord_z = grid_idx / grid_size_x / grid_size_y;
  189. T_int coord_y =
  190. (grid_idx - coord_z * grid_size_x * grid_size_y) / grid_size_x;
  191. T_int coord_x =
  192. grid_idx - coord_z * grid_size_x * grid_size_y - coord_y * grid_size_x;
  193. coords[voxel_idx * 3 + 0] = coord_z;
  194. coords[voxel_idx * 3 + 1] = coord_y;
  195. coords[voxel_idx * 3 + 2] = coord_x;
  196. num_points_per_voxel[voxel_idx] =
  197. min(num_points_in_grid[grid_idx], max_num_points_in_voxel);
  198. }
  199. }
  200. std::vector<paddle::Tensor>
  201. hard_voxelize_cuda(const paddle::Tensor &points,
  202. const std::vector<float> &voxel_size,
  203. const std::vector<float> &point_cloud_range,
  204. int max_num_points_in_voxel, int max_voxels) {
  205. // check device
  206. CHECK_INPUT_CUDA(points);
  207. int64_t num_points = points.shape()[0];
  208. int64_t num_point_dim = points.shape()[1];
  209. const float voxel_size_x = voxel_size[0];
  210. const float voxel_size_y = voxel_size[1];
  211. const float voxel_size_z = voxel_size[2];
  212. const float point_cloud_range_x_min = point_cloud_range[0];
  213. const float point_cloud_range_y_min = point_cloud_range[1];
  214. const float point_cloud_range_z_min = point_cloud_range[2];
  215. int grid_size_x = static_cast<int>(
  216. round((point_cloud_range[3] - point_cloud_range[0]) / voxel_size_x));
  217. int grid_size_y = static_cast<int>(
  218. round((point_cloud_range[4] - point_cloud_range[1]) / voxel_size_y));
  219. int grid_size_z = static_cast<int>(
  220. round((point_cloud_range[5] - point_cloud_range[2]) / voxel_size_z));
  221. int num_grids = grid_size_x * grid_size_y * grid_size_z;
  222. auto voxels =
  223. paddle::empty({max_voxels, max_num_points_in_voxel, num_point_dim},
  224. paddle::DataType::FLOAT32, paddle::GPUPlace());
  225. auto coords = paddle::full({max_voxels, 3}, 0, paddle::DataType::INT32,
  226. paddle::GPUPlace());
  227. auto *coords_data = coords.data<int>();
  228. auto num_points_per_voxel = paddle::full(
  229. {max_voxels}, 0, paddle::DataType::INT32, paddle::GPUPlace());
  230. auto *num_points_per_voxel_data = num_points_per_voxel.data<int>();
  231. auto points_to_grid_idx = paddle::full(
  232. {num_points}, -1, paddle::DataType::INT32, paddle::GPUPlace());
  233. auto *points_to_grid_idx_data = points_to_grid_idx.data<int>();
  234. auto points_to_num_idx = paddle::full(
  235. {num_points}, -1, paddle::DataType::INT32, paddle::GPUPlace());
  236. auto *points_to_num_idx_data = points_to_num_idx.data<int>();
  237. auto num_points_in_grid =
  238. paddle::empty({grid_size_z, grid_size_y, grid_size_x},
  239. paddle::DataType::INT32, paddle::GPUPlace());
  240. auto *num_points_in_grid_data = num_points_in_grid.data<int>();
  241. auto grid_idx_to_voxel_idx =
  242. paddle::full({grid_size_z, grid_size_y, grid_size_x}, -1,
  243. paddle::DataType::INT32, paddle::GPUPlace());
  244. auto *grid_idx_to_voxel_idx_data = grid_idx_to_voxel_idx.data<int>();
  245. auto num_voxels =
  246. paddle::full({1}, 0, paddle::DataType::INT32, paddle::GPUPlace());
  247. auto *num_voxels_data = num_voxels.data<int>();
  248. auto points_valid =
  249. paddle::empty({grid_size_z, grid_size_y, grid_size_x},
  250. paddle::DataType::INT32, paddle::GPUPlace());
  251. int *points_valid_data = points_valid.data<int>();
  252. auto points_flag = paddle::full({num_points}, 0, paddle::DataType::INT32,
  253. paddle::GPUPlace());
  254. // 1. Find the grid index for each point, compute the
  255. // number of points in each grid
  256. int64_t threads = 512;
  257. int64_t blocks = (num_points + threads - 1) / threads;
  258. PD_DISPATCH_FLOATING_TYPES(
  259. points.type(), "init_num_point_grid", ([&] {
  260. init_num_point_grid<data_t, int>
  261. <<<blocks, threads, 0, points.stream()>>>(
  262. points.data<data_t>(), point_cloud_range_x_min,
  263. point_cloud_range_y_min, point_cloud_range_z_min, voxel_size_x,
  264. voxel_size_y, voxel_size_z, grid_size_x, grid_size_y,
  265. grid_size_z, num_points, num_point_dim, num_points_in_grid_data,
  266. points_valid_data);
  267. }));
  268. PD_DISPATCH_FLOATING_TYPES(
  269. points.type(), "map_point_to_grid_kernel", ([&] {
  270. map_point_to_grid_kernel<data_t, int>
  271. <<<blocks, threads, 0, points.stream()>>>(
  272. points.data<data_t>(), point_cloud_range_x_min,
  273. point_cloud_range_y_min, point_cloud_range_z_min, voxel_size_x,
  274. voxel_size_y, voxel_size_z, grid_size_x, grid_size_y,
  275. grid_size_z, num_points, num_point_dim, max_num_points_in_voxel,
  276. points_to_grid_idx_data, points_to_num_idx_data,
  277. num_points_in_grid_data, points_valid_data);
  278. }));
  279. // 2. Find the number of non-zero voxels
  280. int *points_flag_data = points_flag.data<int>();
  281. threads = 512;
  282. blocks = (num_points + threads - 1) / threads;
  283. update_points_flag<int><<<blocks, threads, 0, points.stream()>>>(
  284. points_valid_data, points_to_grid_idx_data, num_points, points_flag_data);
  285. auto points_flag_prefix_sum =
  286. paddle::experimental::cumsum(points_flag, 0, false, true, false);
  287. int *points_flag_prefix_sum_data = points_flag_prefix_sum.data<int>();
  288. get_voxel_idx_kernel<int><<<blocks, threads, 0, points.stream()>>>(
  289. points_flag_data, points_to_grid_idx_data, points_flag_prefix_sum_data,
  290. num_points, max_voxels, num_voxels_data, grid_idx_to_voxel_idx_data);
  291. // 3. Store points to voxels coords and num_points_per_voxel
  292. int64_t num = max_voxels * max_num_points_in_voxel * num_point_dim;
  293. threads = 512;
  294. blocks = (num + threads - 1) / threads;
  295. PD_DISPATCH_FLOATING_TYPES(points.type(), "init_voxels_kernel", ([&] {
  296. init_voxels_kernel<data_t>
  297. <<<blocks, threads, 0, points.stream()>>>(
  298. num, voxels.data<data_t>());
  299. }));
  300. threads = 512;
  301. blocks = (num_points + threads - 1) / threads;
  302. PD_DISPATCH_FLOATING_TYPES(
  303. points.type(), "assign_voxels_kernel", ([&] {
  304. assign_voxels_kernel<data_t, int>
  305. <<<blocks, threads, 0, points.stream()>>>(
  306. points.data<data_t>(), points_to_grid_idx_data,
  307. points_to_num_idx_data, grid_idx_to_voxel_idx_data, num_points,
  308. num_point_dim, max_num_points_in_voxel, voxels.data<data_t>());
  309. }));
  310. // 4. Store coords, num_points_per_voxel
  311. blocks = (num_grids + threads - 1) / threads;
  312. assign_coords_kernel<int><<<blocks, threads, 0, points.stream()>>>(
  313. grid_idx_to_voxel_idx_data, num_points_in_grid_data, num_grids,
  314. grid_size_x, grid_size_y, grid_size_z, max_num_points_in_voxel,
  315. coords_data, num_points_per_voxel_data);
  316. return {voxels, coords, num_points_per_voxel, num_voxels};
  317. }
  318. } // namespace paddle_custom_ops
  319. } // namespace ultra_infer