reduce.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/function/reduce.h"
  15. #include <limits>
  16. #include <set>
  17. #include "ultra_infer/function/eigen.h"
  18. #include "ultra_infer/function/reduce_functor.h"
  19. #include "ultra_infer/function/transpose.h"
  20. #include "ultra_infer/utils/utils.h"
  21. namespace ultra_infer {
  22. namespace function {
  23. template <typename T, size_t D, size_t R_D, typename Functor>
  24. void ReduceFunctor(const FDTensor &input, FDTensor *output,
  25. const std::vector<int64_t> &dims, bool keep_dim) {
  26. auto x = EigenTensor<T, D>::From(input);
  27. auto x_rank = static_cast<int>(x.dimensions().size());
  28. auto reduce_dim = Eigen::array<int, R_D>();
  29. std::vector<int64_t> dims_ref = dims;
  30. auto out_dims = input.shape;
  31. for (size_t i = 0; i < dims_ref.size(); ++i) {
  32. if (dims_ref[i] < 0)
  33. dims_ref[i] = x_rank + dims_ref[i];
  34. reduce_dim[i] = dims_ref[i];
  35. out_dims[dims_ref[i]] = 1;
  36. }
  37. auto origin_output_dims = out_dims;
  38. output->Allocate(origin_output_dims, TypeToDataType<T>::dtype);
  39. // construct the squeezed output tensor
  40. if (x_rank > 1) {
  41. const int kDelFlag = -2;
  42. for (size_t i = 0; i < dims_ref.size(); ++i) {
  43. out_dims[dims_ref[i]] = kDelFlag;
  44. }
  45. out_dims.erase(remove(out_dims.begin(), out_dims.end(), kDelFlag),
  46. out_dims.end());
  47. }
  48. auto &place = *EigenDeviceWrapper::GetInstance()->GetDevice();
  49. Functor functor;
  50. if (D == 1) {
  51. auto out = EigenScalar<T>::From(*output);
  52. functor(place, &x, &out, reduce_dim);
  53. } else {
  54. auto out = EigenTensor<T, (D - R_D)>::From(*output, out_dims);
  55. functor(place, &x, &out, reduce_dim);
  56. if (!keep_dim) {
  57. output->shape = std::move(out_dims);
  58. }
  59. }
  60. }
  61. #define HANDLE_REDUCE_DIM(NDIM, RDIM) \
  62. if (ndim == NDIM && rdim == RDIM) { \
  63. ReduceFunctor<OutT, NDIM, RDIM, Functor>(input, output, dims, keep_dim); \
  64. }
  65. inline void GetShuffledDim(const std::vector<int64_t> &src_dims,
  66. std::vector<int64_t> *dst_dims,
  67. const std::vector<int64_t> &reduced_dims,
  68. std::vector<int64_t> *perm_axis) {
  69. // check if it's a reduced dim
  70. std::vector<bool> src_dims_check(src_dims.size(), false);
  71. size_t src_size = src_dims.size();
  72. size_t reduce_size = reduced_dims.size();
  73. std::vector<int64_t> regular_reduced_dims = reduced_dims;
  74. for (size_t i = 0; i < regular_reduced_dims.size(); i++) {
  75. if (regular_reduced_dims[i] < 0) {
  76. regular_reduced_dims[i] = src_size + regular_reduced_dims[i];
  77. }
  78. }
  79. for (size_t i = 0; i < reduce_size; ++i) {
  80. dst_dims->at(src_size - reduce_size + i) =
  81. src_dims[regular_reduced_dims[i]];
  82. (*perm_axis)[src_size - reduce_size + i] = regular_reduced_dims[i];
  83. src_dims_check[regular_reduced_dims[i]] = true;
  84. }
  85. size_t offset = 0;
  86. for (size_t i = 0; i < src_dims_check.size(); ++i) {
  87. bool is_reduced = src_dims_check[i];
  88. if (!is_reduced) {
  89. (*perm_axis)[offset] = i;
  90. dst_dims->at(offset++) = src_dims[i];
  91. }
  92. }
  93. }
  94. template <typename OutT>
  95. void GetShuffledInput(const FDTensor &input, FDTensor *shuffled_input,
  96. const std::vector<int64_t> &dims) {
  97. auto shuffled_dims = input.shape;
  98. std::vector<int64_t> perm_axis(input.shape.size());
  99. GetShuffledDim(input.shape, &shuffled_dims, dims, &perm_axis);
  100. shuffled_input->Allocate(shuffled_dims, input.dtype);
  101. Transpose(input, shuffled_input, perm_axis);
  102. }
  103. //////////////// HandleLargeDim
  104. template <typename OutT, typename Functor>
  105. void HandleLargeDim(const FDTensor &input, FDTensor *output,
  106. const std::vector<int64_t> &dims, bool keep_dim) {
  107. auto out_dims = input.shape;
  108. std::vector<int64_t> dims_ref = dims;
  109. auto x_rank = input.shape.size();
  110. for (size_t i = 0; i < dims_ref.size(); ++i) {
  111. if (dims_ref[i] < 0)
  112. dims_ref[i] = x_rank + dims_ref[i];
  113. out_dims[dims_ref[i]] = 1;
  114. }
  115. if (!keep_dim) {
  116. const int kDelFlag = -2;
  117. for (size_t i = 0; i < dims_ref.size(); ++i) {
  118. out_dims[dims_ref[i]] = kDelFlag;
  119. }
  120. out_dims.erase(remove(out_dims.begin(), out_dims.end(), kDelFlag),
  121. out_dims.end());
  122. }
  123. output->Allocate(out_dims, TypeToDataType<OutT>::dtype);
  124. // shuffle the reduced dim to the end
  125. FDTensor shuffled_input;
  126. GetShuffledInput<OutT>(input, &shuffled_input, dims);
  127. // transpose to 2D tensor whose shape is {unreduced, reduced}.
  128. const int64_t unreduced = output->Numel();
  129. const int64_t reduced = shuffled_input.Numel() / unreduced;
  130. shuffled_input.Allocate({unreduced, reduced}, TypeToDataType<OutT>::dtype);
  131. output->shape = {unreduced};
  132. ReduceFunctor<OutT, 2, 1, Functor>(shuffled_input, output, {1}, keep_dim);
  133. output->shape = out_dims;
  134. }
  135. ////////////// ReduceKernel
  136. template <typename OutT, typename Functor>
  137. void ReduceKernelImpl(const FDTensor &input, FDTensor *output,
  138. const std::vector<int64_t> &dims, bool keep_dim,
  139. bool reduce_all) {
  140. output->Allocate({1}, TypeToDataType<OutT>::dtype);
  141. const auto &dev = *EigenDeviceWrapper::GetInstance()->GetDevice();
  142. if (reduce_all) {
  143. // Flatten and reduce 1-D tensor
  144. auto x = EigenVector<OutT>::Flatten(input);
  145. auto out = EigenScalar<OutT>::From(*output);
  146. auto reduce_dim = Eigen::array<int, 1>({{0}});
  147. Functor functor;
  148. functor(dev, &x, &out, reduce_dim);
  149. } else {
  150. int ndim = input.shape.size();
  151. int rdim = dims.size();
  152. if (ndim > 4) {
  153. HandleLargeDim<OutT, Functor>(input, output, dims, keep_dim);
  154. } else {
  155. HANDLE_REDUCE_DIM(4, 3);
  156. HANDLE_REDUCE_DIM(4, 2);
  157. HANDLE_REDUCE_DIM(4, 1);
  158. HANDLE_REDUCE_DIM(3, 2);
  159. HANDLE_REDUCE_DIM(3, 1);
  160. HANDLE_REDUCE_DIM(2, 1);
  161. HANDLE_REDUCE_DIM(1, 1);
  162. }
  163. }
  164. }
  165. template <typename OutT, typename Functor>
  166. void BoolReduceKernel(const FDTensor &input, FDTensor *output,
  167. const std::vector<int64_t> &dims, bool keep_dim,
  168. bool reduce_all) {
  169. // The dims has full dim, set the reduce_all is True
  170. const auto &input_dim_size = input.shape.size();
  171. std::set<int> dims_set(dims.begin(), dims.end());
  172. bool full_dim = true;
  173. for (auto i = 0; i < input_dim_size; i++) {
  174. if (dims_set.find(i) == dims_set.end()) {
  175. full_dim = false;
  176. break;
  177. }
  178. }
  179. reduce_all = (reduce_all || full_dim);
  180. ReduceKernelImpl<bool, Functor>(input, output, dims, keep_dim, reduce_all);
  181. }
  182. template <typename Functor>
  183. void Reduce(const FDTensor &x, FDTensor *out, const std::vector<int64_t> &dims,
  184. bool keep_dim, bool reduce_all) {
  185. // If the dims has full dim, set the reduce_all is True
  186. const int &input_dim_size = x.shape.size();
  187. std::set<int> dims_set(dims.begin(), dims.end());
  188. bool full_dim = true;
  189. for (int i = 0; i < input_dim_size; ++i) {
  190. if (dims_set.find(i) == dims_set.end() &&
  191. dims_set.find(i - input_dim_size) == dims_set.end()) {
  192. full_dim = false;
  193. break;
  194. }
  195. }
  196. reduce_all = (reduce_all || full_dim);
  197. FD_VISIT_INT_FLOAT_TYPES(x.dtype, "ReduceKernelImpl", ([&] {
  198. ReduceKernelImpl<data_t, Functor>(
  199. x, out, dims, keep_dim, reduce_all);
  200. }));
  201. }
  202. enum ArgMinMaxType { kArgMin, kArgMax };
  203. template <typename T, typename Tout, int64_t Rank, ArgMinMaxType argMinMaxValue>
  204. struct ArgMinMaxFunctor {};
  205. #define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
  206. template <typename T, typename Tout, int64_t Rank> \
  207. struct ArgMinMaxFunctor<T, Tout, Rank, enum_argminmax_value> { \
  208. void operator()(const FDTensor &in, FDTensor *out, \
  209. const std::vector<int64_t> &x_dims, int64_t axis, \
  210. bool keepdims, bool flatten) { \
  211. const auto &dev = *EigenDeviceWrapper::GetInstance()->GetDevice(); \
  212. auto in_eigen = EigenTensor<T, Rank>::From(in, x_dims); \
  213. if (keepdims) { \
  214. if (!flatten) { \
  215. auto out_eigen = EigenTensor<Tout, Rank>::From(*out); \
  216. out_eigen.device(dev) = \
  217. in_eigen.eigen_op_type(axis).template cast<Tout>(); \
  218. } else { \
  219. auto out_eigen = EigenScalar<Tout>::From(*out); \
  220. out_eigen.device(dev) = \
  221. in_eigen.eigen_op_type(axis).template cast<Tout>(); \
  222. } \
  223. } else { \
  224. auto out_eigen = EigenTensor<Tout, Rank - 1>::From(*out); \
  225. out_eigen.device(dev) = \
  226. in_eigen.eigen_op_type(axis).template cast<Tout>(); \
  227. } \
  228. } \
  229. }
  230. DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
  231. DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax);
  232. template <typename T, typename Tout, ArgMinMaxType EnumArgMinMaxValue>
  233. void ArgMinMaxKernel(const FDTensor &x, FDTensor *out, int64_t axis,
  234. bool keepdims, bool flatten) {
  235. bool new_keepdims = keepdims | flatten;
  236. // if flatten, will construct the new dims for the calculate
  237. std::vector<int64_t> x_dims;
  238. int new_axis = axis;
  239. if (flatten) {
  240. x_dims = {x.Numel()};
  241. // if flatten, the axis just as 0
  242. new_axis = 0;
  243. } else {
  244. x_dims = x.shape;
  245. if (axis < 0)
  246. new_axis = axis + x_dims.size();
  247. }
  248. #define CALL_ARG_MINMAX_FUNCTOR(rank) \
  249. ArgMinMaxFunctor<T, Tout, rank, EnumArgMinMaxValue> functor##rank; \
  250. functor##rank(x, out, x_dims, new_axis, new_keepdims, flatten)
  251. switch (x_dims.size()) {
  252. case 1:
  253. CALL_ARG_MINMAX_FUNCTOR(1);
  254. break;
  255. case 2:
  256. CALL_ARG_MINMAX_FUNCTOR(2);
  257. break;
  258. case 3:
  259. CALL_ARG_MINMAX_FUNCTOR(3);
  260. break;
  261. case 4:
  262. CALL_ARG_MINMAX_FUNCTOR(4);
  263. break;
  264. case 5:
  265. CALL_ARG_MINMAX_FUNCTOR(5);
  266. break;
  267. case 6:
  268. CALL_ARG_MINMAX_FUNCTOR(6);
  269. break;
  270. default:
  271. FDASSERT(x_dims.size() <= 6,
  272. "%s operator doesn't supports tensors whose ranks are greater "
  273. "than 6.",
  274. (EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax"));
  275. break;
  276. #undef CALL_ARG_MINMAX_FUNCTOR
  277. }
  278. }
  279. template <typename T, ArgMinMaxType EnumArgMinMaxValue>
  280. void ArgMinMax(const FDTensor &x, FDTensor *out, int64_t axis,
  281. FDDataType output_dtype, bool keepdims, bool flatten) {
  282. const auto &x_dims = x.shape;
  283. int64_t x_rank = x_dims.size();
  284. FDASSERT(axis >= -x_rank,
  285. "'axis'(%lld) must be greater than or equal to -Rank(X)(%lld).",
  286. axis, -x_rank);
  287. FDASSERT(axis < x_rank,
  288. "'axis'(%lld) must be less than or equal to Rank(X)(%lld).", axis,
  289. x_rank);
  290. FDASSERT(
  291. output_dtype == FDDataType::INT32 || FDDataType::INT64 ||
  292. FDDataType::UINT8,
  293. "The attribute of dtype in argmin/argmax must be [%s], [%s] or [%s], but "
  294. "received [%s].",
  295. Str(FDDataType::INT32).c_str(), Str(FDDataType::INT64).c_str(),
  296. Str(FDDataType::UINT8).c_str(), Str(output_dtype).c_str());
  297. if (axis < 0)
  298. axis += x_rank;
  299. if (output_dtype == FDDataType::INT32) {
  300. int64_t all_element_num = 0;
  301. if (flatten) {
  302. all_element_num = x.Numel();
  303. } else {
  304. all_element_num = x_dims[axis];
  305. }
  306. FDASSERT(all_element_num <= (std::numeric_limits<int>::max)(),
  307. "The element num of the argmin/argmax input at axis is "
  308. "%lld, is larger than int32 maximum value:%d, you must "
  309. "set the dtype of argmin/argmax to 'int64'.",
  310. all_element_num, (std::numeric_limits<int>::max)());
  311. }
  312. std::vector<int64_t> vec;
  313. if (flatten) {
  314. vec.emplace_back(static_cast<int64_t>(1));
  315. } else {
  316. for (int64_t i = 0; i < axis; i++)
  317. vec.emplace_back(x_dims[i]);
  318. if (keepdims) {
  319. vec.emplace_back(static_cast<int64_t>(1));
  320. }
  321. for (int64_t i = axis + 1; i < x_rank; i++)
  322. vec.emplace_back(x_dims[i]);
  323. }
  324. out->Allocate(vec, output_dtype);
  325. FD_VISIT_INT_TYPES(output_dtype, "ArgMinMaxKernel", ([&] {
  326. ArgMinMaxKernel<T, data_t, EnumArgMinMaxValue>(
  327. x, out, axis, keepdims, flatten);
  328. }));
  329. }
  330. void Max(const FDTensor &x, FDTensor *out, const std::vector<int64_t> &dims,
  331. bool keep_dim, bool reduce_all) {
  332. Reduce<MaxFunctor>(x, out, dims, keep_dim, reduce_all);
  333. }
  334. void Min(const FDTensor &x, FDTensor *out, const std::vector<int64_t> &dims,
  335. bool keep_dim, bool reduce_all) {
  336. Reduce<MinFunctor>(x, out, dims, keep_dim, reduce_all);
  337. }
  338. void Sum(const FDTensor &x, FDTensor *out, const std::vector<int64_t> &dims,
  339. bool keep_dim, bool reduce_all) {
  340. Reduce<SumFunctor>(x, out, dims, keep_dim, reduce_all);
  341. }
  342. void All(const FDTensor &x, FDTensor *out, const std::vector<int64_t> &dims,
  343. bool keep_dim, bool reduce_all) {
  344. BoolReduceKernel<bool, AllFunctor>(x, out, dims, keep_dim, reduce_all);
  345. }
  346. void Any(const FDTensor &x, FDTensor *out, const std::vector<int64_t> &dims,
  347. bool keep_dim, bool reduce_all) {
  348. BoolReduceKernel<bool, AnyFunctor>(x, out, dims, keep_dim, reduce_all);
  349. }
  350. void Mean(const FDTensor &x, FDTensor *out, const std::vector<int64_t> &dims,
  351. bool keep_dim, bool reduce_all) {
  352. Reduce<MeanFunctor>(x, out, dims, keep_dim, reduce_all);
  353. }
  354. void Prod(const FDTensor &x, FDTensor *out, const std::vector<int64_t> &dims,
  355. bool keep_dim, bool reduce_all) {
  356. Reduce<ProdFunctor>(x, out, dims, keep_dim, reduce_all);
  357. }
  358. void ArgMax(const FDTensor &x, FDTensor *out, int64_t axis,
  359. FDDataType output_dtype, bool keep_dim, bool flatten) {
  360. FD_VISIT_INT_FLOAT_TYPES(x.dtype, "ArgMaxKernel", ([&] {
  361. ArgMinMax<data_t, kArgMax>(
  362. x, out, axis, output_dtype, keep_dim, flatten);
  363. }));
  364. }
  365. void ArgMin(const FDTensor &x, FDTensor *out, int64_t axis,
  366. FDDataType output_dtype, bool keep_dim, bool flatten) {
  367. FD_VISIT_INT_FLOAT_TYPES(x.dtype, "ArgMaxKernel", ([&] {
  368. ArgMinMax<data_t, kArgMin>(
  369. x, out, axis, output_dtype, keep_dim, flatten);
  370. }));
  371. }
  372. } // namespace function
  373. } // namespace ultra_infer