split.cc 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/function/split.h"
  15. #include "ultra_infer/utils/utils.h"
  16. #include <cstring>
  17. namespace ultra_infer {
  18. namespace function {
  19. /*
  20. * All tensors' dimension should be the same and the values of
  21. * each dimension must be the same, except the axis dimension.
  22. */
  23. template <typename T> struct SplitFunctor {
  24. public:
  25. void operator()(const FDTensor &input,
  26. const std::vector<const FDTensor *> &ref_inputs, int axis,
  27. std::vector<FDTensor> *outputs) {
  28. if (input.Numel() == 0) {
  29. return;
  30. }
  31. size_t num = outputs->size();
  32. int input_rows = 1;
  33. auto dim_0 = ref_inputs[0]->Shape();
  34. for (int i = 0; i < axis; ++i) {
  35. input_rows *= dim_0[i];
  36. }
  37. int input_cols = 0;
  38. std::vector<int64_t> output_cols(outputs->size());
  39. for (size_t i = 0; i < num; ++i) {
  40. int t_cols = ref_inputs[i]->Numel() / input_rows;
  41. input_cols += t_cols;
  42. output_cols[i] = t_cols;
  43. }
  44. // computation
  45. for (int k = 0; k < input_rows; ++k) {
  46. const T *src_ptr =
  47. reinterpret_cast<const T *>(input.Data()) + k * input_cols;
  48. int col_idx = 0;
  49. for (size_t j = 0; j < num; ++j) {
  50. int col_len = output_cols[j];
  51. auto *out_tensor = &(outputs->at(j));
  52. if (out_tensor != nullptr) {
  53. T *dst_ptr = reinterpret_cast<T *>(out_tensor->Data()) + k * col_len;
  54. std::memcpy(dst_ptr, src_ptr + col_idx, sizeof(T) * col_len);
  55. }
  56. col_idx += col_len;
  57. }
  58. }
  59. }
  60. };
  61. inline int GetSplitAxisValue(const FDTensor &x, int axis) {
  62. int rank = x.Shape().size();
  63. FDASSERT(axis >= -rank && axis < rank,
  64. "The axis is expected to be in range of [%d, %d), but got %d", -rank,
  65. rank, axis);
  66. if (axis < 0) {
  67. axis = axis + rank;
  68. }
  69. return axis;
  70. }
  71. void CreateSplitOutputs(const FDTensor &x,
  72. const std::vector<int> &sections_data,
  73. std::vector<FDTensor> *outs, int axis) {
  74. axis = GetSplitAxisValue(x, axis);
  75. auto input_axis_dim = x.Shape().at(axis);
  76. std::vector<int> sections_vec;
  77. const int unknow_dim_val = -1;
  78. int unknow_dim_idx = -1;
  79. int num_of_unknow = 0;
  80. int sum_of_section = 0;
  81. for (size_t i = 0; i < sections_data.size(); ++i) {
  82. sections_vec.push_back(sections_data[i]);
  83. if (sections_data[i] == unknow_dim_val) {
  84. num_of_unknow++;
  85. unknow_dim_idx = i;
  86. } else {
  87. sum_of_section += sections_data[i];
  88. }
  89. }
  90. FDASSERT(num_of_unknow <= 1,
  91. "Only one dimension value of Attr(num_or_sections) "
  92. "in SplitOp can be -1. "
  93. "But received Attr(num_or_sections) = [%s].",
  94. Str(sections_data).c_str());
  95. if (unknow_dim_idx != -1) {
  96. // for example, input shape = [4 ,5], axis = 1, sections = [2, 3, -1].
  97. // input_axis_dim = 5, sum_of_sections = 5.
  98. // the following check will fail.
  99. FDASSERT(sum_of_section < input_axis_dim,
  100. "Sum of Attr(num_or_sections) other than unknown section "
  101. "must be less than the input's "
  102. "size "
  103. "along the split dimension. But received Attr(num_or_sections) "
  104. "= [%s], input(X)'s shape = [%s], Attr(dim) = %d.",
  105. Str(sections_data).c_str(), Str(x.Shape()).c_str(), axis);
  106. sections_vec[unknow_dim_idx] = input_axis_dim - sum_of_section;
  107. } else {
  108. FDASSERT(sum_of_section == input_axis_dim,
  109. "Sum of Attr(num_or_sections) must be equal to the input's "
  110. "size "
  111. "along the split dimension. But received Attr(num_or_sections)"
  112. " = [%s], input(X)'s shape = [%s], Attr(dim) = %d.",
  113. Str(sections_data).c_str(), Str(x.Shape()).c_str(), axis);
  114. }
  115. // fill out dims
  116. std::vector<std::vector<int64_t>> out_dims(sections_vec.size(), x.Shape());
  117. for (size_t i = 0; i < sections_vec.size(); ++i) {
  118. out_dims[i][axis] = sections_vec[i];
  119. }
  120. for (size_t i = 0; i < sections_vec.size(); ++i) {
  121. (*outs)[i].Allocate(out_dims[i], x.Dtype());
  122. }
  123. }
  124. template <typename T>
  125. void SplitKernel(const FDTensor &x, const std::vector<int> &section,
  126. std::vector<FDTensor> *outs, int axis) {
  127. size_t out_number = section.size();
  128. outs->resize(out_number);
  129. CreateSplitOutputs(x, section, outs, axis);
  130. std::vector<const FDTensor *> shape_refer;
  131. for (size_t j = 0; j < outs->size(); ++j) {
  132. shape_refer.emplace_back(&((*outs)[j]));
  133. }
  134. SplitFunctor<T> functor;
  135. functor(x, shape_refer, axis, outs);
  136. }
  137. void Split(const FDTensor &x, const std::vector<int> &num_or_sections,
  138. std::vector<FDTensor> *out, int axis) {
  139. FD_VISIT_ALL_TYPES(x.Dtype(), "Split", ([&] {
  140. SplitKernel<data_t>(x, num_or_sections, out, axis);
  141. }));
  142. }
  143. } // namespace function
  144. } // namespace ultra_infer