tile.cc 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "ultra_infer/function/tile.h"
  15. #include "ultra_infer/function/eigen.h"
  16. namespace ultra_infer {
  17. namespace function {
  18. template <typename T, int Rank>
  19. void TileFunctor(const FDTensor &x,
  20. const std::vector<int64_t> &origin_repeat_times,
  21. FDTensor *out) {
  22. auto x_shape = x.Shape();
  23. auto repeat_times = origin_repeat_times;
  24. for (size_t i = 0; i < repeat_times.size(); ++i) {
  25. FDASSERT(repeat_times[i] > 0,
  26. "All elements of the input 'repeat_times' "
  27. "for tile op must be positive integers, but "
  28. "the value received is %d.",
  29. repeat_times[i]);
  30. }
  31. if (repeat_times.size() < x_shape.size()) {
  32. int diff = x_shape.size() - repeat_times.size();
  33. repeat_times.insert(repeat_times.begin(), diff, 1);
  34. } else {
  35. int diff = repeat_times.size() - x_shape.size();
  36. x_shape.insert(x_shape.begin(), diff, 1);
  37. }
  38. FDASSERT(repeat_times.size() == x_shape.size(),
  39. "The rank (%d) of the input 'x' and the rank (%d) of the input "
  40. "'repeat_times' for tile op must match after promotion.",
  41. x_shape.size(), repeat_times.size());
  42. if (Rank == 0) {
  43. // Deep copy
  44. *out = x;
  45. return;
  46. }
  47. FDTensor out_tmp;
  48. Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
  49. for (size_t i = 0; i < repeat_times.size(); ++i) {
  50. bcast_dims[i] = repeat_times[i];
  51. }
  52. std::vector<int64_t> out_shape(x_shape);
  53. for (size_t i = 0; i < repeat_times.size(); ++i) {
  54. out_shape[i] *= repeat_times[i];
  55. }
  56. out_tmp.Allocate(out_shape, x.Dtype());
  57. auto eigen_x = EigenTensor<T, Rank>::From(x, x_shape);
  58. auto eigen_out = EigenTensor<T, Rank>::From(out_tmp, out_shape);
  59. const auto &dev = *EigenDeviceWrapper::GetInstance()->GetDevice();
  60. eigen_out.device(dev) = eigen_x.broadcast(bcast_dims);
  61. *out = std::move(out_tmp);
  62. }
  63. template <typename T>
  64. void TileKernel(const FDTensor &x, const std::vector<int64_t> &repeat_times,
  65. FDTensor *out) {
  66. auto rank = x.Shape().size();
  67. auto repeat_times_size = repeat_times.size();
  68. rank = (std::max)(rank, repeat_times_size);
  69. switch (rank) {
  70. case 0:
  71. *out = x;
  72. break;
  73. case 1:
  74. TileFunctor<T, 1>(x, repeat_times, out);
  75. break;
  76. case 2:
  77. TileFunctor<T, 2>(x, repeat_times, out);
  78. break;
  79. case 3:
  80. TileFunctor<T, 3>(x, repeat_times, out);
  81. break;
  82. case 4:
  83. TileFunctor<T, 4>(x, repeat_times, out);
  84. break;
  85. case 5:
  86. TileFunctor<T, 5>(x, repeat_times, out);
  87. break;
  88. case 6:
  89. TileFunctor<T, 6>(x, repeat_times, out);
  90. break;
  91. }
  92. }
  93. void Tile(const FDTensor &x, const std::vector<int64_t> &repeat_times,
  94. FDTensor *out) {
  95. FD_VISIT_ALL_TYPES(x.dtype, "TileKernel",
  96. ([&] { TileKernel<data_t>(x, repeat_times, out); }));
  97. }
  98. } // namespace function
  99. } // namespace ultra_infer