elementwise.cc 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/function/elementwise.h"
  15. #include "ultra_infer/function/eigen.h"
  16. #include "ultra_infer/function/elementwise_base.h"
  17. #include "ultra_infer/function/elementwise_functor.h"
  18. #include "ultra_infer/utils/utils.h"
  19. #include <algorithm>
  20. namespace ultra_infer {
  21. namespace function {
  22. DEFINE_ELEMENTWISE_OP(Add);
  23. DEFINE_ELEMENTWISE_OP(Multiply);
  24. DEFINE_ELEMENTWISE_OP(Subtract);
  25. DEFINE_ELEMENTWISE_OP(Divide);
  26. void Add(const FDTensor &x, const FDTensor &y, FDTensor *out) {
  27. FD_VISIT_ALL_TYPES(x.dtype, "AddRawKernel",
  28. ([&] { AddRawKernel<data_t>()(x, y, -1, out); }));
  29. }
  30. void Subtract(const FDTensor &x, const FDTensor &y, FDTensor *out) {
  31. FD_VISIT_ALL_TYPES(x.dtype, "SubtractRawKernel",
  32. ([&] { SubtractRawKernel<data_t>()(x, y, -1, out); }));
  33. }
  34. void Multiply(const FDTensor &x, const FDTensor &y, FDTensor *out) {
  35. FD_VISIT_ALL_TYPES(x.dtype, "MultiplyRawKernel",
  36. ([&] { MultiplyRawKernel<data_t>()(x, y, -1, out); }));
  37. }
  38. void Divide(const FDTensor &x, const FDTensor &y, FDTensor *out) {
  39. FD_VISIT_ALL_TYPES(x.dtype, "DivideRawKernel",
  40. ([&] { DivideRawKernel<data_t>()(x, y, -1, out); }));
  41. }
  42. template <typename T> struct MaximumRawKernel {
  43. void operator()(const FDTensor &x, const FDTensor &y, int axis,
  44. FDTensor *out) {
  45. ElementwiseCompute<MaximumFunctor<T>, T>(x, y, axis, MaximumFunctor<T>(),
  46. out);
  47. }
  48. };
  49. void Maximum(const FDTensor &x, const FDTensor &y, FDTensor *out) {
  50. FD_VISIT_ALL_TYPES(x.dtype, "MaximumRawKernel",
  51. ([&] { MaximumRawKernel<data_t>()(x, y, -1, out); }));
  52. }
  53. } // namespace function
  54. FDTensor operator+(const FDTensor &x, const FDTensor &y) {
  55. FDTensor out;
  56. function::Add(x, y, &out);
  57. return out;
  58. }
  59. FDTensor operator-(const FDTensor &x, const FDTensor &y) {
  60. FDTensor out;
  61. function::Subtract(x, y, &out);
  62. return out;
  63. }
  64. FDTensor operator*(const FDTensor &x, const FDTensor &y) {
  65. FDTensor out;
  66. function::Multiply(x, y, &out);
  67. return out;
  68. }
  69. FDTensor operator/(const FDTensor &x, const FDTensor &y) {
  70. FDTensor out;
  71. function::Divide(x, y, &out);
  72. return out;
  73. }
  74. #define INSTANTIATE_OPERATOR(operation_type) \
  75. template FDTensor operator operation_type(const FDTensor &x, bool y); \
  76. template FDTensor operator operation_type(const FDTensor &x, uint8_t y); \
  77. template FDTensor operator operation_type(const FDTensor &x, int16_t y); \
  78. template FDTensor operator operation_type(const FDTensor &x, int y); \
  79. template FDTensor operator operation_type(const FDTensor &x, int64_t y); \
  80. template FDTensor operator operation_type(const FDTensor &x, float y); \
  81. template FDTensor operator operation_type(const FDTensor &x, double y); \
  82. template FDTensor operator operation_type(bool x, const FDTensor &y); \
  83. template FDTensor operator operation_type(uint8_t x, const FDTensor &y); \
  84. template FDTensor operator operation_type(int16_t x, const FDTensor &y); \
  85. template FDTensor operator operation_type(int x, const FDTensor &y); \
  86. template FDTensor operator operation_type(int64_t x, const FDTensor &y); \
  87. template FDTensor operator operation_type(float x, const FDTensor &y); \
  88. template FDTensor operator operation_type(double x, const FDTensor &y)
  89. INSTANTIATE_OPERATOR(+);
  90. INSTANTIATE_OPERATOR(-);
  91. INSTANTIATE_OPERATOR(*);
  92. INSTANTIATE_OPERATOR(/);
  93. } // namespace ultra_infer