voxelize_op.cu 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "paddle/extension.h"
  15. #define CHECK_INPUT_CUDA(x) \
  16. PD_CHECK(x.is_gpu() || x.is_gpu_pinned(), #x " must be a GPU Tensor.")
  17. #define CUDA_KERNEL_LOOP(i, n) \
  18. for (auto i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
  19. i += blockDim.x * gridDim.x)
  20. template <typename T, typename T_int>
  21. __global__ void init_num_point_grid(
  22. const T *points, const float point_cloud_range_x_min,
  23. const float point_cloud_range_y_min, const float point_cloud_range_z_min,
  24. const float voxel_size_x, const float voxel_size_y,
  25. const float voxel_size_z, const int grid_size_x, const int grid_size_y,
  26. const int grid_size_z, const int64_t num_points, const int num_point_dim,
  27. T_int *num_points_in_grid, int *points_valid) {
  28. int64_t point_idx = blockIdx.x * blockDim.x + threadIdx.x;
  29. if (point_idx > num_points || point_idx == num_points) {
  30. return;
  31. }
  32. int coord_x =
  33. floor((points[point_idx * num_point_dim + 0] - point_cloud_range_x_min) /
  34. voxel_size_x);
  35. int coord_y =
  36. floor((points[point_idx * num_point_dim + 1] - point_cloud_range_y_min) /
  37. voxel_size_y);
  38. int coord_z =
  39. floor((points[point_idx * num_point_dim + 2] - point_cloud_range_z_min) /
  40. voxel_size_z);
  41. if (coord_x < 0 || coord_x > grid_size_x || coord_x == grid_size_x) {
  42. return;
  43. }
  44. if (coord_y < 0 || coord_y > grid_size_y || coord_y == grid_size_y) {
  45. return;
  46. }
  47. if (coord_z < 0 || coord_z > grid_size_z || coord_z == grid_size_z) {
  48. return;
  49. }
  50. int grid_idx =
  51. coord_z * grid_size_y * grid_size_x + coord_y * grid_size_x + coord_x;
  52. num_points_in_grid[grid_idx] = 0;
  53. points_valid[grid_idx] = num_points;
  54. }
  55. template <typename T, typename T_int>
  56. __global__ void map_point_to_grid_kernel(
  57. const T *points, const float point_cloud_range_x_min,
  58. const float point_cloud_range_y_min, const float point_cloud_range_z_min,
  59. const float voxel_size_x, const float voxel_size_y,
  60. const float voxel_size_z, const int grid_size_x, const int grid_size_y,
  61. const int grid_size_z, const int64_t num_points, const int num_point_dim,
  62. const int max_num_points_in_voxel, T_int *points_to_grid_idx,
  63. T_int *points_to_num_idx, T_int *num_points_in_grid, int *points_valid) {
  64. int64_t point_idx = blockIdx.x * blockDim.x + threadIdx.x;
  65. if (point_idx > num_points || point_idx == num_points) {
  66. return;
  67. }
  68. int coord_x =
  69. floor((points[point_idx * num_point_dim + 0] - point_cloud_range_x_min) /
  70. voxel_size_x);
  71. int coord_y =
  72. floor((points[point_idx * num_point_dim + 1] - point_cloud_range_y_min) /
  73. voxel_size_y);
  74. int coord_z =
  75. floor((points[point_idx * num_point_dim + 2] - point_cloud_range_z_min) /
  76. voxel_size_z);
  77. if (coord_x < 0 || coord_x > grid_size_x || coord_x == grid_size_x) {
  78. return;
  79. }
  80. if (coord_y < 0 || coord_y > grid_size_y || coord_y == grid_size_y) {
  81. return;
  82. }
  83. if (coord_z < 0 || coord_z > grid_size_z || coord_z == grid_size_z) {
  84. return;
  85. }
  86. int grid_idx =
  87. coord_z * grid_size_y * grid_size_x + coord_y * grid_size_x + coord_x;
  88. T_int num = atomicAdd(num_points_in_grid + grid_idx, 1);
  89. if (num < max_num_points_in_voxel) {
  90. points_to_num_idx[point_idx] = num;
  91. points_to_grid_idx[point_idx] = grid_idx;
  92. atomicMin(points_valid + grid_idx, static_cast<int>(point_idx));
  93. }
  94. }
  95. template <typename T_int>
  96. __global__ void update_points_flag(const int *points_valid,
  97. const T_int *points_to_grid_idx,
  98. const int num_points, int *points_flag) {
  99. int tid = threadIdx.x + blockIdx.x * blockDim.x;
  100. for (int i = tid; i < num_points; i += gridDim.x * blockDim.x) {
  101. T_int grid_idx = points_to_grid_idx[i];
  102. if (grid_idx >= 0) {
  103. int id = points_valid[grid_idx];
  104. if (id != num_points && id == i) {
  105. points_flag[i] = 1;
  106. }
  107. }
  108. }
  109. }
  110. template <typename T_int>
  111. __global__ void
  112. get_voxel_idx_kernel(const int *points_flag, const T_int *points_to_grid_idx,
  113. const int *points_flag_prefix_sum, const int num_points,
  114. const int max_voxels, T_int *num_voxels,
  115. T_int *grid_idx_to_voxel_idx) {
  116. int tid = threadIdx.x + blockIdx.x * blockDim.x;
  117. for (int i = tid; i < num_points; i += gridDim.x * blockDim.x) {
  118. if (points_flag[i] == 1) {
  119. T_int grid_idx = points_to_grid_idx[i];
  120. int num = points_flag_prefix_sum[i];
  121. if (num < max_voxels) {
  122. grid_idx_to_voxel_idx[grid_idx] = num;
  123. }
  124. }
  125. if (i == num_points - 1) {
  126. int num = points_flag_prefix_sum[i] + points_flag[i];
  127. if (num < max_voxels) {
  128. num_voxels[0] = num;
  129. } else {
  130. num_voxels[0] = max_voxels;
  131. }
  132. }
  133. }
  134. }
  135. template <typename T>
  136. __global__ void init_voxels_kernel(const int64_t num, T *voxels) {
  137. int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  138. if (idx > num || idx == num) {
  139. return;
  140. }
  141. voxels[idx] = static_cast<T>(0);
  142. }
  143. template <typename T, typename T_int>
  144. __global__ void
  145. assign_voxels_kernel(const T *points, const T_int *points_to_grid_idx,
  146. const T_int *points_to_num_idx,
  147. const T_int *grid_idx_to_voxel_idx,
  148. const int64_t num_points, const int num_point_dim,
  149. const int max_num_points_in_voxel, T *voxels) {
  150. int64_t point_idx = blockIdx.x * blockDim.x + threadIdx.x;
  151. if (point_idx > num_points || point_idx == num_points) {
  152. return;
  153. }
  154. T_int grid_idx = points_to_grid_idx[point_idx];
  155. T_int num_idx = points_to_num_idx[point_idx];
  156. if (grid_idx > -1 && num_idx > -1) {
  157. T_int voxel_idx = grid_idx_to_voxel_idx[grid_idx];
  158. if (voxel_idx > -1) {
  159. for (int64_t i = 0; i < num_point_dim; ++i) {
  160. voxels[voxel_idx * max_num_points_in_voxel * num_point_dim +
  161. num_idx * num_point_dim + i] =
  162. points[point_idx * num_point_dim + i];
  163. }
  164. }
  165. }
  166. }
  167. template <typename T, typename T_int>
  168. __global__ void
  169. assign_coords_kernel(const T_int *grid_idx_to_voxel_idx,
  170. const T_int *num_points_in_grid, const int num_grids,
  171. const int grid_size_x, const int grid_size_y,
  172. const int grid_size_z, const int max_num_points_in_voxel,
  173. T *coords, T *num_points_per_voxel) {
  174. int64_t grid_idx = blockIdx.x * blockDim.x + threadIdx.x;
  175. if (grid_idx > num_grids || grid_idx == num_grids) {
  176. return;
  177. }
  178. T_int voxel_idx = grid_idx_to_voxel_idx[grid_idx];
  179. if (voxel_idx > -1) {
  180. T_int coord_z = grid_idx / grid_size_x / grid_size_y;
  181. T_int coord_y =
  182. (grid_idx - coord_z * grid_size_x * grid_size_y) / grid_size_x;
  183. T_int coord_x =
  184. grid_idx - coord_z * grid_size_x * grid_size_y - coord_y * grid_size_x;
  185. coords[voxel_idx * 3 + 0] = coord_z;
  186. coords[voxel_idx * 3 + 1] = coord_y;
  187. coords[voxel_idx * 3 + 2] = coord_x;
  188. num_points_per_voxel[voxel_idx] =
  189. min(num_points_in_grid[grid_idx], max_num_points_in_voxel);
  190. }
  191. }
  192. std::vector<paddle::Tensor>
  193. hard_voxelize_cuda(const paddle::Tensor &points,
  194. const std::vector<float> &voxel_size,
  195. const std::vector<float> &point_cloud_range,
  196. int max_num_points_in_voxel, int max_voxels) {
  197. // check device
  198. CHECK_INPUT_CUDA(points);
  199. int64_t num_points = points.shape()[0];
  200. int64_t num_point_dim = points.shape()[1];
  201. const float voxel_size_x = voxel_size[0];
  202. const float voxel_size_y = voxel_size[1];
  203. const float voxel_size_z = voxel_size[2];
  204. const float point_cloud_range_x_min = point_cloud_range[0];
  205. const float point_cloud_range_y_min = point_cloud_range[1];
  206. const float point_cloud_range_z_min = point_cloud_range[2];
  207. int grid_size_x = static_cast<int>(
  208. round((point_cloud_range[3] - point_cloud_range[0]) / voxel_size_x));
  209. int grid_size_y = static_cast<int>(
  210. round((point_cloud_range[4] - point_cloud_range[1]) / voxel_size_y));
  211. int grid_size_z = static_cast<int>(
  212. round((point_cloud_range[5] - point_cloud_range[2]) / voxel_size_z));
  213. int num_grids = grid_size_x * grid_size_y * grid_size_z;
  214. auto voxels =
  215. paddle::empty({max_voxels, max_num_points_in_voxel, num_point_dim},
  216. paddle::DataType::FLOAT32, paddle::GPUPlace());
  217. auto coords = paddle::full({max_voxels, 3}, 0, paddle::DataType::INT32,
  218. paddle::GPUPlace());
  219. auto *coords_data = coords.data<int>();
  220. auto num_points_per_voxel = paddle::full(
  221. {max_voxels}, 0, paddle::DataType::INT32, paddle::GPUPlace());
  222. auto *num_points_per_voxel_data = num_points_per_voxel.data<int>();
  223. auto points_to_grid_idx = paddle::full(
  224. {num_points}, -1, paddle::DataType::INT32, paddle::GPUPlace());
  225. auto *points_to_grid_idx_data = points_to_grid_idx.data<int>();
  226. auto points_to_num_idx = paddle::full(
  227. {num_points}, -1, paddle::DataType::INT32, paddle::GPUPlace());
  228. auto *points_to_num_idx_data = points_to_num_idx.data<int>();
  229. auto num_points_in_grid =
  230. paddle::empty({grid_size_z, grid_size_y, grid_size_x},
  231. paddle::DataType::INT32, paddle::GPUPlace());
  232. auto *num_points_in_grid_data = num_points_in_grid.data<int>();
  233. auto grid_idx_to_voxel_idx =
  234. paddle::full({grid_size_z, grid_size_y, grid_size_x}, -1,
  235. paddle::DataType::INT32, paddle::GPUPlace());
  236. auto *grid_idx_to_voxel_idx_data = grid_idx_to_voxel_idx.data<int>();
  237. auto num_voxels =
  238. paddle::full({1}, 0, paddle::DataType::INT32, paddle::GPUPlace());
  239. auto *num_voxels_data = num_voxels.data<int>();
  240. auto points_valid =
  241. paddle::empty({grid_size_z, grid_size_y, grid_size_x},
  242. paddle::DataType::INT32, paddle::GPUPlace());
  243. int *points_valid_data = points_valid.data<int>();
  244. auto points_flag = paddle::full({num_points}, 0, paddle::DataType::INT32,
  245. paddle::GPUPlace());
  246. // 1. Find the grid index for each point, compute the
  247. // number of points in each grid
  248. int64_t threads = 512;
  249. int64_t blocks = (num_points + threads - 1) / threads;
  250. PD_DISPATCH_FLOATING_TYPES(
  251. points.type(), "init_num_point_grid", ([&] {
  252. init_num_point_grid<data_t, int>
  253. <<<blocks, threads, 0, points.stream()>>>(
  254. points.data<data_t>(), point_cloud_range_x_min,
  255. point_cloud_range_y_min, point_cloud_range_z_min, voxel_size_x,
  256. voxel_size_y, voxel_size_z, grid_size_x, grid_size_y,
  257. grid_size_z, num_points, num_point_dim, num_points_in_grid_data,
  258. points_valid_data);
  259. }));
  260. PD_DISPATCH_FLOATING_TYPES(
  261. points.type(), "map_point_to_grid_kernel", ([&] {
  262. map_point_to_grid_kernel<data_t, int>
  263. <<<blocks, threads, 0, points.stream()>>>(
  264. points.data<data_t>(), point_cloud_range_x_min,
  265. point_cloud_range_y_min, point_cloud_range_z_min, voxel_size_x,
  266. voxel_size_y, voxel_size_z, grid_size_x, grid_size_y,
  267. grid_size_z, num_points, num_point_dim, max_num_points_in_voxel,
  268. points_to_grid_idx_data, points_to_num_idx_data,
  269. num_points_in_grid_data, points_valid_data);
  270. }));
  271. // 2. Find the number of non-zero voxels
  272. int *points_flag_data = points_flag.data<int>();
  273. threads = 512;
  274. blocks = (num_points + threads - 1) / threads;
  275. update_points_flag<int><<<blocks, threads, 0, points.stream()>>>(
  276. points_valid_data, points_to_grid_idx_data, num_points, points_flag_data);
  277. auto points_flag_prefix_sum =
  278. paddle::experimental::cumsum(points_flag, 0, false, true, false);
  279. int *points_flag_prefix_sum_data = points_flag_prefix_sum.data<int>();
  280. get_voxel_idx_kernel<int><<<blocks, threads, 0, points.stream()>>>(
  281. points_flag_data, points_to_grid_idx_data, points_flag_prefix_sum_data,
  282. num_points, max_voxels, num_voxels_data, grid_idx_to_voxel_idx_data);
  283. // 3. Store points to voxels coords and num_points_per_voxel
  284. int64_t num = max_voxels * max_num_points_in_voxel * num_point_dim;
  285. threads = 512;
  286. blocks = (num + threads - 1) / threads;
  287. PD_DISPATCH_FLOATING_TYPES(points.type(), "init_voxels_kernel", ([&] {
  288. init_voxels_kernel<data_t>
  289. <<<blocks, threads, 0, points.stream()>>>(
  290. num, voxels.data<data_t>());
  291. }));
  292. threads = 512;
  293. blocks = (num_points + threads - 1) / threads;
  294. PD_DISPATCH_FLOATING_TYPES(
  295. points.type(), "assign_voxels_kernel", ([&] {
  296. assign_voxels_kernel<data_t, int>
  297. <<<blocks, threads, 0, points.stream()>>>(
  298. points.data<data_t>(), points_to_grid_idx_data,
  299. points_to_num_idx_data, grid_idx_to_voxel_idx_data, num_points,
  300. num_point_dim, max_num_points_in_voxel, voxels.data<data_t>());
  301. }));
  302. // 4. Store coords, num_points_per_voxel
  303. blocks = (num_grids + threads - 1) / threads;
  304. assign_coords_kernel<int><<<blocks, threads, 0, points.stream()>>>(
  305. grid_idx_to_voxel_idx_data, num_points_in_grid_data, num_grids,
  306. grid_size_x, grid_size_y, grid_size_z, max_num_points_in_voxel,
  307. coords_data, num_points_per_voxel_data);
  308. return {voxels, coords, num_points_per_voxel, num_voxels};
  309. }