cast.cc 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/function/cast.h"
  15. #include <algorithm>
  16. namespace ultra_infer {
  17. namespace function {
  18. template <typename InT, typename OutT> struct CastOpTransformFunctor {
  19. OutT operator()(InT in) const { return static_cast<OutT>(in); }
  20. };
  21. template <typename InT>
  22. void CastKernel(const FDTensor &x, FDTensor *out, FDDataType output_dtype) {
  23. FD_VISIT_ALL_TYPES(output_dtype, "CastOpTransformFunctor", ([&] {
  24. auto *in_begin = reinterpret_cast<const InT *>(x.Data());
  25. auto *in_end = in_begin + x.Numel();
  26. FDTensor out_tmp;
  27. out_tmp.Allocate(x.Shape(), output_dtype);
  28. auto *out_begin =
  29. reinterpret_cast<data_t *>(out_tmp.Data());
  30. std::transform(in_begin, in_end, out_begin,
  31. CastOpTransformFunctor<InT, data_t>());
  32. *out = std::move(out_tmp);
  33. }));
  34. }
  35. void Cast(const FDTensor &x, FDTensor *out, FDDataType output_dtype) {
  36. FD_VISIT_ALL_TYPES(x.dtype, "CastKernel",
  37. ([&] { CastKernel<data_t>(x, out, output_dtype); }));
  38. }
  39. } // namespace function
  40. } // namespace ultra_infer