sort.cc 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/function/sort.h"
  15. #include "ultra_infer/function/eigen.h"
  16. #include "ultra_infer/function/transpose.h"
  17. #include <algorithm>
  18. #include <cmath>
  19. #include <numeric>
  20. namespace ultra_infer {
  21. namespace function {
  22. template <typename T, typename Type>
  23. static void FullSort(Type input_height, Type input_width, int input_dim,
  24. const FDTensor *input, FDTensor *out, FDTensor *indices,
  25. bool descending) {
  26. out->Allocate(input->Shape(), input->Dtype());
  27. indices->Allocate(input->Shape(), TypeToDataType<Type>::dtype);
  28. T *t_out = reinterpret_cast<T *>(out->Data());
  29. Type *t_indices = reinterpret_cast<Type *>(indices->Data());
  30. for (Type i = 0; i < input_height; ++i) {
  31. std::vector<std::pair<T, Type>> col_vec;
  32. col_vec.reserve(input_width);
  33. if (input_dim == 1) {
  34. auto e_input = EigenVector<T>::Flatten(*input);
  35. for (Type j = 0; j < input_width; ++j) {
  36. col_vec.push_back(std::pair<T, Type>(e_input(j), j));
  37. }
  38. } else {
  39. auto e_input = EigenMatrix<T>::Reshape(*input, input_dim - 1);
  40. for (Type j = 0; j < input_width; ++j) {
  41. col_vec.push_back(std::pair<T, Type>(e_input(i, j), j));
  42. }
  43. }
  44. std::sort(col_vec.begin(), col_vec.end(),
  45. [&](const std::pair<T, Type> &l, const std::pair<T, Type> &r) {
  46. if (descending)
  47. return (std::isnan(static_cast<double>(l.first)) &&
  48. !std::isnan(static_cast<double>(r.first))) ||
  49. (l.first > r.first);
  50. else
  51. return (!std::isnan(static_cast<double>(l.first)) &&
  52. std::isnan(static_cast<double>(r.first))) ||
  53. (l.first < r.first);
  54. });
  55. for (Type j = 0; j < input_width; ++j) {
  56. t_out[i * input_width + j] = col_vec[j].first;
  57. t_indices[i * input_width + j] = col_vec[j].second;
  58. }
  59. }
  60. }
  61. template <typename T>
  62. void SortKernel(const FDTensor &x, FDTensor *out, FDTensor *indices,
  63. FDDataType indices_type, bool descending, int axis) {
  64. auto input_shape = x.Shape();
  65. int rank = input_shape.size();
  66. axis = (axis < 0) ? (rank + axis) : axis;
  67. // Do full sort
  68. if (axis == -1 || axis + 1 == rank) {
  69. int64_t numel = x.Numel();
  70. int64_t input_width = input_shape[axis];
  71. int64_t input_height = numel / input_width;
  72. FD_VISIT_INT_TYPES(indices_type, "FullSort", ([&] {
  73. FullSort<T, data_t>(input_height, input_width, rank,
  74. &x, out, indices, descending);
  75. }));
  76. } else {
  77. // If not full sort do transpose
  78. std::vector<int64_t> trans;
  79. for (int i = 0; i < axis; i++) {
  80. trans.push_back(i);
  81. }
  82. trans.push_back(rank - 1);
  83. for (int i = axis + 1; i < rank - 1; i++) {
  84. trans.push_back(i);
  85. }
  86. trans.push_back(axis);
  87. FDTensor trans_inp;
  88. Transpose(x, &trans_inp, trans);
  89. int64_t numel = x.Numel();
  90. int64_t input_width = input_shape[axis];
  91. int64_t input_height = numel / input_width;
  92. FD_VISIT_INT_TYPES(indices_type, "FullSort", ([&] {
  93. FullSort<T, data_t>(input_height, input_width, rank,
  94. &trans_inp, out, indices,
  95. descending);
  96. }));
  97. // transpose back
  98. Transpose(*out, out, trans);
  99. Transpose(*indices, indices, trans);
  100. }
  101. }
  102. void Sort(const FDTensor &x, FDTensor *out, FDTensor *indices, int axis,
  103. bool descending, FDDataType indices_type) {
  104. FD_VISIT_INT_FLOAT_TYPES(x.dtype, "SortKernel", ([&] {
  105. SortKernel<data_t>(x, out, indices, indices_type,
  106. descending, axis);
  107. }));
  108. }
  109. } // namespace function
  110. } // namespace ultra_infer