fusion_ops.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import paddle
  16. import paddle.nn.functional as F
  17. try:
  18. from paddle.incubate.nn.functional import fused_rotary_position_embedding
  19. except ImportError:
  20. fused_rotary_position_embedding = None
  21. try:
  22. from paddle.incubate.nn.functional import swiglu
  23. except ImportError:
  24. def swiglu(x, y=None):
  25. if y is None:
  26. x, y = paddle.chunk(x, chunks=2, axis=-1)
  27. return F.silu(x) * y
  28. from paddle.utils import try_import
  29. def get_env_device():
  30. """
  31. Return the device name of running environment.
  32. """
  33. if paddle.is_compiled_with_cuda():
  34. return "gpu"
  35. elif "npu" in paddle.device.get_all_custom_device_type():
  36. return "npu"
  37. elif "mlu" in paddle.device.get_all_custom_device_type():
  38. return "mlu"
  39. elif "gcu" in paddle.device.get_all_custom_device_type():
  40. return "gcu"
  41. elif "intel_hpu" in paddle.device.get_all_custom_device_type():
  42. return "intel_hpu"
  43. elif paddle.is_compiled_with_rocm():
  44. return "rocm"
  45. elif paddle.is_compiled_with_xpu():
  46. return "xpu"
  47. return "cpu"
  48. try:
  49. from paddle.incubate.nn.functional import fused_rotary_position_embedding
  50. except ImportError:
  51. fused_rotary_position_embedding = None
  52. try:
  53. if get_env_device() in ["npu", "mlu", "gcu"]:
  54. from paddle.base import core
  55. for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
  56. if lib.endswith(".so"):
  57. paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(
  58. lib
  59. )
  60. from paddle.nn.functional.flash_attention import flash_attention
  61. except:
  62. flash_attention = None
  63. def fusion_rope(
  64. query_states,
  65. key_states,
  66. value_states,
  67. hidden_states,
  68. position_ids,
  69. past_key_value,
  70. rotary_emb,
  71. context_parallel_degree=-1,
  72. ):
  73. if get_env_device() not in ["gcu", "intel_hpu"]:
  74. assert past_key_value is None, "fuse rotary not support cache kv for now"
  75. batch_size, seq_length, num_heads, head_dim = query_states.shape
  76. _, kv_seq_len, num_key_value_heads, _ = key_states.shape
  77. if context_parallel_degree > 1:
  78. assert (
  79. get_env_device() == "gpu"
  80. ), "context parallel only support cuda device for now"
  81. kv_seq_len *= context_parallel_degree
  82. if get_env_device() not in ["gcu", "intel_hpu"]:
  83. cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
  84. if get_env_device() == "npu":
  85. query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[
  86. 0
  87. ]
  88. key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
  89. elif get_env_device() == "intel_hpu":
  90. if past_key_value is not None:
  91. kv_seq_len += past_key_value[0].shape[-3]
  92. cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
  93. cos = cos.squeeze().unsqueeze(0).unsqueeze(0)
  94. sin = sin.squeeze().unsqueeze(0).unsqueeze(0)
  95. query_states, _, _ = (
  96. paddle.incubate.nn.functional.fused_rotary_position_embedding(
  97. paddle.transpose(query_states, [0, 2, 1, 3]),
  98. None,
  99. None,
  100. sin=sin,
  101. cos=cos,
  102. position_ids=position_ids,
  103. )
  104. )
  105. key_states, _, _ = (
  106. paddle.incubate.nn.functional.fused_rotary_position_embedding(
  107. paddle.transpose(key_states, [0, 2, 1, 3]),
  108. None,
  109. None,
  110. sin=sin,
  111. cos=cos,
  112. position_ids=position_ids,
  113. )
  114. )
  115. query_states = paddle.transpose(query_states, [0, 2, 1, 3])
  116. key_states = paddle.transpose(key_states, [0, 2, 1, 3])
  117. elif get_env_device() == "gcu":
  118. cos_sin = rotary_emb.get_fused_cos_sin(value_states, seq_len=kv_seq_len)
  119. query_states, key_states = core.eager._run_custom_op(
  120. "fused_rotary_embedding_gcu",
  121. query_states,
  122. key_states,
  123. cos_sin,
  124. position_ids,
  125. True,
  126. )
  127. else:
  128. # paddle version > 2.6 or develop support q and k/v with different num_heads
  129. paddle_version = float(paddle.__version__[:3])
  130. if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (
  131. num_heads != num_key_value_heads
  132. ):
  133. query_states, _, _ = fused_rotary_position_embedding(
  134. query_states,
  135. None,
  136. None,
  137. sin=sin,
  138. cos=cos,
  139. position_ids=position_ids,
  140. use_neox_rotary_style=False,
  141. )
  142. key_states, _, _ = fused_rotary_position_embedding(
  143. key_states,
  144. None,
  145. None,
  146. sin=sin,
  147. cos=cos,
  148. position_ids=position_ids,
  149. use_neox_rotary_style=False,
  150. )
  151. else:
  152. query_states, key_states, _ = fused_rotary_position_embedding(
  153. query_states,
  154. key_states,
  155. v=None,
  156. sin=sin,
  157. cos=cos,
  158. position_ids=position_ids,
  159. use_neox_rotary_style=False,
  160. )
  161. return query_states, key_states
  162. def rms_norm_fused(x_in, w, eps, use_fast_ln=False):
  163. if use_fast_ln:
  164. fast_ln = try_import("fast_ln")
  165. return fast_ln.fast_rms_norm(x_in, w, eps)[0]
  166. else:
  167. fused_ln = try_import("fused_ln")
  168. return fused_ln.fused_rms_norm(x_in, w, eps)[0]
  169. def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False):
  170. if get_env_device() == "npu":
  171. return core.eager._run_custom_op(
  172. "rms_norm_npu", hidden_states, weight, variance_epsilon
  173. )[0]
  174. if get_env_device() == "mlu":
  175. return core.eager._run_custom_op(
  176. "rms_norm_mlu", hidden_states, weight, variance_epsilon
  177. )[0]
  178. elif get_env_device() == "gcu":
  179. return core.eager._run_custom_op(
  180. "rms_norm_gcu", hidden_states, weight, variance_epsilon
  181. )[0]
  182. elif get_env_device() == "intel_hpu":
  183. return paddle.incubate.nn.functional.fused_rms_norm(
  184. hidden_states, weight, None, variance_epsilon, hidden_states.dim() - 1
  185. )[0]
  186. return rms_norm_fused(hidden_states, weight, variance_epsilon, use_fast_ln)