gather_scatter_along_axis.cc 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/function/gather_scatter_along_axis.h"
  15. #include "ultra_infer/function/tile.h"
  16. namespace ultra_infer {
  17. namespace function {
  18. class TensorAssign {
  19. public:
  20. template <typename tensor_t>
  21. void operator()(tensor_t *self_data, tensor_t *src_data) const {
  22. *self_data = *src_data;
  23. }
  24. };
  25. static TensorAssign tensor_assign;
  26. template <typename T, typename index_t = int64_t, bool is_scatter_like = true>
  27. struct GatherScatterFunctor {
  28. template <typename func_t>
  29. void operator()(const FDTensor &x, int axis, const FDTensor &index,
  30. FDTensor *result, const func_t &reduce_op) {
  31. if (index.Numel() == 0) {
  32. return;
  33. }
  34. result->Allocate(index.Shape(), x.Dtype());
  35. const T *x_data = reinterpret_cast<const T *>(x.Data());
  36. const index_t *index_data = reinterpret_cast<const index_t *>(index.Data());
  37. T *result_data = reinterpret_cast<T *>(result->Data());
  38. int64_t x_size = x.Numel();
  39. int64_t index_size = index.Numel();
  40. int64_t result_size = result->Numel();
  41. auto x_dims = x.Shape();
  42. auto index_dims = index.Shape();
  43. auto result_dims = result->Shape();
  44. if (x_size == 0 || result_size == 0 || index_size == 0) {
  45. FDASSERT(false, "zero size input found, self_size, result_size, "
  46. "index_size cannot be 0");
  47. return;
  48. }
  49. int select_dim_size = index_dims[axis];
  50. // index matrix has different shape with self matrix or src matrix.
  51. int replaced_select_dim_size =
  52. is_scatter_like ? result_dims[axis] : x_dims[axis];
  53. int64_t inner_dim_size = 1;
  54. int64_t outer_dim_size = 1;
  55. for (int64_t i = 0; i < axis; ++i) {
  56. inner_dim_size *= index_dims[i];
  57. }
  58. for (int i = axis + 1; i < index_dims.size(); i++) {
  59. outer_dim_size *= index_dims[i];
  60. }
  61. int64_t index_idx = 0;
  62. int64_t self_idx, src_idx;
  63. // N layer loop squeezed into 3 layers loop
  64. for (int64_t i = 0; i < inner_dim_size; i++) {
  65. for (int64_t j = 0; j < select_dim_size; j++) {
  66. for (int64_t k = 0; k < outer_dim_size; k++) {
  67. int64_t index = index_data[index_idx];
  68. // This index might out of bound of index matrix's index, so here
  69. // multiply the replaced_select_dim_size.
  70. int64_t replace_index = k + index * outer_dim_size +
  71. i * outer_dim_size * replaced_select_dim_size;
  72. self_idx = is_scatter_like ? replace_index : index_idx;
  73. src_idx = is_scatter_like ? index_idx : replace_index;
  74. reduce_op((T *)(result_data + self_idx), // NOLINT
  75. (T *)(x_data + src_idx)); // NOLINT
  76. index_idx++;
  77. }
  78. }
  79. }
  80. }
  81. };
  82. template <typename T> struct GatherFunctor {
  83. void operator()(const FDTensor &x, int axis, const FDTensor &index,
  84. FDTensor *result) {
  85. FD_VISIT_INT_TYPES(index.Dtype(), "GatherFunctor", [&]() {
  86. auto x_shape = x.Shape();
  87. auto index_shape = index.Shape();
  88. std::vector<int64_t> repeat_times(x_shape.size(), 1);
  89. for (int i = 0; i < x_shape.size(); ++i) {
  90. repeat_times[i] = x_shape[i] / index_shape[i];
  91. }
  92. repeat_times[axis] = 1;
  93. FDTensor gs_index;
  94. Tile(index, repeat_times, &gs_index);
  95. GatherScatterFunctor<T, data_t, /*is_scatter_like=*/false>()(
  96. x, axis, gs_index, result, tensor_assign);
  97. });
  98. }
  99. };
  100. void GatherAlongAxis(const FDTensor &x, const FDTensor &index, FDTensor *result,
  101. int axis) {
  102. int rank = x.Shape().size();
  103. FDASSERT(axis >= -rank && axis < rank,
  104. "axis should be in range [-%d, %d - 1].", rank, rank - 1);
  105. if (axis < 0) {
  106. axis += rank;
  107. }
  108. FD_VISIT_ALL_TYPES(x.Dtype(), "GatherAlongAxis", [&]() {
  109. GatherFunctor<data_t>()(x, axis, index, result);
  110. });
  111. }
  112. } // namespace function
  113. } // namespace ultra_infer