slice.cc 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/function/slice.h"
  15. #include "ultra_infer/function/eigen.h"
  16. #include <algorithm>
  17. namespace ultra_infer {
  18. namespace function {
  19. std::vector<int64_t> GetSliceDims(const std::vector<int64_t> &in_dims,
  20. const std::vector<int64_t> &axes,
  21. const std::vector<int64_t> &starts,
  22. const std::vector<int64_t> &ends,
  23. std::vector<int64_t> *steps = nullptr) {
  24. std::vector<int64_t> slice_dims(in_dims);
  25. for (size_t i = 0; i < axes.size(); ++i) {
  26. int64_t axis = axes[i];
  27. if (in_dims[axis] == -1) {
  28. continue;
  29. }
  30. int64_t start = starts[i];
  31. int64_t end = ends[i];
  32. int64_t step = steps == nullptr ? 1 : (*steps)[i];
  33. if (step > 0) {
  34. slice_dims[axis] = (end - start + step - 1) / step;
  35. } else {
  36. slice_dims[axis] = (end - start + step + 1) / step;
  37. }
  38. }
  39. return slice_dims;
  40. }
  41. void CheckAndUpdateSliceAttrs(const std::vector<int64_t> &in_dims,
  42. const std::vector<int64_t> &axes,
  43. std::vector<int64_t> *starts,
  44. std::vector<int64_t> *ends,
  45. std::vector<int64_t> *steps = nullptr) {
  46. for (size_t i = 0; i < axes.size(); ++i) {
  47. int64_t axis = axes[i];
  48. FDASSERT(axis < in_dims.size(),
  49. "The axis value should be less than the rank of input, "
  50. "but received axes[%d] = %d, rank of input is %d.",
  51. i, axis, in_dims.size());
  52. int64_t dim_value = in_dims[axis];
  53. if (dim_value > 0) {
  54. int64_t step = steps == nullptr ? 1 : (*steps)[i];
  55. FDASSERT(step != 0, "Step should not be 0, but received step = %d.",
  56. step);
  57. int64_t start =
  58. (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
  59. start = (std::max)(start, static_cast<int64_t>(0));
  60. int64_t end =
  61. 0 < step && (*ends)[i] < 0 ? ((*ends)[i] + dim_value) : (*ends)[i];
  62. end = (std::min)(end, dim_value);
  63. if (step > 0) {
  64. start = (std::min)(start, dim_value);
  65. end = (std::max)(end, static_cast<int64_t>(0));
  66. FDASSERT(end > start,
  67. "When step > 0, end should be greater than start, but "
  68. "received end = %d, start = %d.",
  69. end, start)
  70. } else {
  71. start = (std::min)(start, dim_value - 1);
  72. if (end < -1) {
  73. end += dim_value;
  74. }
  75. end = (std::max)(end, static_cast<int64_t>(-1));
  76. FDASSERT(start >= end,
  77. "When step < 0, start should be greater than end, but "
  78. "received start = %d, end = %d.",
  79. start, end);
  80. }
  81. (*starts)[i] = start;
  82. (*ends)[i] = end;
  83. } else if (dim_value == 0) {
  84. (*starts)[i] = 0;
  85. (*ends)[i] = 0;
  86. }
  87. }
  88. }
  89. template <typename T, size_t D>
  90. void SliceKernel(const FDTensor &x, const std::vector<int64_t> &axes,
  91. const std::vector<int64_t> &starts,
  92. const std::vector<int64_t> &ends, FDTensor *out) {
  93. FDASSERT(starts.size() == axes.size(),
  94. "The size of starts must be equal to the size of axes.");
  95. FDASSERT(ends.size() == axes.size(),
  96. "The size of ends must be equal to the size of axes.");
  97. auto starts_idx = starts;
  98. auto end_idx = ends;
  99. auto in_dims = x.Shape();
  100. CheckAndUpdateSliceAttrs(in_dims, axes, &starts_idx, &end_idx);
  101. auto slice_dims = GetSliceDims(in_dims, axes, starts, ends);
  102. auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
  103. auto extents = Eigen::DSizes<Eigen::DenseIndex, D>();
  104. for (size_t i = 0; i < D; ++i) {
  105. offsets[i] = 0;
  106. extents[i] = slice_dims[i];
  107. }
  108. for (size_t i = 0; i < axes.size(); ++i) {
  109. offsets[axes[i]] = starts[i];
  110. }
  111. out->Allocate(slice_dims, x.Dtype());
  112. auto in_t = EigenTensor<T, D>::From(x, in_dims);
  113. auto out_t = EigenTensor<T, D>::From(*out, slice_dims);
  114. const auto &dev = *EigenDeviceWrapper::GetInstance()->GetDevice();
  115. out_t.device(dev) = in_t.slice(offsets, extents);
  116. }
  117. void Slice(const FDTensor &x, const std::vector<int64_t> &axes,
  118. const std::vector<int64_t> &starts, const std::vector<int64_t> &ends,
  119. FDTensor *out) {
  120. FD_VISIT_ALL_TYPES(
  121. x.dtype, "SliceKernel", ([&] {
  122. int rank = x.Shape().size();
  123. switch (rank) {
  124. case 1:
  125. SliceKernel<data_t, 1>(x, axes, starts, ends, out);
  126. break;
  127. case 2:
  128. SliceKernel<data_t, 2>(x, axes, starts, ends, out);
  129. break;
  130. case 3:
  131. SliceKernel<data_t, 3>(x, axes, starts, ends, out);
  132. break;
  133. case 4:
  134. SliceKernel<data_t, 4>(x, axes, starts, ends, out);
  135. break;
  136. case 5:
  137. SliceKernel<data_t, 5>(x, axes, starts, ends, out);
  138. break;
  139. case 6:
  140. SliceKernel<data_t, 6>(x, axes, starts, ends, out);
  141. break;
  142. default:
  143. FDASSERT(false,
  144. "The rank of input should be less than 7, but received %d.",
  145. rank);
  146. }
  147. }));
  148. }
  149. void Slice(const FDTensor &x, const std::vector<int64_t> &axes,
  150. const std::vector<int64_t> &index, FDTensor *out) {
  151. std::vector<int64_t> ends = index;
  152. for (int i = 0; i < ends.size(); ++i) {
  153. ends[i] += 1;
  154. }
  155. Slice(x, axes, index, ends, out);
  156. for (int i = 0; i < axes.size(); ++i) {
  157. if (out->Shape().size() <= 1) {
  158. break;
  159. }
  160. out->Squeeze(axes[i]);
  161. }
  162. }
  163. } // namespace function
  164. } // namespace ultra_infer