isfinite.cc 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/function/isfinite.h"
  15. #include "ultra_infer/core/float16.h"
  16. #include <algorithm>
  17. #include <type_traits>
  18. namespace ultra_infer {
  19. namespace function {
  20. template <typename T, typename OutT, class Enable = void> struct IsNanFunctor {
  21. OutT operator()(const T &a) const { return static_cast<OutT>(std::isnan(a)); }
  22. };
  23. template <typename T, typename OutT>
  24. struct IsNanFunctor<T, OutT,
  25. typename std::enable_if<std::is_integral<T>::value>::type> {
  26. OutT operator()(const T &a) const { return static_cast<OutT>(false); }
  27. };
  28. template <typename OutT> struct IsNanFunctor<ultra_infer::float16, OutT, void> {
  29. OutT operator()(const ultra_infer::float16 &a) const {
  30. return static_cast<OutT>(ultra_infer::isnan(a));
  31. }
  32. };
  33. template <typename T, typename OutT, class Enable = void> struct IsInfFunctor {
  34. OutT operator()(const T &a) const { return static_cast<OutT>(std::isinf(a)); }
  35. };
  36. template <typename T, typename OutT>
  37. struct IsInfFunctor<T, OutT,
  38. typename std::enable_if<std::is_integral<T>::value>::type> {
  39. OutT operator()(const T &a) const { return static_cast<OutT>(false); }
  40. };
  41. template <typename OutT> struct IsInfFunctor<ultra_infer::float16, OutT, void> {
  42. OutT operator()(const ultra_infer::float16 &a) const {
  43. return static_cast<OutT>(ultra_infer::isinf(a));
  44. }
  45. };
  46. template <typename T, typename OutT, class Enable = void>
  47. struct IsFiniteFunctor {
  48. OutT operator()(const T &a) const {
  49. return static_cast<OutT>(std::isfinite(a));
  50. }
  51. };
  52. template <typename T, typename OutT>
  53. struct IsFiniteFunctor<
  54. T, OutT, typename std::enable_if<std::is_integral<T>::value>::type> {
  55. OutT operator()(const T &a) const { return static_cast<OutT>(true); }
  56. };
  57. template <typename OutT>
  58. struct IsFiniteFunctor<ultra_infer::float16, OutT, void> {
  59. OutT operator()(const ultra_infer::float16 &a) const {
  60. return static_cast<OutT>(ultra_infer::isfinite(a));
  61. }
  62. };
  63. #define DEFINE_ISFINITE_KERNEL(isfinite_kernel, functor) \
  64. template <typename T> \
  65. void isfinite_kernel(const FDTensor &x, FDTensor *out, FDDataType dtype) { \
  66. FD_VISIT_ALL_TYPES(dtype, #isfinite_kernel, ([&] { \
  67. out->Allocate(x.Shape(), dtype); \
  68. functor<T, data_t> unary_func; \
  69. data_t *out_ptr = \
  70. reinterpret_cast<data_t *>(out->Data()); \
  71. const T *input_ptr = \
  72. reinterpret_cast<const T *>(x.Data()); \
  73. std::transform(input_ptr, input_ptr + x.Numel(), \
  74. out_ptr, unary_func); \
  75. })); \
  76. }
  77. DEFINE_ISFINITE_KERNEL(IsNanKernel, IsNanFunctor)
  78. DEFINE_ISFINITE_KERNEL(IsInfKernel, IsInfFunctor)
  79. DEFINE_ISFINITE_KERNEL(IsFiniteKernel, IsFiniteFunctor)
  80. #undef DEFINE_ISFINITE_KERNEL
  81. void IsNan(const FDTensor &x, FDTensor *out, FDDataType dtype) {
  82. FD_VISIT_FLOAT_TYPES(x.dtype, "IsNanKernel",
  83. ([&] { IsNanKernel<data_t>(x, out, dtype); }));
  84. }
  85. void IsInf(const FDTensor &x, FDTensor *out, FDDataType dtype) {
  86. FD_VISIT_FLOAT_TYPES(x.dtype, "IsInfKernel",
  87. ([&] { IsInfKernel<data_t>(x, out, dtype); }));
  88. }
  89. void IsFinite(const FDTensor &x, FDTensor *out, FDDataType dtype) {
  90. FD_VISIT_FLOAT_TYPES(x.dtype, "IsFiniteKernel",
  91. ([&] { IsFiniteKernel<data_t>(x, out, dtype); }));
  92. }
  93. } // namespace function
  94. } // namespace ultra_infer