|
|
@@ -49,6 +49,8 @@ def get_env_device():
|
|
|
return "gcu"
|
|
|
elif "intel_hpu" in paddle.device.get_all_custom_device_type():
|
|
|
return "intel_hpu"
|
|
|
+ elif "iluvatar_gpu" in paddle.device.get_all_custom_device_type():
|
|
|
+ return "iluvatar_gpu"
|
|
|
elif paddle.is_compiled_with_rocm():
|
|
|
return "rocm"
|
|
|
elif paddle.is_compiled_with_xpu():
|
|
|
@@ -61,7 +63,7 @@ try:
|
|
|
except ImportError:
|
|
|
fused_rotary_position_embedding = None
|
|
|
try:
|
|
|
- if get_env_device() in ["npu", "mlu", "gcu"]:
|
|
|
+ if get_env_device() in ["npu", "mlu", "gcu", "iluvatar_gpu"]:
|
|
|
from paddle.base import core
|
|
|
|
|
|
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
|
|
|
@@ -84,7 +86,7 @@ def fusion_rope(
|
|
|
rotary_emb,
|
|
|
context_parallel_degree=-1,
|
|
|
):
|
|
|
- if get_env_device() not in ["gcu", "intel_hpu"]:
|
|
|
+ if get_env_device() not in ["gcu", "intel_hpu", "iluvatar_gpu"]:
|
|
|
assert past_key_value is None, "fuse rotary not support cache kv for now"
|
|
|
batch_size, seq_length, num_heads, head_dim = query_states.shape
|
|
|
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
|
|
|
@@ -93,7 +95,7 @@ def fusion_rope(
|
|
|
get_env_device() == "gpu"
|
|
|
), "context parallel only support cuda device for now"
|
|
|
kv_seq_len *= context_parallel_degree
|
|
|
- if get_env_device() not in ["gcu", "intel_hpu"]:
|
|
|
+ if get_env_device() not in ["gcu", "intel_hpu", "iluvatar_gpu"]:
|
|
|
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
if get_env_device() == "npu":
|
|
|
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[
|