fusion_ops.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import paddle
  16. import paddle.nn.functional as F
  17. try:
  18. from paddle.incubate.nn.functional import fused_rotary_position_embedding
  19. except ImportError:
  20. fused_rotary_position_embedding = None
  21. try:
  22. from paddle.incubate.nn.functional import swiglu
  23. except ImportError:
  24. def swiglu(x, y=None):
  25. if y is None:
  26. x, y = paddle.chunk(x, chunks=2, axis=-1)
  27. return F.silu(x) * y
  28. from paddle.utils import try_import
  29. from paddlenlp.utils.tools import get_env_device
  30. try:
  31. from paddle.incubate.nn.functional import fused_rotary_position_embedding
  32. except ImportError:
  33. fused_rotary_position_embedding = None
  34. try:
  35. if get_env_device() in ["npu", "mlu", "gcu"]:
  36. from paddle.base import core
  37. for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
  38. if lib.endswith(".so"):
  39. paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(
  40. lib
  41. )
  42. from paddle.nn.functional.flash_attention import flash_attention
  43. except:
  44. flash_attention = None
  45. from paddlenlp.transformers.refined_recompute import no_recompute
  46. from paddlenlp.transformers.ring_flash_attention import RingFlashAttention
  47. def fusion_rope(
  48. query_states,
  49. key_states,
  50. value_states,
  51. hidden_states,
  52. position_ids,
  53. past_key_value,
  54. rotary_emb,
  55. context_parallel_degree=-1,
  56. ):
  57. if get_env_device() not in ["gcu", "intel_hpu"]:
  58. assert past_key_value is None, "fuse rotary not support cache kv for now"
  59. batch_size, seq_length, num_heads, head_dim = query_states.shape
  60. _, kv_seq_len, num_key_value_heads, _ = key_states.shape
  61. if context_parallel_degree > 1:
  62. assert (
  63. get_env_device() == "gpu"
  64. ), "context parallel only support cuda device for now"
  65. kv_seq_len *= context_parallel_degree
  66. if get_env_device() not in ["gcu", "intel_hpu"]:
  67. cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
  68. if get_env_device() == "npu":
  69. query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[
  70. 0
  71. ]
  72. key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
  73. elif get_env_device() == "intel_hpu":
  74. if past_key_value is not None:
  75. kv_seq_len += past_key_value[0].shape[-3]
  76. cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
  77. cos = cos.squeeze().unsqueeze(0).unsqueeze(0)
  78. sin = sin.squeeze().unsqueeze(0).unsqueeze(0)
  79. query_states, _, _ = (
  80. paddle.incubate.nn.functional.fused_rotary_position_embedding(
  81. paddle.transpose(query_states, [0, 2, 1, 3]),
  82. None,
  83. None,
  84. sin=sin,
  85. cos=cos,
  86. position_ids=position_ids,
  87. )
  88. )
  89. key_states, _, _ = (
  90. paddle.incubate.nn.functional.fused_rotary_position_embedding(
  91. paddle.transpose(key_states, [0, 2, 1, 3]),
  92. None,
  93. None,
  94. sin=sin,
  95. cos=cos,
  96. position_ids=position_ids,
  97. )
  98. )
  99. query_states = paddle.transpose(query_states, [0, 2, 1, 3])
  100. key_states = paddle.transpose(key_states, [0, 2, 1, 3])
  101. elif get_env_device() == "gcu":
  102. cos_sin = rotary_emb.get_fused_cos_sin(value_states, seq_len=kv_seq_len)
  103. query_states, key_states = core.eager._run_custom_op(
  104. "fused_rotary_embedding_gcu",
  105. query_states,
  106. key_states,
  107. cos_sin,
  108. position_ids,
  109. True,
  110. )
  111. else:
  112. # paddle version > 2.6 or develop support q and k/v with different num_heads
  113. paddle_version = float(paddle.__version__[:3])
  114. if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (
  115. num_heads != num_key_value_heads
  116. ):
  117. query_states, _, _ = fused_rotary_position_embedding(
  118. query_states,
  119. None,
  120. None,
  121. sin=sin,
  122. cos=cos,
  123. position_ids=position_ids,
  124. use_neox_rotary_style=False,
  125. )
  126. key_states, _, _ = fused_rotary_position_embedding(
  127. key_states,
  128. None,
  129. None,
  130. sin=sin,
  131. cos=cos,
  132. position_ids=position_ids,
  133. use_neox_rotary_style=False,
  134. )
  135. else:
  136. query_states, key_states, _ = fused_rotary_position_embedding(
  137. query_states,
  138. key_states,
  139. v=None,
  140. sin=sin,
  141. cos=cos,
  142. position_ids=position_ids,
  143. use_neox_rotary_style=False,
  144. )
  145. return query_states, key_states
  146. def rms_norm_fused(x_in, w, eps, use_fast_ln=False):
  147. if use_fast_ln:
  148. fast_ln = try_import("fast_ln")
  149. return fast_ln.fast_rms_norm(x_in, w, eps)[0]
  150. else:
  151. fused_ln = try_import("fused_ln")
  152. return fused_ln.fused_rms_norm(x_in, w, eps)[0]
  153. def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False):
  154. if get_env_device() == "npu":
  155. return core.eager._run_custom_op(
  156. "rms_norm_npu", hidden_states, weight, variance_epsilon
  157. )[0]
  158. if get_env_device() == "mlu":
  159. return core.eager._run_custom_op(
  160. "rms_norm_mlu", hidden_states, weight, variance_epsilon
  161. )[0]
  162. elif get_env_device() == "gcu":
  163. return core.eager._run_custom_op(
  164. "rms_norm_gcu", hidden_states, weight, variance_epsilon
  165. )[0]
  166. elif get_env_device() == "intel_hpu":
  167. return paddle.incubate.nn.functional.fused_rms_norm(
  168. hidden_states, weight, None, variance_epsilon, hidden_states.dim() - 1
  169. )[0]
  170. return rms_norm_fused(hidden_states, weight, variance_epsilon, use_fast_ln)
  171. def fusion_flash_attention(
  172. query_states,
  173. config,
  174. key_states,
  175. value_states,
  176. attention_mask,
  177. output_attentions,
  178. alibi=None,
  179. attn_mask_startend_row_indices=None,
  180. sequence_parallel=False,
  181. reshard_layer=None,
  182. npu_is_casual=False,
  183. skip_recompute=False,
  184. ):
  185. # Note:
  186. # 1. The head_dim of query_states and key_states should be the same. And the head_dim of value_states should be used for reshape.
  187. bsz, q_len, num_heads, _ = query_states.shape
  188. _, kv_seq_len, _, head_dim = value_states.shape
  189. version = paddle.version.full_version
  190. if version != "0.0.0" and version <= "2.5.2":
  191. if alibi is not None:
  192. raise ValueError("Flash Attention doesn't support alibi")
  193. if config.context_parallel_degree > 1:
  194. raise ValueError(
  195. f"Context parallel is not implemented in version {version}"
  196. )
  197. attn_output, attn_weights = flash_attention(
  198. query_states,
  199. key_states,
  200. value_states,
  201. causal=True,
  202. return_softmax=output_attentions,
  203. )
  204. else:
  205. if alibi is not None:
  206. alibi = alibi.reshape([bsz, num_heads, 1, -1])
  207. attention_mask = attention_mask.cast(alibi.dtype) + alibi
  208. if get_env_device() == "npu":
  209. if config.context_parallel_degree > 1:
  210. raise ValueError("Context parallel is not implemented for npu")
  211. attn_output = core.eager._run_custom_op(
  212. "flash_attention_npu",
  213. query_states,
  214. key_states,
  215. value_states,
  216. None,
  217. attention_mask,
  218. None,
  219. None,
  220. 0.0,
  221. attention_mask is None,
  222. True,
  223. False,
  224. npu_is_casual,
  225. False,
  226. )[0]
  227. elif get_env_device() == "gcu":
  228. if config.context_parallel_degree > 1:
  229. raise ValueError("Context parallel is not implemented for gcu")
  230. attn_output = core.eager._run_custom_op(
  231. "fused_sdp_flash_attention_gcu",
  232. query_states,
  233. key_states,
  234. value_states,
  235. attention_mask,
  236. 0.0,
  237. attention_mask is None,
  238. True,
  239. )[0]
  240. elif get_env_device() == "intel_hpu":
  241. if config.context_parallel_degree > 1:
  242. raise ValueError("Context parallel is not implemented for intel_hpu")
  243. scaling_factor = query_states.shape[3] ** -0.5
  244. attention_mask = attention_mask.astype(query_states.dtype)
  245. attn_output = paddle.incubate.nn.functional.fused_dot_product_attention(
  246. query_states,
  247. key_states,
  248. value_states,
  249. attention_mask,
  250. 0.0,
  251. attention_mask is None,
  252. scaling_factor,
  253. False,
  254. )
  255. else:
  256. if config.context_parallel_degree > 1:
  257. attn_output = RingFlashAttention.apply(
  258. query_states,
  259. key_states,
  260. value_states,
  261. attn_mask=None,
  262. is_causal=True,
  263. )
  264. else:
  265. if attn_mask_startend_row_indices is not None:
  266. assert (
  267. alibi is None
  268. ), "flashmask_attention or flash_attention_with_sparse_mask not support alibi"
  269. if len(attn_mask_startend_row_indices.shape) == 2:
  270. attn_mask_startend_row_indices = paddle.unsqueeze(
  271. attn_mask_startend_row_indices, axis=1
  272. )
  273. if hasattr(F, "flashmask_attention"):
  274. attn_output = no_recompute(
  275. F.flashmask_attention,
  276. query_states,
  277. key_states,
  278. value_states,
  279. startend_row_indices=attn_mask_startend_row_indices.unsqueeze(
  280. -1
  281. ),
  282. causal=True,
  283. enable=skip_recompute,
  284. )
  285. else:
  286. attn_output = no_recompute(
  287. F.flash_attention_with_sparse_mask,
  288. query_states,
  289. key_states,
  290. value_states,
  291. attn_mask_start_row_indices=attn_mask_startend_row_indices,
  292. is_causal=True,
  293. enable=skip_recompute,
  294. )
  295. else:
  296. attn_output = no_recompute(
  297. F.scaled_dot_product_attention,
  298. query_states,
  299. key_states,
  300. value_states,
  301. attn_mask=attention_mask,
  302. is_causal=query_states.shape[1] != 1,
  303. enable=skip_recompute,
  304. )
  305. attn_weights = None
  306. if reshard_layer is not None:
  307. # attn_output shape: [bs, seqlen, num_head/sep, head_dim]
  308. attn_output = reshard_layer(
  309. attn_output,
  310. split_axis=1,
  311. concat_axis=2,
  312. )
  313. # attn_output shape: [bs, seqlen/sep, num_head, head_dim]
  314. assert (
  315. config.sep_parallel_degree > 1 and q_len % config.sep_parallel_degree == 0
  316. ), f"q_len:{q_len}, config.sep_parallel_degree:{config.sep_parallel_degree}"
  317. q_len = q_len // config.sep_parallel_degree
  318. num_heads = num_heads * config.sep_parallel_degree
  319. if sequence_parallel:
  320. attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
  321. else:
  322. attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
  323. return (attn_output, attn_weights) if output_attentions else attn_output