iou3d_nms_api.cpp 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <cuda.h>
  15. #include <cuda_runtime_api.h>
  16. #include <paddle/extension.h>
  17. #include <vector>
  18. #include "iou3d_cpu.h"
  19. #include "iou3d_nms.h"
  20. std::vector<paddle::DataType> BoxesIouBevCpuInferDtype(
  21. paddle::DataType boxes_a_dtype, paddle::DataType boxes_b_dtype) {
  22. return {boxes_a_dtype};
  23. }
  24. std::vector<std::vector<int64_t>> BoxesIouBevCpuInferShape(
  25. std::vector<int64_t> boxes_a_shape, std::vector<int64_t> boxes_b_shape) {
  26. return {{boxes_a_shape[0], boxes_b_shape[0]}};
  27. }
  28. std::vector<paddle::DataType> NmsInferDtype(paddle::DataType boxes_dtype) {
  29. return {paddle::DataType::INT64, paddle::DataType::INT64};
  30. }
  31. std::vector<std::vector<int64_t>> NmsInferShape(
  32. std::vector<int64_t> boxes_shape) {
  33. return {{boxes_shape[0]}, {1}};
  34. }
  35. std::vector<paddle::DataType> NmsNormalInferDtype(
  36. paddle::DataType boxes_dtype) {
  37. return {paddle::DataType::INT64, paddle::DataType::INT64};
  38. }
  39. std::vector<std::vector<int64_t>> NmsNormalInferShape(
  40. std::vector<int64_t> boxes_shape) {
  41. return {{boxes_shape[0]}, {1}};
  42. }
  43. std::vector<paddle::DataType> BoxesIouBevGpuInferDtype(
  44. paddle::DataType boxes_a_dtype, paddle::DataType boxes_b_dtype) {
  45. return {boxes_a_dtype};
  46. }
  47. std::vector<std::vector<int64_t>> BoxesIouBevGpuInferShape(
  48. std::vector<int64_t> boxes_a_shape, std::vector<int64_t> boxes_b_shape) {
  49. return {{boxes_a_shape[0], boxes_b_shape[0]}};
  50. }
  51. std::vector<paddle::DataType> BoxesOverlapBevGpuInferDtype(
  52. paddle::DataType boxes_a_dtype, paddle::DataType boxes_b_dtype) {
  53. return {boxes_a_dtype};
  54. }
  55. std::vector<std::vector<int64_t>> BoxesOverlapBevGpuInferShape(
  56. std::vector<int64_t> boxes_a_shape, std::vector<int64_t> boxes_b_shape) {
  57. return {{boxes_a_shape[0], boxes_b_shape[0]}};
  58. }
  59. PD_BUILD_OP(boxes_iou_bev_cpu)
  60. .Inputs({"boxes_a_tensor", " boxes_b_tensor"})
  61. .Outputs({"ans_iou_tensor"})
  62. .SetKernelFn(PD_KERNEL(boxes_iou_bev_cpu))
  63. .SetInferDtypeFn(PD_INFER_DTYPE(BoxesIouBevCpuInferDtype))
  64. .SetInferShapeFn(PD_INFER_SHAPE(BoxesIouBevCpuInferShape));
  65. PD_BUILD_OP(boxes_iou_bev_gpu)
  66. .Inputs({"boxes_a_tensor", " boxes_b_tensor"})
  67. .Outputs({"ans_iou_tensor"})
  68. .SetKernelFn(PD_KERNEL(boxes_iou_bev_gpu))
  69. .SetInferDtypeFn(PD_INFER_DTYPE(BoxesIouBevGpuInferDtype))
  70. .SetInferShapeFn(PD_INFER_SHAPE(BoxesIouBevGpuInferShape));
  71. PD_BUILD_OP(boxes_overlap_bev_gpu)
  72. .Inputs({"boxes_a", " boxes_b"})
  73. .Outputs({"ans_overlap"})
  74. .SetKernelFn(PD_KERNEL(boxes_overlap_bev_gpu))
  75. .SetInferDtypeFn(PD_INFER_DTYPE(BoxesOverlapBevGpuInferDtype))
  76. .SetInferShapeFn(PD_INFER_SHAPE(BoxesOverlapBevGpuInferShape));
  77. PD_BUILD_OP(nms_gpu)
  78. .Inputs({"boxes"})
  79. .Outputs({"keep", "num_to_keep"})
  80. .Attrs({"nms_overlap_thresh: float"})
  81. .SetKernelFn(PD_KERNEL(nms_gpu))
  82. .SetInferDtypeFn(PD_INFER_DTYPE(NmsInferDtype))
  83. .SetInferShapeFn(PD_INFER_SHAPE(NmsInferShape));
  84. PD_BUILD_OP(nms_normal_gpu)
  85. .Inputs({"boxes"})
  86. .Outputs({"keep", "num_to_keep"})
  87. .Attrs({"nms_overlap_thresh: float"})
  88. .SetInferShapeFn(PD_INFER_SHAPE(NmsNormalInferShape))
  89. .SetKernelFn(PD_KERNEL(nms_normal_gpu))
  90. .SetInferDtypeFn(PD_INFER_DTYPE(NmsNormalInferDtype));