math_functor.h 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include "ultra_infer/function/eigen.h"
  16. namespace ultra_infer {
  17. namespace function {
  18. // log(x) = natural logarithm of x
  19. template <typename T> struct LogFunctor {
  20. template <typename Device, typename X, typename Out>
  21. void operator()(Device d, X x, Out out) const {
  22. out.device(d) = x.log();
  23. }
  24. };
  25. // exp functor
  26. // exp(x) = e^x
  27. template <typename T> struct ExpFunctor {
  28. template <typename Device, typename X, typename Out>
  29. void operator()(Device d, X x, Out out) const {
  30. out.device(d) = x.exp();
  31. }
  32. };
  33. // round(x) = [x]
  34. template <typename T> struct RoundFunctor {
  35. template <typename Device, typename X, typename Out>
  36. void operator()(Device d, X x, Out out) const {
  37. out.device(d) = x.round();
  38. }
  39. };
  40. // sqrt(x) = x^(1/2)
  41. template <typename T> struct SqrtFunctor {
  42. template <typename Device, typename X, typename Out>
  43. void operator()(Device d, X x, Out out) const {
  44. out.device(d) = x.sqrt();
  45. }
  46. };
  47. // abs(x) = x if x > 0 else -x
  48. template <typename T> struct AbsFunctor {
  49. template <typename Device, typename X, typename Out>
  50. void operator()(Device d, X x, Out out) const {
  51. out.device(d) =
  52. x.unaryExpr([](T v) { return v > static_cast<T>(0) ? v : -v; });
  53. }
  54. };
  55. // ceil(x) = ceiling(x)
  56. template <typename T> struct CeilFunctor {
  57. template <typename Device, typename X, typename Out>
  58. void operator()(Device d, X x, Out out) const {
  59. out.device(d) = x.ceil();
  60. }
  61. };
  62. // floor(x) = flooring(x)
  63. template <typename T> struct FloorFunctor {
  64. template <typename Device, typename X, typename Out>
  65. void operator()(Device d, X x, Out out) const {
  66. out.device(d) = x.floor();
  67. }
  68. };
  69. } // namespace function
  70. } // namespace ultra_infer