elementwise_functor.h 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include "ultra_infer/function/eigen.h"
  16. #include "ultra_infer/function/elementwise.h"
  17. #include "ultra_infer/function/elementwise_base.h"
  18. #include <algorithm>
  19. namespace ultra_infer {
  20. namespace function {
  21. template <typename Functor> struct SameDimsElementwiseCompute {
  22. void operator()(const FDTensor &x, const FDTensor &y, FDTensor *z) {
  23. z->Allocate(x.Shape(), x.Dtype());
  24. Functor()(x, y, z);
  25. }
  26. };
  27. template <typename T> struct SameDimsAddFunctor {
  28. void operator()(const FDTensor &x, const FDTensor &y, FDTensor *z) {
  29. const auto &dev = *EigenDeviceWrapper::GetInstance()->GetDevice();
  30. auto eigen_x = EigenVector<T>::Flatten(x);
  31. auto eigen_y = EigenVector<T>::Flatten(y);
  32. auto eigen_z = EigenVector<T>::Flatten(*z);
  33. eigen_z.device(dev) = eigen_x + eigen_y;
  34. }
  35. };
  36. template <typename T> struct SameDimsSubtractFunctor {
  37. void operator()(const FDTensor &x, const FDTensor &y, FDTensor *z) {
  38. const auto &dev = *EigenDeviceWrapper::GetInstance()->GetDevice();
  39. auto eigen_x = EigenVector<T>::Flatten(x);
  40. auto eigen_y = EigenVector<T>::Flatten(y);
  41. auto eigen_z = EigenVector<T>::Flatten(*z);
  42. eigen_z.device(dev) = eigen_x - eigen_y;
  43. }
  44. };
  45. template <typename T> struct SameDimsMultiplyFunctor {
  46. void operator()(const FDTensor &x, const FDTensor &y, FDTensor *z) {
  47. const auto &dev = *EigenDeviceWrapper::GetInstance()->GetDevice();
  48. auto eigen_x = EigenVector<T>::Flatten(x);
  49. auto eigen_y = EigenVector<T>::Flatten(y);
  50. auto eigen_z = EigenVector<T>::Flatten(*z);
  51. eigen_z.device(dev) = eigen_x * eigen_y;
  52. }
  53. };
  54. template <typename T> struct SameDimsDivideFunctor {
  55. void operator()(const FDTensor &x, const FDTensor &y, FDTensor *z) {
  56. const auto &dev = *EigenDeviceWrapper::GetInstance()->GetDevice();
  57. auto eigen_x = EigenVector<T>::Flatten(x);
  58. auto eigen_y = EigenVector<T>::Flatten(y);
  59. auto eigen_z = EigenVector<T>::Flatten(*z);
  60. eigen_z.device(dev) = eigen_x / eigen_y;
  61. }
  62. };
  63. // Add
  64. template <typename T> struct AddFunctor {
  65. inline T operator()(const T a, const T b) const { return a + b; }
  66. };
  67. template <typename T> struct InverseAddFunctor {
  68. inline T operator()(const T a, const T b) const { return b + a; }
  69. };
  70. // Subtract
  71. template <typename T> struct SubtractFunctor {
  72. inline T operator()(const T a, const T b) const { return a - b; }
  73. };
  74. template <typename T> struct InverseSubtractFunctor {
  75. inline T operator()(const T a, const T b) const { return b - a; }
  76. };
  77. // Multiply
  78. template <typename T> struct MultiplyFunctor {
  79. inline T operator()(const T a, const T b) const { return a * b; }
  80. };
  81. template <> struct MultiplyFunctor<bool> {
  82. inline bool operator()(const bool a, const bool b) const { return a && b; }
  83. };
  84. template <typename T> struct InverseMultiplyFunctor {
  85. inline T operator()(const T a, const T b) const { return b * a; }
  86. };
  87. template <> struct InverseMultiplyFunctor<bool> {
  88. inline bool operator()(const bool a, const bool b) const { return b && a; }
  89. };
  90. // Divide
  91. #define DIV_ERROR_INFO \
  92. "InvalidArgumentError: Integer division by zero encountered in " \
  93. "(floor) divide. Please check the input value."
  94. template <typename T, typename Enable = void> struct DivideFunctor {
  95. inline T operator()(const T a, const T b) const { return a / b; }
  96. };
  97. template <typename T>
  98. struct DivideFunctor<
  99. T, typename std::enable_if<std::is_integral<T>::value>::type> {
  100. inline T operator()(const T a, const T b) const {
  101. // For int32/int64, need to check whether the division is zero.
  102. FDASSERT(b != 0, DIV_ERROR_INFO);
  103. return a / b;
  104. }
  105. };
  106. template <typename T, typename Enable = void> struct InverseDivideFunctor {
  107. inline T operator()(const T a, const T b) const { return b / a; }
  108. };
  109. // Maximum
  110. template <typename T> struct MaximumFunctor {
  111. inline T operator()(const T a, const T b) const { return a > b ? a : b; }
  112. };
  113. } // namespace function
  114. } // namespace ultra_infer