elementwise_base.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include <algorithm>
  16. #include "ultra_infer/core/fd_tensor.h"
  17. #include "ultra_infer/function/eigen.h"
  18. namespace ultra_infer {
  19. namespace function {
  20. #define DEFINE_ELEMENTWISE_OP(name) \
  21. template <typename T> struct name##RawKernel { \
  22. void operator()(const FDTensor &x, const FDTensor &y, int axis, \
  23. FDTensor *out) { \
  24. if (x.Shape() == y.Shape()) { \
  25. SameDimsElementwiseCompute<SameDims##name##Functor<T>>()(x, y, out); \
  26. } else { \
  27. auto x_dims = x.Shape(); \
  28. auto y_dims = y.Shape(); \
  29. if (x_dims.size() >= y_dims.size()) { \
  30. ElementwiseCompute<name##Functor<T>, T>(x, y, axis, \
  31. name##Functor<T>(), out); \
  32. } else { \
  33. ElementwiseCompute<Inverse##name##Functor<T>, T>( \
  34. x, y, axis, Inverse##name##Functor<T>(), out); \
  35. } \
  36. } \
  37. } \
  38. }
  39. inline void GetMidDims(const std::vector<int64_t> &x_dims,
  40. const std::vector<int64_t> &y_dims, const int axis,
  41. int *pre, int *n, int *post,
  42. int *is_run_common_broadcast) {
  43. *pre = 1;
  44. *n = 1;
  45. *post = 1;
  46. *is_run_common_broadcast = 0;
  47. for (int i = 0; i < axis; ++i) {
  48. (*pre) *= x_dims[i];
  49. }
  50. for (int i = 0; i < y_dims.size(); ++i) {
  51. if (x_dims[i + axis] != y_dims[i]) {
  52. FDASSERT(y_dims[i] == 1 || x_dims[i + axis] == 1,
  53. "Broadcast dimension mismatch. Operands "
  54. "could not be broadcast together with the shape of "
  55. "X = [%s] and the shape of Y = [%s]. Received [%d] "
  56. "in X is not equal to [%d] in Y.",
  57. Str(x_dims).c_str(), Str(y_dims).c_str(), x_dims[i + axis],
  58. y_dims[i]);
  59. *is_run_common_broadcast = 1;
  60. return;
  61. }
  62. (*n) *= y_dims[i];
  63. }
  64. for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
  65. (*post) *= x_dims[i];
  66. }
  67. }
  68. inline std::vector<int64_t>
  69. TrimTrailingSingularDims(const std::vector<int64_t> &dims) {
  70. // Remove trailing dimensions of size 1 for y
  71. auto actual_dims_size = dims.size();
  72. for (; actual_dims_size != 0; --actual_dims_size) {
  73. if (dims[actual_dims_size - 1] != 1)
  74. break;
  75. }
  76. if (actual_dims_size == dims.size())
  77. return dims;
  78. std::vector<int64_t> trim_dims;
  79. trim_dims.resize(actual_dims_size);
  80. for (int i = 0; i < actual_dims_size; ++i) {
  81. trim_dims[i] = dims[i];
  82. }
  83. return trim_dims;
  84. }
  85. inline int GetElementwiseIndex(const int64_t *x_dims_array, const int max_dim,
  86. const int64_t *index_array) {
  87. int index_ = 0;
  88. for (int i = 0; i < max_dim; i++) {
  89. if (x_dims_array[i] > 1) {
  90. index_ = index_ * x_dims_array[i] + index_array[i];
  91. }
  92. }
  93. return index_;
  94. }
  95. inline void UpdateElementwiseIndexArray(const int64_t *out_dims_array,
  96. const int max_dim,
  97. int64_t *index_array) {
  98. for (int i = max_dim - 1; i >= 0; --i) {
  99. ++index_array[i];
  100. if (index_array[i] >= out_dims_array[i]) {
  101. index_array[i] -= out_dims_array[i];
  102. } else {
  103. break;
  104. }
  105. }
  106. }
  107. inline void GetBroadcastDimsArrays(const std::vector<int64_t> &x_dims,
  108. const std::vector<int64_t> &y_dims,
  109. int64_t *x_dims_array, int64_t *y_dims_array,
  110. int64_t *out_dims_array, const int max_dim,
  111. const int axis) {
  112. FDASSERT(axis >= 0,
  113. "Axis should be great than or equal to 0, but received axis is %d.",
  114. axis);
  115. FDASSERT(axis < max_dim,
  116. "Axis should be less than %d, but received axis is %d.", max_dim,
  117. axis);
  118. if (x_dims.size() > y_dims.size()) {
  119. std::fill(y_dims_array, y_dims_array + axis, 1);
  120. if (axis + y_dims.size() < max_dim) {
  121. std::fill(y_dims_array + axis + y_dims.size(), y_dims_array + max_dim, 1);
  122. }
  123. std::copy(x_dims.data(), x_dims.data() + x_dims.size(), x_dims_array);
  124. std::copy(y_dims.data(), y_dims.data() + y_dims.size(),
  125. y_dims_array + axis);
  126. } else {
  127. std::fill(x_dims_array, x_dims_array + axis, 1);
  128. if (axis + x_dims.size() < max_dim) {
  129. std::fill(x_dims_array + axis + x_dims.size(), x_dims_array + max_dim, 1);
  130. }
  131. std::copy(x_dims.data(), x_dims.data() + x_dims.size(),
  132. x_dims_array + axis);
  133. std::copy(y_dims.data(), y_dims.data() + y_dims.size(), y_dims_array);
  134. }
  135. for (int i = 0; i < max_dim; i++) {
  136. FDASSERT(x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 ||
  137. y_dims_array[i] <= 1,
  138. "Broadcast dimension mismatch. Operands "
  139. "could not be broadcast together with the shape of "
  140. "X = [%s] and the shape of Y = [%s]. Received [%d] "
  141. "in X is not equal to [%d] in Y.",
  142. Str(x_dims).c_str(), Str(y_dims).c_str(), x_dims[i + axis],
  143. y_dims[i]);
  144. if ((x_dims_array[i] > 1 || y_dims_array[i] > 1) ||
  145. (x_dims_array[i] == 1 && y_dims_array[i] == 1)) {
  146. out_dims_array[i] = (std::max)(x_dims_array[i], y_dims_array[i]);
  147. } else {
  148. out_dims_array[i] = -1;
  149. }
  150. }
  151. }
  152. template <typename Functor, typename T, typename OutType = T>
  153. void CommonForwardBroadcastCPU(const FDTensor &x, const FDTensor &y,
  154. FDTensor *z, int64_t *x_dims_array,
  155. int64_t *y_dims_array, int64_t *out_dims_array,
  156. int max_dim, Functor func,
  157. const bool is_xsize_larger = true) {
  158. std::vector<int64_t> index_array(max_dim, 0);
  159. const T *x_data = reinterpret_cast<const T *>(x.Data());
  160. const T *y_data = reinterpret_cast<const T *>(y.Data());
  161. FDASSERT(x_data != nullptr, "The input X should not be empty.");
  162. FDASSERT(y_data != nullptr, "The input X should not be empty.");
  163. OutType *out_data = reinterpret_cast<OutType *>(z->Data());
  164. const int out_size = std::accumulate(out_dims_array, out_dims_array + max_dim,
  165. 1, std::multiplies<int64_t>());
  166. int x_index, y_index;
  167. for (int out_index = 0; out_index < out_size; ++out_index) {
  168. x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data());
  169. y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data());
  170. if (is_xsize_larger) {
  171. out_data[out_index] = func(x_data[x_index], y_data[y_index]);
  172. } else {
  173. out_data[out_index] = func(y_data[y_index], x_data[x_index]);
  174. }
  175. UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data());
  176. }
  177. }
  178. template <typename Functor, typename T, typename OutType = T>
  179. void CommonElementwiseBroadcastForward(const FDTensor &x, const FDTensor &y,
  180. FDTensor *z,
  181. const std::vector<int64_t> &x_dims,
  182. const std::vector<int64_t> &y_dims,
  183. Functor func, int axis,
  184. const bool is_xsize_larger = true) {
  185. int x_dims_size = x_dims.size();
  186. int y_dims_size = y_dims.size();
  187. int max_dim = (std::max)(x_dims_size, y_dims_size);
  188. axis = (axis == -1 ? std::abs(x_dims_size - y_dims_size) : axis);
  189. FDASSERT(axis >= 0,
  190. "Axis should be great than or equal to 0, but received axis is %d.",
  191. axis);
  192. FDASSERT(axis < max_dim,
  193. "Axis should be less than %d, but received axis is %d.", max_dim,
  194. axis);
  195. std::vector<int64_t> x_dims_array(max_dim);
  196. std::vector<int64_t> y_dims_array(max_dim);
  197. std::vector<int64_t> out_dims_array(max_dim);
  198. GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
  199. y_dims_array.data(), out_dims_array.data(), max_dim,
  200. axis);
  201. FDTensor tmp;
  202. tmp.Allocate(out_dims_array, TypeToDataType<OutType>::dtype);
  203. CommonForwardBroadcastCPU<Functor, T, OutType>(
  204. x, y, &tmp, x_dims_array.data(), y_dims_array.data(),
  205. out_dims_array.data(), max_dim, func, is_xsize_larger);
  206. *z = std::move(tmp);
  207. }
  208. template <typename Functor, typename T, typename OutType = T>
  209. void ElementwiseCompute(const FDTensor &x, const FDTensor &y, int axis,
  210. Functor func, FDTensor *z) {
  211. auto x_dims = x.Shape();
  212. auto y_dims = y.Shape();
  213. bool is_xsize_larger = true;
  214. int max_dim = x_dims.size();
  215. if (x_dims.size() < y_dims.size()) {
  216. is_xsize_larger = false;
  217. max_dim = y_dims.size();
  218. }
  219. int diff_size = x_dims.size() - y_dims.size();
  220. axis = (axis == -1 ? std::abs(diff_size) : axis);
  221. FDASSERT(axis >= 0,
  222. "Axis should be great than or equal to 0, but received axis is %d.",
  223. axis);
  224. FDASSERT(axis < max_dim,
  225. "Axis should be less than %d, but received axis is %d.", max_dim,
  226. axis);
  227. int pre, n, post, is_run_common_broadcast, axis_trim = 0;
  228. if (is_xsize_larger) {
  229. auto y_dims_trimed = TrimTrailingSingularDims(y_dims);
  230. axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis;
  231. GetMidDims(x_dims, y_dims_trimed, axis_trim, &pre, &n, &post,
  232. &is_run_common_broadcast);
  233. } else {
  234. auto x_dims_trimed = TrimTrailingSingularDims(x_dims);
  235. axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis;
  236. GetMidDims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post,
  237. &is_run_common_broadcast);
  238. }
  239. // special case for common implementation.
  240. // case 1: x=[2,3,1,5], y=[2,1,4,1]
  241. // case 2: x=[2,3,4], y=[1,1,4]
  242. CommonElementwiseBroadcastForward<Functor, T, OutType>(
  243. x, y, z, x_dims, y_dims, func, axis, is_xsize_larger);
  244. }
  245. } // namespace function
  246. } // namespace ultra_infer