verify_flash_attention.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import torch
  2. import subprocess
  3. import pkg_resources
  4. def check_flash_attention():
  5. print("🔍 Flash Attention 状态检查")
  6. print("=" * 50)
  7. # 检查已安装的包
  8. try:
  9. flash_attn_version = pkg_resources.get_distribution("flash-attn").version
  10. print(f"✅ flash-attn: {flash_attn_version}")
  11. except:
  12. print("❌ flash-attn: 未安装")
  13. try:
  14. flashinfer_version = pkg_resources.get_distribution("flashinfer").version
  15. print(f"✅ flashinfer: {flashinfer_version}")
  16. except:
  17. print("❌ flashinfer: 未安装")
  18. # 检查 CUDA 可用性
  19. print(f"\n🔧 CUDA 状态:")
  20. print(f"CUDA 可用: {torch.cuda.is_available()}")
  21. if torch.cuda.is_available():
  22. print(f"CUDA 版本: {torch.version.cuda}")
  23. print(f"GPU 数量: {torch.cuda.device_count()}")
  24. for i in range(torch.cuda.device_count()):
  25. print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
  26. # 检查 Flash Attention 功能
  27. try:
  28. import flash_attn
  29. print(f"\n✅ Flash Attention 可导入")
  30. print(f"Flash Attention 版本: {flash_attn.__version__}")
  31. except ImportError as e:
  32. print(f"\n❌ Flash Attention 导入失败: {e}")
  33. if __name__ == "__main__":
  34. check_flash_attention()