flash_attn_utils.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. def is_flash_attn_available():
  16. try:
  17. import os
  18. if "npu" in paddle.get_device(): # NOTE: flash attn has not been tested yet
  19. for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
  20. if lib.endswith(".so"):
  21. paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(
  22. lib
  23. )
  24. from paddle.base import core
  25. def flash_attention_npu(
  26. query,
  27. key,
  28. value,
  29. dropout=0.0,
  30. causal=False,
  31. return_softmax=False,
  32. *,
  33. fixed_seed_offset=None,
  34. rng_name="",
  35. training=True,
  36. name=None,
  37. attn_mask=None,
  38. is_varlen=False,
  39. batch_size=None,
  40. seq_length=None,
  41. ):
  42. is_triangle_upper_mask = True
  43. if is_varlen:
  44. if len(query.shape) == 4:
  45. B, S, N, D = query.shape
  46. query = query.reshape([B * S, N, D])
  47. key = key.reshape([B * S, N, D])
  48. value = value.reshape([B * S, N, D])
  49. else:
  50. assert batch_size is not None
  51. assert seq_length is not None
  52. B = batch_size
  53. S = seq_length
  54. actual_seq_q_len = actual_seq_kv_len = list(range(S, B * S + S, S))
  55. else:
  56. actual_seq_q_len = actual_seq_kv_len = []
  57. out = core.eager._run_custom_op(
  58. "flash_attention_npu",
  59. query,
  60. key,
  61. value,
  62. fixed_seed_offset,
  63. attn_mask,
  64. actual_seq_q_len,
  65. actual_seq_kv_len,
  66. dropout,
  67. causal,
  68. return_softmax,
  69. not training,
  70. is_triangle_upper_mask,
  71. is_varlen,
  72. )[0]
  73. return out
  74. q = paddle.rand((1, 4, 2, 8)).astype("bfloat16")
  75. _ = flash_attention_npu(q, q, q, 0.9, False, False)
  76. paddle.nn.functional.flash_attention_npu = flash_attention_npu
  77. return True
  78. q = paddle.rand((1, 4, 2, 8)).astype("bfloat16")
  79. _ = paddle.nn.functional.flash_attention.flash_attention(
  80. q, q, q, 0.9, False, False
  81. )
  82. return True
  83. except:
  84. return False
  85. HAS_FLASH_ATTN = is_flash_attn_available()
  86. def has_flash_attn_func():
  87. if HAS_FLASH_ATTN:
  88. try:
  89. if "npu" in paddle.get_device():
  90. flash_attn_func_npu = paddle.nn.functional.flash_attention_npu
  91. return flash_attn_func_npu, flash_attn_func_npu
  92. else:
  93. from paddle.nn.functional.flash_attention import (
  94. flash_attention as flash_attn_func,
  95. )
  96. from paddle.nn.functional.flash_attention import (
  97. flash_attn_unpadded as flash_attn_varlen_func,
  98. )
  99. return flash_attn_func, flash_attn_varlen_func
  100. except:
  101. return None, None
  102. else:
  103. return None, None