| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import paddle
- def is_flash_attn_available():
- try:
- import os
- if "npu" in paddle.get_device(): # NOTE: flash attn has not been tested yet
- for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
- if lib.endswith(".so"):
- paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(
- lib
- )
- from paddle.base import core
- def flash_attention_npu(
- query,
- key,
- value,
- dropout=0.0,
- causal=False,
- return_softmax=False,
- *,
- fixed_seed_offset=None,
- rng_name="",
- training=True,
- name=None,
- attn_mask=None,
- is_varlen=False,
- batch_size=None,
- seq_length=None,
- ):
- is_triangle_upper_mask = True
- if is_varlen:
- if len(query.shape) == 4:
- B, S, N, D = query.shape
- query = query.reshape([B * S, N, D])
- key = key.reshape([B * S, N, D])
- value = value.reshape([B * S, N, D])
- else:
- assert batch_size is not None
- assert seq_length is not None
- B = batch_size
- S = seq_length
- actual_seq_q_len = actual_seq_kv_len = list(range(S, B * S + S, S))
- else:
- actual_seq_q_len = actual_seq_kv_len = []
- out = core.eager._run_custom_op(
- "flash_attention_npu",
- query,
- key,
- value,
- fixed_seed_offset,
- attn_mask,
- actual_seq_q_len,
- actual_seq_kv_len,
- dropout,
- causal,
- return_softmax,
- not training,
- is_triangle_upper_mask,
- is_varlen,
- )[0]
- return out
- q = paddle.rand((1, 4, 2, 8)).astype("bfloat16")
- _ = flash_attention_npu(q, q, q, 0.9, False, False)
- paddle.nn.functional.flash_attention_npu = flash_attention_npu
- return True
- q = paddle.rand((1, 4, 2, 8)).astype("bfloat16")
- _ = paddle.nn.functional.flash_attention.flash_attention(
- q, q, q, 0.9, False, False
- )
- return True
- except:
- return False
- HAS_FLASH_ATTN = is_flash_attn_available()
- def has_flash_attn_func():
- if HAS_FLASH_ATTN:
- try:
- if "npu" in paddle.get_device():
- flash_attn_func_npu = paddle.nn.functional.flash_attention_npu
- return flash_attn_func_npu, flash_attn_func_npu
- else:
- from paddle.nn.functional.flash_attention import (
- flash_attention as flash_attn_func,
- )
- from paddle.nn.functional.flash_attention import (
- flash_attn_unpadded as flash_attn_varlen_func,
- )
- return flash_attn_func, flash_attn_varlen_func
- except:
- return None, None
- else:
- return None, None
|