cuda_cast.cu 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifdef WITH_GPU
  15. #include "ultra_infer/function/cuda_cast.h"
  16. namespace ultra_infer {
  17. namespace function {
  18. template <typename T_IN, typename T_OUT>
  19. __global__ void CudaCastKernel(const T_IN *in, T_OUT *out, int edge) {
  20. int position = blockDim.x * blockIdx.x + threadIdx.x;
  21. if (position >= edge)
  22. return;
  23. out[position] = (T_OUT)in[position];
  24. }
  25. void CudaCast(const FDTensor &in, FDTensor *out, cudaStream_t stream) {
  26. int jobs = in.Numel();
  27. int threads = 256;
  28. int blocks = ceil(jobs / (float)threads);
  29. if (in.dtype == FDDataType::INT64 && out->dtype == FDDataType::INT32) {
  30. CudaCastKernel<int64_t, int32_t><<<blocks, threads, 0, stream>>>(
  31. reinterpret_cast<int64_t *>(const_cast<void *>(in.Data())),
  32. reinterpret_cast<int32_t *>(out->MutableData()), jobs);
  33. } else if (in.dtype == FDDataType::INT32 && out->dtype == FDDataType::INT64) {
  34. CudaCastKernel<int32_t, int64_t><<<blocks, threads, 0, stream>>>(
  35. reinterpret_cast<int32_t *>(const_cast<void *>(in.Data())),
  36. reinterpret_cast<int64_t *>(out->MutableData()), jobs);
  37. } else {
  38. FDASSERT(false, "CudaCast only support input INT64, output INT32.");
  39. }
  40. }
  41. } // namespace function
  42. } // namespace ultra_infer
  43. #endif