| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import paddle
- import paddle.nn.functional as F
- try:
- from paddle.incubate.nn.functional import fused_rotary_position_embedding
- except ImportError:
- fused_rotary_position_embedding = None
- try:
- from paddle.incubate.nn.functional import swiglu
- except ImportError:
- def swiglu(x, y=None):
- if y is None:
- x, y = paddle.chunk(x, chunks=2, axis=-1)
- return F.silu(x) * y
- from paddle.utils import try_import
- def get_env_device():
- """
- Return the device name of running environment.
- """
- if paddle.is_compiled_with_cuda():
- return "gpu"
- elif "npu" in paddle.device.get_all_custom_device_type():
- return "npu"
- elif "mlu" in paddle.device.get_all_custom_device_type():
- return "mlu"
- elif "gcu" in paddle.device.get_all_custom_device_type():
- return "gcu"
- elif "intel_hpu" in paddle.device.get_all_custom_device_type():
- return "intel_hpu"
- elif "iluvatar_gpu" in paddle.device.get_all_custom_device_type():
- return "iluvatar_gpu"
- elif paddle.is_compiled_with_rocm():
- return "rocm"
- elif paddle.is_compiled_with_xpu():
- return "xpu"
- return "cpu"
- try:
- from paddle.incubate.nn.functional import fused_rotary_position_embedding
- except ImportError:
- fused_rotary_position_embedding = None
- try:
- if get_env_device() in ["npu", "mlu", "gcu", "iluvatar_gpu"]:
- from paddle.base import core
- for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
- if lib.endswith(".so"):
- paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(
- lib
- )
- from paddle.nn.functional.flash_attention import flash_attention
- except:
- flash_attention = None
- def fusion_rope(
- query_states,
- key_states,
- value_states,
- hidden_states,
- position_ids,
- past_key_value,
- rotary_emb,
- context_parallel_degree=-1,
- ):
- if get_env_device() not in ["gcu", "intel_hpu", "iluvatar_gpu"]:
- assert past_key_value is None, "fuse rotary not support cache kv for now"
- batch_size, seq_length, num_heads, head_dim = query_states.shape
- _, kv_seq_len, num_key_value_heads, _ = key_states.shape
- if context_parallel_degree > 1:
- assert (
- get_env_device() == "gpu"
- ), "context parallel only support cuda device for now"
- kv_seq_len *= context_parallel_degree
- if get_env_device() not in ["gcu", "intel_hpu", "iluvatar_gpu"]:
- cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
- if get_env_device() == "npu":
- query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[
- 0
- ]
- key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
- elif get_env_device() == "intel_hpu":
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-3]
- cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
- cos = cos.squeeze().unsqueeze(0).unsqueeze(0)
- sin = sin.squeeze().unsqueeze(0).unsqueeze(0)
- query_states, _, _ = (
- paddle.incubate.nn.functional.fused_rotary_position_embedding(
- paddle.transpose(query_states, [0, 2, 1, 3]),
- None,
- None,
- sin=sin,
- cos=cos,
- position_ids=position_ids,
- )
- )
- key_states, _, _ = (
- paddle.incubate.nn.functional.fused_rotary_position_embedding(
- paddle.transpose(key_states, [0, 2, 1, 3]),
- None,
- None,
- sin=sin,
- cos=cos,
- position_ids=position_ids,
- )
- )
- query_states = paddle.transpose(query_states, [0, 2, 1, 3])
- key_states = paddle.transpose(key_states, [0, 2, 1, 3])
- elif get_env_device() == "gcu":
- cos_sin = rotary_emb.get_fused_cos_sin(value_states, seq_len=kv_seq_len)
- query_states, key_states = core.eager._run_custom_op(
- "fused_rotary_embedding_gcu",
- query_states,
- key_states,
- cos_sin,
- position_ids,
- True,
- )
- else:
- # paddle version > 2.6 or develop support q and k/v with different num_heads
- paddle_version = float(paddle.__version__[:3])
- if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (
- num_heads != num_key_value_heads
- ):
- query_states, _, _ = fused_rotary_position_embedding(
- query_states,
- None,
- None,
- sin=sin,
- cos=cos,
- position_ids=position_ids,
- use_neox_rotary_style=False,
- )
- key_states, _, _ = fused_rotary_position_embedding(
- key_states,
- None,
- None,
- sin=sin,
- cos=cos,
- position_ids=position_ids,
- use_neox_rotary_style=False,
- )
- else:
- query_states, key_states, _ = fused_rotary_position_embedding(
- query_states,
- key_states,
- v=None,
- sin=sin,
- cos=cos,
- position_ids=position_ids,
- use_neox_rotary_style=False,
- )
- return query_states, key_states
- def rms_norm_fused(x_in, w, eps, use_fast_ln=False):
- if use_fast_ln:
- fast_ln = try_import("fast_ln")
- return fast_ln.fast_rms_norm(x_in, w, eps)[0]
- else:
- fused_ln = try_import("fused_ln")
- return fused_ln.fused_rms_norm(x_in, w, eps)[0]
- def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False):
- if get_env_device() == "npu":
- return core.eager._run_custom_op(
- "rms_norm_npu", hidden_states, weight, variance_epsilon
- )[0]
- if get_env_device() == "mlu":
- return core.eager._run_custom_op(
- "rms_norm_mlu", hidden_states, weight, variance_epsilon
- )[0]
- elif get_env_device() == "gcu":
- return core.eager._run_custom_op(
- "rms_norm_gcu", hidden_states, weight, variance_epsilon
- )[0]
- elif get_env_device() == "intel_hpu":
- return paddle.incubate.nn.functional.fused_rms_norm(
- hidden_states, weight, None, variance_epsilon, hidden_states.dim() - 1
- )[0]
- return rms_norm_fused(hidden_states, weight, variance_epsilon, use_fast_ln)
|