fusion_ops.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import paddle
  16. import paddle.nn.functional as F
  17. try:
  18. from paddle.incubate.nn.functional import fused_rotary_position_embedding
  19. except ImportError:
  20. fused_rotary_position_embedding = None
  21. try:
  22. from paddle.incubate.nn.functional import swiglu
  23. except ImportError:
  24. def swiglu(x, y=None):
  25. if y is None:
  26. x, y = paddle.chunk(x, chunks=2, axis=-1)
  27. return F.silu(x) * y
  28. from paddle.utils import try_import
  29. def get_env_device():
  30. """
  31. Return the device name of running environment.
  32. """
  33. if paddle.is_compiled_with_cuda():
  34. return "gpu"
  35. elif "npu" in paddle.device.get_all_custom_device_type():
  36. return "npu"
  37. elif "mlu" in paddle.device.get_all_custom_device_type():
  38. return "mlu"
  39. elif "gcu" in paddle.device.get_all_custom_device_type():
  40. return "gcu"
  41. elif "intel_hpu" in paddle.device.get_all_custom_device_type():
  42. return "intel_hpu"
  43. elif "iluvatar_gpu" in paddle.device.get_all_custom_device_type():
  44. return "iluvatar_gpu"
  45. elif paddle.is_compiled_with_rocm():
  46. return "rocm"
  47. elif paddle.is_compiled_with_xpu():
  48. return "xpu"
  49. return "cpu"
  50. try:
  51. from paddle.incubate.nn.functional import fused_rotary_position_embedding
  52. except ImportError:
  53. fused_rotary_position_embedding = None
  54. try:
  55. if get_env_device() in ["npu", "mlu", "gcu", "iluvatar_gpu"]:
  56. from paddle.base import core
  57. for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
  58. if lib.endswith(".so"):
  59. paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(
  60. lib
  61. )
  62. from paddle.nn.functional.flash_attention import flash_attention
  63. except:
  64. flash_attention = None
  65. def fusion_rope(
  66. query_states,
  67. key_states,
  68. value_states,
  69. hidden_states,
  70. position_ids,
  71. past_key_value,
  72. rotary_emb,
  73. context_parallel_degree=-1,
  74. ):
  75. if get_env_device() not in ["gcu", "intel_hpu", "iluvatar_gpu"]:
  76. assert past_key_value is None, "fuse rotary not support cache kv for now"
  77. batch_size, seq_length, num_heads, head_dim = query_states.shape
  78. _, kv_seq_len, num_key_value_heads, _ = key_states.shape
  79. if context_parallel_degree > 1:
  80. assert (
  81. get_env_device() == "gpu"
  82. ), "context parallel only support cuda device for now"
  83. kv_seq_len *= context_parallel_degree
  84. if get_env_device() not in ["gcu", "intel_hpu", "iluvatar_gpu"]:
  85. cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
  86. if get_env_device() == "npu":
  87. query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[
  88. 0
  89. ]
  90. key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
  91. elif get_env_device() == "intel_hpu":
  92. if past_key_value is not None:
  93. kv_seq_len += past_key_value[0].shape[-3]
  94. cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
  95. cos = cos.squeeze().unsqueeze(0).unsqueeze(0)
  96. sin = sin.squeeze().unsqueeze(0).unsqueeze(0)
  97. query_states, _, _ = (
  98. paddle.incubate.nn.functional.fused_rotary_position_embedding(
  99. paddle.transpose(query_states, [0, 2, 1, 3]),
  100. None,
  101. None,
  102. sin=sin,
  103. cos=cos,
  104. position_ids=position_ids,
  105. )
  106. )
  107. key_states, _, _ = (
  108. paddle.incubate.nn.functional.fused_rotary_position_embedding(
  109. paddle.transpose(key_states, [0, 2, 1, 3]),
  110. None,
  111. None,
  112. sin=sin,
  113. cos=cos,
  114. position_ids=position_ids,
  115. )
  116. )
  117. query_states = paddle.transpose(query_states, [0, 2, 1, 3])
  118. key_states = paddle.transpose(key_states, [0, 2, 1, 3])
  119. elif get_env_device() == "gcu":
  120. cos_sin = rotary_emb.get_fused_cos_sin(value_states, seq_len=kv_seq_len)
  121. query_states, key_states = core.eager._run_custom_op(
  122. "fused_rotary_embedding_gcu",
  123. query_states,
  124. key_states,
  125. cos_sin,
  126. position_ids,
  127. True,
  128. )
  129. else:
  130. # paddle version > 2.6 or develop support q and k/v with different num_heads
  131. paddle_version = float(paddle.__version__[:3])
  132. if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (
  133. num_heads != num_key_value_heads
  134. ):
  135. query_states, _, _ = fused_rotary_position_embedding(
  136. query_states,
  137. None,
  138. None,
  139. sin=sin,
  140. cos=cos,
  141. position_ids=position_ids,
  142. use_neox_rotary_style=False,
  143. )
  144. key_states, _, _ = fused_rotary_position_embedding(
  145. key_states,
  146. None,
  147. None,
  148. sin=sin,
  149. cos=cos,
  150. position_ids=position_ids,
  151. use_neox_rotary_style=False,
  152. )
  153. else:
  154. query_states, key_states, _ = fused_rotary_position_embedding(
  155. query_states,
  156. key_states,
  157. v=None,
  158. sin=sin,
  159. cos=cos,
  160. position_ids=position_ids,
  161. use_neox_rotary_style=False,
  162. )
  163. return query_states, key_states
  164. def rms_norm_fused(x_in, w, eps, use_fast_ln=False):
  165. if use_fast_ln:
  166. fast_ln = try_import("fast_ln")
  167. return fast_ln.fast_rms_norm(x_in, w, eps)[0]
  168. else:
  169. fused_ln = try_import("fused_ln")
  170. return fused_ln.fused_rms_norm(x_in, w, eps)[0]
  171. def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False):
  172. if get_env_device() == "npu":
  173. return core.eager._run_custom_op(
  174. "rms_norm_npu", hidden_states, weight, variance_epsilon
  175. )[0]
  176. if get_env_device() == "mlu":
  177. return core.eager._run_custom_op(
  178. "rms_norm_mlu", hidden_states, weight, variance_epsilon
  179. )[0]
  180. elif get_env_device() == "gcu":
  181. return core.eager._run_custom_op(
  182. "rms_norm_gcu", hidden_states, weight, variance_epsilon
  183. )[0]
  184. elif get_env_device() == "intel_hpu":
  185. return paddle.incubate.nn.functional.fused_rms_norm(
  186. hidden_states, weight, None, variance_epsilon, hidden_states.dim() - 1
  187. )[0]
  188. return rms_norm_fused(hidden_states, weight, variance_epsilon, use_fast_ln)