conversion_utils.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. def fuse_param_func():
  17. def fn(fuse_params, is_qkv=False, num_heads=None, num_key_value_heads=None):
  18. concat_fn = np.concatenate
  19. split_fn = np.split
  20. if isinstance(fuse_params[0], paddle.Tensor):
  21. concat_fn = paddle.concat
  22. split_fn = paddle.split
  23. if is_qkv:
  24. assert (
  25. num_heads
  26. ), f"num_heads should be number of heads for Q, but got {num_heads}"
  27. assert (
  28. num_key_value_heads
  29. ), f"num_key_value_heads should be number of key_value_heads for K and V, but got {num_key_value_heads}"
  30. assert (
  31. len(fuse_params) == 3
  32. ), f"fuse_params length is not equal 3, it should be Q K V list. but got length {len(fuse_params)}"
  33. num_query_groups = num_heads // num_key_value_heads
  34. q_list = split_fn(fuse_params[0], num_heads, axis=-1)
  35. k_list = split_fn(fuse_params[1], num_key_value_heads, axis=-1)
  36. v_list = split_fn(fuse_params[2], num_key_value_heads, axis=-1)
  37. qkv_pairs = []
  38. for i in range(num_key_value_heads):
  39. qkv_pairs += q_list[i * num_query_groups : (i + 1) * num_query_groups]
  40. qkv_pairs.append(k_list[i])
  41. qkv_pairs.append(v_list[i])
  42. return concat_fn(qkv_pairs, axis=-1)
  43. else:
  44. return concat_fn(fuse_params, axis=-1)
  45. return fn
  46. def split_param_func():
  47. def fn(
  48. fused_param,
  49. split_nums=2,
  50. is_qkv=False,
  51. num_heads=None,
  52. num_key_value_heads=None,
  53. ):
  54. concat_fn = np.concatenate
  55. split_fn = np.split
  56. if isinstance(fused_param, paddle.Tensor):
  57. concat_fn = paddle.concat
  58. split_fn = paddle.split
  59. if is_qkv:
  60. assert (
  61. num_heads
  62. ), f"num_heads should be number of heads for Q, but got {num_heads}"
  63. assert (
  64. num_key_value_heads
  65. ), f"num_key_value_heads should be number of key_value_heads for K and V, but got {num_key_value_heads}"
  66. num_query_groups = num_heads // num_key_value_heads
  67. q_list, k_list, v_list = [], [], []
  68. split_heads = split_fn(
  69. fused_param, num_heads + 2 * num_key_value_heads, axis=-1
  70. )
  71. for i in range(num_key_value_heads):
  72. q_list += split_heads[
  73. i * (num_query_groups + 2) : (i + 1) * (num_query_groups + 2) - 2
  74. ]
  75. k_list.append(split_heads[(i + 1) * (num_query_groups + 2) - 2])
  76. v_list.append(split_heads[(i + 1) * (num_query_groups + 2) - 1])
  77. return (
  78. concat_fn(q_list, axis=-1),
  79. concat_fn(k_list, axis=-1),
  80. concat_fn(v_list, axis=-1),
  81. )
  82. else:
  83. return split_fn(fused_param, split_nums, axis=-1)
  84. return fn
  85. def split_or_fuse_func(is_fuse=True):
  86. return fuse_param_func() if is_fuse else split_param_func()