quantile.cc 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/function/quantile.h"
  15. #include "ultra_infer/core/fd_scalar.h"
  16. #include "ultra_infer/function/cast.h"
  17. #include "ultra_infer/function/concat.h"
  18. #include "ultra_infer/function/elementwise.h"
  19. #include "ultra_infer/function/gather_scatter_along_axis.h"
  20. #include "ultra_infer/function/isfinite.h"
  21. #include "ultra_infer/function/math.h"
  22. #include "ultra_infer/function/reduce.h"
  23. #include "ultra_infer/function/sort.h"
  24. #include "ultra_infer/function/transpose.h"
  25. #include <algorithm>
  26. #include <cmath>
  27. #include <numeric>
  28. namespace ultra_infer {
  29. namespace function {
  30. template <typename T>
  31. void QuantileKernel(const FDTensor &x, const std::vector<double> &q,
  32. const std::vector<int> &axis, FDTensor *out) {
  33. FDASSERT(q.size() > 0, "q should not be empty.");
  34. FDASSERT(axis.size() > 0, "axis should not be empty.");
  35. std::vector<int64_t> axis_src;
  36. std::vector<int64_t> out_shape = x.Shape();
  37. int64_t rank = x.Shape().size();
  38. for (auto axis_single : axis) {
  39. FDASSERT(axis_single >= -rank && axis_single < rank,
  40. "The axis is expected to be in range of [%d, %d), but got %d",
  41. -rank, rank, axis_single);
  42. if (axis_single < 0) {
  43. axis_single += rank;
  44. }
  45. axis_src.push_back(axis_single);
  46. out_shape[axis_single] = 1;
  47. }
  48. std::vector<int64_t> axis_dst;
  49. for (int64_t i = 0; i < rank; ++i) {
  50. if (std::find(axis_src.begin(), axis_src.end(), i) == axis_src.end()) {
  51. axis_dst.push_back(i);
  52. }
  53. }
  54. axis_dst.insert(axis_dst.end(), axis_src.begin(), axis_src.end());
  55. FDTensor y;
  56. Transpose(x, &y, axis_dst);
  57. std::vector<int64_t> y_shape(rank - axis_src.size(), 0);
  58. y_shape.push_back(-1);
  59. y.Reshape({y_shape});
  60. int64_t target_axis = rank - 1;
  61. FDTensor mask, valid_counts, mask_any;
  62. IsNan(y, &mask);
  63. Any(mask, &mask_any, {target_axis}, true);
  64. bool *mask_data = reinterpret_cast<bool *>(mask.Data());
  65. std::transform(mask_data, mask_data + mask.Numel(), mask_data,
  66. [](const bool &val) { return !val; });
  67. Cast(mask_any, &mask_any, FDDataType::FP64);
  68. Cast(mask, &mask, FDDataType::FP64);
  69. Sum(mask, &valid_counts, {target_axis}, true);
  70. FDTensor one_tensor(Scalar(static_cast<double>(1.0)));
  71. std::vector<FDTensor> indices;
  72. FDTensor last_index(Scalar(static_cast<double>(x.Shape()[target_axis])));
  73. for (auto q_num : q) {
  74. FDASSERT(q_num >= 0 && q_num <= 1, "q should be in range [0, 1]");
  75. FDTensor q_tensor(static_cast<double>(q_num));
  76. FDTensor index = q_tensor * (valid_counts - one_tensor);
  77. index = mask_any * last_index + (one_tensor - mask_any) * index;
  78. indices.push_back(index);
  79. }
  80. std::vector<FDTensor> outputs;
  81. FDTensor sorted_tensor, sorted_indices_tensor;
  82. Sort(y, &sorted_tensor, &sorted_indices_tensor, target_axis);
  83. Cast(sorted_tensor, &sorted_tensor, FDDataType::FP64);
  84. FDTensor indices_below, indices_upper;
  85. for (auto &&index : indices) {
  86. Floor(index, &indices_below);
  87. Ceil(index, &indices_upper);
  88. Cast(indices_below, &indices_below, FDDataType::INT32);
  89. Cast(indices_upper, &indices_upper, FDDataType::INT32);
  90. FDTensor tensor_below, tensor_upper;
  91. GatherAlongAxis(sorted_tensor, indices_below, &tensor_below, target_axis);
  92. GatherAlongAxis(sorted_tensor, indices_upper, &tensor_upper, target_axis);
  93. // Need to cast to FP64 to compute with index and tensor_upper
  94. Cast(indices_below, &indices_below, FDDataType::FP64);
  95. FDTensor weight = index - indices_below;
  96. FDTensor out = tensor_below + weight * (tensor_upper - tensor_below);
  97. out.Squeeze(target_axis);
  98. if (out.Dtype() != x.Dtype()) {
  99. Cast(out, &out, x.Dtype());
  100. }
  101. outputs.push_back(std::move(out));
  102. }
  103. if (outputs.size() > 1) {
  104. // Execute stack operation
  105. for (auto &output : outputs) {
  106. output.ExpandDim(0);
  107. }
  108. Concat(outputs, out, 0);
  109. } else {
  110. *out = std::move(outputs[0]);
  111. }
  112. }
  113. void Quantile(const FDTensor &x, const std::vector<double> &q,
  114. const std::vector<int> &axis, FDTensor *out) {
  115. FD_VISIT_FLOAT_TYPES(x.dtype, "QuantileKernel",
  116. ([&] { QuantileKernel<data_t>(x, q, axis, out); }));
  117. }
  118. } // namespace function
  119. } // namespace ultra_infer