clip.cc 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/function/clip.h"
  15. #include <algorithm>
  16. namespace ultra_infer {
  17. namespace function {
  18. template <typename T> class ClipFunctor {
  19. public:
  20. explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {}
  21. T operator()(const T x) const {
  22. return x < min_ ? min_ : x > max_ ? max_ : x;
  23. }
  24. private:
  25. T min_;
  26. T max_;
  27. };
  28. template <typename T>
  29. void ClipKernel(const FDTensor &x, double min, double max, FDTensor *out) {
  30. T max_ = static_cast<T>(max);
  31. T min_ = static_cast<T>(min);
  32. FDASSERT(min_ < max_,
  33. "max should be greater than or equal to min. But received min = %f, "
  34. "max = %f",
  35. static_cast<float>(min_), static_cast<float>(max_));
  36. FDTensor tmp;
  37. tmp.Allocate(x.Shape(), x.Dtype());
  38. const T *x_data = reinterpret_cast<const T *>(x.Data());
  39. int64_t numel = x.Numel();
  40. T *out_data = reinterpret_cast<T *>(tmp.Data());
  41. std::transform(x_data, x_data + numel, out_data, ClipFunctor<T>(min_, max_));
  42. *out = std::move(tmp);
  43. }
  44. void Clip(const FDTensor &x, double min, double max, FDTensor *out) {
  45. FD_VISIT_INT_FLOAT_TYPES(x.dtype, "ClipKernel",
  46. ([&] { ClipKernel<data_t>(x, min, max, out); }));
  47. }
  48. } // namespace function
  49. } // namespace ultra_infer